From 3715fd128eaa5dc6d027789470a5966b8d880253 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 15 Jun 2018 16:24:00 -0400 Subject: [PATCH 0001/1550] Relax pytest constraint in appveyor tests (#2060) Previously appveyor was failing because pytest was pinned to a version that made the pytest.timeout package unhappy. Lets relax this constraint for now. --- continuous_integration/setup_conda_environment.cmd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/continuous_integration/setup_conda_environment.cmd b/continuous_integration/setup_conda_environment.cmd index 786868fca9f..4f2674dd65b 100644 --- a/continuous_integration/setup_conda_environment.cmd +++ b/continuous_integration/setup_conda_environment.cmd @@ -31,7 +31,7 @@ call deactivate jupyter_client ^ mock ^ psutil ^ - pytest=3.1 ^ + pytest ^ python=%PYTHON% ^ requests ^ toolz ^ From 1eb486dae117fc93ef12f80d240117eb4a4f5fd6 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 20 Jun 2018 16:45:16 -0400 Subject: [PATCH 0002/1550] Pull data outside of while loop in gather (#2059) See https://github.com/dask/distributed/issues/2025 --- distributed/client.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 87255dd3499..525ccfe205d 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -1391,6 +1391,7 @@ def _gather(self, futures, errors='raise', direct=None, local_worker=None): futures2, keys = unpack_remotedata(futures, byte_keys=True) keys = [tokey(key) for key in keys] bad_data = dict() + data = {} if direct is None: try: @@ -1445,8 +1446,6 @@ def wait(k): keys = [k for k in keys if k not in bad_keys] - data = {} - if local_worker: # look inside local worker data.update({k: local_worker.data[k] for k in keys From 45bff01259c822986f4ff612b165b3fcdbb954cf Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 15 Jun 2018 16:22:20 -0400 Subject: [PATCH 0003/1550] Avoid reference cycle in str_graph This caused an intermittent failure in distributed/tests/test_batched.py::test_dont_hold_on_to_large_messages --- distributed/tests/test_client.py | 2 +- distributed/utils.py | 33 ++++++++++++++++---------------- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index e015db9f99c..554bd451e04 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -3706,7 +3706,7 @@ def start_worker(sleep, duration, repeat=1): sleep(1) for i in range(count): - done.acquire() + done.acquire(timeout=20) gc.collect() if not running: break diff --git a/distributed/utils.py b/distributed/utils.py index f9026976f05..4c02860becc 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -34,7 +34,7 @@ import dask from dask import istask -from toolz import memoize, valmap +from toolz import memoize import tornado from tornado import gen from tornado.ioloop import IOLoop, PollIOLoop @@ -798,22 +798,23 @@ def _maybe_complex(task): type(task) is dict and any(map(_maybe_complex, task.values()))) -def str_graph(dsk, extra_values=()): - def convert(task): - if type(task) is list: - return [convert(v) for v in task] - if type(task) is dict: - return valmap(convert, task) - if istask(task): - return (task[0],) + tuple(map(convert, task[1:])) - try: - if task in dsk or task in extra_values: - return tokey(task) - except TypeError: - pass - return task +def convert(task, dsk, extra_values): + if type(task) is list: + return [convert(v, dsk, extra_values) for v in task] + if type(task) is dict: + return {k: convert(v, dsk, extra_values) for k, v in task.items()} + if istask(task): + return (task[0],) + tuple(convert(x, dsk, extra_values) for x in task[1:]) + try: + if task in dsk or task in extra_values: + return tokey(task) + except TypeError: + pass + return task - return {tokey(k): convert(v) for k, v in dsk.items()} + +def str_graph(dsk, extra_values=()): + return {tokey(k): convert(v, dsk, extra_values) for k, v in dsk.items()} def seek_delimiter(file, delimiter, blocksize): From a40fc080b909f3db32220030882bf081129a84da Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 18 Jun 2018 09:12:09 -0400 Subject: [PATCH 0004/1550] Use ConnectionPool for Worker.scheduler --- distributed/client.py | 4 ++-- distributed/core.py | 4 ++++ distributed/tests/test_client.py | 2 +- distributed/worker.py | 11 ++++++----- 4 files changed, 13 insertions(+), 8 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 525ccfe205d..7f474171835 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -46,7 +46,7 @@ scatter_to_workers, gather_from_workers) from .cfexecutor import ClientExecutor from .compatibility import Queue as pyQueue, Empty, isqueue, html_escape -from .core import connect, rpc, clean_exception, CommClosedError +from .core import connect, rpc, clean_exception, CommClosedError, PooledRPCCall from .metrics import time from .node import Node from .protocol import to_serialize @@ -575,7 +575,7 @@ def __init__(self, address=None, loop=None, timeout=no_default, logger.info("Config value `scheduler-address` found: %s", address) - if isinstance(address, rpc): + if isinstance(address, (rpc, PooledRPCCall)): self.scheduler = address elif hasattr(address, "scheduler_address"): # It's a LocalCluster or LocalCluster-compatible object diff --git a/distributed/core.py b/distributed/core.py index b152a58cf6d..7bd6c12960c 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -598,6 +598,10 @@ def __init__(self, addr, pool, serializers=None, deserializers=None): self.serializers = serializers self.deserializers = deserializers if deserializers is not None else serializers + @property + def address(self): + return self.addr + def __getattr__(self, key): @gen.coroutine def send_recv_from_rpc(**kwargs): diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 554bd451e04..e015db9f99c 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -3706,7 +3706,7 @@ def start_worker(sleep, duration, repeat=1): sleep(1) for i in range(count): - done.acquire(timeout=20) + done.acquire() gc.collect() if not running: break diff --git a/distributed/worker.py b/distributed/worker.py index 74dbc0949db..52bf3d9af78 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -31,7 +31,7 @@ from .comm.utils import offload from .compatibility import unicode, get_thread_identity, finalize from .core import (error_message, CommClosedError, - rpc, pingpong, coerce_to_address) + pingpong, coerce_to_address) from .diskutils import WorkSpace from .metrics import time from .node import ServerNode @@ -165,14 +165,10 @@ def __init__(self, scheduler_ip=None, scheduler_port=None, self._closed = Event() self.reconnect = reconnect self.executor = executor or ThreadPoolExecutor(self.ncores) - self.scheduler = rpc(scheduler_addr, connection_args=self.connection_args) self.name = name self.scheduler_delay = 0 self.stream_comms = dict() self.heartbeat_active = False - self.execution_state = {'scheduler': self.scheduler.address, - 'ioloop': self.loop, - 'worker': self} self._ipython_kernel = None if self.local_dir not in sys.path: @@ -216,6 +212,11 @@ def __init__(self, scheduler_ip=None, scheduler_port=None, connection_args=self.connection_args, **kwargs) + self.scheduler = self.rpc(scheduler_addr) + self.execution_state = {'scheduler': self.scheduler.address, + 'ioloop': self.loop, + 'worker': self} + pc = PeriodicCallback(self.heartbeat, 1000, io_loop=self.io_loop) self.periodic_callbacks['heartbeat'] = pc self._address = contact_address From 83682686a52a6b29cba9688ec4ad3c77b52fe671 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Thu, 21 Jun 2018 08:00:13 -0500 Subject: [PATCH 0005/1550] BUG: Normalize address before comparison (#2066) Fixes https://github.com/dask/distributed/issues/2058 --- distributed/tests/test_worker_client.py | 16 ++++++++++++++++ distributed/worker.py | 22 +++++++++++++++++++--- 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/distributed/tests/test_worker_client.py b/distributed/tests/test_worker_client.py index c4e3c775a3a..2b96ae59c4f 100644 --- a/distributed/tests/test_worker_client.py +++ b/distributed/tests/test_worker_client.py @@ -297,3 +297,19 @@ def f(): result = yield c.submit(f) assert result + + +@gen_cluster() +def test_submit_different_names(s, a, b): + # https://github.com/dask/distributed/issues/2058 + da = pytest.importorskip('dask.array') + c = yield Client('localhost:' + s.address.split(":")[-1], loop=s.loop, + asynchronous=True) + try: + X = c.persist(da.random.uniform(size=(100, 10), chunks=50)) + yield wait(X) + + fut = yield c.submit(lambda x: x.sum().compute(), X) + assert fut > 0 + finally: + yield c.close() diff --git a/distributed/worker.py b/distributed/worker.py index 52bf3d9af78..7bf212baa6f 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -25,7 +25,7 @@ from tornado.ioloop import IOLoop from tornado.locks import Event -from . import profile +from . import profile, comm from .batched import BatchedSend from .comm import get_address_host, get_local_address_for, connect from .comm.utils import offload @@ -2597,11 +2597,25 @@ def get_worker(): raise ValueError("No workers found") -def get_client(address=None, timeout=3): - """ Get a client while within a task +def get_client(address=None, timeout=3, resolve_address=True): + """Get a client while within a task. This client connects to the same scheduler to which the worker is connected + Parameters + ---------- + address : str, optional + The address of the scheduler to connect to. Defaults to the scheduler + the worker is connected to. + timeout : int, default 3 + Timeout (in seconds) for getting the Client + resolve_address : bool, default True + Whether to resolve `address` to its canonical form. + + Returns + ------- + Client + Examples -------- >>> def f(): @@ -2620,6 +2634,8 @@ def get_client(address=None, timeout=3): worker_client secede """ + if address and resolve_address: + address = comm.resolve_address(address) try: worker = get_worker() except ValueError: # could not find worker From cdec12c5da4b201a37b9aeca9411930f82df1f77 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 21 Jun 2018 18:47:11 -0400 Subject: [PATCH 0006/1550] Add asynchronous parameter to docstring of LocalCluster --- distributed/deploy/local.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/distributed/deploy/local.py b/distributed/deploy/local.py index 399a4746b51..b88966a2a02 100644 --- a/distributed/deploy/local.py +++ b/distributed/deploy/local.py @@ -47,6 +47,9 @@ class LocalCluster(Cluster): to choose a random port, ``None`` to disable it, or an :samp:`({ip}:{port})` tuple to listen on a different IP address than the scheduler. + asynchronous: bool (False by default) + Set to True if using this cluster within async/await functions or within + Tornado gen.coroutines. This should remain False for normal use. kwargs: dict Extra worker arguments, will be passed to the Worker constructor. service_kwargs: Dict[str, Dict] From deaa0b3bfeb0076f458eabe748a10e86b40e1155 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sat, 23 Jun 2018 14:36:08 -0400 Subject: [PATCH 0007/1550] Support async def functions in Client.sync (#2070) Remove support for using sync for synchronous functions --- distributed/tests/py3_test_client.py | 14 +++++++++++++- distributed/tests/test_utils.py | 6 ------ distributed/utils.py | 12 ++++-------- 3 files changed, 17 insertions(+), 15 deletions(-) diff --git a/distributed/tests/py3_test_client.py b/distributed/tests/py3_test_client.py index 8abd568d385..9c7a1fdf3ef 100644 --- a/distributed/tests/py3_test_client.py +++ b/distributed/tests/py3_test_client.py @@ -3,8 +3,9 @@ import pytest from tornado import gen -from distributed.utils_test import div, gen_cluster, inc, loop +from distributed.utils_test import div, gen_cluster, inc, loop, cluster from distributed import as_completed, Client, Lock +from distributed.utils import sync @gen_cluster(client=True) @@ -111,3 +112,14 @@ async def f(): assert result is False loop.run_sync(f) + + +def test_client_sync_with_async_def(loop): + async def ff(): + await gen.sleep(0.01) + return 1 + + with cluster() as (s, [a, b]): + with Client(s['address'], loop=loop) as c: + assert sync(loop, ff) == 1 + assert c.sync(ff) == 1 diff --git a/distributed/tests/test_utils.py b/distributed/tests/test_utils.py index e53cbdbc7a1..60b5105a78d 100644 --- a/distributed/tests/test_utils.py +++ b/distributed/tests/test_utils.py @@ -59,12 +59,6 @@ def f(): loop.run_sync(f) -def test_sync(loop_in_thread): - loop = loop_in_thread - result = sync(loop, inc, 1) - assert result == 2 - - def test_sync_error(loop_in_thread): loop = loop_in_thread try: diff --git a/distributed/utils.py b/distributed/utils.py index 4c02860becc..e10c64e6344 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -238,13 +238,6 @@ def sync(loop, func, *args, **kwargs): timeout = kwargs.pop('callback_timeout', None) - def make_coro(): - coro = gen.maybe_future(func(*args, **kwargs)) - if timeout is None: - return coro - else: - return gen.with_timeout(timedelta(seconds=timeout), coro) - e = threading.Event() main_tid = get_thread_identity() result = [None] @@ -257,7 +250,10 @@ def f(): raise RuntimeError("sync() called from thread of running loop") yield gen.moment thread_state.asynchronous = True - result[0] = yield make_coro() + future = func(*args, **kwargs) + if timeout is not None: + future = gen.with_timeout(timedelta(seconds=timeout), future) + result[0] = yield future except Exception as exc: error[0] = sys.exc_info() finally: From 84246ff3943513a58241e39f44d5f02241d6b809 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sat, 23 Jun 2018 20:28:08 -0400 Subject: [PATCH 0008/1550] Have worker data transfer wait until recipient acknowledges (#2052) Previously when a worker sent data it would think it was finished as soon as the data was dumped to the socket. Now we wait until we hear an acknowledgement from the recipient that the transfer is complete. This helps with our diagnostics a bit may help avoid backing up a bunch of memory on the OS level, and also assists in future GPU work, where the sending side wants to wait until deserialization on the recipient side has finished. --- distributed/tests/test_utils_comm.py | 2 ++ distributed/tests/test_worker.py | 18 +++++++++++++ distributed/utils_comm.py | 6 ++--- distributed/worker.py | 39 ++++++++++++++++++++++++---- 4 files changed, 56 insertions(+), 9 deletions(-) diff --git a/distributed/tests/test_utils_comm.py b/distributed/tests/test_utils_comm.py index 6d633a09db5..1e69eef6a03 100644 --- a/distributed/tests/test_utils_comm.py +++ b/distributed/tests/test_utils_comm.py @@ -1,5 +1,6 @@ from __future__ import print_function, division, absolute_import +import pytest from distributed.core import rpc from distributed.utils_test import gen_cluster @@ -13,6 +14,7 @@ def test_pack_data(): assert pack_data({'a': ['x'], 'b': 'y'}, data) == {'a': [1], 'b': 'y'} +@pytest.mark.xfail(reason='rpc now needs to be a connection pool') @gen_cluster(client=True) def test_gather_from_workers_permissive(c, s, a, b): x = yield c.scatter({'x': 1}, workers=a.address) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index afe95c0a69c..e8878029e8e 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1162,3 +1162,21 @@ def test_scheduler_address_config(c, s): yield worker._start() assert worker.scheduler.address == s.address yield worker._close() + + +@slow +@gen_cluster(client=True) +def test_wait_for_outgoing(c, s, a, b): + np = pytest.importorskip('numpy') + x = np.random.random(10000000) + future = yield c.scatter(x, workers=a.address) + + y = c.submit(inc, future, workers=b.address) + yield wait(y) + + assert len(b.incoming_transfer_log) == len(a.outgoing_transfer_log) == 1 + bb = b.incoming_transfer_log[0]['duration'] + aa = a.outgoing_transfer_log[0]['duration'] + ratio = aa / bb + + assert 1 / 3 < ratio < 3 diff --git a/distributed/utils_comm.py b/distributed/utils_comm.py index 0f66def37ff..8e4d2ac1300 100644 --- a/distributed/utils_comm.py +++ b/distributed/utils_comm.py @@ -30,6 +30,7 @@ def gather_from_workers(who_has, rpc, close=True, serializers=None): gather _gather """ + from .worker import get_data_from_worker bad_addresses = set() missing_workers = set() original_who_has = who_has @@ -55,10 +56,7 @@ def gather_from_workers(who_has, rpc, close=True, serializers=None): rpcs = {addr: rpc(addr) for addr in d} try: - coroutines = {address: rpcs[address].get_data( - keys=keys, - close=close, - serializers=serializers) + coroutines = {address: get_data_from_worker(rpc, keys, address) for address, keys in d.items()} response = {} for worker, c in coroutines.items(): diff --git a/distributed/worker.py b/distributed/worker.py index 7bf212baa6f..ff7628b7346 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -30,7 +30,7 @@ from .comm import get_address_host, get_local_address_for, connect from .comm.utils import offload from .compatibility import unicode, get_thread_identity, finalize -from .core import (error_message, CommClosedError, +from .core import (error_message, CommClosedError, send_recv, pingpong, coerce_to_address) from .diskutils import WorkSpace from .metrics import time @@ -604,14 +604,17 @@ def delete_data(self, comm=None, keys=None, report=True): def get_data(self, comm, keys=None, who=None, serializers=None): start = time() - msg = {k: to_serialize(self.data[k]) for k in keys if k in self.data} - nbytes = {k: self.nbytes.get(k) for k in keys if k in self.data} + data = {k: self.data[k] for k in keys if k in self.data} + msg = {k: to_serialize(v) for k, v in data.items()} + nbytes = {k: self.nbytes.get(k) for k in data} stop = time() if self.digests is not None: self.digests['get-data-load-duration'].add(stop - start) start = time() try: compressed = yield comm.write(msg, serializers=serializers) + response = yield comm.read(deserializers=serializers) + assert response == 'OK', response except EnvironmentError: logger.exception('failed during get data with %s -> %s', self.address, who, exc_info=True) @@ -1771,8 +1774,7 @@ def gather_dep(self, worker, dep, deps, total_nbytes, cause=None): logger.debug("Request %d keys", len(deps)) start = time() + self.scheduler_delay - response = yield self.rpc(worker).get_data(keys=deps, - who=self.address) + response = yield get_data_from_worker(self.rpc, deps, worker, self.address) stop = time() + self.scheduler_delay if cause: @@ -2714,3 +2716,30 @@ def parse_memory_limit(memory_limit, ncores): return parse_bytes(memory_limit) else: return int(memory_limit) + + +@gen.coroutine +def get_data_from_worker(rpc, keys, worker, who=None): + """ Get keys from worker + + The worker has a two step handshake to acknowledge when data has been fully + delivered. This function implements that handshake. + + See Also + -------- + Worker.get_data + Worker.gather_deps + utils_comm.gather_data_from_workers + """ + comm = yield rpc.connect(worker) + try: + response = yield send_recv(comm, + serializers=rpc.serializers, + deserializers=rpc.deserializers, + deserialize=rpc.deserialize, + op='get_data', keys=keys, who=who) + yield comm.write('OK') + finally: + rpc.reuse(worker, comm) + + raise gen.Return(response) From db758d0f8609dd0fa041b212ecc89a088b57291e Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sat, 23 Jun 2018 20:30:13 -0400 Subject: [PATCH 0009/1550] Allow adaptive to exist without a cluster (#2064) This allows the Adaptive object to exist on the scheduler and make recommendations with an external route without being attached to an explicit cluster object. For motivation see conversation in https://github.com/dask/dask-yarn/issues/1 --- distributed/deploy/adaptive.py | 82 ++++++++++++++--------- distributed/deploy/tests/test_adaptive.py | 16 ++++- 2 files changed, 66 insertions(+), 32 deletions(-) diff --git a/distributed/deploy/adaptive.py b/distributed/deploy/adaptive.py index d7ae6ea5997..014373ac7a8 100644 --- a/distributed/deploy/adaptive.py +++ b/distributed/deploy/adaptive.py @@ -4,6 +4,7 @@ import logging import math +import toolz from tornado import gen from ..metrics import time @@ -85,7 +86,7 @@ class Adaptive(object): the cluster's ``scale_up`` method. ''' - def __init__(self, scheduler, cluster, interval='1s', startup_cost='1s', + def __init__(self, scheduler, cluster=None, interval='1s', startup_cost='1s', scale_factor=2, minimum=0, maximum=None, wait_count=3, target_duration='5s', worker_key=lambda x: x, **kwargs): interval = parse_timedelta(interval, default='ms') @@ -94,9 +95,10 @@ def __init__(self, scheduler, cluster, interval='1s', startup_cost='1s', self.cluster = cluster self.startup_cost = parse_timedelta(startup_cost, default='s') self.scale_factor = scale_factor - self._adapt_callback = PeriodicCallback(self._adapt, interval * 1000, - io_loop=scheduler.loop) - self.scheduler.loop.add_callback(self._adapt_callback.start) + if self.cluster: + self._adapt_callback = PeriodicCallback(self._adapt, interval * 1000, + io_loop=scheduler.loop) + self.scheduler.loop.add_callback(self._adapt_callback.start) self._adapting = False self._workers_to_close_kwargs = kwargs self.minimum = minimum @@ -106,10 +108,13 @@ def __init__(self, scheduler, cluster, interval='1s', startup_cost='1s', self.wait_count = wait_count self.target_duration = parse_timedelta(target_duration) + self.scheduler.handlers['adaptive_recommendations'] = self.recommendations + def stop(self): - self._adapt_callback.stop() - self._adapt_callback = None - del self._adapt_callback + if self.cluster: + self._adapt_callback.stop() + self._adapt_callback = None + del self._adapt_callback def needs_cpu(self): """ @@ -272,27 +277,21 @@ def get_scale_up_kwargs(self): logger.info("Scaling up to %d workers", instances) return {'n': instances} - @gen.coroutine - def _adapt(self): - if self._adapting: # Semaphore to avoid overlapping adapt calls - return - - self._adapting = True - try: - should_scale_up = self.should_scale_up() - workers = set(self.workers_to_close(key=self.worker_key, - minimum=self.minimum)) - if should_scale_up and workers: - logger.info("Attempting to scale up and scale down simultaneously.") - return - - if should_scale_up: - kwargs = self.get_scale_up_kwargs() - f = self.cluster.scale_up(**kwargs) - self.log.append((time(), 'up', kwargs)) - if gen.is_future(f): - yield f - + def recommendations(self, comm=None): + should_scale_up = self.should_scale_up() + workers = set(self.workers_to_close(key=self.worker_key, + minimum=self.minimum)) + if should_scale_up and workers: + logger.info("Attempting to scale up and scale down simultaneously.") + self.close_counts.clear() + return {'status': 'error', + 'msg': 'Trying to scale up and down simultaneously'} + + elif should_scale_up: + self.close_counts.clear() + return toolz.merge({'status': 'up'}, self.get_scale_up_kwargs()) + + elif workers: d = {} to_close = [] for w, c in self.close_counts.items(): @@ -308,8 +307,31 @@ def _adapt(self): self.close_counts = d if to_close: - self.log.append((time(), 'down', workers)) - workers = yield self._retire_workers(workers=to_close) + return {'status': 'down', 'workers': to_close} + else: + self.close_counts.clear() + return None + + @gen.coroutine + def _adapt(self): + if self._adapting: # Semaphore to avoid overlapping adapt calls + return + + self._adapting = True + try: + recommendations = self.recommendations() + if not recommendations: + return + status = recommendations.pop('status') + if status == 'up': + f = self.cluster.scale_up(**recommendations) + self.log.append((time(), 'up', recommendations)) + if gen.is_future(f): + yield f + + elif status == 'down': + self.log.append((time(), 'down', recommendations['workers'])) + workers = yield self._retire_workers(workers=recommendations['workers']) finally: self._adapting = False diff --git a/distributed/deploy/tests/test_adaptive.py b/distributed/deploy/tests/test_adaptive.py index 9b73756bcfb..3014defa74d 100644 --- a/distributed/deploy/tests/test_adaptive.py +++ b/distributed/deploy/tests/test_adaptive.py @@ -8,7 +8,7 @@ from tornado.ioloop import IOLoop from distributed import Client, wait, Adaptive, LocalCluster -from distributed.utils_test import gen_cluster, gen_test, slowinc +from distributed.utils_test import gen_cluster, gen_test, slowinc, inc from distributed.utils_test import loop, nodebug # noqa: F401 from distributed.metrics import time @@ -215,7 +215,7 @@ def test_avoid_churn(): diagnostics_port=None) client = yield Client(cluster, asynchronous=True) try: - adapt = Adaptive(cluster.scheduler, cluster, interval=20, wait_count=5) + adapt = Adaptive(cluster.scheduler, cluster, interval='20 ms', wait_count=5) for i in range(10): yield client.submit(slowinc, i, delay=0.040) @@ -392,3 +392,15 @@ def key(ws): assert names == {'a-1', 'a-2'} or names == {'b-1', 'b-2'} finally: yield cluster._close() + + +@gen_cluster(client=True, ncores=[]) +def test_without_cluster(c, s): + adapt = Adaptive(scheduler=s) + + future = c.submit(inc, 1) + while not s.tasks: + yield gen.sleep(0.01) + + response = yield c.scheduler.adaptive_recommendations() + assert response['status'] == 'up' From 53e3770e01f98ce5ae08d191a35fc2b06d8b8269 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 25 Jun 2018 09:12:18 -0400 Subject: [PATCH 0010/1550] Add test for as_completed for loops in Python 2 (#2071) --- distributed/client.py | 3 ++- distributed/compatibility.py | 3 +++ distributed/tests/test_as_completed.py | 20 ++++++++++++++++++-- 3 files changed, 23 insertions(+), 3 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 7f474171835..72c08978ee5 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -45,7 +45,8 @@ from .utils_comm import (WrappedKey, unpack_remotedata, pack_data, scatter_to_workers, gather_from_workers) from .cfexecutor import ClientExecutor -from .compatibility import Queue as pyQueue, Empty, isqueue, html_escape +from .compatibility import (Queue as pyQueue, Empty, isqueue, html_escape, + StopAsyncIteration) from .core import connect, rpc, clean_exception, CommClosedError, PooledRPCCall from .metrics import time from .node import Node diff --git a/distributed/compatibility.py b/distributed/compatibility.py index 2bf4ceb0f4f..ef5e7040586 100644 --- a/distributed/compatibility.py +++ b/distributed/compatibility.py @@ -18,6 +18,8 @@ PY3 = False ConnectionRefusedError = OSError FileExistsError = OSError + class StopAsyncIteration(Exception): + pass import gzip @@ -71,6 +73,7 @@ def iscoroutinefunction(func): from gzip import compress as gzip_compress ConnectionRefusedError = ConnectionRefusedError FileExistsError = FileExistsError + StopAsyncIteration = StopAsyncIteration def isqueue(o): return isinstance(o, Queue) diff --git a/distributed/tests/test_as_completed.py b/distributed/tests/test_as_completed.py index 42e906f68f6..d9c2636a178 100644 --- a/distributed/tests/test_as_completed.py +++ b/distributed/tests/test_as_completed.py @@ -8,10 +8,9 @@ from distributed import Client from distributed.client import _as_completed, as_completed, _first_completed -from distributed.compatibility import Empty +from distributed.compatibility import Empty, StopAsyncIteration, Queue from distributed.utils_test import cluster, gen_cluster, inc from distributed.utils_test import loop # noqa: F401 -from distributed.compatibility import Queue @gen_cluster(client=True) @@ -152,3 +151,20 @@ def _(): result = list(ac) assert result == [x] + + +@gen_cluster(client=True) +def test_async_for_py2_equivalent(c, s, a, b): + futures = c.map(sleep, [0.01] * 3, pure=False) + seq = as_completed(futures) + x = yield seq.__anext__() + y = yield seq.__anext__() + z = yield seq.__anext__() + + assert x.done() + assert y.done() + assert z.done() + assert x.key != y.key + + with pytest.raises(StopAsyncIteration): + yield seq.__anext__() From 40e27ea577b25dc643c1166556b3ba273477b053 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 25 Jun 2018 12:29:49 -0400 Subject: [PATCH 0011/1550] support TB and PB in format bytes (#2072) --- distributed/utils.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/distributed/utils.py b/distributed/utils.py index e10c64e6344..c3a3ac9cb62 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -1081,7 +1081,15 @@ def format_bytes(n): '12.35 MB' >>> format_bytes(1234567890) '1.23 GB' + >>> format_bytes(1234567890000) + '1.23 TB' + >>> format_bytes(1234567890000000) + '1.23 PB' """ + if n > 1e15: + return '%0.2f PB' % (n / 1e15) + if n > 1e12: + return '%0.2f TB' % (n / 1e12) if n > 1e9: return '%0.2f GB' % (n / 1e9) if n > 1e6: From 0dedc514caa08d7a177a13c8cc745fd22aaabb20 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 26 Jun 2018 11:39:14 -0400 Subject: [PATCH 0012/1550] Avoid accessing Worker.scheduler_delay around yield point (#2074) If this changed during the transfer then this could cause negative durations --- distributed/worker.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/distributed/worker.py b/distributed/worker.py index ff7628b7346..702fc7d9eb7 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1773,19 +1773,23 @@ def gather_dep(self, worker, dep, deps, total_nbytes, cause=None): self.log.append(('request-dep', dep, worker, deps)) logger.debug("Request %d keys", len(deps)) - start = time() + self.scheduler_delay + start = time() response = yield get_data_from_worker(self.rpc, deps, worker, self.address) - stop = time() + self.scheduler_delay + stop = time() if cause: - self.startstops[cause].append(('transfer', start, stop)) + self.startstops[cause].append(( + 'transfer', + start + self.scheduler_delay, + stop + self.scheduler_delay + )) total_bytes = sum(self.nbytes.get(dep, 0) for dep in response) duration = (stop - start) or 0.5 self.incoming_transfer_log.append({ - 'start': start, - 'stop': stop, - 'middle': (start + stop) / 2.0, + 'start': start + self.scheduler_delay, + 'stop': stop + self.scheduler_delay, + 'middle': (start + stop) / 2.0 + self.scheduler_delay, 'duration': duration, 'keys': {dep: self.nbytes.get(dep, None) for dep in response}, 'total': total_bytes, From 41e1a014603c5b63b22f5c83874f8f3b914a0ae2 Mon Sep 17 00:00:00 2001 From: Marius van Niekerk Date: Wed, 27 Jun 2018 12:01:08 -0400 Subject: [PATCH 0013/1550] Allow `name` to be explicitly passed in publish_dataset (#1995) Added to the function signature for publish_dataset. We can now make datasets that don't have string names --- distributed/client.py | 32 ++++++++++++++++++----- distributed/publish.py | 10 +++---- distributed/tests/test_publish.py | 43 ++++++++++++++++++++++++------- 3 files changed, 64 insertions(+), 21 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 72c08978ee5..ca8123e1e34 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -1799,19 +1799,34 @@ def cancel(self, futures, asynchronous=None, force=False): force=force) @gen.coroutine - def _publish_dataset(self, **kwargs): + def _publish_dataset(self, *args, **kwargs): with log_errors(): coroutines = [] - for name, data in kwargs.items(): + + def add_coro(name, data): keys = [tokey(f.key) for f in futures_of(data)] coroutines.append(self.scheduler.publish_put(keys=keys, - name=tokey(name), + name=name, data=to_serialize(data), client=self.id)) + name = kwargs.pop('name', None) + if name: + if len(args) == 0: + raise ValueError( + "If name is provided, expecting call signature like" + " publish_dataset(df, name='ds')") + # in case this is a singleton, collapse it + elif len(args) == 1: + args = args[0] + add_coro(name, args) + + for name, data in kwargs.items(): + add_coro(name, data) + yield coroutines - def publish_dataset(self, **kwargs): + def publish_dataset(self, *args, **kwargs): """ Publish named datasets to scheduler @@ -1824,6 +1839,8 @@ def publish_dataset(self, **kwargs): Parameters ---------- + args : list of objects to publish as name + name : optional name of the dataset to publish kwargs: dict named collections to publish on the scheduler @@ -1835,6 +1852,9 @@ def publish_dataset(self, **kwargs): >>> df = c.persist(df) # doctest: +SKIP >>> c.publish_dataset(my_dataset=df) # doctest: +SKIP + Alternative invocation + >>> c.publish_dataset(df, name='my_dataset') + Receiving client: >>> c.list_datasets() # doctest: +SKIP @@ -1852,7 +1872,7 @@ def publish_dataset(self, **kwargs): Client.unpublish_dataset Client.persist """ - return self.sync(self._publish_dataset, **kwargs) + return self.sync(self._publish_dataset, *args, **kwargs) def unpublish_dataset(self, name, **kwargs): """ @@ -1902,7 +1922,7 @@ def get_dataset(self, name, **kwargs): Client.publish_dataset Client.list_datasets """ - return self.sync(self._get_dataset, tokey(name), **kwargs) + return self.sync(self._get_dataset, name, **kwargs) @gen.coroutine def _run_on_scheduler(self, function, *args, **kwargs): diff --git a/distributed/publish.py b/distributed/publish.py index a275cfeff08..3260e99b0e4 100644 --- a/distributed/publish.py +++ b/distributed/publish.py @@ -1,5 +1,5 @@ from collections import MutableMapping -from distributed.utils import log_errors +from distributed.utils import log_errors, tokey class PublishExtension(object): @@ -27,18 +27,18 @@ def put(self, stream=None, keys=None, data=None, name=None, client=None): with log_errors(): if name in self.datasets: raise KeyError("Dataset %s already exists" % name) - self.scheduler.client_desires_keys(keys, 'published-%s' % name) + self.scheduler.client_desires_keys(keys, 'published-%s' % tokey(name)) self.datasets[name] = {'data': data, 'keys': keys} return {'status': 'OK', 'name': name} def delete(self, stream=None, name=None): with log_errors(): out = self.datasets.pop(name, {'keys': []}) - self.scheduler.client_releases_keys(out['keys'], 'published-%s' % name) + self.scheduler.client_releases_keys(out['keys'], 'published-%s' % tokey(name)) def list(self, *args): with log_errors(): - return list(sorted(self.datasets.keys())) + return list(sorted(self.datasets.keys(), key=str)) def get(self, stream, name=None, client=None): with log_errors(): @@ -60,7 +60,7 @@ def __getitem__(self, key): return self.__client.get_dataset(key) def __setitem__(self, key, value): - self.__client.publish_dataset(**{key: value}) + self.__client.publish_dataset(value, name=key) def __delitem__(self, key): self.__client.unpublish_dataset(key) diff --git a/distributed/tests/test_publish.py b/distributed/tests/test_publish.py index a331b7b957a..a67bfce9887 100644 --- a/distributed/tests/test_publish.py +++ b/distributed/tests/test_publish.py @@ -37,6 +37,26 @@ def test_publish_simple(s, a, b): yield f.close() +@gen_cluster(client=False) +def test_publish_non_string_key(s, a, b): + c = yield Client((s.ip, s.port), asynchronous=True) + f = yield Client((s.ip, s.port), asynchronous=True) + + try: + for name in [('a', 'b'), 9.0, 8]: + data = yield c.scatter(range(3)) + out = yield c.publish_dataset(data, name=name) + assert name in s.extensions['publish'].datasets + assert isinstance(s.extensions['publish'].datasets[name]['data'], Serialized) + + datasets = yield c.scheduler.publish_list() + assert name in datasets + + finally: + c.close() + f.close() + + @gen_cluster(client=False) def test_publish_roundtrip(s, a, b): c = yield Client((s.ip, s.port), asynchronous=True) @@ -167,26 +187,29 @@ def test_publish_bag(s, a, b): def test_datasets_setitem(loop): with cluster() as (s, _): with Client(s['address'], loop=loop) as client: - key, value = 'key', 'value' - client.datasets[key] = value - assert client.get_dataset('key') == value + for key in ['key', ('key', 'key'), 1]: + value = 'value' + client.datasets[key] = value + assert client.get_dataset(key) == value def test_datasets_getitem(loop): with cluster() as (s, _): with Client(s['address'], loop=loop) as client: - key, value = 'key', 'value' - client.publish_dataset(key=value) - assert client.datasets[key] == value + for key in ['key', ('key', 'key'), 1]: + value = 'value' + client.publish_dataset(value, name=key) + assert client.datasets[key] == value def test_datasets_delitem(loop): with cluster() as (s, _): with Client(s['address'], loop=loop) as client: - key, value = 'key', 'value' - client.publish_dataset(key=value) - del client.datasets[key] - assert key not in client.list_datasets() + for key in ['key', ('key', 'key'), 1]: + value = 'value' + client.publish_dataset(value, name=key) + del client.datasets[key] + assert key not in client.list_datasets() def test_datasets_keys(loop): From a352a406ea7ad7cc4a8dab83c096bfbd04491d54 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Est=C3=A8ve?= Date: Wed, 27 Jun 2018 20:32:23 +0200 Subject: [PATCH 0014/1550] Install msgpack-python with conda on AppVeyor. (#2075) This is a cleanup of the mspack-python to msgpack renaming on PyPI. --- continuous_integration/setup_conda_environment.cmd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/continuous_integration/setup_conda_environment.cmd b/continuous_integration/setup_conda_environment.cmd index 4f2674dd65b..f03441c336e 100644 --- a/continuous_integration/setup_conda_environment.cmd +++ b/continuous_integration/setup_conda_environment.cmd @@ -30,6 +30,7 @@ call deactivate joblib ^ jupyter_client ^ mock ^ + msgpack-python ^ psutil ^ pytest ^ python=%PYTHON% ^ @@ -48,7 +49,6 @@ call activate %CONDA_ENV% %PIP_INSTALL% git+https://github.com/dask/zict --upgrade %PIP_INSTALL% pytest-repeat pytest-timeout pytest-faulthandler sortedcollections -%PIP_INSTALL% msgpack @rem Display final environment (for reproducing) %CONDA% list From 1167cd0d7c6887645ce7fbf4d0e13fd3a01ebf7f Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 28 Jun 2018 14:35:13 -0400 Subject: [PATCH 0015/1550] Provide communication context to serialization functions (#2054) * collect dask-specific type-based serialization to bottom of serialize.py * Generalize has_keyword function and move to utils.py * Provide context to serialization functions This enables comms to provide a context of information to serialization functions *if* they provide a ``context=`` keyword for it. --- distributed/comm/tcp.py | 4 +- distributed/comm/utils.py | 5 +- distributed/core.py | 17 +- distributed/node.py | 7 +- distributed/protocol/__init__.py | 1 + distributed/protocol/core.py | 5 +- distributed/protocol/serialize.py | 216 ++++++++++--------- distributed/protocol/tests/test_serialize.py | 93 +++++++- distributed/utils.py | 11 + docs/source/serialization.rst | 32 ++- 10 files changed, 267 insertions(+), 124 deletions(-) diff --git a/distributed/comm/tcp.py b/distributed/comm/tcp.py index 9a8fe7bf087..d1dcab7569c 100644 --- a/distributed/comm/tcp.py +++ b/distributed/comm/tcp.py @@ -216,7 +216,9 @@ def write(self, msg, serializers=None, on_error='message'): frames = yield to_frames(msg, serializers=serializers, - on_error=on_error) + on_error=on_error, + context={'sender': self._local_addr, + 'recipient': self._peer_addr}) try: lengths = ([struct.pack('Q', len(frames))] + diff --git a/distributed/comm/utils.py b/distributed/comm/utils.py index 32347e3952b..02677b9faba 100644 --- a/distributed/comm/utils.py +++ b/distributed/comm/utils.py @@ -29,7 +29,7 @@ def offload(fn, *args, **kwargs): @gen.coroutine -def to_frames(msg, serializers=None, on_error='message'): +def to_frames(msg, serializers=None, on_error='message', context=None): """ Serialize a message into a list of Distributed protocol frames. """ @@ -37,7 +37,8 @@ def _to_frames(): try: return list(protocol.dumps(msg, serializers=serializers, - on_error=on_error)) + on_error=on_error, + context=context)) except Exception as e: logger.info("Unserializable Message: %s", msg) logger.exception(e) diff --git a/distributed/core.py b/distributed/core.py index 7bd6c12960c..337c7a626ef 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -3,7 +3,6 @@ from collections import defaultdict, deque from concurrent.futures import CancelledError from functools import partial -import inspect import logging import six import traceback @@ -17,14 +16,14 @@ from tornado.ioloop import IOLoop from tornado.locks import Event -from .compatibility import PY3, get_thread_identity +from .compatibility import get_thread_identity from .comm import (connect, listen, CommClosedError, normalize_address, unparse_host_port, get_address_host_port) from .metrics import time from .system_monitor import SystemMonitor from .utils import (get_traceback, truncate_exception, ignoring, shutting_down, - PeriodicCallback, parse_timedelta) + PeriodicCallback, parse_timedelta, has_keyword) from . import protocol @@ -310,7 +309,7 @@ def handle_comm(self, comm, shutting_down=shutting_down): logger.warning("No handler %s found in %s", op, type(self).__name__, exc_info=True) else: - if serializers is not None and has_serializers_keyword(handler): + if serializers is not None and has_keyword(handler, 'serializers'): msg['serializers'] = serializers # add back in logger.debug("Calling into handler %s", handler.__name__) @@ -852,13 +851,3 @@ def clean_exception(exception, traceback, **kwargs): elif isinstance(traceback, string_types): traceback = None # happens if the traceback failed serializing return type(exception), exception, traceback - - -def has_serializers_keyword(func): - if PY3: - return 'serializers' in inspect.signature(func).parameters - else: - # https://stackoverflow.com/questions/50100498/determine-keywords-of-a-tornado-coroutine - if gen.is_coroutine_function(func): - func = func.__wrapped__ - return 'serializers' in inspect.getargspec(func).args diff --git a/distributed/node.py b/distributed/node.py index e7fe00484b6..8373c07709c 100644 --- a/distributed/node.py +++ b/distributed/node.py @@ -33,11 +33,14 @@ class ServerNode(Node, Server): def __init__(self, handlers=None, stream_handlers=None, connection_limit=512, deserialize=True, - connection_args=None, io_loop=None): + connection_args=None, io_loop=None, serializers=None, + deserializers=None): Node.__init__(self, deserialize=deserialize, connection_limit=connection_limit, connection_args=connection_args, - io_loop=io_loop) + io_loop=io_loop, + serializers=serializers, + deserializers=deserializers) Server.__init__(self, handlers=handlers, stream_handlers=stream_handlers, connection_limit=connection_limit, diff --git a/distributed/protocol/__init__.py b/distributed/protocol/__init__.py index 38c8ce05d95..a6a9afaf324 100644 --- a/distributed/protocol/__init__.py +++ b/distributed/protocol/__init__.py @@ -8,6 +8,7 @@ serialize, deserialize, nested_deserialize, Serialize, Serialized, to_serialize, register_serialization, register_serialization_lazy, serialize_bytes, deserialize_bytes, serialize_bytelist, + register_serialization_family, ) from ..utils import ignoring diff --git a/distributed/protocol/core.py b/distributed/protocol/core.py index 9209aa06184..4033c9be1a9 100644 --- a/distributed/protocol/core.py +++ b/distributed/protocol/core.py @@ -22,7 +22,7 @@ logger = logging.getLogger(__name__) -def dumps(msg, serializers=None, on_error='message'): +def dumps(msg, serializers=None, on_error='message', context=None): """ Transform Python message to bytestream suitable for communication """ try: data = {} @@ -40,7 +40,8 @@ def dumps(msg, serializers=None, on_error='message'): data = {key: serialize(value.data, serializers=serializers, - on_error=on_error) + on_error=on_error, + context=context) for key, value in data.items() if type(value) is Serialize} diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index 773b08d7887..c2a1274afe2 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -12,6 +12,7 @@ from . import pickle from ..compatibility import PY2 +from ..utils import has_keyword from .compression import maybe_compress, decompress from .utils import unpack_frames, pack_frames_prelude, frame_split_size @@ -21,12 +22,15 @@ lazy_registrations = {} -def dask_dumps(x): +def dask_dumps(x, context=None): """Serialise object using the class-based registry""" typ = typename(type(x)) if typ in class_serializers: - dumps, loads = class_serializers[typ] - header, frames = dumps(x) + dumps, loads, has_context = class_serializers[typ] + if has_context: + header, frames = dumps(x, context=context) + else: + header, frames = dumps(x) header['type'] = typ header['serializer'] = 'dask' return header, frames @@ -43,7 +47,7 @@ def dask_loads(header, frames): _find_lazy_registration(typ) try: - dumps, loads = class_serializers[typ] + dumps, loads, _ = class_serializers[typ] except KeyError: raise TypeError("Serialization for type %s not found" % typ) else: @@ -71,82 +75,20 @@ def serialization_error_loads(header, frames): raise TypeError(msg) -families = { - 'dask': (dask_dumps, dask_loads), - 'pickle': (pickle_dumps, pickle_loads), - 'msgpack': (msgpack_dumps, msgpack_loads), - 'error': (None, serialization_error_loads), -} +families = {} -def register_serialization(cls, serialize, deserialize): - """ Register a new class for dask-custom serialization +def register_serialization_family(name, dumps, loads): + families[name] = (dumps, loads, dumps and has_keyword(dumps, 'context')) - Parameters - ---------- - cls: type - serialize: function - deserialize: function - Examples - -------- - >>> class Human(object): - ... def __init__(self, name): - ... self.name = name - - >>> def serialize(human): - ... header = {} - ... frames = [human.name.encode()] - ... return header, frames +register_serialization_family('dask', dask_dumps, dask_loads) +register_serialization_family('pickle', pickle_dumps, pickle_loads) +register_serialization_family('msgpack', msgpack_dumps, msgpack_loads) +register_serialization_family('error', None, serialization_error_loads) - >>> def deserialize(header, frames): - ... return Human(frames[0].decode()) - - >>> register_serialization(Human, serialize, deserialize) - >>> serialize(Human('Alice')) - ({}, [b'Alice']) - - See Also - -------- - serialize - deserialize - """ - if isinstance(cls, type): - name = typename(cls) - elif isinstance(cls, str): - name = cls - class_serializers[name] = (serialize, deserialize) - - -def register_serialization_lazy(toplevel, func): - """Register a registration function to be called if *toplevel* - module is ever loaded. - """ - lazy_registrations[toplevel] = func - - -def typename(typ): - """ Return name of type - Examples - -------- - >>> from distributed import Scheduler - >>> typename(Scheduler) - 'distributed.scheduler.Scheduler' - """ - return typ.__module__ + '.' + typ.__name__ - - -def _find_lazy_registration(typename): - toplevel, _, _ = typename.partition('.') - if toplevel in lazy_registrations: - lazy_registrations.pop(toplevel)() - return True - else: - return False - - -def serialize(x, serializers=None, on_error='message'): +def serialize(x, serializers=None, on_error='message', context=None): r""" Convert object to a header and list of bytestrings @@ -191,9 +133,9 @@ def serialize(x, serializers=None, on_error='message'): tb = '' for name in serializers: - dumps, loads = families[name] + dumps, loads, wants_context = families[name] try: - header, frames = dumps(x) + header, frames = dumps(x, context=context) if wants_context else dumps(x) header['serializer'] = name return header, frames except NotImplementedError: @@ -232,7 +174,7 @@ def deserialize(header, frames, deserializers=None): if deserializers is not None and name not in deserializers: raise TypeError("Data serialized with %s but only able to deserialize " "data with %s" % (name, str(list(deserializers)))) - dumps, loads = families[name] + dumps, loads, wants_context = families[name] return loads(header, frames) @@ -394,29 +336,6 @@ def replace_inner(x): return replace_inner(x) -@partial(normalize_token.register, Serialized) -def normalize_Serialized(o): - return [o.header] + o.frames # for dask.base.tokenize - - -# Teach serialize how to handle bytestrings -def _serialize_bytes(obj): - header = {} # no special metadata - frames = [obj] - return header, frames - - -def _deserialize_bytes(header, frames): - return frames[0] - - -# NOTE: using the same exact serialization means a bytes object may be -# deserialized as bytearray or vice-versa... Not sure this is a problem -# in practice. -register_serialization(bytes, _serialize_bytes, _deserialize_bytes) -register_serialization(bytearray, _serialize_bytes, _deserialize_bytes) - - def serialize_bytelist(x, **kwargs): header, frames = serialize(x, **kwargs) frames = frame_split_size(frames) @@ -448,3 +367,100 @@ def deserialize_bytes(b): header = {} frames = decompress(header, frames) return deserialize(header, frames) + + +################################ +# Class specific serialization # +################################ + + +def register_serialization(cls, serialize, deserialize): + """ Register a new class for dask-custom serialization + + Parameters + ---------- + cls: type + serialize: function + deserialize: function + + Examples + -------- + >>> class Human(object): + ... def __init__(self, name): + ... self.name = name + + >>> def serialize(human): + ... header = {} + ... frames = [human.name.encode()] + ... return header, frames + + >>> def deserialize(header, frames): + ... return Human(frames[0].decode()) + + >>> register_serialization(Human, serialize, deserialize) + >>> serialize(Human('Alice')) + ({}, [b'Alice']) + + See Also + -------- + serialize + deserialize + """ + if isinstance(cls, type): + name = typename(cls) + elif isinstance(cls, str): + name = cls + class_serializers[name] = (serialize, + deserialize, + has_keyword(serialize, 'context')) + + +def register_serialization_lazy(toplevel, func): + """Register a registration function to be called if *toplevel* + module is ever loaded. + """ + lazy_registrations[toplevel] = func + + +def typename(typ): + """ Return name of type + + Examples + -------- + >>> from distributed import Scheduler + >>> typename(Scheduler) + 'distributed.scheduler.Scheduler' + """ + return typ.__module__ + '.' + typ.__name__ + + +def _find_lazy_registration(typename): + toplevel, _, _ = typename.partition('.') + if toplevel in lazy_registrations: + lazy_registrations.pop(toplevel)() + return True + else: + return False + + +@partial(normalize_token.register, Serialized) +def normalize_Serialized(o): + return [o.header] + o.frames # for dask.base.tokenize + + +# Teach serialize how to handle bytestrings +def _serialize_bytes(obj): + header = {} # no special metadata + frames = [obj] + return header, frames + + +def _deserialize_bytes(header, frames): + return frames[0] + + +# NOTE: using the same exact serialization means a bytes object may be +# deserialized as bytearray or vice-versa... Not sure this is a problem +# in practice. +register_serialization(bytes, _serialize_bytes, _deserialize_bytes) +register_serialization(bytearray, _serialize_bytes, _deserialize_bytes) diff --git a/distributed/protocol/tests/test_serialize.py b/distributed/protocol/tests/test_serialize.py index a1cedf5f2f7..dc7377385ea 100644 --- a/distributed/protocol/tests/test_serialize.py +++ b/distributed/protocol/tests/test_serialize.py @@ -3,14 +3,17 @@ import copy import pickle +import msgpack import numpy as np import pytest from toolz import identity +from distributed import wait from distributed.protocol import (register_serialization, serialize, deserialize, nested_deserialize, Serialize, Serialized, to_serialize, serialize_bytes, - deserialize_bytes, serialize_bytelist,) + deserialize_bytes, serialize_bytelist, + register_serialization_family) from distributed.utils import nbytes from distributed.utils_test import inc, gen_test from distributed.comm.utils import to_frames, from_frames @@ -245,3 +248,91 @@ def test_err_on_bad_deserializer(): with pytest.raises(TypeError) as info: yield from_frames(frames, deserializers=['msgpack']) + + +class MyObject(object): + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + +def my_dumps(obj, context=None): + if type(obj).__name__ == 'MyObject': + header = {'serializer': 'my-ser'} + frames = [msgpack.dumps(obj.__dict__, use_bin_type=True), + msgpack.dumps(context, use_bin_type=True)] + return header, frames + else: + raise NotImplementedError() + + +def my_loads(header, frames): + obj = MyObject(**msgpack.loads(frames[0], encoding='utf8')) + + # to provide something to test against, lets just attach the context to + # the object itself + obj.context = msgpack.loads(frames[1], encoding='utf8') + return obj + + +@gen_cluster(client=True, + client_kwargs={'serializers': ['my-ser', 'pickle']}, + worker_kwargs={'serializers': ['my-ser', 'pickle']}) +def test_context_specific_serialization(c, s, a, b): + register_serialization_family('my-ser', my_dumps, my_loads) + + try: + # Create the object on A, force communication to B + x = c.submit(MyObject, x=1, y=2, workers=a.address) + y = c.submit(lambda x: x, x, workers=b.address) + + yield wait(y) + + key = y.key + + def check(dask_worker): + # Get the context from the object stored on B + my_obj = dask_worker.data[key] + return my_obj.context + + result = yield c.run(check, workers=[b.address]) + expected = {'sender': a.address, 'recipient': b.address} + assert result[b.address]['sender'] == a.address # see origin worker + + z = yield y # bring object to local process + + assert z.x == 1 and z.y == 2 + assert z.context['sender'] == b.address + finally: + from distributed.protocol.serialize import families + del families['my-ser'] + + +@gen_cluster(client=True) +def test_context_specific_serialization_class(c, s, a, b): + register_serialization(MyObject, my_dumps, my_loads) + + try: + # Create the object on A, force communication to B + x = c.submit(MyObject, x=1, y=2, workers=a.address) + y = c.submit(lambda x: x, x, workers=b.address) + + yield wait(y) + + key = y.key + + def check(dask_worker): + # Get the context from the object stored on B + my_obj = dask_worker.data[key] + return my_obj.context + + result = yield c.run(check, workers=[b.address]) + expected = {'sender': a.address, 'recipient': b.address} + assert result[b.address]['sender'] == a.address # see origin worker + + z = yield y # bring object to local process + + assert z.x == 1 and z.y == 2 + assert z.context['sender'] == b.address + finally: + from distributed.protocol.serialize import class_serializers, typename + del class_serializers[typename(MyObject)] diff --git a/distributed/utils.py b/distributed/utils.py index c3a3ac9cb62..53495ba60b4 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -5,6 +5,7 @@ from contextlib import contextmanager from datetime import timedelta import functools +import inspect import json import logging import multiprocessing @@ -1385,3 +1386,13 @@ def reset_logger_locks(): # Only bother if asyncio has been loaded by Tornado if 'asyncio' in sys.modules: fix_asyncio_event_loop_policy(sys.modules['asyncio']) + + +def has_keyword(func, keyword): + if PY3: + return keyword in inspect.signature(func).parameters + else: + # https://stackoverflow.com/questions/50100498/determine-keywords-of-a-tornado-coroutine + if gen.is_coroutine_function(func): + func = func.__wrapped__ + return keyword in inspect.getargspec(func).args diff --git a/docs/source/serialization.rst b/docs/source/serialization.rst index 01eb1044dd6..e457681a662 100644 --- a/docs/source/serialization.rst +++ b/docs/source/serialization.rst @@ -86,13 +86,41 @@ dictionary with an appropriate name. Here is the definition of frame = frames[0] return pickle.loads(frame) - from distributed.protocol.serialize import families - families['pickle'] = (pickle_dumps, pickle_loads) + from distributed.protocol.serialize import register_serialization_family + register_serialization_family('pickle', pickle_dumps, pickle_loads) After this the name ``'pickle'`` can be used in the ``serializers=`` and ``deserializers=`` keywords in ``Client`` and other parts of Dask. +Communication Context ++++++++++++++++++++++ + +.. note:: This is an experimental feature and may change without notice + +Dask :doc:`Comms ` can provide additional context to +serialization family functions if they provide a ``context=`` keyword. +This allows serialization to behave differently according to how it is being +used. + +.. code-block:: python + + def my_dumps(x, context=None): + if context and 'recipient' in context: + # check if we're sending to the same host or not + +The context depends on the kind of communication. For example when sending +over TCP, the address of the sender (us) and the recipient are available in a +dictionary. + +.. code-block:: python + + >>> context + {'sender': 'tcp://127.0.0.1:1234', 'recipient': 'tcp://127.0.0.1:5678'} + +Other comms may provide other information. + + Dask Serialization Family ------------------------- From 94f076c92bd3e2ab9340a0876e10eab5b278bbd6 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sun, 1 Jul 2018 20:28:19 -0400 Subject: [PATCH 0016/1550] Use default pygments styling --- docs/source/conf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 3bb5a13ae7c..8ba681f0a73 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -104,7 +104,7 @@ #show_authors = False # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = 'default' # A list of ignored prefixes for module index sorting. #modindex_common_prefix = [] From f7c8f339412764367bcd42e01347a80fb25b9822 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Est=C3=A8ve?= Date: Mon, 2 Jul 2018 17:29:32 +0200 Subject: [PATCH 0017/1550] Fix typo in docstring (#2087) --- distributed/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/client.py b/distributed/client.py index ca8123e1e34..b066d5cc05d 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -1165,7 +1165,7 @@ def submit(self, func, *args, **kwargs): key: str Unique identifier for the task. Defaults to function-name and hash allow_other_workers: bool (defaults to False) - Used with `workers`. Inidicates whether or not the computations + Used with `workers`. Indicates whether or not the computations may be performed on workers that are not in the `workers` set(s). retries: int (default to 0) Number of allowed automatic retries if the task fails From fc1312facff8df81f766562fb069ea9865ebe10d Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 2 Jul 2018 17:47:01 -0400 Subject: [PATCH 0018/1550] use https to get stylesheet in docs --- docs/source/conf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 8ba681f0a73..3869ec0367b 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -385,4 +385,4 @@ } def setup(app): - app.add_stylesheet("http://dask.pydata.org/en/latest/_static/style.css") + app.add_stylesheet("https://dask.pydata.org/en/latest/_static/style.css") From e9d527ab7cf840e9ff40d2f9c3833f3999a198e9 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 4 Jul 2018 07:54:15 -0400 Subject: [PATCH 0019/1550] Handle exceptions on deserialized comm with text error (#2093) --- distributed/core.py | 14 ++++++++++---- distributed/tests/test_core.py | 16 +++++++++++++++- distributed/worker.py | 1 - 3 files changed, 25 insertions(+), 6 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index 337c7a626ef..05b901cf818 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -408,8 +408,8 @@ def pingpong(comm): @gen.coroutine -def send_recv(comm, reply=True, deserialize=True, serializers=None, - deserializers=None, **kwargs): +def send_recv(comm, reply=True, serializers=None, deserializers=None, + **kwargs): """ Send and recv with a Comm. Keyword arguments turn into the message @@ -442,7 +442,10 @@ def send_recv(comm, reply=True, deserialize=True, serializers=None, comm.abort() if isinstance(response, dict) and response.get('status') == 'uncaught-error': - six.reraise(*clean_exception(**response)) + if comm.deserialize: + six.reraise(*clean_exception(**response)) + else: + raise Exception(response['text']) raise gen.Return(response) @@ -834,7 +837,10 @@ def error_message(e, status='error'): else: tb_result = protocol.to_serialize(tb) - return {'status': status, 'exception': e4, 'traceback': tb_result} + return {'status': status, + 'exception': e4, + 'traceback': tb_result, + 'text': str(e2)} def clean_exception(exception, traceback, **kwargs): diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index bf7fa2aa42c..ce67817e6a6 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -23,7 +23,7 @@ assert_can_connect_from_everywhere_4, assert_can_connect_from_everywhere_4_6, assert_can_connect_from_everywhere_6, assert_can_connect_locally_4, assert_can_connect_locally_6, - tls_security, captured_logger, inc) + tls_security, captured_logger, inc, throws) from distributed.utils_test import loop # noqa F401 @@ -658,3 +658,17 @@ def f(): @gen_cluster() def test_thread_id(s, a, b): assert s.thread_id == a.thread_id == b.thread_id == get_thread_identity() + + +@gen_test() +def test_deserialize_error(): + server = Server({'throws': throws}) + server.listen(0) + + comm = yield connect(server.address, deserialize=False) + with pytest.raises(Exception) as info: + yield send_recv(comm, op='throws') + + assert type(info.value) == Exception + for c in str(info.value): + assert c.isalpha() or c in "(',!)" # no crazy bytestrings diff --git a/distributed/worker.py b/distributed/worker.py index 702fc7d9eb7..ab4f1d81570 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -2740,7 +2740,6 @@ def get_data_from_worker(rpc, keys, worker, who=None): response = yield send_recv(comm, serializers=rpc.serializers, deserializers=rpc.deserializers, - deserialize=rpc.deserialize, op='get_data', keys=keys, who=who) yield comm.write('OK') finally: From b71e8250b475216a58154d449d4305cc2f1f7f6a Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 4 Jul 2018 07:55:01 -0400 Subject: [PATCH 0020/1550] Prefer gathering data from same host (#2090) * Prefer gathering data from same host * add who= to all get_data calls * relax test_retire_many_workers --- distributed/tests/test_client.py | 4 +++- distributed/tests/test_worker.py | 12 ++++++++++++ distributed/utils_comm.py | 5 +++-- distributed/worker.py | 12 +++++++++--- 4 files changed, 27 insertions(+), 6 deletions(-) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index e015db9f99c..75558352111 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -4082,8 +4082,10 @@ def test_retire_many_workers(c, s, *workers): assert results == list(range(100)) assert len(s.has_what) == len(s.ncores) == 3 + assert all(future.done() for future in futures) + assert all(s.tasks[future.key].state == 'memory' for future in futures) for w, keys in s.has_what.items(): - assert 20 < len(keys) < 50 + assert 15 < len(keys) < 50 @gen_cluster(client=True, diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index e8878029e8e..185278e7e94 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1180,3 +1180,15 @@ def test_wait_for_outgoing(c, s, a, b): ratio = aa / bb assert 1 / 3 < ratio < 3 + + +@gen_cluster(ncores=[('127.0.0.1', 1), ('127.0.0.1', 1), ('127.0.0.2', 1)], + client=True) +def test_prefer_gather_from_local_address(c, s, w1, w2, w3): + x = yield c.scatter(123, workers=[w1.address, w3.address], broadcast=True) + + y = c.submit(inc, x, workers=[w2.address]) + yield wait(y) + + assert any(d['who'] == w2.address for d in w1.outgoing_transfer_log) + assert not any(d['who'] == w2.address for d in w3.outgoing_transfer_log) diff --git a/distributed/utils_comm.py b/distributed/utils_comm.py index 8e4d2ac1300..43dbe49fb0c 100644 --- a/distributed/utils_comm.py +++ b/distributed/utils_comm.py @@ -14,7 +14,7 @@ @gen.coroutine -def gather_from_workers(who_has, rpc, close=True, serializers=None): +def gather_from_workers(who_has, rpc, close=True, serializers=None, who=None): """ Gather data directly from peers Parameters @@ -56,7 +56,8 @@ def gather_from_workers(who_has, rpc, close=True, serializers=None): rpcs = {addr: rpc(addr) for addr in d} try: - coroutines = {address: get_data_from_worker(rpc, keys, address) + coroutines = {address: get_data_from_worker(rpc, keys, address, + who=who) for address, keys in d.items()} response = {} for worker, c in coroutines.items(): diff --git a/distributed/worker.py b/distributed/worker.py index ab4f1d81570..df8cd6ad8a9 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -704,7 +704,7 @@ def gather(self, comm=None, who_has=None): for k, v in who_has.items() if k not in self.data} result, missing_keys, missing_workers = yield gather_from_workers( - who_has, rpc=self.rpc) + who_has, rpc=self.rpc, who=self.address) if missing_keys: logger.warning("Could not find data: %s on workers: %s (who_has: %s)", missing_keys, missing_workers, who_has) @@ -1662,7 +1662,12 @@ def ensure_communicating(self): if not workers: in_flight = True continue - worker = random.choice(list(workers)) + host = get_address_host(self.address) + local = [w for w in workers if get_address_host(w) == host] + if local: + worker = random.choice(local) + else: + worker = random.choice(list(workers)) to_gather, total_nbytes = self.select_keys_for_gather(worker, dep) self.comm_nbytes += total_nbytes self.in_flight_workers[worker] = to_gather @@ -1774,7 +1779,8 @@ def gather_dep(self, worker, dep, deps, total_nbytes, cause=None): logger.debug("Request %d keys", len(deps)) start = time() - response = yield get_data_from_worker(self.rpc, deps, worker, self.address) + response = yield get_data_from_worker(self.rpc, deps, worker, + who=self.address) stop = time() if cause: From 89992d0403287b6b29398474af166ca43b0717c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Ponte?= Date: Wed, 4 Jul 2018 14:25:34 +0200 Subject: [PATCH 0021/1550] Adjust worker doc after change in config file location and treatment (#2094) --- docs/source/worker.rst | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/docs/source/worker.rst b/docs/source/worker.rst index ddaf5c334c4..deaa0243913 100644 --- a/docs/source/worker.rst +++ b/docs/source/worker.rst @@ -156,16 +156,19 @@ Workers use a few different policies to keep memory use beneath this limit: 3. At 80% of memory load, stop accepting new work on local thread pool 4. At 95% of memory load, terminate and restart the worker -These values can be configured by modifying the ``~/.dask/config.yaml`` file +These values can be configured by modifying the ``~/.config/dask/distributed.yaml`` file .. code-block:: yaml - # Fractions of worker memory at which we take action to avoid memory blowup - # Set any of the lower three values to False to turn off the behavior entirely - worker-memory-target: 0.60 # target fraction to stay below - worker-memory-spill: 0.70 # fraction at which we spill to disk - worker-memory-pause: 0.80 # fraction at which we pause worker threads - worker-memory-terminate: 0.95 # fraction at which we terminate the worker + distributed: + worker: + # Fractions of worker memory at which we take action to avoid memory blowup + # Set any of the lower three values to False to turn off the behavior entirely + memory: + target: 0.60 # target fraction to stay below + spill: 0.70 # fraction at which we spill to disk + pause: 0.80 # fraction at which we pause worker threads + terminate: 0.95 # fraction at which we terminate the worker Spill data to Disk From 5bbbb3c4370ed05ba226925f47b023960966987e Mon Sep 17 00:00:00 2001 From: Bartosz Marcinkowski Date: Thu, 5 Jul 2018 16:48:06 +0200 Subject: [PATCH 0022/1550] removed hardcoded value of memory terminate fraction from a log message (#2096) --- distributed/nanny.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/distributed/nanny.py b/distributed/nanny.py index afe02431438..6b0c0ec9620 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -249,7 +249,7 @@ def _(): raise gen.Return('OK') def memory_monitor(self): - """ Track worker's memory. Restart if it goes above 95% """ + """ Track worker's memory. Restart if it goes above terminate fraction """ if self.status != 'running': return process = self.process.process @@ -262,7 +262,8 @@ def memory_monitor(self): memory = proc.memory_info().rss frac = memory / self.memory_limit if self.memory_terminate_fraction and frac > self.memory_terminate_fraction: - logger.warning("Worker exceeded 95% memory budget. Restarting") + logger.warning("Worker exceeded %d%% memory budget. Restarting", + 100 * self.memory_terminate_fraction) process.terminate() def is_alive(self): From 1c1c72e79ad47c3209c87ba1e66806d8c1fbb6c2 Mon Sep 17 00:00:00 2001 From: jakirkham Date: Thu, 5 Jul 2018 12:38:18 -0400 Subject: [PATCH 0023/1550] Update example for stopping a worker (#2088) The `remove_worker` syntax is outdated. Update it to use `stop_worker`, which is the current syntax. --- distributed/deploy/local.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/deploy/local.py b/distributed/deploy/local.py index b88966a2a02..be063366b6d 100644 --- a/distributed/deploy/local.py +++ b/distributed/deploy/local.py @@ -70,7 +70,7 @@ class LocalCluster(Cluster): Shut down the extra worker - >>> c.remove_worker(w) # doctest: +SKIP + >>> c.stop_worker(w) # doctest: +SKIP Pass extra keyword arguments to Bokeh From 95d55ab2ca2b050b45c08737d902386775bc0940 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sat, 7 Jul 2018 14:08:43 -0400 Subject: [PATCH 0024/1550] Don't forget released keys (#2098) Previously we would allow forgetting keys if a dependency of a forgotten key had no active waiting tasks. Now we properly check dependents, not active waiters. --- distributed/scheduler.py | 9 +++++---- distributed/tests/test_scheduler.py | 15 +++++++++++++++ 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 33f8263b93d..bbcf0dad08d 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -3456,7 +3456,8 @@ def transition_processing_released(self, key): recommendations[key] = 'forgotten' elif ts.waiters or ts.who_wants: recommendations[key] = 'waiting' - else: + + if recommendations.get(key) != 'waiting': for dts in ts.dependencies: if dts.state != 'released': s = dts.waiters @@ -3590,7 +3591,7 @@ def _propagate_forgotten(self, ts, recommendations): dts.dependents.remove(ts) s = dts.waiters s.discard(ts) - if not s and not dts.who_wants: + if not dts.dependents and not dts.who_wants: # Task not needed anymore assert dts is not ts recommendations[dts.key] = 'forgotten' @@ -3621,7 +3622,7 @@ def transition_memory_forgotten(self, key): elif ts.has_lost_dependencies: # It's ok to forget a task with forgotten dependencies pass - elif not ts.who_wants and not ts.waiters: + elif not ts.who_wants and not ts.waiters and not ts.dependents: # It's ok to forget a task that nobody needs pass else: @@ -3656,7 +3657,7 @@ def transition_released_forgotten(self, key): elif ts.has_lost_dependencies: # It's ok to forget a task with forgotten dependencies pass - elif not ts.who_wants and not ts.waiters: + elif not ts.who_wants and not ts.waiters and not ts.dependents: # It's ok to forget a task that nobody needs pass else: diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 4b0cfca25ae..4f6bea05e1c 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -1300,6 +1300,21 @@ def test_dont_recompute_if_persisted_4(c, s, a, b): assert len(new) > len(old) +@gen_cluster(client=True) +def test_dont_forget_released_keys(c, s, a, b): + x = c.submit(inc, 1, key='x') + y = c.submit(inc, x, key='y') + z = c.submit(dec, x, key='z') + del x + yield wait([y, z]) + del z + + while 'z' in s.tasks: + yield gen.sleep(0.01) + + assert 'x' in s.tasks + + @gen_cluster(client=True) def test_dont_recompute_if_erred(c, s, a, b): x = delayed(inc)(1, dask_key_name='x') From 7a9fa83266cb05382094f45f204d12a69e904144 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sun, 8 Jul 2018 12:17:55 -0400 Subject: [PATCH 0025/1550] Allow worker to refuse data requests with busy signal (#2092) This allows workers to say "I'm too busy right now" when presented with a request for data from another worker. That worker then waits a bit, queries the scheduler to see if anyone else has that data, and then tries again. The wait time is an exponential backoff. Pragmatically this means that when single pieces of data are in high demand that the cluster will informally do a tree scattering. Some workers will get the data directly while others wait on the busy signal. Then other workers will get from them, etc.. We used to ask users to do this explicitly with the following: client.replicate(future) or client.scatter(data, broadcast=True) And now the replicate/broadcast step is no longer strictly necessary. (though some scattering of local data still is). Machines on the same host are given some preference, and so should be able to sneak in more easily. --- distributed/bokeh/worker.py | 2 +- distributed/distributed.yaml | 3 + distributed/tests/test_worker.py | 21 +++++- distributed/utils_comm.py | 5 +- distributed/worker.py | 109 ++++++++++++++++++++++--------- 5 files changed, 104 insertions(+), 36 deletions(-) diff --git a/distributed/bokeh/worker.py b/distributed/bokeh/worker.py index fcc4ae91995..7e577979a30 100644 --- a/distributed/bokeh/worker.py +++ b/distributed/bokeh/worker.py @@ -159,7 +159,7 @@ def __init__(self, worker, **kwargs): fig = figure(title="Communication History", x_axis_type='datetime', - y_range=[-0.1, worker.total_connections + 0.5], + y_range=[-0.1, worker.total_out_connections + 0.5], height=150, tools='', x_range=x_range, **kwargs) fig.line(source=self.source, x='x', y='in', color='red') fig.line(source=self.source, x='x', y='out', color='blue') diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index 35766471e00..31bd73e9663 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -19,6 +19,9 @@ distributed: worker: multiprocessing-method: forkserver use-file-locking: True + connections: # Maximum concurrent connections for data + outgoing: 50 # This helps to control network saturation + incoming: 10 profile: interval: 10ms # Time between statistical profiling queries diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 185278e7e94..1210d5213d0 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -564,7 +564,7 @@ def test_clean_nbytes(c, s, a, b): @gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 20) def test_gather_many_small(c, s, a, *workers): - a.total_connections = 2 + a.total_out_connections = 2 futures = yield c._scatter(list(range(100))) assert all(w.data for w in workers) @@ -1192,3 +1192,22 @@ def test_prefer_gather_from_local_address(c, s, w1, w2, w3): assert any(d['who'] == w2.address for d in w1.outgoing_transfer_log) assert not any(d['who'] == w2.address for d in w3.outgoing_transfer_log) + + +@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 20, timeout=30, + config={'distributed.worker.connections.incoming': 1}) +def test_avoid_oversubscription(c, s, *workers): + np = pytest.importorskip('numpy') + x = c.submit(np.random.random, 1000000, workers=[workers[0].address]) + yield wait(x) + + futures = [c.submit(len, x, pure=False, workers=[w.address]) + for w in workers[1:]] + + yield wait(futures) + + # Original worker not responsible for all transfers + assert len(workers[0].outgoing_transfer_log) < len(workers) - 2 + + # Some other workers did some work + assert len([w for w in workers if len(w.outgoing_transfer_log) > 0]) >= 3 diff --git a/distributed/utils_comm.py b/distributed/utils_comm.py index 43dbe49fb0c..46724973996 100644 --- a/distributed/utils_comm.py +++ b/distributed/utils_comm.py @@ -57,7 +57,8 @@ def gather_from_workers(who_has, rpc, close=True, serializers=None, who=None): rpcs = {addr: rpc(addr) for addr in d} try: coroutines = {address: get_data_from_worker(rpc, keys, address, - who=who) + who=who, + max_connections=False) for address, keys in d.items()} response = {} for worker, c in coroutines.items(): @@ -66,7 +67,7 @@ def gather_from_workers(who_has, rpc, close=True, serializers=None, who=None): except EnvironmentError: missing_workers.add(worker) else: - response.update(r) + response.update(r['data']) finally: for r in rpcs.values(): r.close_rpc() diff --git a/distributed/worker.py b/distributed/worker.py index df8cd6ad8a9..5d4bf44e3a0 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -601,11 +601,24 @@ def delete_data(self, comm=None, keys=None, report=True): raise Return('OK') @gen.coroutine - def get_data(self, comm, keys=None, who=None, serializers=None): + def get_data(self, comm, keys=None, who=None, serializers=None, + max_connections=None): start = time() + if max_connections is None: + max_connections = self.total_in_connections + + # Allow same-host connections more liberally + if max_connections and comm and get_address_host(comm.peer_address) == get_address_host(self.address): + max_connections = max_connections * 2 + + if max_connections is not False and self.outgoing_current_count > max_connections: + raise gen.Return({'status': 'busy'}) + + self.outgoing_current_count += 1 data = {k: self.data[k] for k in keys if k in self.data} - msg = {k: to_serialize(v) for k, v in data.items()} + msg = {'status': 'OK', + 'data': {k: to_serialize(v) for k, v in data.items()}} nbytes = {k: self.nbytes.get(k) for k in data} stop = time() if self.digests is not None: @@ -620,6 +633,8 @@ def get_data(self, comm, keys=None, who=None, serializers=None): self.address, who, exc_info=True) comm.abort() raise + finally: + self.outgoing_current_count -= 1 stop = time() if self.digests is not None: self.digests['get-data-send-duration'].add(stop - start) @@ -1002,8 +1017,10 @@ class Worker(WorkerBase): * **services:** ``{str: Server}``: Auxiliary web servers running on this worker * **service_ports:** ``{str: port}``: - * **total_connections**: ``int`` - The maximum number of concurrent connections we want to see + * **total_out_connections**: ``int`` + The maximum number of concurrent outgoing requests for data + * **total_in_connections**: ``int`` + The maximum number of concurrent incoming requests for data * **total_comm_nbytes**: ``int`` * **batched_stream**: ``BatchedSend`` A batched stream along which we communicate to the scheduler @@ -1146,7 +1163,8 @@ def __init__(self, *args, **kwargs): self.in_flight_tasks = dict() self.in_flight_workers = dict() - self.total_connections = 50 + self.total_out_connections = dask.config.get('distributed.worker.connections.outgoing') + self.total_in_connections = dask.config.get('distributed.worker.connections.incoming') self.total_comm_nbytes = 10e6 self.comm_nbytes = 0 self.suspicious_deps = defaultdict(lambda: 0) @@ -1211,6 +1229,8 @@ def __init__(self, *args, **kwargs): self.incoming_count = 0 self.outgoing_transfer_log = deque(maxlen=(100000)) self.outgoing_count = 0 + self.outgoing_current_count = 0 + self.repetitively_busy = 0 self._client = None profile_cycle_interval = kwargs.pop('profile_cycle_interval', @@ -1381,20 +1401,21 @@ def transition_dep_waiting_flight(self, dep, worker=None): pdb.set_trace() raise - def transition_dep_flight_waiting(self, dep, worker=None): + def transition_dep_flight_waiting(self, dep, worker=None, remove=True): try: if self.validate: assert dep in self.in_flight_tasks del self.in_flight_tasks[dep] - try: - self.who_has[dep].remove(worker) - except KeyError: - pass - try: - self.has_what[worker].remove(dep) - except KeyError: - pass + if remove: + try: + self.who_has[dep].remove(worker) + except KeyError: + pass + try: + self.has_what[worker].remove(dep) + except KeyError: + pass if not self.who_has.get(dep): if dep not in self._missing_dep_flight: @@ -1402,7 +1423,10 @@ def transition_dep_flight_waiting(self, dep, worker=None): self.loop.add_callback(self.handle_missing_dep, dep) for key in self.dependents.get(dep, ()): if self.task_state[key] == 'waiting': - self.data_needed.appendleft(key) + if remove: # try a new worker immediately + self.data_needed.appendleft(key) + else: # worker was probably busy, wait a while + self.data_needed.append(key) if not self.dependents[dep]: self.release_dep(dep) @@ -1608,12 +1632,12 @@ def maybe_transition_long_running(self, key, compute_duration=None): def ensure_communicating(self): changed = True try: - while changed and self.data_needed and len(self.in_flight_workers) < self.total_connections: + while changed and self.data_needed and len(self.in_flight_workers) < self.total_out_connections: changed = False logger.debug("Ensure communicating. Pending: %d. Connections: %d/%d", len(self.data_needed), len(self.in_flight_workers), - self.total_connections) + self.total_out_connections) key = self.data_needed[0] @@ -1650,7 +1674,7 @@ def ensure_communicating(self): in_flight = False - while deps and (len(self.in_flight_workers) < self.total_connections + while deps and (len(self.in_flight_workers) < self.total_out_connections or self.comm_nbytes < self.total_comm_nbytes): dep = deps.pop() if self.dep_state[dep] != 'waiting': @@ -1783,6 +1807,12 @@ def gather_dep(self, worker, dep, deps, total_nbytes, cause=None): who=self.address) stop = time() + if response['status'] == 'busy': + self.log.append(('busy-gather', worker, deps)) + for dep in deps: + self.transition_dep(dep, 'waiting') + return + if cause: self.startstops[cause].append(( 'transfer', @@ -1790,14 +1820,14 @@ def gather_dep(self, worker, dep, deps, total_nbytes, cause=None): stop + self.scheduler_delay )) - total_bytes = sum(self.nbytes.get(dep, 0) for dep in response) + total_bytes = sum(self.nbytes.get(dep, 0) for dep in response['data']) duration = (stop - start) or 0.5 self.incoming_transfer_log.append({ 'start': start + self.scheduler_delay, 'stop': stop + self.scheduler_delay, 'middle': (start + stop) / 2.0 + self.scheduler_delay, 'duration': duration, - 'keys': {dep: self.nbytes.get(dep, None) for dep in response}, + 'keys': {dep: self.nbytes.get(dep, None) for dep in response['data']}, 'total': total_bytes, 'bandwidth': total_bytes / duration, 'who': worker @@ -1805,14 +1835,14 @@ def gather_dep(self, worker, dep, deps, total_nbytes, cause=None): if self.digests is not None: self.digests['transfer-bandwidth'].add(total_bytes / duration) self.digests['transfer-duration'].add(duration) - self.counters['transfer-count'].add(len(response)) + self.counters['transfer-count'].add(len(response['data'])) self.incoming_count += 1 - self.log.append(('receive-dep', worker, list(response))) + self.log.append(('receive-dep', worker, list(response['data']))) - if response: + if response['data']: self.batched_stream.send({'op': 'add-keys', - 'keys': list(response)}) + 'keys': list(response['data'])}) except EnvironmentError as e: logger.exception("Worker stream died during communication: %s", worker) @@ -1830,14 +1860,16 @@ def gather_dep(self, worker, dep, deps, total_nbytes, cause=None): raise finally: self.comm_nbytes -= total_nbytes + busy = response['status'] == 'busy' for d in self.in_flight_workers.pop(worker): - if d in response: - self.transition_dep(d, 'memory', value=response[d]) + if not busy and d in response['data']: + self.transition_dep(d, 'memory', value=response['data'][d]) elif self.dep_state.get(d) != 'memory': - self.transition_dep(d, 'waiting', worker=worker) + self.transition_dep(d, 'waiting', worker=worker, + remove=not busy) - if d not in response and d in self.dependents: + if not busy and d not in response['data'] and d in self.dependents: self.log.append(('missing-dep', d)) self.batched_stream.send({'op': 'missing-data', 'errant_worker': worker, @@ -1847,7 +1879,18 @@ def gather_dep(self, worker, dep, deps, total_nbytes, cause=None): self.validate_state() self.ensure_computing() - self.ensure_communicating() + + if not busy: + self.repetitively_busy = 0 + self.ensure_communicating() + else: + # Exponential backoff to avoid hammering scheduler/worker + self.repetitively_busy += 1 + yield gen.sleep(0.100 * 1.5 ** self.repetitively_busy) + + # See if anyone new has the data + yield self.query_who_has(dep) + self.ensure_communicating() def bad_dep(self, dep): exc = ValueError("Could not find dependent %s. Check worker logs" % str(dep)) @@ -2729,7 +2772,7 @@ def parse_memory_limit(memory_limit, ncores): @gen.coroutine -def get_data_from_worker(rpc, keys, worker, who=None): +def get_data_from_worker(rpc, keys, worker, who=None, max_connections=None): """ Get keys from worker The worker has a two step handshake to acknowledge when data has been fully @@ -2746,8 +2789,10 @@ def get_data_from_worker(rpc, keys, worker, who=None): response = yield send_recv(comm, serializers=rpc.serializers, deserializers=rpc.deserializers, - op='get_data', keys=keys, who=who) - yield comm.write('OK') + op='get_data', keys=keys, who=who, + max_connections=max_connections) + if response['status'] == 'OK': + yield comm.write('OK') finally: rpc.reuse(worker, comm) From 86260008d5799c89466f633c8e403608152b00b2 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 9 Jul 2018 14:37:31 -0400 Subject: [PATCH 0026/1550] Retire workers from scale (#2104) --- distributed/deploy/cluster.py | 6 ++++- distributed/deploy/tests/test_local.py | 32 ++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/distributed/deploy/cluster.py b/distributed/deploy/cluster.py index 319b969d513..4265a151945 100644 --- a/distributed/deploy/cluster.py +++ b/distributed/deploy/cluster.py @@ -93,7 +93,11 @@ def scale(self, n): if n >= len(self.scheduler.workers): self.scheduler.loop.add_callback(self.scale_up, n) else: - to_close = self.scheduler.workers_to_close(n=len(self.scheduler.workers) - n) + to_close = self.scheduler.retire_workers( + remove=False, + close_workers=True, + n=len(self.scheduler.workers) - n + ) logger.debug("Closing workers: %s", to_close) self.scheduler.loop.add_callback(self.scale_down, to_close) diff --git a/distributed/deploy/tests/test_local.py b/distributed/deploy/tests/test_local.py index 988711c41fb..7f2f5874d41 100644 --- a/distributed/deploy/tests/test_local.py +++ b/distributed/deploy/tests/test_local.py @@ -467,5 +467,37 @@ def test_local_tls(loop): ) +@gen_test() +def test_scale_retires_workers(): + class MyCluster(LocalCluster): + def scale_down(self, *args, **kwargs): + pass + + loop = IOLoop.current() + cluster = yield MyCluster(0, scheduler_port=0, processes=False, + silence_logs=False, diagnostics_port=None, + loop=loop, asynchronous=True) + c = yield Client(cluster, loop=loop, asynchronous=True) + + assert not cluster.workers + + yield cluster.scale(2) + + start = time() + while len(cluster.scheduler.workers) != 2: + yield gen.sleep(0.01) + assert time() < start + 3 + + yield cluster.scale(1) + + start = time() + while len(cluster.scheduler.workers) != 1: + yield gen.sleep(0.01) + assert time() < start + 3 + + yield c._close() + yield cluster._close() + + if sys.version_info >= (3, 5): from distributed.deploy.tests.py3_test_deploy import * # noqa F401 From fd8ca5ebc3a781d5f6625d3daa4d09ce806938a3 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 9 Jul 2018 18:08:06 -0400 Subject: [PATCH 0027/1550] Be robust to empty response in gather_dep (#2105) --- distributed/worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/worker.py b/distributed/worker.py index 5d4bf44e3a0..91100373e46 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1860,7 +1860,7 @@ def gather_dep(self, worker, dep, deps, total_nbytes, cause=None): raise finally: self.comm_nbytes -= total_nbytes - busy = response['status'] == 'busy' + busy = response.get('status', '') == 'busy' for d in self.in_flight_workers.pop(worker): if not busy and d in response['data']: From 696030682305b0ef5bb31583b547ffb434c04b04 Mon Sep 17 00:00:00 2001 From: Phil Tooley <32297355+ptooley@users.noreply.github.com> Date: Thu, 12 Jul 2018 22:08:13 +0100 Subject: [PATCH 0028/1550] insert newline by default after TextProgressBar (#1976) --- distributed/diagnostics/progressbar.py | 4 ++++ distributed/diagnostics/tests/test_progressbar.py | 3 ++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/distributed/diagnostics/progressbar.py b/distributed/diagnostics/progressbar.py index d79f91f5e7d..3263503b2a2 100644 --- a/distributed/diagnostics/progressbar.py +++ b/distributed/diagnostics/progressbar.py @@ -124,6 +124,10 @@ def _draw_bar(self, remaining, all, **kwargs): sys.stdout.write(msg) sys.stdout.flush() + def _draw_stop(self, **kwargs): + sys.stdout.write('\r') + sys.stdout.flush() + class ProgressWidget(ProgressBar): """ ProgressBar that uses an IPython ProgressBar widget for the notebook diff --git a/distributed/diagnostics/tests/test_progressbar.py b/distributed/diagnostics/tests/test_progressbar.py index 3c25a71b645..16eeeab0464 100644 --- a/distributed/diagnostics/tests/test_progressbar.py +++ b/distributed/diagnostics/tests/test_progressbar.py @@ -73,7 +73,8 @@ def f(): def check_bar_completed(capsys, width=40): out, err = capsys.readouterr() - bar, percent, time = [i.strip() for i in out.split('\r')[-1].split('|')] + # trailing newline so grab next to last line for final state of bar + bar, percent, time = [i.strip() for i in out.split('\r')[-2].split('|')] assert bar == '[' + '#' * width + ']' assert percent == '100% Completed' From 05a046b1d8aedee4e9fdd338acf5f5314ad9ead9 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Thu, 12 Jul 2018 20:31:41 -0500 Subject: [PATCH 0029/1550] TST: Added another nested parallelism test (#1710) --- distributed/tests/test_joblib.py | 37 ++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/distributed/tests/test_joblib.py b/distributed/tests/test_joblib.py index 6ff06a922ff..aa81a45b5f3 100644 --- a/distributed/tests/test_joblib.py +++ b/distributed/tests/test_joblib.py @@ -182,6 +182,43 @@ def test_errors(loop, joblib): assert "create a dask client" in str(info.value).lower() +def test_correct_nested_backend(loop, joblib): + if LooseVersion(joblib.__version__) <= LooseVersion("0.11.0"): + pytest.skip("Requires nested parallelism") + + with cluster() as (s, [a, b]): + with Client(s['address'], loop=loop) as client: + # No requirement, should be us + with joblib.parallel_backend('dask') as (ba, _): + result = joblib.Parallel(n_jobs=2)(joblib.delayed(outer)( + joblib, nested_require=None) for _ in range(1)) + assert isinstance(result[0][0][0], + distributed_joblib.DaskDistributedBackend) + + # Require threads, should be threading + with joblib.parallel_backend('dask') as (ba, _): + result = joblib.Parallel(n_jobs=2)(joblib.delayed(outer)( + joblib, nested_require='sharedmem') for _ in range(1)) + assert isinstance(result[0][0][0], + joblib.parallel.ThreadingBackend) + + +def outer(joblib, nested_require): + return joblib.Parallel(n_jobs=2, prefer='threads')( + joblib.delayed(middle)(joblib, nested_require) for _ in range(1) + ) + + +def middle(joblib, require): + return joblib.Parallel(n_jobs=2, require=require)( + joblib.delayed(inner)(joblib) for _ in range(1) + ) + + +def inner(joblib): + return joblib.parallel.Parallel()._backend + + def test_secede_with_no_processes(loop, joblib): # https://github.com/dask/distributed/issues/1775 From 67239aa40c1e939a2744bd210e4dd15e4b1624b6 Mon Sep 17 00:00:00 2001 From: Dave Hirschfeld Date: Fri, 13 Jul 2018 22:34:25 +1000 Subject: [PATCH 0030/1550] Use type hints to further emphasize the custom serialization api (#2116) --- distributed/protocol/serialize.py | 4 ++-- docs/source/serialization.rst | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index c2a1274afe2..cb3b802c504 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -380,8 +380,8 @@ def register_serialization(cls, serialize, deserialize): Parameters ---------- cls: type - serialize: function - deserialize: function + serialize: callable(cls) -> Tuple[Dict, List[bytes]] + deserialize: callable(header: Dict, frames: List[bytes]) -> cls Examples -------- diff --git a/docs/source/serialization.rst b/docs/source/serialization.rst index e457681a662..cdb765c380a 100644 --- a/docs/source/serialization.rst +++ b/docs/source/serialization.rst @@ -150,12 +150,12 @@ register them with Dask. def __init__(self, name): self.name = name - def serialize(human): + def serialize(human: Human) -> Tuple[Dict, List[bytes]]: header = {} frames = [human.name.encode()] return header, frames - def deserialize(header, frames): + def deserialize(header: Dict, frames: List[bytes]) -> Human: return Human(frames[0].decode()) from distributed.protocol.serialize import register_serialization From 81a0f9547e61701a9ef65f40436b18d318bd5f2c Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sun, 15 Jul 2018 17:19:33 -0500 Subject: [PATCH 0031/1550] Fix cleanup with empty response in gather dep (#2112) --- distributed/worker.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/distributed/worker.py b/distributed/worker.py index 91100373e46..d6a0398efe3 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1861,15 +1861,16 @@ def gather_dep(self, worker, dep, deps, total_nbytes, cause=None): finally: self.comm_nbytes -= total_nbytes busy = response.get('status', '') == 'busy' + data = response.get('data', {}) for d in self.in_flight_workers.pop(worker): - if not busy and d in response['data']: - self.transition_dep(d, 'memory', value=response['data'][d]) + if not busy and d in data: + self.transition_dep(d, 'memory', value=data[d]) elif self.dep_state.get(d) != 'memory': self.transition_dep(d, 'waiting', worker=worker, remove=not busy) - if not busy and d not in response['data'] and d in self.dependents: + if not busy and d not in data and d in self.dependents: self.log.append(('missing-dep', d)) self.batched_stream.send({'op': 'missing-data', 'errant_worker': worker, From 808afe23d7bea8e4c4a42f6b91fe49f4da013cb6 Mon Sep 17 00:00:00 2001 From: Matt Nicolls <2540582+nicolls1@users.noreply.github.com> Date: Mon, 16 Jul 2018 22:38:07 +0200 Subject: [PATCH 0032/1550] Update dask-scheduler cli help text for preload (#2120) --- distributed/cli/dask_scheduler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/distributed/cli/dask_scheduler.py b/distributed/cli/dask_scheduler.py index 35eb72bb6b8..b719bbda12b 100755 --- a/distributed/cli/dask_scheduler.py +++ b/distributed/cli/dask_scheduler.py @@ -59,8 +59,8 @@ @click.option('--local-directory', default='', type=str, help="Directory to place scheduler files") @click.option('--preload', type=str, multiple=True, is_eager=True, - help='Module that should be loaded by each worker process ' - 'like "foo.bar" or "/path/to/foo.py"') + help='Module that should be loaded by the scheduler process ' + 'like "foo.bar" or "/path/to/foo.py".') @click.argument('preload_argv', nargs=-1, type=click.UNPROCESSED, callback=validate_preload_argv) def main(host, port, bokeh_port, show, _bokeh, bokeh_whitelist, bokeh_prefix, From 82d51e1974c4b11b14b628689ce9651b01065aec Mon Sep 17 00:00:00 2001 From: Dave Hirschfeld Date: Wed, 18 Jul 2018 02:37:35 +1000 Subject: [PATCH 0033/1550] Add custom serialization support for pyarrow (#2115) --- .gitignore | 2 + distributed/protocol/__init__.py | 5 +++ distributed/protocol/arrow.py | 53 ++++++++++++++++++++++++ distributed/protocol/tests/test_arrow.py | 44 ++++++++++++++++++++ distributed/utils.py | 5 ++- 5 files changed, 108 insertions(+), 1 deletion(-) create mode 100644 distributed/protocol/arrow.py create mode 100644 distributed/protocol/tests/test_arrow.py diff --git a/.gitignore b/.gitignore index 7e110237d4b..7510e74bcbf 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,5 @@ docs/build continuous_integration/hdfs-initialized .cache .#* +.idea/ +.pytest_cache/ diff --git a/distributed/protocol/__init__.py b/distributed/protocol/__init__.py index a6a9afaf324..01ac7e8464a 100644 --- a/distributed/protocol/__init__.py +++ b/distributed/protocol/__init__.py @@ -37,3 +37,8 @@ def _register_keras(): @partial(register_serialization_lazy, "sparse") def _register_sparse(): from . import sparse + + +@partial(register_serialization_lazy, "pyarrow") +def _register_arrow(): + from . import arrow diff --git a/distributed/protocol/arrow.py b/distributed/protocol/arrow.py new file mode 100644 index 00000000000..87c5d05c99f --- /dev/null +++ b/distributed/protocol/arrow.py @@ -0,0 +1,53 @@ +from __future__ import print_function, division, absolute_import + +from .serialize import register_serialization + + +def serialize_batch(batch): + import pyarrow as pa + sink = pa.BufferOutputStream() + writer = pa.RecordBatchStreamWriter(sink, batch.schema) + writer.write_batch(batch) + writer.close() + buf = sink.get_result() + header = {} + frames = [buf] + return header, frames + + +def deserialize_batch(header, frames): + import pyarrow as pa + blob = frames[0] + reader = pa.RecordBatchStreamReader(pa.BufferReader(blob)) + return reader.read_next_batch() + + +def serialize_table(tbl): + import pyarrow as pa + sink = pa.BufferOutputStream() + writer = pa.RecordBatchStreamWriter(sink, tbl.schema) + writer.write_table(tbl) + writer.close() + buf = sink.get_result() + header = {} + frames = [buf] + return header, frames + + +def deserialize_table(header, frames): + import pyarrow as pa + blob = frames[0] + reader = pa.RecordBatchStreamReader(pa.BufferReader(blob)) + return reader.read_all() + + +register_serialization( + 'pyarrow.lib.RecordBatch', + serialize_batch, + deserialize_batch +) +register_serialization( + 'pyarrow.lib.Table', + serialize_table, + deserialize_table +) diff --git a/distributed/protocol/tests/test_arrow.py b/distributed/protocol/tests/test_arrow.py new file mode 100644 index 00000000000..6f014bae323 --- /dev/null +++ b/distributed/protocol/tests/test_arrow.py @@ -0,0 +1,44 @@ +import pandas as pd +import pytest + +pa = pytest.importorskip('pyarrow') + +from distributed.utils_test import gen_cluster +from distributed.protocol import deserialize, serialize +from distributed.protocol.serialize import class_serializers, typename + + +df = pd.DataFrame({'A': list('abc'), 'B': [1,2,3]}) +tbl = pa.Table.from_pandas(df, preserve_index=False) +batch = pa.RecordBatch.from_pandas(df, preserve_index=False) + + +@pytest.mark.parametrize('obj', [batch, tbl], ids=["RecordBatch", "Table"]) +def test_roundtrip(obj): + # Test that the serialize/deserialize functions actually + # work independent of distributed + header, frames = serialize(obj) + new_obj = deserialize(header, frames) + assert obj.equals(new_obj) + + +@pytest.mark.parametrize('obj', [batch, tbl], ids=["RecordBatch", "Table"]) +def test_typename(obj): + # The typename used to register the custom serialization is hardcoded + # ensure that the typename hasn't changed + assert typename(type(obj)) in class_serializers + + +def echo(arg): + return arg + + +@pytest.mark.parametrize('obj', [batch, tbl], ids=["RecordBatch", "Table"]) +def test_scatter(obj): + @gen_cluster(client=True) + def run_test(client, scheduler, worker1, worker2): + obj_fut = yield client.scatter(obj) + fut = client.submit(echo, obj_fut) + result = yield fut + assert obj.equals(result) + run_test() diff --git a/distributed/utils.py b/distributed/utils.py index 53495ba60b4..666bb9dd26d 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -1266,7 +1266,10 @@ def nbytes(frame, _bytes_like=(bytes, bytearray)): if isinstance(frame, _bytes_like): return len(frame) else: - return frame.nbytes + try: + return frame.nbytes + except AttributeError: + return len(frame) def PeriodicCallback(callback, callback_time, io_loop=None): From 1283d415879c065391e736b9361e9d357588edd9 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 19 Jul 2018 07:20:50 -0400 Subject: [PATCH 0034/1550] XFail test_open_close_many_workers (#2125) We should fix this, but don't have the time right now. Intermittent failures on this test are interrupting development flow. --- distributed/tests/test_client.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 75558352111..c854e2bcb27 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -3663,16 +3663,16 @@ def test_reconnect(loop): c.close() -# On Python 2, heavy process spawning can deadlock (e.g. on a logging IO lock) -_params = ([(Worker, 100, 5), (Nanny, 10, 20)] - if sys.version_info >= (3,) - else [(Worker, 100, 5)]) - - @slow @pytest.mark.skipif(sys.platform.startswith('win'), reason="num_fds not supported on windows") -@pytest.mark.parametrize("worker,count,repeat", _params) +@pytest.mark.skipif(sys.version_info[0] == 2, + reason="Semaphore.acquire doesn't support timeout option") +@pytest.mark.xfail(reason='TODO: intermittent failures') +@pytest.mark.parametrize("worker,count,repeat", [ + (Worker, 100, 5), + (Nanny, 10, 20) +]) def test_open_close_many_workers(loop, worker, count, repeat): psutil = pytest.importorskip('psutil') proc = psutil.Process() @@ -3706,7 +3706,7 @@ def start_worker(sleep, duration, repeat=1): sleep(1) for i in range(count): - done.acquire() + done.acquire(timeout=5) gc.collect() if not running: break From ff2602adb26139dc9c0addcb5499466e982827cf Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 19 Jul 2018 14:58:39 -0400 Subject: [PATCH 0035/1550] Expand resources in graph_to_futures (#2131) Previously this was handled only in a few of the submission functions. Now we lower this logic to the core graph_to_futures method, applying it more uniformly across computation. --- distributed/client.py | 29 ++++++------------- distributed/tests/test_resources.py | 43 +++++++++++++++++++++++++++-- 2 files changed, 50 insertions(+), 22 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index b066d5cc05d..065a62c65a6 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -2069,6 +2069,15 @@ def _graph_to_futures(self, dsk, keys, restrictions=None, user_priority=0, resources=None, retries=None, fifo_timeout=0): with self._lock: + if resources: + resources = self._expand_resources(resources, + all_keys=itertools.chain(dsk, keys)) + + if retries: + retries = self._expand_retries(retries, + all_keys=itertools.chain(dsk, keys)) + + print(resources) keyset = set(keys) flatkeys = list(map(tokey, keys)) futures = {key: Future(key, self, inform=False) for key in keyset} @@ -2327,16 +2336,6 @@ def compute(self, collections, sync=False, optimize_graph=True, restrictions, loose_restrictions = self.get_restrictions(collections, workers, allow_other_workers) - if resources: - resources = self._expand_resources(resources, - all_keys=itertools.chain(dsk, dsk2)) - - if retries: - retries = self._expand_retries(retries, - all_keys=itertools.chain(dsk, dsk2)) - else: - retries = None - if not isinstance(priority, Number): priority = {k: p for c, p in priority.items() for k in self._expand_key(c)} @@ -2429,16 +2428,6 @@ def persist(self, collections, optimize_graph=True, workers=None, restrictions, loose_restrictions = self.get_restrictions(collections, workers, allow_other_workers) - if resources: - resources = self._expand_resources(resources, - all_keys=itertools.chain(dsk, names)) - - if retries: - retries = self._expand_retries(retries, - all_keys=itertools.chain(dsk, names)) - else: - retries = None - if not isinstance(priority, Number): priority = {k: p for c, p in priority.items() for k in self._expand_key(c)} diff --git a/distributed/tests/test_resources.py b/distributed/tests/test_resources.py index 55e178eb89d..46e22d0f530 100644 --- a/distributed/tests/test_resources.py +++ b/distributed/tests/test_resources.py @@ -6,10 +6,10 @@ import pytest from tornado import gen -from distributed import Worker +from distributed import Worker, Client from distributed.client import wait from distributed.utils import tokey -from distributed.utils_test import (inc, gen_cluster, +from distributed.utils_test import (inc, gen_cluster, cluster, slowinc, slowadd) from distributed.utils_test import loop # noqa: F401 @@ -260,3 +260,42 @@ def test_dont_optimize_out(c, s, a, b): for key in map(tokey, y.__dask_keys__()): assert 'executing' in str(a.story(key)) + + +@gen_cluster(client=True, ncores=[('127.0.0.1', 1, {'resources': {'A': 1}}), + ('127.0.0.1', 1, {'resources': {'B': 1}})]) +def test_full_collections(c, s, a, b): + dd = pytest.importorskip('dask.dataframe') + df = dd.demo.make_timeseries(freq='60s', partition_freq='1d', + start='2000-01-01', end='2000-01-31') + z = df.x + df.y # some extra nodes in the graph + + yield c.compute(z, resources={tuple(z.dask): {'A': 1}}) + assert a.log + assert not b.log + + +@pytest.mark.parametrize('optimize_graph', [ + pytest.mark.xfail(True, reason="don't track resources through optimization"), + False +]) +def test_collections_get(loop, optimize_graph): + da = pytest.importorskip('dask.array') + with cluster() as (s, [a, b]): + with Client(s['address'], loop=loop) as c: + def f(dask_worker): + dask_worker.set_resources(**{'A': 1}) + + c.run(f, workers=[a['address']]) + + x = da.random.random(100, chunks=(10,)) + 1 + + x.compute(resources={tuple(x.dask): {'A': 1}}, + optimize_graph=optimize_graph) + + def g(dask_worker): + return len(dask_worker.log) + + logs = c.run(g) + assert logs[a['address']] + assert not logs[b['address']] From ca70550daa6765bb45c846cfdc069f70151f1634 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 19 Jul 2018 14:59:25 -0400 Subject: [PATCH 0036/1550] Test that worker restrictions are cleared after cancellation (#2107) --- distributed/tests/test_scheduler.py | 20 ++++++++++++++++++++ distributed/worker.py | 3 +++ 2 files changed, 23 insertions(+) diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 4f6bea05e1c..5a80faf8f83 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -6,6 +6,7 @@ import json from operator import add, mul import sys +from time import sleep import dask from dask import delayed @@ -1340,3 +1341,22 @@ def test_closing_scheduler_closes_workers(s, a, b): while a.status != 'closed' or b.status != 'closed': yield gen.sleep(0.01) assert time() < start + 2 + + +@gen_cluster(client=True, ncores=[('127.0.0.1', 1)], + worker_kwargs={'resources': {'A': 1}}) +def test_resources_reset_after_cancelled_task(c, s, w): + future = c.submit(sleep, 0.2, resources={'A': 1}) + + while not w.executing: + yield gen.sleep(0.01) + + yield future.cancel() + + while w.executing: + yield gen.sleep(0.01) + + assert not s.workers[w.address].used_resources['A'] + assert w.available_resources == {'A': 1} + + yield c.submit(inc, 1, resources={'A': 1}) diff --git a/distributed/worker.py b/distributed/worker.py index d6a0398efe3..feb67ce5a92 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -2034,6 +2034,9 @@ def release_key(self, key, cause=None, reason=None, report=True): self.executing.remove(key) if key in self.resource_restrictions: + if state == 'executing': + for resource, quantity in self.resource_restrictions[key].items(): + self.available_resources[resource] += quantity del self.resource_restrictions[key] if report and state in PROCESSING: # not finished From 525549938b6765900b6c77a444d43796b73dfa78 Mon Sep 17 00:00:00 2001 From: Dave Hirschfeld Date: Sat, 21 Jul 2018 01:52:20 +1000 Subject: [PATCH 0037/1550] Update .gitignore (#2135) --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 7510e74bcbf..7888407d33f 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,4 @@ continuous_integration/hdfs-initialized .#* .idea/ .pytest_cache/ +dask-worker-space/ From 8a78ec2dc12f9770758409ae7482872ab389086b Mon Sep 17 00:00:00 2001 From: Dror Birkman Date: Sun, 22 Jul 2018 06:33:00 +0300 Subject: [PATCH 0038/1550] Use PID and counter in thread names (#2084) (#2128) --- distributed/tests/test_threadpoolexecutor.py | 6 ++++++ distributed/threadpoolexecutor.py | 7 ++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/distributed/tests/test_threadpoolexecutor.py b/distributed/tests/test_threadpoolexecutor.py index f63dfe0244c..8777e574282 100644 --- a/distributed/tests/test_threadpoolexecutor.py +++ b/distributed/tests/test_threadpoolexecutor.py @@ -119,3 +119,9 @@ def f(): future = e.submit(f) result = future.result() + + +def test_thread_name(): + with ThreadPoolExecutor(2) as e: + e.map(id, range(10)) + assert len({thread.name for thread in e._threads}) == 2 diff --git a/distributed/threadpoolexecutor.py b/distributed/threadpoolexecutor.py index 1b31ce125d7..8e9f10cadc8 100644 --- a/distributed/threadpoolexecutor.py +++ b/distributed/threadpoolexecutor.py @@ -23,8 +23,10 @@ from __future__ import print_function, division, absolute_import from . import _concurrent_futures_thread as thread +import os import logging import threading +import itertools from .metrics import time @@ -62,6 +64,9 @@ def _worker(executor, work_queue): class ThreadPoolExecutor(thread.ThreadPoolExecutor): + # Used to assign unique thread names + _counter = itertools.count() + def __init__(self, *args, **kwargs): super(ThreadPoolExecutor, self).__init__(*args, **kwargs) self._rejoin_list = [] @@ -70,7 +75,7 @@ def __init__(self, *args, **kwargs): def _adjust_thread_count(self): if len(self._threads) < self._max_workers: t = threading.Thread(target=_worker, - name="ThreadPool worker %d" % len(self._threads,), + name="ThreadPoolExecutor-%d-%d" % (os.getpid(), next(self._counter)), args=(self, self._work_queue)) t.daemon = True self._threads.add(t) From 9a7016de0db2f5ecada1e1290eb0ee9fd91d65a2 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 23 Jul 2018 07:52:36 -0400 Subject: [PATCH 0039/1550] Allow client to query the task stream plot (#2122) * Move TaskStreamPlugin to diagnostics * Add Client.get_task_stream function --- distributed/__init__.py | 2 +- distributed/bokeh/scheduler.py | 2 +- distributed/bokeh/tests/test_task_stream.py | 40 ------- distributed/client.py | 96 ++++++++++++++++ .../{bokeh => diagnostics}/task_stream.py | 62 +++++++++-- .../diagnostics/tests/test_task_stream.py | 104 ++++++++++++++++++ distributed/scheduler.py | 17 ++- distributed/utils_test.py | 7 +- docs/source/api.rst | 2 + 9 files changed, 277 insertions(+), 55 deletions(-) delete mode 100644 distributed/bokeh/tests/test_task_stream.py rename distributed/{bokeh => diagnostics}/task_stream.py (64%) create mode 100644 distributed/diagnostics/tests/test_task_stream.py diff --git a/distributed/__init__.py b/distributed/__init__.py index 3cde15ef52a..71e71a79143 100644 --- a/distributed/__init__.py +++ b/distributed/__init__.py @@ -7,7 +7,7 @@ from .diagnostics import progress from .client import (Client, Executor, CompatibleExecutor, wait, as_completed, default_client, fire_and_forget, - Future, futures_of) + Future, futures_of, get_task_stream) from .lock import Lock from .nanny import Nanny from .pubsub import Pub, Sub diff --git a/distributed/bokeh/scheduler.py b/distributed/bokeh/scheduler.py index 6533eb620ba..cd677e455c2 100644 --- a/distributed/bokeh/scheduler.py +++ b/distributed/bokeh/scheduler.py @@ -39,7 +39,7 @@ from ..diagnostics.progress_stream import color_of, progress_quads, nbytes_bar from ..diagnostics.progress import AllProgress from ..diagnostics.graph_layout import GraphLayout -from .task_stream import TaskStreamPlugin +from ..diagnostics.task_stream import TaskStreamPlugin try: from cytoolz.curried import map, concat, groupby, valmap, first diff --git a/distributed/bokeh/tests/test_task_stream.py b/distributed/bokeh/tests/test_task_stream.py deleted file mode 100644 index 4b578a89c47..00000000000 --- a/distributed/bokeh/tests/test_task_stream.py +++ /dev/null @@ -1,40 +0,0 @@ -from __future__ import absolute_import, division, print_function - -import pytest -pytest.importorskip('bokeh') - -from toolz import frequencies - -from distributed.utils_test import gen_cluster, div -from distributed.client import wait -from distributed.bokeh.task_stream import TaskStreamPlugin - - -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 3) -def test_TaskStreamPlugin(c, s, *workers): - es = TaskStreamPlugin(s) - assert not es.buffer - - futures = c.map(div, [1] * 10, range(10)) - total = c.submit(sum, futures[1:]) - yield wait(total) - - assert len(es.buffer) == 11 - - workers = dict() - - rects = es.rectangles(0, 10, workers) - assert all(n == 'div' for n in rects['name']) - assert all(d > 0 for d in rects['duration']) - counts = frequencies(rects['color']) - assert counts['black'] == 1 - assert set(counts.values()) == {9, 1} - assert len(set(rects['y'])) == 3 - - rects = es.rectangles(2, 5, workers) - assert all(len(L) == 3 for L in rects.values()) - - starts = sorted(rects['start']) - rects = es.rectangles(2, 5, workers=workers, - start_boundary=(starts[0] + starts[1]) / 2000) - assert set(rects['start']).issubset(set(starts[1:])) diff --git a/distributed/client.py b/distributed/client.py index 065a62c65a6..de4183138a6 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -3322,6 +3322,59 @@ def get_restrictions(cls, collections, workers, allow_other_workers): def collections_to_dsk(collections, *args, **kwargs): return collections_to_dsk(collections, *args, **kwargs) + def get_task_stream(self, start=None, stop=None, count=None): + """ Get task stream data from scheduler + + This collects the data present in the diagnostic "Task Stream" plot on + the dashboard. It includes the start, stop, transfer, and + deserialization time of every task for a particular duration. + + Note that the task stream diagnostic does not run by default. You may + wish to call this function once before you start work to ensure that + things start recording, and then again after you have completed. + + Parameters + ---------- + start: Number or string + When you want to start recording + If a number it should be the result of calling time() + If a string then it should be a time difference before now, + like '60s' or '500 ms' + stop: Number or string + When you want to stop recording + count: int + The number of desired records, ignored if both start and stop are + specified + + Examples + -------- + >>> client.get_task_stream() # prime plugin if not already connected + >>> x.compute() # do some work + >>> client.get_task_stream() + [{'task': ..., + 'type': ..., + 'thread': ..., + ...}] + + Alternatively consider the context manager + + >>> from dask.distributed import get_task_stream + >>> with get_task_stream() as ts: + ... x.compute() + >>> ts.data + [...] + + Returns + ------- + L: List[Dict] + + See Also + -------- + get_task_stream: a dontext manager version of this method + """ + return self.sync(self.scheduler.get_task_stream, start=start, + stop=stop, count=count) + class Executor(Client): """ Deprecated: see Client """ @@ -3691,6 +3744,49 @@ def fire_and_forget(obj): 'client': 'fire-and-forget'}) +class get_task_stream(object): + """ + Collect task stream within a context block + + This provides diagnostic information about every task that was run during + the time when this block was active. + + This must be used as a context manager. + + Examples + -------- + >>> with get_task_stream() as ts: + ... x.compute() + >>> ts.data + [...] + + See Also + -------- + Client.get_task_stream: Function version of this context manager + """ + def __init__(self, client=None): + self.data = [] + self.client = client or default_client() + self.client.get_task_stream(start=0, stop=0) # ensure plugin + + def __enter__(self): + self.start = time() + return self + + def __exit__(self, typ, value, traceback): + L = self.client.get_task_stream(start=self.start) + self.data.extend(L) + + @gen.coroutine + def __aenter__(self): + raise gen.Return(self) + + @gen.coroutine + def __aexit__(self, typ, value, traceback): + L = yield self.client.get_task_stream(start=self.start) + self.data.extend(L) + + @contextmanager def temp_default_client(c): """ Set the default client for the duration of the context diff --git a/distributed/bokeh/task_stream.py b/distributed/diagnostics/task_stream.py similarity index 64% rename from distributed/bokeh/task_stream.py rename to distributed/diagnostics/task_stream.py index efa378c3dbf..7cabcf96311 100644 --- a/distributed/bokeh/task_stream.py +++ b/distributed/diagnostics/task_stream.py @@ -1,22 +1,23 @@ from __future__ import print_function, division, absolute_import +from collections import deque import logging -from ..diagnostics.progress_stream import color_of -from ..diagnostics.plugin import SchedulerPlugin -from ..utils import key_split, format_time +from .progress_stream import color_of +from .plugin import SchedulerPlugin +from ..utils import key_split, format_time, parse_timedelta +from ..metrics import time logger = logging.getLogger(__name__) class TaskStreamPlugin(SchedulerPlugin): - def __init__(self, scheduler): - self.buffer = [] + def __init__(self, scheduler, maxlen=100000): + self.buffer = deque(maxlen=maxlen) self.scheduler = scheduler scheduler.add_plugin(self) self.index = 0 - self.maxlen = 100000 def transition(self, key, start, finish, *args, **kwargs): if start == 'processing': @@ -26,8 +27,48 @@ def transition(self, key, start, finish, *args, **kwargs): if finish == 'memory' or finish == 'erred': self.buffer.append(kwargs) self.index += 1 - if len(self.buffer) > self.maxlen: - self.buffer = self.buffer[len(self.buffer):] + + def collect(self, start=None, stop=None, count=None): + def bisect(target, left, right): + if left == right: + return left + + mid = (left + right) // 2 + value = max(stop for _, start, stop in self.buffer[mid]['startstops']) + + if value < target: + return bisect(target, mid + 1, right) + else: + return bisect(target, left, mid) + + if isinstance(start, str): + start = time() - parse_timedelta(start) + if start is not None: + start = bisect(start, 0, len(self.buffer)) + + if isinstance(stop, str): + stop = time() - parse_timedelta(stop) + if stop is not None: + stop = bisect(stop, 0, len(self.buffer)) + + if count is not None: + if start is None and stop is None: + stop = len(self.buffer) + start = stop - count + elif start is None and stop is not None: + start = stop - count + elif start is not None and stop is None: + stop = start + count + + if stop is None: + stop = len(self.buffer) + if start is None: + start = 0 + + start = max(0, start) + stop = min(stop, len(self.buffer)) + + return [self.buffer[i] for i in range(start, stop)] def rectangles(self, istart, istop=None, workers=None, start_boundary=0): L_start = [] @@ -42,7 +83,10 @@ def rectangles(self, istart, istop=None, workers=None, start_boundary=0): L_y = [] diff = self.index - len(self.buffer) - for msg in self.buffer[istart - diff: istop - diff if istop else istop]: + if istop is None: + istop = len(self.buffer) + for i in range((istart or 0) - diff, istop - diff if istop else istop): + msg = self.buffer[i] key = msg['key'] name = key_split(key) startstops = msg.get('startstops', []) diff --git a/distributed/diagnostics/tests/test_task_stream.py b/distributed/diagnostics/tests/test_task_stream.py new file mode 100644 index 00000000000..eccb0a9db8e --- /dev/null +++ b/distributed/diagnostics/tests/test_task_stream.py @@ -0,0 +1,104 @@ +from __future__ import absolute_import, division, print_function + +from time import sleep + +from toolz import frequencies + +from distributed import Client, get_task_stream +from distributed.utils_test import gen_cluster, div, inc, slowinc, cluster +from distributed.utils_test import loop # noqa F401 +from distributed.client import wait +from distributed.diagnostics.task_stream import TaskStreamPlugin +from distributed.metrics import time + + +@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 3) +def test_TaskStreamPlugin(c, s, *workers): + es = TaskStreamPlugin(s) + assert not es.buffer + + futures = c.map(div, [1] * 10, range(10)) + total = c.submit(sum, futures[1:]) + yield wait(total) + + assert len(es.buffer) == 11 + + workers = dict() + + rects = es.rectangles(0, 10, workers) + assert all(n == 'div' for n in rects['name']) + assert all(d > 0 for d in rects['duration']) + counts = frequencies(rects['color']) + assert counts['black'] == 1 + assert set(counts.values()) == {9, 1} + assert len(set(rects['y'])) == 3 + + rects = es.rectangles(2, 5, workers) + assert all(len(L) == 3 for L in rects.values()) + + starts = sorted(rects['start']) + rects = es.rectangles(2, 5, workers=workers, + start_boundary=(starts[0] + starts[1]) / 2000) + assert set(rects['start']).issubset(set(starts[1:])) + + +@gen_cluster(client=True) +def test_maxlen(c, s, a, b): + tasks = TaskStreamPlugin(s, maxlen=5) + futures = c.map(inc, range(10)) + yield wait(futures) + assert len(tasks.buffer) == 5 + + +@gen_cluster(client=True) +def test_collect(c, s, a, b): + tasks = TaskStreamPlugin(s) + start = time() + futures = c.map(slowinc, range(10), delay=0.1) + yield wait(futures) + + L = tasks.collect() + assert len(L) == len(futures) + L = tasks.collect(start=start) + assert len(L) == len(futures) + + L = tasks.collect(start=start + 0.2) + assert 4 <= len(L) <= len(futures) + + L = tasks.collect(start='20 s') + assert len(L) == len(futures) + + L = tasks.collect(start='500ms') + assert 0 < len(L) <= len(futures) + + L = tasks.collect(count=3) + assert len(L) == 3 + assert L == list(tasks.buffer)[-3:] + + assert tasks.collect(stop=start + 100, count=3) == tasks.collect(count=3) + assert tasks.collect(start=start, count=3) == list(tasks.buffer)[:3] + + +@gen_cluster(client=True) +def test_client(c, s, a, b): + L = yield c.get_task_stream() + assert L == () + + futures = c.map(slowinc, range(10), delay=0.1) + yield wait(futures) + + tasks = [p for p in s.plugins if isinstance(p, TaskStreamPlugin)][0] + L = yield c.get_task_stream() + assert L == tuple(tasks.buffer) + + +def test_client_sync(loop): + with cluster() as (s, [a, b]): + with Client(s['address'], loop=loop) as c: + with get_task_stream(client=c) as ts: + sleep(0.1) # to smooth over time differences on the scheduler + # to smooth over time differences on the scheduler + futures = c.map(inc, range(10)) + wait(futures) + + assert len(ts.data) == 10 diff --git a/distributed/scheduler.py b/distributed/scheduler.py index bbcf0dad08d..3fd802b117d 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -924,7 +924,8 @@ def __init__( 'get_metadata': self.get_metadata, 'set_metadata': self.set_metadata, 'heartbeat_worker': self.heartbeat_worker, - 'get_task_status': self.get_task_status + 'get_task_status': self.get_task_status, + 'get_task_stream': self.get_task_stream, } self._transitions = { @@ -2100,12 +2101,18 @@ def handle_worker(self, comm=None, worker=None): worker_comm.abort() self.remove_worker(address=worker) - def add_plugin(self, plugin): + def add_plugin(self, plugin=None, idempotent=True, **kwargs): """ Add external plugin to scheduler See https://distributed.readthedocs.io/en/latest/plugins.html """ + if isinstance(plugin, type): + plugin = plugin(self, **kwargs) + + if idempotent and any(isinstance(p, type(plugin)) for p in self.plugins): + return + self.plugins.append(plugin) def remove_plugin(self, plugin): @@ -2960,6 +2967,12 @@ def get_task_status(self, stream=None, keys=None): if key in self.tasks else None) for key in keys} + def get_task_stream(self, comm=None, start=None, stop=None, count=None): + from distributed.diagnostics.task_stream import TaskStreamPlugin + self.add_plugin(TaskStreamPlugin, idempotent=True) + ts = [p for p in self.plugins if isinstance(p, TaskStreamPlugin)][0] + return ts.collect(start=start, stop=stop, count=count) + ##################### # State Transitions # ##################### diff --git a/distributed/utils_test.py b/distributed/utils_test.py index d6f713cd430..1e65c4e94cb 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -765,6 +765,7 @@ def test_func(): @gen.coroutine def coro(): with dask.config.set(config): + s = False for i in range(5): try: s, ws = yield start_cluster( @@ -774,9 +775,11 @@ def coro(): except Exception as e: logger.error("Failed to start gen_cluster, retryng", exc_info=True) else: + workers[:] = ws + args = [s] + workers break - workers[:] = ws - args = [s] + workers + if s is False: + raise Exception("Could not start cluster") if client: c = yield Client(s.address, loop=loop, security=security, asynchronous=True, **client_kwargs) diff --git a/docs/source/api.rst b/docs/source/api.rst index eff1b5c3409..687b87bebb3 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -19,6 +19,7 @@ API Client.get_executor Client.get_metadata Client.get_scheduler_logs + Client.get_task_stream Client.get_worker_logs Client.has_what Client.list_datasets @@ -92,6 +93,7 @@ API wait fire_and_forget futures_of + get_task_stream Asynchronous methods From 9206f32bbc18128e7a3aaaaa8aec91ac85e43cbc Mon Sep 17 00:00:00 2001 From: Dror Birkman Date: Tue, 24 Jul 2018 16:31:08 +0300 Subject: [PATCH 0040/1550] Remove extra print (#2141) --- distributed/client.py | 1 - 1 file changed, 1 deletion(-) diff --git a/distributed/client.py b/distributed/client.py index de4183138a6..a0c3028c1b0 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -2077,7 +2077,6 @@ def _graph_to_futures(self, dsk, keys, restrictions=None, retries = self._expand_retries(retries, all_keys=itertools.chain(dsk, keys)) - print(resources) keyset = set(keys) flatkeys = list(map(tokey, keys)) futures = {key: Future(key, self, inform=False) for key in keyset} From 4979df6aa4acb16c2e8c694fa9d1478d6fe48b07 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sun, 29 Jul 2018 08:57:48 -0700 Subject: [PATCH 0041/1550] Make bokeh coloring deterministic using hash function (#2143) Previously we would assign colors based on the order of tasks arrival in the scheduler. Now we use the has so that this is consistent across sessions. --- distributed/diagnostics/progress_stream.py | 26 +++---------------- .../diagnostics/tests/test_progress_stream.py | 4 +-- distributed/profile.py | 21 ++------------- distributed/utils.py | 21 ++++++++++++--- 4 files changed, 24 insertions(+), 48 deletions(-) diff --git a/distributed/diagnostics/progress_stream.py b/distributed/diagnostics/progress_stream.py index 84d87f4c157..60704a5670a 100644 --- a/distributed/diagnostics/progress_stream.py +++ b/distributed/diagnostics/progress_stream.py @@ -1,26 +1,20 @@ from __future__ import print_function, division, absolute_import -import itertools import logging -import random -from bokeh.palettes import viridis -from toolz import valmap, merge, memoize +from toolz import valmap, merge from tornado import gen from .progress import AllProgress from ..core import connect, coerce_to_address from ..scheduler import Scheduler -from ..utils import key_split +from ..utils import key_split, color_of from ..worker import dumps_function logger = logging.getLogger(__name__) -task_stream_palette = list(viridis(25)) -random.shuffle(task_stream_palette) - def counts(scheduler, allprogress): return merge({'all': valmap(len, allprogress.all), @@ -29,20 +23,6 @@ def counts(scheduler, allprogress): for state in ['memory', 'erred', 'released', 'processing']}) -counter = itertools.count() - -_incrementing_index_cache = dict() - - -@memoize(cache=_incrementing_index_cache) -def incrementing_index(o): - return next(counter) - - -def color_of(o, palette=task_stream_palette): - return palette[incrementing_index(o) % len(palette)] - - @gen.coroutine def progress_stream(address, interval): """ Open a TCP connection to scheduler, receive progress messages @@ -206,7 +186,7 @@ def color_of_message(msg): 'compute': ''} -def task_stream_append(lists, msg, workers, palette=task_stream_palette): +def task_stream_append(lists, msg, workers): key = msg['key'] name = key_split(key) startstops = msg.get('startstops', []) diff --git a/distributed/diagnostics/tests/test_progress_stream.py b/distributed/diagnostics/tests/test_progress_stream.py index 34bc6f2c9c1..ce21bb34193 100644 --- a/distributed/diagnostics/tests/test_progress_stream.py +++ b/distributed/diagnostics/tests/test_progress_stream.py @@ -7,7 +7,7 @@ from dask import delayed from distributed.client import wait from distributed.diagnostics.progress_stream import (progress_quads, - nbytes_bar, progress_stream, _incrementing_index_cache) + nbytes_bar, progress_stream) from distributed.utils_test import div, gen_cluster, inc @@ -18,8 +18,6 @@ def test_progress_quads(): 'released': {'inc': 1, 'dec': 0, 'add': 1}, 'processing': {'inc': 1, 'dec': 0, 'add': 2}} - _incrementing_index_cache.clear() - d = progress_quads(msg, nrows=2) color = d.pop('color') assert len(set(color)) == 3 diff --git a/distributed/profile.py b/distributed/profile.py index 3cb1c838d9b..46a4e441ebb 100644 --- a/distributed/profile.py +++ b/distributed/profile.py @@ -28,10 +28,8 @@ from collections import defaultdict import linecache -import itertools -import toolz -from .utils import format_time +from .utils import format_time, color_of def identifier(frame): @@ -190,7 +188,7 @@ def traverse(state, start, stop, height): try: colors.append(color_of(desc['filename'])) except IndexError: - colors.append(palette[-1]) + colors.append('gray') delta = (stop - start) / state['count'] @@ -216,18 +214,3 @@ def traverse(state, start, stop, height): 'name': names, 'time': times, 'percentage': percentages} - - -try: - from bokeh.palettes import viridis -except ImportError: - palette = ['red', 'green', 'blue', 'yellow'] -else: - palette = viridis(10) - -counter = itertools.count() - - -@toolz.memoize -def color_of(x): - return palette[next(counter) % len(palette)] diff --git a/distributed/utils.py b/distributed/utils.py index 666bb9dd26d..81ea0d67d92 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -5,6 +5,7 @@ from contextlib import contextmanager from datetime import timedelta import functools +from hashlib import md5 import inspect import json import logging @@ -35,7 +36,7 @@ import dask from dask import istask -from toolz import memoize +import toolz import tornado from tornado import gen from tornado.ioloop import IOLoop, PollIOLoop @@ -113,7 +114,7 @@ def get_fileno_limit(): return 512 -@memoize +@toolz.memoize def _get_ip(host, port, family, default): # By using a UDP socket, we don't actually try to connect but # simply select the local address through which *host* is reachable. @@ -670,7 +671,7 @@ def silence_logging(level, root='distributed'): return old -@memoize +@toolz.memoize def ensure_ip(hostname): """ Ensure that address is an IP address @@ -1399,3 +1400,17 @@ def has_keyword(func, keyword): if gen.is_coroutine_function(func): func = func.__wrapped__ return keyword in inspect.getargspec(func).args + + +# from bokeh.palettes import viridis +# palette = viridis(18) +palette = ['#440154', '#471669', '#472A79', '#433C84', '#3C4D8A', '#355D8C', + '#2E6C8E', '#287A8E', '#23898D', '#1E978A', '#20A585', '#2EB27C', + '#45BF6F', '#64CB5D', '#88D547', '#AFDC2E', '#D7E219', '#FDE724'] + + +@toolz.memoize +def color_of(x, palette=palette): + h = md5(str(x).encode()) + n = int(h.hexdigest()[:8], 16) + return palette[n % len(palette)] From da7eb711992b3a715f1d7a697ab825c088525a58 Mon Sep 17 00:00:00 2001 From: Ray Bell Date: Tue, 31 Jul 2018 16:58:24 -0400 Subject: [PATCH 0042/1550] DOC: typos (#2148) --- docs/source/quickstart.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index 7f5157fdb25..3d1e326f528 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -14,7 +14,7 @@ See :doc:`installation ` document for more information. Setup Dask.distributed the Easy Way ----------------------------------- -If you create an client without providing an address it will start up a local +If you create a client without providing an address it will start up a local scheduler and worker for you. .. code-block:: python @@ -41,7 +41,7 @@ Set up scheduler and worker processes on your local computer:: .. note:: At least one ``dask-worker`` must be running after launching a scheduler. -Launch an Client and point it to the IP/port of the scheduler. +Launch a Client and point it to the IP/port of the scheduler. .. code-block:: python From e4827100a814f3747a3d0aa01831da5053303bef Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 1 Aug 2018 10:11:04 -0700 Subject: [PATCH 0043/1550] Cleanup intermittent failures (#2146) * make synchronous clients awaitable * clean up handling of bare ioloop in tests * test that the test suite doesn't leak threads * reuse memory monitor's psutil.Process Profiling showed that creating a new one each cycle generated a nontrivial amount of overhead * Be robust to missing data file this happens during cleanup * don't ask workers to report closed when removing them from scheduler This is unnecessary and causes delays if the scheudler is going down as well. * move client/worker cleanup within test function Previously we did this at decorator call time, rather than testing time * test that no new threads are created during test * allow process watching threads to leak * xfail test_quiet_client_close * Don't test threads on python 2 * cleanup bokeh test * improve reporting * improve reporting * don't check threads on windows * clean up stealing test * log del data error --- .../bokeh/tests/test_scheduler_bokeh.py | 3 +- distributed/client.py | 8 ++- distributed/scheduler.py | 4 +- distributed/tests/test_client.py | 23 +++--- distributed/tests/test_steal.py | 6 +- distributed/tests/test_worker.py | 2 +- distributed/threadpoolexecutor.py | 6 +- distributed/utils_test.py | 70 ++++++++++++------- distributed/worker.py | 9 ++- 9 files changed, 87 insertions(+), 44 deletions(-) diff --git a/distributed/bokeh/tests/test_scheduler_bokeh.py b/distributed/bokeh/tests/test_scheduler_bokeh.py index a9cc0208801..235d85d0c5c 100644 --- a/distributed/bokeh/tests/test_scheduler_bokeh.py +++ b/distributed/bokeh/tests/test_scheduler_bokeh.py @@ -258,7 +258,8 @@ def test_ProcessingHistogram(c, s, a, b): assert (ph.source.data['top'] != 0).sum() == 1 futures = c.map(slowinc, range(10), delay=0.050) - yield gen.sleep(0.100) + while not s.tasks: + yield gen.sleep(0.01) ph.update() assert ph.source.data['right'][-1] > 2 diff --git a/distributed/client.py b/distributed/client.py index a0c3028c1b0..6b5de8ae3d4 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -739,7 +739,13 @@ def start(self, **kwargs): sync(self.loop, self._start, **kwargs) def __await__(self): - return self._started.__await__() + if hasattr(self, '_started'): + return self._started.__await__() + else: + @gen.coroutine + def _(): + raise gen.Return(self) + return _().__await__() def _send_to_scheduler_safe(self, msg): if self.status in ('running', 'closing'): diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 3fd802b117d..a2412068c7b 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1160,7 +1160,7 @@ def close_worker(self, stream=None, worker=None, safe=None): nanny_addr = self.get_worker_service_addr(worker, 'nanny') address = nanny_addr or worker - self.worker_send(worker, {'op': 'close'}) + self.worker_send(worker, {'op': 'close', 'report': False}) self.remove_worker(address=worker, safe=safe) def _setup_logging(self): @@ -1613,7 +1613,7 @@ def remove_worker(self, comm=None, address=None, safe=False, close=True): logger.info("Remove worker %s", address) if close: with ignoring(AttributeError, CommClosedError): - self.stream_comms[address].send({'op': 'close'}) + self.stream_comms[address].send({'op': 'close', 'report': False}) self.remove_resources(address) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index c854e2bcb27..57b9c3aca01 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -46,7 +46,7 @@ randominc, inc, dec, div, throws, geninc, asyncinc, gen_cluster, gen_test, double, deep, popen, captured_logger, varying, map_varying, - wait_for, async_wait_for) + wait_for, async_wait_for, pristine_loop) from distributed.utils_test import loop, loop_in_thread, nodebug # noqa F401 @@ -2084,7 +2084,6 @@ def test_multi_client(s, a, b): def long_running_client_connection(address): - from distributed.utils_test import pristine_loop with pristine_loop(): c = Client(address) x = c.submit(lambda x: x + 1, 10) @@ -3425,7 +3424,7 @@ def test_get_foo_lost_keys(c, s, u, v, w): @slow -@gen_cluster(client=True, Worker=Nanny) +@gen_cluster(client=True, Worker=Nanny, check_new_threads=False) def test_bad_tasks_fail(c, s, a, b): f = c.submit(sys.exit, 1) with pytest.raises(KilledWorker): @@ -4128,7 +4127,7 @@ def f(x, y=0): assert len(b.data) > 2 * len(a.data) -@gen_cluster(client=True) +@gen_cluster(client=True, check_new_threads=False) def test_add_done_callback(c, s, a, b): S = set() @@ -4628,6 +4627,7 @@ def test_fire_and_forget_err(c, s, a, b): assert time() < start + 1 +@pytest.mark.xfail(reason='Other tests bleed into the logs of this one') def test_quiet_client_close(loop): with captured_logger(logging.getLogger('distributed')) as logger: with Client(loop=loop, processes=False, threads_per_worker=4) as c: @@ -5103,12 +5103,10 @@ def test_future_auto_inform(c, s, a, b): def test_client_async_before_loop_starts(): - loop = IOLoop() - client = Client(asynchronous=True, loop=loop) - assert client.asynchronous - client.close() - # Avoid long wait for cluster close at shutdown - loop.close() + with pristine_loop() as loop: + client = Client(asynchronous=True, loop=loop) + assert client.asynchronous + client.close() @slow @@ -5433,5 +5431,10 @@ def bad_fn(x): assert y.status == 'error' # not cancelled +def test_no_threads_lingering(): + active = dict(threading._active) + assert threading.active_count() < 30, list(active.values()) + + if sys.version_info >= (3, 5): from distributed.tests.py3_test_client import * # noqa F401 diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 3309efc7240..01193b61e20 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -330,9 +330,11 @@ def test_steal_when_more_tasks(c, s, a, *rest): futures = [c.submit(slowidentity, x, pure=False, delay=0.2) for i in range(20)] - yield gen.sleep(0.1) - assert any(w.task_state for w in rest) + start = time() + while not any(w.task_state for w in rest): + yield gen.sleep(0.01) + assert time() < start + 1 @gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 10) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 1210d5213d0..8c6f0ee752b 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -181,7 +181,7 @@ def g(): assert not os.path.exists(os.path.join(a.local_dir, 'foobar.py')) -@pytest.mark.xfail(reason="don't yet support uploading pyc files") +@pytest.mark.skip(reason="don't yet support uploading pyc files") @gen_cluster(client=True, ncores=[('127.0.0.1', 1)]) def test_upload_file_pyc(c, s, w): with tmpfile() as dirname: diff --git a/distributed/threadpoolexecutor.py b/distributed/threadpoolexecutor.py index 8e9f10cadc8..c5c953ce0b9 100644 --- a/distributed/threadpoolexecutor.py +++ b/distributed/threadpoolexecutor.py @@ -23,6 +23,7 @@ from __future__ import print_function, division, absolute_import from . import _concurrent_futures_thread as thread +from .compatibility import Empty import os import logging import threading @@ -48,7 +49,10 @@ def _worker(executor, work_queue): executor._threads.remove(threading.current_thread()) rejoin_event.set() break - task = work_queue.get() + try: + task = work_queue.get(timeout=1) + except Empty: + continue if task is not None: # sentinel task.run() del task diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 1e65c4e94cb..be4b241a9b7 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -39,8 +39,8 @@ from tornado.gen import TimeoutError from tornado.ioloop import IOLoop -from .client import default_client -from .compatibility import PY3, iscoroutinefunction, Empty +from .client import default_client, _global_clients +from .compatibility import PY3, iscoroutinefunction, Empty, WINDOWS from .config import initialize_logging from .core import connect, rpc, CommClosedError from .metrics import time @@ -97,6 +97,7 @@ def cleanup_global_workers(): @pytest.fixture def loop(): del _global_workers[:] + _global_clients.clear() with pristine_loop() as loop: # Monkey-patch IOLoop.start to wait for loop stop orig_start = loop.start @@ -125,6 +126,7 @@ def start(): else: is_stopped.wait() del _global_workers[:] + _global_clients.clear() @pytest.fixture @@ -727,7 +729,7 @@ def gen_cluster(ncores=[('127.0.0.1', 1), ('127.0.0.1', 2)], scheduler='127.0.0.1', timeout=10, security=None, Worker=Worker, client=False, scheduler_kwargs={}, worker_kwargs={}, client_kwargs={}, active_rpc_timeout=1, - config={}): + config={}, check_new_threads=True): from distributed import Client """ Coroutine test with small cluster @@ -739,11 +741,6 @@ def test_foo(scheduler, worker1, worker2): start end """ - del _global_workers[:] - - reset_config() - - dask.config.set({'distributed.comm.timeouts.connect': '5s'}) worker_kwargs = merge({'memory_limit': TOTAL_MEMORY, 'death_timeout': 5}, worker_kwargs) @@ -752,6 +749,13 @@ def _(func): func = gen.coroutine(func) def test_func(): + del _global_workers[:] + _global_clients.clear() + active_threads_start = set(threading._active) + + reset_config() + + dask.config.set({'distributed.comm.timeouts.connect': '5s'}) # Restore default logging levels # XXX use pytest hooks/fixtures instead? for name, level in logging_levels.items(): @@ -810,22 +814,40 @@ def coro(): result = loop.run_sync(coro, timeout=timeout * 2 if timeout else timeout) - for w in workers: - if getattr(w, 'data', None): - try: - w.data.clear() - except EnvironmentError: - # zict backends can fail if their storage directory - # was already removed - pass - del w.data - DequeHandler.clear_all_instances() - for w in _global_workers: - w = w() - w._close(report=False, executor_wait=False) - if w.status == 'running': - w.close() - del _global_workers[:] + for w in workers: + if getattr(w, 'data', None): + try: + w.data.clear() + except EnvironmentError: + # zict backends can fail if their storage directory + # was already removed + pass + del w.data + DequeHandler.clear_all_instances() + for w in _global_workers: + w = w() + w._close(report=False, executor_wait=False) + if w.status == 'running': + w.close() + del _global_workers[:] + + if PY3 and not WINDOWS and check_new_threads: + start = time() + while True: + bad = [t for t, v in threading._active.items() + if t not in active_threads_start and + "Threaded" not in v.name and + "watch message queue" not in v.name] + if not bad: + break + else: + sleep(0.01) + if time() > start + 2: + from distributed import profile + tid = bad[0] + thread = threading._active[tid] + call_stacks = profile.call_stack(sys._current_frames()[tid]) + assert False, (thread, call_stacks) return result return test_func diff --git a/distributed/worker.py b/distributed/worker.py index feb67ce5a92..891b413c302 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -478,6 +478,7 @@ def _close(self, report=True, timeout=10, nanny=True, executor_wait=True): self.scheduler.unregister(address=self.contact_address)) self.scheduler.close_rpc() if isinstance(self.executor, ThreadPoolExecutor): + self.executor._work_queue.queue.clear() self.executor.shutdown(wait=executor_wait, timeout=timeout) else: self.executor.shutdown(wait=False) @@ -2004,7 +2005,11 @@ def release_key(self, key, cause=None, reason=None, report=True): self.log.append((key, 'release-key')) del self.tasks[key] if key in self.data and key not in self.dep_state: - del self.data[key] + try: + del self.data[key] + except FileNotFoundError: + logger.error("Tried to delete %s but no file found", + exc_info=True) del self.nbytes[key] del self.types[key] @@ -2266,7 +2271,7 @@ def memory_monitor(self): self._memory_monitoring = True total = 0 - proc = psutil.Process() + proc = self.monitor.proc memory = proc.memory_info().rss frac = memory / self.memory_limit From 444e7bb8eb7fb258628ff9bdeecca9babf9e71fe Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 1 Aug 2018 13:14:18 -0700 Subject: [PATCH 0044/1550] Include serializers in Scheduler.gather calls (#2151) --- distributed/tests/test_client.py | 74 +++++++++++++++++--------------- distributed/utils_comm.py | 1 + distributed/worker.py | 12 ++++-- 3 files changed, 49 insertions(+), 38 deletions(-) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 57b9c3aca01..6ab903ffdee 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -5317,41 +5317,45 @@ def test_client_active_bad_port(): http_server.stop() -@gen_cluster() -def test_turn_off_pickle(s, a, b): - import numpy as np - c = yield Client(s.address, asynchronous=True, - serializers=['dask', 'msgpack']) - try: - assert (yield c.submit(inc, 1)) == 2 - yield c.submit(np.ones, 5) - yield c.scatter(1) - - # Can't send complex data - with pytest.raises(TypeError): - future = yield c.scatter(inc) - - # can send complex tasks (this uses pickle regardless) - future = c.submit(lambda x: x, inc) - yield wait(future) - - # but can't receive complex results - with pytest.raises(TypeError): - yield future - - # Run works - result = yield c.run(lambda: 1) - assert list(result.values()) == [1, 1] - result = yield c.run_on_scheduler(lambda: 1) - assert result == 1 - - # But not with complex return values - with pytest.raises(TypeError): - yield c.run(lambda: inc) - with pytest.raises(TypeError): - yield c.run_on_scheduler(lambda: inc) - finally: - yield c._close() +@pytest.mark.parametrize('direct', [True, False]) +def test_turn_off_pickle(direct): + @gen_cluster() + def test(s, a, b): + import numpy as np + c = yield Client(s.address, asynchronous=True, + serializers=['dask', 'msgpack']) + try: + assert (yield c.submit(inc, 1)) == 2 + yield c.submit(np.ones, 5) + yield c.scatter(1) + + # Can't send complex data + with pytest.raises(TypeError): + future = yield c.scatter(inc) + + # can send complex tasks (this uses pickle regardless) + future = c.submit(lambda x: x, inc) + yield wait(future) + + # but can't receive complex results + with pytest.raises(TypeError): + yield c.gather(future, direct=direct) + + # Run works + result = yield c.run(lambda: 1) + assert list(result.values()) == [1, 1] + result = yield c.run_on_scheduler(lambda: 1) + assert result == 1 + + # But not with complex return values + with pytest.raises(TypeError): + yield c.run(lambda: inc) + with pytest.raises(TypeError): + yield c.run_on_scheduler(lambda: inc) + finally: + yield c._close() + + test() @gen_cluster() diff --git a/distributed/utils_comm.py b/distributed/utils_comm.py index 46724973996..7e8702e40a2 100644 --- a/distributed/utils_comm.py +++ b/distributed/utils_comm.py @@ -58,6 +58,7 @@ def gather_from_workers(who_has, rpc, close=True, serializers=None, who=None): try: coroutines = {address: get_data_from_worker(rpc, keys, address, who=who, + serializers=serializers, max_connections=False) for address, keys in d.items()} response = {} diff --git a/distributed/worker.py b/distributed/worker.py index 891b413c302..48e4bb8a2e8 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -2781,7 +2781,8 @@ def parse_memory_limit(memory_limit, ncores): @gen.coroutine -def get_data_from_worker(rpc, keys, worker, who=None, max_connections=None): +def get_data_from_worker(rpc, keys, worker, who=None, max_connections=None, + serializers=None, deserializers=None): """ Get keys from worker The worker has a two step handshake to acknowledge when data has been fully @@ -2793,11 +2794,16 @@ def get_data_from_worker(rpc, keys, worker, who=None, max_connections=None): Worker.gather_deps utils_comm.gather_data_from_workers """ + if serializers is None: + serializers = rpc.serializers + if deserializers is None: + deserializers = rpc.deserializers + comm = yield rpc.connect(worker) try: response = yield send_recv(comm, - serializers=rpc.serializers, - deserializers=rpc.deserializers, + serializers=serializers, + deserializers=deserializers, op='get_data', keys=keys, who=who, max_connections=max_connections) if response['status'] == 'OK': From 9919543ca62f1c7ac80c97559bfa2b245899b634 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 2 Aug 2018 15:55:39 +0200 Subject: [PATCH 0045/1550] Fix msgpack PendingDeprecationWarning for encoding='utf-8' (#2153) --- distributed/protocol/core.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/distributed/protocol/core.py b/distributed/protocol/core.py index 4033c9be1a9..f7df6597752 100644 --- a/distributed/protocol/core.py +++ b/distributed/protocol/core.py @@ -19,6 +19,14 @@ _deserialize = deserialize +try: + msgpack.loads(msgpack.dumps(''), raw=False) + msgpack_raw_false = {'raw': False} +except TypeError: + # Backward compat with old msgpack (prior to 0.5.2) + msgpack_raw_false = {'encoding': 'utf-8'} + + logger = logging.getLogger(__name__) @@ -102,7 +110,7 @@ def loads(frames, deserialize=True, deserializers=None): return msg header = frames.pop() - header = msgpack.loads(header, encoding='utf8', use_list=False) + header = msgpack.loads(header, use_list=False, **msgpack_raw_false) keys = header['keys'] headers = header['headers'] bytestrings = set(header['bytestrings']) @@ -174,7 +182,7 @@ def loads_msgpack(header, payload): dumps_msgpack """ if header: - header = msgpack.loads(header, encoding='utf8', use_list=False) + header = msgpack.loads(header, use_list=False, **msgpack_raw_false) else: header = {} @@ -186,4 +194,4 @@ def loads_msgpack(header, payload): raise ValueError("Data is compressed as %s but we don't have this" " installed" % str(header['compression'])) - return msgpack.loads(payload, encoding='utf8', use_list=False) + return msgpack.loads(payload, use_list=False, **msgpack_raw_false) From 0c9ee7328656b175bdfe4f310c5ceeb7897b692a Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 2 Aug 2018 16:40:02 -0400 Subject: [PATCH 0046/1550] Cleanup recent cleanup PR (#2152) * Cleanup recent cleanup PR Some fallout from the recent intermittent testing PR * extend timeout * xfail test_diskutils Raised at https://github.com/dask/distributed/issues/2155 --- distributed/cli/tests/test_dask_scheduler.py | 2 +- distributed/tests/test_diskutils.py | 5 ++++- distributed/utils_test.py | 2 +- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/distributed/cli/tests/test_dask_scheduler.py b/distributed/cli/tests/test_dask_scheduler.py index 88507429cdd..456f44c55c4 100644 --- a/distributed/cli/tests/test_dask_scheduler.py +++ b/distributed/cli/tests/test_dask_scheduler.py @@ -30,7 +30,7 @@ def test_defaults(loop): def f(): # Default behaviour is to listen on all addresses yield [ - assert_can_connect_from_everywhere_4_6(8786, 2.0), # main port + assert_can_connect_from_everywhere_4_6(8786, 5.0), # main port ] loop.run_sync(f) diff --git a/distributed/tests/test_diskutils.py b/distributed/tests/test_diskutils.py index 598c2506b0d..d7079ca2039 100644 --- a/distributed/tests/test_diskutils.py +++ b/distributed/tests/test_diskutils.py @@ -9,9 +9,10 @@ from time import sleep import mock +import pytest import dask -from distributed.compatibility import Empty +from distributed.compatibility import Empty, WINDOWS from distributed.diskutils import WorkSpace from distributed.metrics import time from distributed.utils import mp_context @@ -257,6 +258,8 @@ def _test_workspace_concurrency(tmpdir, timeout, max_procs): def test_workspace_concurrency(tmpdir): + if WINDOWS: + raise pytest.xfail.Exception('TODO: unknown failure on windows') _test_workspace_concurrency(tmpdir, 2.0, 6) diff --git a/distributed/utils_test.py b/distributed/utils_test.py index be4b241a9b7..cce9ad75dda 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -837,7 +837,7 @@ def coro(): bad = [t for t, v in threading._active.items() if t not in active_threads_start and "Threaded" not in v.name and - "watch message queue" not in v.name] + "watch message" not in v.name] if not bad: break else: From 770b4afa51bbc8e5fbe77d6f6b5d2fbe1f69a7bc Mon Sep 17 00:00:00 2001 From: Martin Durant Date: Fri, 3 Aug 2018 10:07:15 -0400 Subject: [PATCH 0047/1550] Worker class (#2147) * Start: make worker class an attribute of Nanny * Make Nanny subclass test * as keyword argument * fix tests --- distributed/nanny.py | 17 +++++++++-------- distributed/tests/test_nanny.py | 22 +++++++++++++++++++++- 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/distributed/nanny.py b/distributed/nanny.py index 6b0c0ec9620..c973985db78 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -23,8 +23,7 @@ from .security import Security from .utils import (get_ip, mp_context, silence_logging, json_load_robust, PeriodicCallback) -from .worker import _ncores, run, parse_memory_limit - +from .worker import _ncores, run, parse_memory_limit, Worker logger = logging.getLogger(__name__) @@ -44,7 +43,8 @@ def __init__(self, scheduler_ip=None, scheduler_port=None, name=None, memory_limit='auto', reconnect=True, validate=False, quiet=False, resources=None, silence_logs=None, death_timeout=None, preload=(), preload_argv=[], security=None, - contact_address=None, listen_address=None, **kwargs): + contact_address=None, listen_address=None, worker_class=None, + **kwargs): if scheduler_file: cfg = json_load_robust(scheduler_file) self.scheduler_addr = cfg['address'] @@ -62,6 +62,7 @@ def __init__(self, scheduler_ip=None, scheduler_port=None, self.death_timeout = death_timeout self.preload = preload self.preload_argv = preload_argv + self.Worker = Worker if worker_class is None else worker_class self.contact_address = contact_address self.memory_terminate_fraction = dask.config.get('distributed.worker.memory.terminate') @@ -214,6 +215,7 @@ def instantiate(self, comm=None): worker_start_args=(start_arg,), silence_logs=self.silence_logs, on_exit=self._on_exit, + worker=self.Worker ) self.auto_restart = True @@ -320,7 +322,7 @@ def _close(self, comm=None, timeout=5, report=None): class WorkerProcess(object): def __init__(self, worker_args, worker_kwargs, worker_start_args, - silence_logs, on_exit): + silence_logs, on_exit, worker): self.status = 'init' self.silence_logs = silence_logs self.worker_args = worker_args @@ -328,6 +330,7 @@ def __init__(self, worker_args, worker_kwargs, worker_start_args, self.worker_start_args = worker_start_args self.on_exit = on_exit self.process = None + self.Worker = worker # Initialized when worker is ready self.worker_dir = None @@ -357,7 +360,7 @@ def start(self): silence_logs=self.silence_logs, init_result_q=self.init_result_q, child_stop_q=self.child_stop_q, - uid=uid), + uid=uid, Worker=self.Worker), ) self.process.daemon = True self.process.set_exit_callback(self._on_exit) @@ -485,9 +488,7 @@ def _wait_until_connected(self, uid): @classmethod def _run(cls, worker_args, worker_kwargs, worker_start_args, - silence_logs, init_result_q, child_stop_q, uid): # pragma: no cover - from distributed import Worker - + silence_logs, init_result_q, child_stop_q, uid, Worker): # pragma: no cover try: from dask.multiprocessing import initialize_worker_process except ImportError: # old Dask version diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index 7f0e703b6c4..9a9feabf114 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -13,7 +13,7 @@ from tornado import gen import dask -from distributed import Nanny, rpc, Scheduler +from distributed import Nanny, rpc, Scheduler, Worker from distributed.core import CommClosedError from distributed.metrics import time from distributed.protocol.pickle import dumps @@ -145,6 +145,26 @@ def test_close_on_disconnect(s, w): assert time() < start + 9 +class Something(Worker): + # a subclass of Worker which is not Worker + pass + + +@gen_cluster(client=True, Worker=Nanny) +def test_nanny_worker_class(c, s, w1, w2): + out = yield c._run(lambda dask_worker=None: str(dask_worker.__class__)) + assert 'Worker' in list(out.values())[0] + assert w1.Worker is Worker + + +@gen_cluster(client=True, Worker=Nanny, + worker_kwargs={'worker_class': Something}) +def test_nanny_alt_worker_class(c, s, w1, w2): + out = yield c._run(lambda dask_worker=None: str(dask_worker.__class__)) + assert 'Something' in list(out.values())[0] + assert w1.Worker is Something + + @slow @gen_cluster(client=False, ncores=[]) def test_nanny_death_timeout(s): From 4bb66acae20f0939e221f441b9bf00410358a79b Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 3 Aug 2018 18:56:52 -0400 Subject: [PATCH 0048/1550] Support lack of PollIOLoop in Tornado --- distributed/utils.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/distributed/utils.py b/distributed/utils.py index 81ea0d67d92..9fb7480290f 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -39,7 +39,11 @@ import toolz import tornado from tornado import gen -from tornado.ioloop import IOLoop, PollIOLoop +from tornado.ioloop import IOLoop +try: + from tornado.ioloop import PollIOLoop +except ImportError: + PollIOLoop = None # dropped in tornado 6.0 from .compatibility import Queue, PY3, PY2, get_thread_identity, unicode from .metrics import time @@ -234,7 +238,7 @@ def sync(loop, func, *args, **kwargs): Run coroutine in loop running in separate thread. """ # Tornado's PollIOLoop doesn't raise when using closed, do it ourselves - if ((isinstance(loop, PollIOLoop) and getattr(loop, '_closing', False)) or + if PollIOLoop and ((isinstance(loop, PollIOLoop) and getattr(loop, '_closing', False)) or (hasattr(loop, 'asyncio_loop') and loop.asyncio_loop._closed)): raise RuntimeError("IOLoop is closed") From 3b8f67498adb9cf2965a8b1163b5bbeaa8af4964 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 3 Aug 2018 12:50:24 -0400 Subject: [PATCH 0049/1550] bump version to 1.22.1 --- distributed/tests/test_worker.py | 2 ++ distributed/utils.py | 2 +- docs/source/changelog.rst | 48 ++++++++++++++++++++++++++++++-- 3 files changed, 49 insertions(+), 3 deletions(-) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 8c6f0ee752b..77730fac85c 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1182,6 +1182,8 @@ def test_wait_for_outgoing(c, s, a, b): assert 1 / 3 < ratio < 3 +@pytest.mark.skipif(not sys.platform.startswith('linux'), + reason="Need 127.0.0.2 to mean localhost") @gen_cluster(ncores=[('127.0.0.1', 1), ('127.0.0.1', 1), ('127.0.0.2', 1)], client=True) def test_prefer_gather_from_local_address(c, s, w1, w2, w3): diff --git a/distributed/utils.py b/distributed/utils.py index 9fb7480290f..6257a9ba83b 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -239,7 +239,7 @@ def sync(loop, func, *args, **kwargs): """ # Tornado's PollIOLoop doesn't raise when using closed, do it ourselves if PollIOLoop and ((isinstance(loop, PollIOLoop) and getattr(loop, '_closing', False)) or - (hasattr(loop, 'asyncio_loop') and loop.asyncio_loop._closed)): + (hasattr(loop, 'asyncio_loop') and loop.asyncio_loop._closed)): raise RuntimeError("IOLoop is closed") timeout = kwargs.pop('callback_timeout', None) diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index 5dd96ede6e2..13e0f537294 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -1,10 +1,48 @@ Changelog ========= -1.21.1 - 2018-XX-XX +X.XX.X - 2018-XX-XX ------------------- -- + +1.22.1 - 2018-08-03 +------------------- + +- Add worker_class= keyword to Nanny to support different worker types (:pr:`2147`) `Martin Durant`_ +- Cleanup intermittent worker failures (:pr:`2152`) (:pr:`2146`) `Matthew Rocklin`_ +- Fix msgpack PendingDeprecationWarning for encoding='utf-8' (:pr:`2153`) `Olivier Grisel`_ +- Make bokeh coloring deterministic using hash function (:pr:`2143`) `Matthew Rocklin`_ +- Allow client to query the task stream plot (:pr:`2122`) `Matthew Rocklin`_ +- Use PID and counter in thread names (:pr:`2084`) (:pr:`2128`) `Dror Birkman`_ +- Test that worker restrictions are cleared after cancellation (:pr:`2107`) `Matthew Rocklin`_ +- Expand resources in graph_to_futures (:pr:`2131`) `Matthew Rocklin`_ +- Add custom serialization support for pyarrow (:pr:`2115`) `Dave Hirschfeld`_ +- Update dask-scheduler cli help text for preload (:pr:`2120`) `Matt Nicolls`_ +- Added another nested parallelism test (:pr:`1710`) `Tom Augspurger`_ +- insert newline by default after TextProgressBar (:pr:`1976`) `Phil Tooley`_ +- Retire workers from scale (:pr:`2104`) `Matthew Rocklin`_ +- Allow worker to refuse data requests with busy signal (:pr:`2092`) `Matthew Rocklin`_ +- Don't forget released keys (:pr:`2098`) `Matthew Rocklin`_ +- Update example for stopping a worker (:pr:`2088`) `John A Kirkham`_ +- removed hardcoded value of memory terminate fraction from a log message (:pr:`2096`) `Bartosz Marcinkowski`_ +- Adjust worker doc after change in config file location and treatment (:pr:`2094`) `Aurélien Ponte`_ +- Prefer gathering data from same host (:pr:`2090`) `Matthew Rocklin`_ +- Handle exceptions on deserialized comm with text error (:pr:`2093`) `Matthew Rocklin`_ +- Fix typo in docstring (:pr:`2087`) `Loïc Estève`_ +- Provide communication context to serialization functions (:pr:`2054`) `Matthew Rocklin`_ +- Allow `name` to be explicitly passed in publish_dataset (:pr:`1995`) `Marius van Niekerk`_ +- Avoid accessing Worker.scheduler_delay around yield point (:pr:`2074`) `Matthew Rocklin`_ +- Support TB and PB in format bytes (:pr:`2072`) `Matthew Rocklin`_ +- Add test for as_completed for loops in Python 2 (:pr:`2071`) `Matthew Rocklin`_ +- Allow adaptive to exist without a cluster (:pr:`2064`) `Matthew Rocklin`_ +- Have worker data transfer wait until recipient acknowledges (:pr:`2052`) `Matthew Rocklin`_ +- Support async def functions in Client.sync (:pr:`2070`) `Matthew Rocklin`_ +- Add asynchronous parameter to docstring of LocalCluster `Matthew Rocklin`_ +- Normalize address before comparison (:pr:`2066`) `Tom Augspurger`_ +- Use ConnectionPool for Worker.scheduler `Matthew Rocklin`_ +- Avoid reference cycle in str_graph `Matthew Rocklin`_ +- Pull data outside of while loop in gather (:pr:`2059`) `Matthew Rocklin`_ + 1.22.0 - 2018-06-14 ------------------- @@ -683,3 +721,9 @@ significantly without many new features. .. _`@bmaisson`: https://github.com/bmaisson .. _`Martin Durant`: https://github.com/martindurant .. _`Grant Jenks`: https://github.com/grantjenks +.. _`Dror Birkman`: https://github.com/Dror-LightCyber +.. _`Dave Hirschfeld`: https://github.com/dhirschfeld +.. _`Matt Nicolls`: https://github.com/nicolls1 +.. _`Phil Tooley`: https://github.com/ptooley +.. _`Bartosz Marcinkowski`: https://github.com/bm371613 +.. _`Aurélien Ponte`: https://github.com/apatlpo From 257b4402275911a0498576aec105dc3e3fa5714c Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 19 Jul 2018 09:05:14 -0400 Subject: [PATCH 0050/1550] add direct_to_workers to Client --- distributed/client.py | 10 +++++++++- distributed/tests/test_client.py | 22 ++++++++++++++++++++++ distributed/worker.py | 1 + 3 files changed, 32 insertions(+), 1 deletion(-) diff --git a/distributed/client.py b/distributed/client.py index 6b5de8ae3d4..25dd2925538 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -482,6 +482,9 @@ class resembles executors in ``concurrent.futures`` but also allows name: string (optional) Gives the client a name that will be included in logs generated on the scheduler for matters relating to this client + direct_to_workers: bool (optional) + Can this client connect directly to workers or should it proxy through + the scheduler? heartbeat_interval: int Time in milliseconds between heartbeats to scheduler @@ -514,7 +517,7 @@ def __init__(self, address=None, loop=None, timeout=no_default, security=None, asynchronous=False, name=None, heartbeat_interval=None, serializers=None, deserializers=None, - extensions=DEFAULT_EXTENSIONS, + extensions=DEFAULT_EXTENSIONS, direct_to_workers=False, **kwargs): if timeout == no_default: timeout = dask.config.get('distributed.comm.timeouts.connect') @@ -544,6 +547,7 @@ def __init__(self, address=None, loop=None, timeout=no_default, if deserializers is None: deserializers = serializers self._deserializers = deserializers + self.direct_to_workers = direct_to_workers # Communication self.security = security or Security() @@ -1408,6 +1412,8 @@ def _gather(self, futures, errors='raise', direct=None, local_worker=None): else: if w.scheduler.address == self.scheduler.address: direct = True + if direct is None: + direct = self.direct_to_workers @gen.coroutine def wait(k): @@ -1610,6 +1616,8 @@ def _scatter(self, data, workers=None, broadcast=False, direct=None, else: if w.scheduler.address == self.scheduler.address: direct = True + if direct is None: + direct = self.direct_to_workers if local_worker: # running within task local_worker.update_data(data=data, report=False) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 6ab903ffdee..83ce26ec734 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -5440,5 +5440,27 @@ def test_no_threads_lingering(): assert threading.active_count() < 30, list(active.values()) +@gen_cluster() +def test_direct_async(s, a, b): + c = yield Client(s.address, asynchronous=True, direct_to_workers=True) + assert c.direct_to_workers + yield c.close() + + c = yield Client(s.address, asynchronous=True, direct_to_workers=False) + assert not c.direct_to_workers + yield c.close() + + +def test_direct_sync(loop): + with cluster() as (s, [a, b]): + with Client(s['address'], loop=loop) as c: + assert not c.direct_to_workers + + def f(): + return get_client().direct_to_workers + + assert c.submit(f).result() + + if sys.version_info >= (3, 5): from distributed.tests.py3_test_client import * # noqa F401 diff --git a/distributed/worker.py b/distributed/worker.py index 48e4bb8a2e8..ee0a6628046 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -2605,6 +2605,7 @@ def _get_client(self, timeout=3): security=self.security, set_as_default=True, asynchronous=asynchronous, + direct_to_workers=True, name='worker', timeout=timeout) if not asynchronous: From 0fc888cfa63920c8782423f905a713de96ed4324 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 19 Jul 2018 09:10:24 -0400 Subject: [PATCH 0051/1550] add Scheduler.proxy to workers --- distributed/core.py | 2 +- distributed/scheduler.py | 8 ++++++++ distributed/tests/test_client.py | 6 ++++++ distributed/worker.py | 2 +- 4 files changed, 16 insertions(+), 2 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index 05b901cf818..56360c7ce4d 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -233,7 +233,7 @@ def port(self): _, self._port = get_address_host_port(self.address) return self._port - def identity(self, comm): + def identity(self, comm=None): return {'type': type(self).__name__, 'id': self.id} def listen(self, port_or_addr=None, listen_args=None): diff --git a/distributed/scheduler.py b/distributed/scheduler.py index a2412068c7b..4ef146d06c6 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -903,6 +903,7 @@ def __init__( 'feed': self.feed, 'terminate': self.close, 'broadcast': self.broadcast, + 'proxy': self.proxy, 'ncores': self.get_ncores, 'has_what': self.get_has_what, 'who_has': self.get_who_has, @@ -2323,6 +2324,13 @@ def send_message(addr): raise Return(dict(zip(workers, results))) + @gen.coroutine + def proxy(self, comm=None, msg=None, worker=None, serializers=None): + """ Proxy a communication through the scheduler to some other worker """ + d = yield self.broadcast(comm=comm, msg=msg, workers=[worker], + serializers=serializers) + raise gen.Return(d[worker]) + @gen.coroutine def rebalance(self, comm=None, keys=None, workers=None): """ Rebalance keys so that each worker stores roughly equal bytes diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 83ce26ec734..0e099ccb678 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -2200,6 +2200,12 @@ def test_broadcast(loop): b['address']: {x.key, y.key}} +@gen_cluster(client=True) +def test_proxy(c, s, a, b): + msg = yield c.scheduler.proxy(msg={'op': 'identity'}, worker=a.address) + assert msg['id'] == a.identity()['id'] + + @gen_cluster(client=True) def test__cancel(c, s, a, b): x = c.submit(slowinc, 1) diff --git a/distributed/worker.py b/distributed/worker.py index ee0a6628046..6d98ebc4659 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -450,7 +450,7 @@ def _start(self, addr_or_port=0): def start(self, port=0): self.loop.add_callback(self._start, port) - def identity(self, comm): + def identity(self, comm=None): return {'type': type(self).__name__, 'id': self.id, 'scheduler': self.scheduler.address, From b16ee25506fd20ec5daa7d77e338e2725e778562 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 18 Jul 2018 15:02:25 -0400 Subject: [PATCH 0052/1550] Implement Actors --- distributed/__init__.py | 1 + distributed/actor.py | 209 +++++++++++++ distributed/client.py | 45 ++- distributed/core.py | 12 +- distributed/scheduler.py | 76 ++++- distributed/tests/test_actor.py | 521 +++++++++++++++++++++++++++++++ distributed/tests/test_worker.py | 15 - distributed/utils.py | 8 + distributed/utils_test.py | 11 +- distributed/worker.py | 150 +++++++-- docs/source/actors.rst | 235 ++++++++++++++ docs/source/index.rst | 1 + 12 files changed, 1198 insertions(+), 86 deletions(-) create mode 100644 distributed/actor.py create mode 100644 distributed/tests/test_actor.py create mode 100644 docs/source/actors.rst diff --git a/distributed/__init__.py b/distributed/__init__.py index 71e71a79143..ac324592dd2 100644 --- a/distributed/__init__.py +++ b/distributed/__init__.py @@ -2,6 +2,7 @@ from . import config from dask.config import config +from .actor import Actor, ActorFuture from .core import connect, rpc from .deploy import LocalCluster, Adaptive from .diagnostics import progress diff --git a/distributed/actor.py b/distributed/actor.py new file mode 100644 index 00000000000..b97c79a6041 --- /dev/null +++ b/distributed/actor.py @@ -0,0 +1,209 @@ +from tornado import gen +import functools + +from .client import Future, default_client +from .compatibility import get_thread_identity, Queue +from .protocol import to_serialize +from .utils import sync +from .utils_comm import WrappedKey +from .worker import get_worker + + +class Actor(WrappedKey): + """ Controls an object on a remote worker + + An actor allows remote control of a stateful object living on a remote + worker. Method calls on this object trigger operations on the remote + object and return ActorFutures on which we can block to get results. + + Examples + -------- + >>> class Counter: + ... def __init__(self): + ... self.n = 0 + ... def increment(self): + ... self.n += 1 + ... return self.n + + >>> from dask.distributed import Client + >>> client = Client() + + You can create an actor by submitting a class with the keyword + ``actor=True``. + + >>> future = client.submit(Counter, actor=True) + >>> counter = future.result() + >>> counter + + + Calling methods on this object immediately returns deferred ``ActorFuture`` + objects. You can call ``.result()`` on these objects to block and get the + result of the function call. + + >>> future = counter.increment() + >>> future.result() + 1 + >>> future = counter.increment() + >>> future.result() + 2 + """ + def __init__(self, cls, address, key, worker=None): + self._cls = cls + self._address = address + self.key = key + self._future = None + if worker: + self._worker = worker + self._client = None + else: + try: + self._worker = get_worker() + except ValueError: + self._worker = None + try: + self._client = default_client() + self._future = Future(key) + except ValueError: + self._client = None + + def __repr__(self): + return '' % (self._cls.__name__, self.key) + + def __reduce__(self): + return (Actor, (self._cls, self._address, self.key)) + + @property + def _io_loop(self): + if self._worker: + return self._worker.io_loop + else: + return self._client.io_loop + + @property + def _scheduler_rpc(self): + if self._worker: + return self._worker.scheduler + else: + return self._client.scheduler + + @property + def _worker_rpc(self): + if self._worker: + return self._worker.rpc(self._address) + else: + if self._client.direct_to_workers: + return self._client.rpc(self._address) + else: + return ProxyRPC(self._client.scheduler, self._address) + + @property + def _asynchronous(self): + if self._client: + return self._client.asynchronous + else: + return get_thread_identity() == self._worker.thread_id + + def _sync(self, func, *args, **kwargs): + if self._client: + return self._client.sync(func, *args, **kwargs) + else: + # TODO support sync operation by checking against thread ident of loop + return sync(self._worker.loop, func, *args, **kwargs) + + def __dir__(self): + o = set(dir(type(self))) + o.update(attr for attr in dir(self._cls) if not attr.startswith('_')) + return sorted(o) + + def __getattr__(self, key): + attr = getattr(self._cls, key) + + if self._future and not self._future.status == 'finished': + raise ValueError("Worker holding Actor was lost") + + if callable(attr): + @functools.wraps(attr) + def func(*args, **kwargs): + @gen.coroutine + def run_actor_function_on_worker(): + try: + result = yield self._worker_rpc.actor_execute( + function=key, + actor=self.key, + args=[to_serialize(arg) for arg in args], + kwargs={k: to_serialize(v) for k, v in kwargs.items()}, + ) + except OSError: + if self._future: + yield self._future + else: + raise OSError("Unable to contact Actor's worker") + raise gen.Return(result['result']) + + if self._asynchronous: + return run_actor_function_on_worker() + else: + # TODO: this mechanism is error prone + # we should endeavor to make dask's standard code work here + q = Queue() + + @gen.coroutine + def wait_then_add_to_queue(): + x = yield run_actor_function_on_worker() + q.put(x) + self._io_loop.add_callback(wait_then_add_to_queue) + + return ActorFuture(q, self._io_loop) + return func + + else: + @gen.coroutine + def get_actor_attribute_from_worker(): + x = yield self._worker_rpc.actor_attribute(attribute=key, actor=self.key) + raise gen.Return(x['result']) + + return self._sync(get_actor_attribute_from_worker) + + +class ProxyRPC(object): + """ + An rpc-like object that uses the scheduler's rpc to connect to a worker + """ + def __init__(self, rpc, address): + self.rpc = rpc + self._address = address + + def __getattr__(self, key): + @gen.coroutine + def func(**msg): + msg['op'] = key + result = yield self.rpc.proxy(worker=self._address, msg=msg) + raise gen.Return(result) + + return func + + +class ActorFuture(object): + """ Future to an actor's method call + + Whenever you call a method on an Actor you get an ActorFuture immediately + while the computation happens in the background. You can call ``.result`` + to block and collect the full result + + See Also + -------- + Actor + """ + def __init__(self, q, io_loop): + self.q = q + self.io_loop = io_loop + + def result(self, timeout=None): + try: + return self._cached_result + except AttributeError: + self._cached_result = self.q.get(timeout=timeout) + return self._cached_result + + def __repr__(self): + return '' diff --git a/distributed/client.py b/distributed/client.py index 25dd2925538..f136b91bcd6 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -401,19 +401,7 @@ def lose(self): self._get_event().clear() def set_error(self, exception, traceback): - if isinstance(exception, bytes): - try: - exception = loads(exception) - except TypeError: - exception = Exception("Undeserializable exception", exception) - if traceback: - if isinstance(traceback, bytes): - try: - traceback = loads(traceback) - except (TypeError, AttributeError): - traceback = None - else: - traceback = None + _, exception, traceback = clean_exception(exception, traceback) self.status = 'error' self.exception = exception @@ -1201,13 +1189,14 @@ def submit(self, func, *args, **kwargs): raise TypeError("First input to submit must be a callable function") key = kwargs.pop('key', None) - pure = kwargs.pop('pure', True) workers = kwargs.pop('workers', None) resources = kwargs.pop('resources', None) retries = kwargs.pop('retries', None) priority = kwargs.pop('priority', 0) fifo_timeout = kwargs.pop('fifo_timeout', '100ms') allow_other_workers = kwargs.pop('allow_other_workers', False) + actor = kwargs.pop('actor', kwargs.pop('actors', False)) + pure = kwargs.pop('pure', not actor) if allow_other_workers not in (True, False, None): raise TypeError("allow_other_workers= must be True or False") @@ -1246,7 +1235,8 @@ def submit(self, func, *args, **kwargs): user_priority=priority, resources={skey: resources} if resources else None, retries=retries, - fifo_timeout=fifo_timeout) + fifo_timeout=fifo_timeout, + actors=actor) logger.debug("Submit %s(...), %s", funcname(func), key) @@ -1328,13 +1318,14 @@ def map(self, func, *iterables, **kwargs): key = kwargs.pop('key', None) key = key or funcname(func) - pure = kwargs.pop('pure', True) workers = kwargs.pop('workers', None) retries = kwargs.pop('retries', None) resources = kwargs.pop('resources', None) user_priority = kwargs.pop('priority', 0) allow_other_workers = kwargs.pop('allow_other_workers', False) fifo_timeout = kwargs.pop('fifo_timeout', '100ms') + actor = kwargs.pop('actor', kwargs.pop('actors', False)) + pure = kwargs.pop('pure', not actor) if allow_other_workers and workers is None: raise ValueError("Only use allow_other_workers= if using workers=") @@ -1392,7 +1383,8 @@ def map(self, func, *iterables, **kwargs): resources=resources, retries=retries, user_priority=user_priority, - fifo_timeout=fifo_timeout) + fifo_timeout=fifo_timeout, + actors=actor) logger.debug("map(%s, ...)", funcname(func)) return [futures[tokey(k)] for k in keys] @@ -2081,7 +2073,7 @@ def run_coroutine(self, function, *args, **kwargs): def _graph_to_futures(self, dsk, keys, restrictions=None, loose_restrictions=None, priority=None, user_priority=0, resources=None, retries=None, - fifo_timeout=0): + fifo_timeout=0, actors=None): with self._lock: if resources: resources = self._expand_resources(resources, @@ -2091,6 +2083,9 @@ def _graph_to_futures(self, dsk, keys, restrictions=None, retries = self._expand_retries(retries, all_keys=itertools.chain(dsk, keys)) + if actors is not None and actors is not True and actors is not False: + actors = list(self._expand_key(actors)) + keyset = set(keys) flatkeys = list(map(tokey, keys)) futures = {key: Future(key, self, inform=False) for key in keyset} @@ -2145,7 +2140,8 @@ def _graph_to_futures(self, dsk, keys, restrictions=None, 'resources': resources, 'submitting_task': getattr(thread_state, 'key', None), 'retries': retries, - 'fifo_timeout': fifo_timeout}) + 'fifo_timeout': fifo_timeout, + 'actors': actors}) return futures def get(self, dsk, keys, restrictions=None, loose_restrictions=None, @@ -2265,7 +2261,8 @@ def normalize_collection(self, collection): def compute(self, collections, sync=False, optimize_graph=True, workers=None, allow_other_workers=False, resources=None, - retries=0, priority=0, fifo_timeout='60s', **kwargs): + retries=0, priority=0, fifo_timeout='60s', actors=None, + **kwargs): """ Compute dask collections on cluster Parameters @@ -2358,7 +2355,8 @@ def compute(self, collections, sync=False, optimize_graph=True, resources=resources, retries=retries, user_priority=priority, - fifo_timeout=fifo_timeout) + fifo_timeout=fifo_timeout, + actors=actors) i = 0 futures = [] @@ -2381,7 +2379,7 @@ def compute(self, collections, sync=False, optimize_graph=True, def persist(self, collections, optimize_graph=True, workers=None, allow_other_workers=None, resources=None, retries=None, - priority=0, fifo_timeout='60s', **kwargs): + priority=0, fifo_timeout='60s', actors=None, **kwargs): """ Persist dask collections on cluster Starts computation of the collection on the cluster in the background. @@ -2450,7 +2448,8 @@ def persist(self, collections, optimize_graph=True, workers=None, resources=resources, retries=retries, user_priority=priority, - fifo_timeout=fifo_timeout) + fifo_timeout=fifo_timeout, + actors=actors) postpersists = [c.__dask_postpersist__() for c in collections] result = [func({k: futures[k] for k in flatten(c.__dask_keys__())}, *args) diff --git a/distributed/core.py b/distributed/core.py index 56360c7ce4d..aa8b77984d7 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -851,9 +851,17 @@ def clean_exception(exception, traceback, **kwargs): error_message: create and serialize errors into message """ if isinstance(exception, bytes): - exception = protocol.pickle.loads(exception) + try: + exception = protocol.pickle.loads(exception) + except Exception: + exception = Exception(exception) + elif isinstance(exception, str): + exception = Exception(exception) if isinstance(traceback, bytes): - traceback = protocol.pickle.loads(traceback) + try: + traceback = protocol.pickle.loads(traceback) + except (TypeError, AttributeError): + traceback = None elif isinstance(traceback, string_types): traceback = None # happens if the traceback failed serializing return type(exception), exception, traceback diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 4ef146d06c6..034ea3eeadb 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -182,6 +182,12 @@ class WorkerState(object): The last time we received a heartbeat from this worker, in local scheduler time. + .. attribute:: actors: {TaskState} + + A set of all TaskStates on this worker that are actors. This only + includes those actors whose state actually lives on this worker, not + actors to which this worker has a reference. + """ # XXX need a state field to signal active/removed? @@ -200,6 +206,7 @@ class WorkerState(object): 'used_resources', 'status', 'last_seen', + 'actors', ) def __init__(self, worker, ncores, memory_limit, name=None): @@ -214,6 +221,7 @@ def __init__(self, worker, ncores, memory_limit, name=None): self.resources = {} self.used_resources = {} self.last_seen = 0 + self.actors = set() self.info = {'name': name, 'memory_limit': memory_limit, @@ -453,10 +461,13 @@ class TaskState(object): into the "processing" state and be sent for execution to another connected worker. - """ + .. attribute: actor: bool + Whether or not this task is an Actor. + """ __slots__ = ( # === General description === + 'actor', # Key name 'key', # Key prefix (see key_split()) @@ -518,6 +529,7 @@ def __init__(self, key, run_spec): self.worker_restrictions = None self.resource_restrictions = None self.loose_restrictions = False + self.actor = None def get_nbytes(self): nbytes = self.nbytes @@ -1301,7 +1313,7 @@ def update_graph(self, client=None, tasks=None, keys=None, dependencies=None, restrictions=None, priority=None, loose_restrictions=None, resources=None, submitting_task=None, retries=None, user_priority=0, - fifo_timeout=0): + actors=None, fifo_timeout=0): """ Add new computations to the internal dask graph @@ -1401,6 +1413,12 @@ def update_graph(self, client=None, tasks=None, keys=None, if isinstance(user_priority, Number): user_priority = {k: user_priority for k in tasks} + # Add actors + if actors is True: + actors = list(keys) + for actor in actors or []: + self.tasks[actor].actor = True + priority = priority or dask.order.order(tasks) # TODO: define order wrt old graph if submitting_task: # sub-tasks get better priority than parent tasks @@ -1965,6 +1983,8 @@ def send_task_to_worker(self, worker, key): 'duration': self.get_task_duration(ts)} if ts.resource_restrictions: msg['resource_restrictions'] = ts.resource_restrictions + if ts.actor: + msg['actor'] = True deps = ts.dependencies if deps: @@ -3206,6 +3226,9 @@ def transition_waiting_processing(self, key): self.check_idle_saturated(ws) self.n_tasks += 1 + if ts.actor: + ws.actors.add(ts) + # logger.debug("Send job to worker: %s, %s", worker, key) self.send_task_to_worker(worker, key) @@ -3346,6 +3369,14 @@ def transition_memory_released(self, key, safe=False): if safe: assert not ts.waiters + if ts.actor: + for ws in ts.who_has: + ws.actors.discard(ts) + if ts.who_wants: + ts.exception_blame = ts + ts.exception = "Worker holding Actor was lost" + return {ts.key: 'erred'} # don't try to recreate + recommendations = OrderedDict() for dts in ts.waiters: @@ -3509,6 +3540,10 @@ def transition_processing_erred(self, key, cause=None, exception=None, assert not ts.who_has assert not ts.waiting_on + if ts.actor: + ws = ts.processing_on + ws.actors.remove(ts) + self._remove_from_processing(ts) if exception is not None: @@ -3650,6 +3685,11 @@ def transition_memory_forgotten(self, key): assert 0, (ts,) recommendations = {} + + if ts.actor: + for ws in ts.who_has: + ws.actors.discard(ts) + self._propagate_forgotten(ts, recommendations) self.report_on_key(ts=ts) @@ -3991,7 +4031,11 @@ def worker_objective(self, ts, ws): if ws not in dts.who_has]) stack_time = ws.occupancy / ws.ncores start_time = comm_bytes / BANDWIDTH + stack_time - return (start_time, ws.nbytes) + + if ts.actor: + return (len(ws.actors), start_time, ws.nbytes) + else: + return (start_time, ws.nbytes) @gen.coroutine def get_profile(self, comm=None, workers=None, merge_workers=True, @@ -4161,8 +4205,11 @@ def decide_worker(ts, all_workers, valid_workers, objective): """ deps = ts.dependencies assert all(dts.who_has for dts in deps) - candidates = frequencies([ws for dts in deps - for ws in dts.who_has]) + if ts.actor: + candidates = all_workers + else: + candidates = frequencies([ws for dts in deps + for ws in dts.who_has]) if valid_workers is True: if not candidates: candidates = all_workers @@ -4239,6 +4286,21 @@ def validate_task_state(ts): assert ts in cs.wants_what, \ ("not in who_wants' wants_what", str(ts), str(cs), str(cs.wants_what)) + if ts.actor: + if ts.state == 'memory': + assert sum([ts in ws.actors for ws in ts.who_has]) == 1 + if ts.state == 'processing': + assert ts in ts.processing_on.actors + + +def validate_worker_state(ws): + for ts in ws.has_what: + assert ws in ts.who_has, \ + ("not in has_what' who_has", str(ws), str(ts), str(ts.who_has)) + + for ts in ws.actors: + assert ts.state in ('memory', 'processing') + def validate_state(tasks, workers, clients): """ @@ -4251,9 +4313,7 @@ def validate_state(tasks, workers, clients): validate_task_state(ts) for ws in workers.values(): - for ts in ws.has_what: - assert ws in ts.who_has, \ - ("not in has_what' who_has", str(ws), str(ts), str(ts.who_has)) + validate_worker_state(ws) for cs in clients.values(): for ts in cs.wants_what: diff --git a/distributed/tests/test_actor.py b/distributed/tests/test_actor.py new file mode 100644 index 00000000000..580b884b554 --- /dev/null +++ b/distributed/tests/test_actor.py @@ -0,0 +1,521 @@ +import operator +from time import sleep +from tornado import gen + +import pytest + +import dask +from distributed import Actor, ActorFuture, Client, Future, wait, Nanny +from distributed.utils_test import gen_cluster, cluster +from distributed.utils_test import loop # noqa: F401 +from distributed.metrics import time + + +class Counter(object): + n = 0 + + def __init__(self): + self.n = 0 + + def increment(self): + self.n += 1 + return self.n + + def add(self, x): + self.n += x + return self.n + + +class List(object): + L = [] + + def __init__(self, dummy=None): + self.L = [] + + def append(self, x): + self.L.append(x) + + +class ParameterServer(object): + def __init__(self): + self.data = {} + + def put(self, key, value): + self.data[key] = value + + def get(self, key): + return self.data[key] + + +@pytest.mark.parametrize('direct_to_workers', [True, False]) +def test_client_actions(direct_to_workers): + + @gen_cluster(client=True) + def test(c, s, a, b): + c = yield Client(s.address, asynchronous=True, + direct_to_workers=direct_to_workers) + + counter = c.submit(Counter, workers=[a.address], actor=True) + assert isinstance(counter, Future) + counter = yield counter + assert counter._address + assert hasattr(counter, 'increment') + assert hasattr(counter, 'add') + assert hasattr(counter, 'n') + + n = yield counter.n + assert n == 0 + + assert counter._address == a.address + + assert isinstance(a.actors[counter.key], Counter) + assert s.tasks[counter.key].actor + + yield [counter.increment(), counter.increment()] + + n = yield counter.n + assert n == 2 + + counter.add(10) + while (yield counter.n) != 10 + 2: + n = yield counter.n + yield gen.sleep(0.01) + + yield c.close() + + test() + + +@pytest.mark.parametrize('separate_thread', [False, True]) +def test_worker_actions(separate_thread): + + @gen_cluster(client=True) + def test(c, s, a, b): + counter = c.submit(Counter, workers=[a.address], actor=True) + a_address = a.address + + def f(counter): + start = counter.n + + assert type(counter) is Actor + assert counter._address == a_address + + future = counter.increment(separate_thread=separate_thread) + assert isinstance(future, ActorFuture) + assert "Future" in type(future).__name__ + end = future.result(timeout=1) + assert end > start + + futures = [c.submit(f, counter, pure=False) for _ in range(10)] + yield futures + + counter = yield counter + assert (yield counter.n) == 10 + + test() + + +@gen_cluster(client=True) +def test_Actor(c, s, a, b): + counter = yield c.submit(Counter, actor=True) + + assert counter._cls == Counter + + assert hasattr(counter, 'n') + assert hasattr(counter, 'increment') + assert hasattr(counter, 'add') + + assert not hasattr(counter, 'abc') + + +@pytest.mark.xfail(reason="Tornado can pass things out of order" + + "Should rely on sending small messages rather than rpc") +@gen_cluster(client=True) +def test_linear_access(c, s, a, b): + start = time() + future = c.submit(sleep, 0.2) + actor = c.submit(List, actor=True, dummy=future) + actor = yield actor + + for i in range(100): + actor.append(i) + + while True: + yield gen.sleep(0.1) + L = yield actor.L + if len(L) == 100: + break + + L = yield actor.L + stop = time() + assert L == tuple(range(100)) + + assert stop - start > 0.2 + + +@gen_cluster(client=True) +def test_exceptions_create(c, s, a, b): + class Foo(object): + x = 0 + + def __init__(self): + raise ValueError('bar') + + with pytest.raises(ValueError) as info: + future = yield c.submit(Foo, actor=True) + + assert "bar" in str(info.value) + + +@gen_cluster(client=True) +def test_exceptions_method(c, s, a, b): + class Foo(object): + def throw(self): + 1 / 0 + + foo = yield c.submit(Foo, actor=True) + with pytest.raises(ZeroDivisionError): + yield foo.throw() + + +@gen_cluster(client=True) +def test_gc(c, s, a, b): + actor = c.submit(Counter, actor=True) + yield wait(actor) + del actor + + while a.actors or b.actors: + yield gen.sleep(0.01) + + +@gen_cluster(client=True) +def test_track_dependencies(c, s, a, b): + actor = c.submit(Counter, actor=True) + yield wait(actor) + x = c.submit(sleep, 0.5) + y = c.submit(lambda x, y: x, x, actor) + del actor + + yield gen.sleep(0.3) + + assert a.actors or b.actors + + +@gen_cluster(client=True) +def test_future(c, s, a, b): + counter = c.submit(Counter, actor=True, workers=[a.address]) + assert isinstance(counter, Future) + yield wait(counter) + assert isinstance(a.actors[counter.key], Counter) + + counter = yield counter + assert isinstance(counter, Actor) + assert counter._address + + yield gen.sleep(0.1) + assert counter.key in c.futures # don't lose future + + +@gen_cluster(client=True) +def test_future_dependencies(c, s, a, b): + counter = c.submit(Counter, actor=True, workers=[a.address]) + + def f(a): + assert isinstance(a, Actor) + assert a._cls == Counter + + x = c.submit(f, counter, workers=[b.address]) + yield x + + assert {ts.key for ts in s.tasks[x.key].dependencies} == {counter.key} + assert {ts.key for ts in s.tasks[counter.key].dependents} == {x.key} + + y = c.submit(f, counter, workers=[a.address], pure=False) + yield y + + assert {ts.key for ts in s.tasks[y.key].dependencies} == {counter.key} + assert {ts.key for ts in s.tasks[counter.key].dependents} == {x.key, y.key} + + +def test_sync(loop): + with cluster() as (s, [a, b]): + with Client(s['address'], loop=loop) as c: + counter = c.submit(Counter, actor=True) + counter = counter.result() + + assert counter.n == 0 + + future = counter.increment() + n = future.result() + assert n == 1 + assert counter.n == 1 + + assert future.result() == future.result() + + assert 'ActorFuture' in repr(future) + assert 'distributed.actor' not in repr(future) + + +@gen_cluster(client=True, config={'distributed.comm.timeouts.connect': '1s'}) +def test_failed_worker(c, s, a, b): + future = c.submit(Counter, actor=True, workers=[a.address]) + yield wait(future) + counter = yield future + + yield a._close() + + with pytest.raises(Exception) as info: + yield counter.increment() + + assert "actor" in str(info.value).lower() + assert "worker" in str(info.value).lower() + assert "lost" in str(info.value).lower() + + +@gen_cluster(client=True) +def bench(c, s, a, b): + counter = yield c.submit(Counter, actor=True) + + for i in range(1000): + yield counter.increment() + + +@gen_cluster(client=True) +def test_numpy_roundtrip(c, s, a, b): + np = pytest.importorskip('numpy') + + server = yield c.submit(ParameterServer, actor=True) + + x = np.random.random(1000) + yield server.put('x', x) + + y = yield server.get('x') + + assert (x == y).all() + + +@gen_cluster(client=True) +def test_numpy_roundtrip_getattr(c, s, a, b): + np = pytest.importorskip('numpy') + + counter = yield c.submit(Counter, actor=True) + + x = np.random.random(1000) + + yield counter.add(x) + + y = yield counter.n + + assert (x == y).all() + + +@gen_cluster(client=True) +def test_repr(c, s, a, b): + counter = yield c.submit(Counter, actor=True) + + assert 'Counter' in repr(counter) + assert 'Actor' in repr(counter) + assert counter.key in repr(counter) + assert 'distributed.actor' not in repr(counter) + + +@gen_cluster(client=True) +def test_dir(c, s, a, b): + counter = yield c.submit(Counter, actor=True) + + d = set(dir(counter)) + + for attr in dir(Counter): + if not attr.startswith('_'): + assert attr in d + + +@gen_cluster(client=True) +def test_many_computations(c, s, a, b): + counter = yield c.submit(Counter, actor=True) + + def add(n, counter): + for i in range(n): + counter.increment().result() + + futures = c.map(add, range(10), counter=counter) + done = c.submit(lambda x: None, futures) + + while not done.done(): + assert len(s.processing) <= a.ncores + b.ncores + yield gen.sleep(0.01) + + yield done + + +@gen_cluster(client=True, ncores=[('127.0.0.1', 5)] * 2) +def test_thread_safety(c, s, a, b): + class Unsafe(object): + def __init__(self): + self.n = 0 + + def f(self): + assert self.n == 0 + self.n += 1 + + for i in range(20): + sleep(0.002) + assert self.n == 1 + self.n = 0 + + unsafe = yield c.submit(Unsafe, actor=True) + + futures = [unsafe.f() for i in range(10)] + yield futures + + +@gen_cluster(client=True) +def test_Actors_create_dependencies(c, s, a, b): + counter = yield c.submit(Counter, actor=True) + future = c.submit(lambda x: None, counter) + yield wait(future) + assert s.tasks[future.key].dependencies == {s.tasks[counter.key]} + + +@gen_cluster(client=True) +def test_load_balance(c, s, a, b): + class Foo(object): + def __init__(self, x): + pass + + b = c.submit(operator.mul, 'b', 1000000) + yield wait(b) + [ws] = s.tasks[b.key].who_has + + x = yield c.submit(Foo, b, actor=True) + y = yield c.submit(Foo, b, actor=True) + assert x.key != y.key # actors assumed not pure + + assert s.tasks[x.key].who_has == {ws} # first went to best match + assert s.tasks[x.key].who_has != s.tasks[y.key].who_has # second load balanced + + +@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 5) +def test_load_balance_map(c, s, *workers): + class Foo(object): + def __init__(self, x, y=None): + pass + + b = c.submit(operator.mul, 'b', 1000000) + yield wait(b) + + actors = c.map(Foo, range(10), y=b, actor=True) + yield wait(actors) + + assert all(len(w.actors) == 2 for w in workers) + + +@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 4, Worker=Nanny) +def bench_param_server(c, s, *workers): + import dask.array as da + import numpy as np + x = da.random.random((500000, 1000), chunks=(1000, 1000)) + x = x.persist() + yield wait(x) + + class ParameterServer: + data = None + + def __init__(self, n): + self.data = np.random.random(n) + + def update(self, x): + self.data += x + self.data /= 2 + + def get_data(self): + return self.data + + def f(block, ps=None): + start = time() + params = ps.get_data(separate_thread=False).result() + stop = time() + update = (block - params).mean(axis=0) + ps.update(update, separate_thread=False) + print(format_time(stop - start)) + return np.array([[stop - start]]) + + from distributed.utils import format_time + start = time() + ps = yield c.submit(ParameterServer, x.shape[1], actor=True) + y = x.map_blocks(f, ps=ps, dtype=x.dtype) + # result = yield c.compute(y.mean()) + yield wait(y.persist()) + end = time() + print(format_time(end - start)) + + +@gen_cluster(client=True) +def test_compute(c, s, a, b): + + @dask.delayed + def f(n, counter): + assert isinstance(counter, Actor) + for i in range(n): + counter.increment().result() + + @dask.delayed + def check(counter, blanks): + return counter.n + + counter = dask.delayed(Counter)() + values = [f(i, counter) for i in range(5)] + final = check(counter, values) + + result = yield c.compute(final, actors=counter) + assert result == 0 + 1 + 2 + 3 + 4 + + start = time() + while a.data or b.data or a.actors or b.actors: + yield gen.sleep(0.01) + assert time() < start + 2 + + +@gen_cluster(client=True, ncores=[('127.0.0.1', 1)], + config={'distributed.worker.profile.interval': '1ms'}) +def test_actors_in_profile(c, s, a): + class Sleeper(object): + def sleep(self, time): + sleep(time) + + sleeper = yield c.submit(Sleeper, actor=True) + + for i in range(5): + yield sleeper.sleep(0.200) + if (list(a.profile_recent['children'])[0].startswith('sleep') or + 'Sleeper.sleep' in a.profile_keys): + return + assert False, list(a.profile_keys) + + +@gen_cluster(client=True) +def test_waiter(c, s, a, b): + from tornado.locks import Event + + class Waiter(object): + def __init__(self): + self.event = Event() + + @gen.coroutine + def set(self): + self.event.set() + + @gen.coroutine + def wait(self): + yield self.event.wait() + + waiter = yield c.submit(Waiter, actor=True) + + futures = [waiter.wait() for i in range(5)] # way more than we have actor threads + + yield gen.sleep(0.1) + assert not any(future.done() for future in futures) + + yield waiter.set() + + yield futures diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 77730fac85c..4ebfd9869c9 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -705,21 +705,6 @@ def test_stop_doing_unnecessary_work(c, s, a, b): @gen_cluster(client=True, ncores=[('127.0.0.1', 1)]) def test_priorities(c, s, w): - a = delayed(slowinc)(1, dask_key_name='a', delay=0.05) - b = delayed(slowinc)(2, dask_key_name='b', delay=0.05) - a1 = delayed(slowinc)(a, dask_key_name='a1', delay=0.05) - a2 = delayed(slowinc)(a1, dask_key_name='a2', delay=0.05) - b1 = delayed(slowinc)(b, dask_key_name='b1', delay=0.05) - - z = delayed(add)(a2, b1) - future = yield c.compute(z) - - log = [t for t in w.log if t[1] == 'executing' and t[2] == 'memory'] - assert [t[0] for t in log[:5]] == ['a', 'b', 'a1', 'b1', 'a2'] - - -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)]) -def test_priorities_2(c, s, w): values = [] for i in range(10): a = delayed(slowinc)(i, dask_key_name='a-%d' % i, delay=0.01) diff --git a/distributed/utils.py b/distributed/utils.py index 6257a9ba83b..193005dbc83 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -1418,3 +1418,11 @@ def color_of(x, palette=palette): h = md5(str(x).encode()) n = int(h.hexdigest()[:8], 16) return palette[n % len(palette)] + + +def iscoroutinefunction(f): + if gen.is_coroutine_function(f): + return True + if sys.version_info >= (3, 5) and inspect.iscoroutinefunction(f): + return True + return False diff --git a/distributed/utils_test.py b/distributed/utils_test.py index cce9ad75dda..10e5329f4d4 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -8,7 +8,6 @@ import gc from glob import glob import itertools -import inspect import logging import logging.config import os @@ -40,14 +39,14 @@ from tornado.ioloop import IOLoop from .client import default_client, _global_clients -from .compatibility import PY3, iscoroutinefunction, Empty, WINDOWS +from .compatibility import PY3, Empty, WINDOWS from .config import initialize_logging from .core import connect, rpc, CommClosedError from .metrics import time from .proctitle import enable_proctitle_on_children from .security import Security from .utils import (ignoring, log_errors, mp_context, get_ip, get_ipv6, - DequeHandler, reset_logger_locks, sync) + DequeHandler, reset_logger_locks, sync, iscoroutinefunction) from .worker import Worker, TOTAL_MEMORY, _global_workers try: @@ -719,12 +718,6 @@ def end_worker(w): s.stop() -def iscoroutinefunction(f): - if sys.version_info >= (3, 5) and inspect.iscoroutinefunction(f): - return True - return False - - def gen_cluster(ncores=[('127.0.0.1', 1), ('127.0.0.1', 2)], scheduler='127.0.0.1', timeout=10, security=None, Worker=Worker, client=False, scheduler_kwargs={}, diff --git a/distributed/worker.py b/distributed/worker.py index 6d98ebc4659..f614c8e0b34 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -47,7 +47,7 @@ ignoring, mp_context, import_file, silence_logging, thread_state, json_load_robust, key_split, format_bytes, DequeHandler, PeriodicCallback, - parse_bytes, parse_timedelta) + parse_bytes, parse_timedelta, iscoroutinefunction) from .utils_comm import pack_data, gather_from_workers from .utils_perf import ThrottledGC, enable_gc_diagnosis, disable_gc_diagnosis @@ -160,11 +160,13 @@ def __init__(self, scheduler_ip=None, scheduler_port=None, self.data = Buffer({}, storage, target, weight) else: self.data = dict() + self.actors = {} self.loop = loop or IOLoop.current() self.status = None self._closed = Event() self.reconnect = reconnect self.executor = executor or ThreadPoolExecutor(self.ncores) + self.actor_executor = ThreadPoolExecutor(1) self.name = name self.scheduler_delay = 0 self.stream_comms = dict() @@ -195,6 +197,8 @@ def __init__(self, scheduler_ip=None, scheduler_port=None, 'get_logs': self.get_logs, 'keys': self.keys, 'versions': self.versions, + 'actor_execute': self.actor_execute, + 'actor_attribute': self.actor_attribute, } stream_handlers = { @@ -477,11 +481,13 @@ def _close(self, report=True, timeout=10, nanny=True, executor_wait=True): yield gen.with_timeout(timedelta(seconds=timeout), self.scheduler.unregister(address=self.contact_address)) self.scheduler.close_rpc() + self.actor_executor._work_queue.queue.clear() if isinstance(self.executor, ThreadPoolExecutor): self.executor._work_queue.queue.clear() self.executor.shutdown(wait=executor_wait, timeout=timeout) else: self.executor.shutdown(wait=False) + self.actor_executor.shutdown(wait=executor_wait, timeout=timeout) self._workdir.release() for k, v in self.services.items(): @@ -527,7 +533,8 @@ def wait_until_closed(self): assert self.status == 'closed' @gen.coroutine - def executor_submit(self, key, function, *args, **kwargs): + def executor_submit(self, key, function, args=(), kwargs=None, + executor=None): """ Safely run function in thread pool executor We've run into issues running concurrent.future futures within @@ -535,9 +542,11 @@ def executor_submit(self, key, function, *args, **kwargs): callbacks to ensure things run smoothly. This can get tricky, so we pull it off into an separate method. """ + executor = executor or self.executor job_counter[0] += 1 # logger.info("%s:%d Starts job %d, %s", self.ip, self.port, i, key) - future = self.executor.submit(function, *args, **kwargs) + kwargs = kwargs or {} + future = executor.submit(function, *args, **kwargs) pc = PeriodicCallback(lambda: logger.debug("future state: %s - %s", key, future._state), 1000) pc.start() @@ -558,6 +567,33 @@ def run_coroutine(self, comm, function, args=(), kwargs={}, wait=True): return run(self, comm, function=function, args=args, kwargs=kwargs, is_coro=True, wait=wait) + @gen.coroutine + def actor_execute(self, comm=None, actor=None, function=None, args=(), kwargs={}): + separate_thread = kwargs.pop('separate_thread', True) + key = actor + actor = self.actors[key] + func = getattr(actor, function) + name = key_split(key) + '.' + function + + if iscoroutinefunction(func): + result = yield func(*args, **kwargs) + elif separate_thread: + result = yield self.executor_submit(name, + apply_function_actor, + args=(func, args, kwargs, + self.execution_state, + name, + self.active_threads, + self.active_threads_lock), + executor=self.actor_executor) + else: + result = func(*args, **kwargs) + raise gen.Return({'status': 'OK', 'result': to_serialize(result)}) + + def actor_attribute(self, comm=None, actor=None, attribute=None): + value = getattr(self.actors[actor], attribute) + return {'status': 'OK', 'result': to_serialize(value)} + def update_data(self, comm=None, data=None, report=True, serializers=None): for key, value in data.items(): if key in self.task_state: @@ -618,6 +654,13 @@ def get_data(self, comm, keys=None, who=None, serializers=None, self.outgoing_current_count += 1 data = {k: self.data[k] for k in keys if k in self.data} + + if len(data) < len(keys): + for k in set(keys) - set(data): + if k in self.actors: + from .actor import Actor + data[k] = Actor(type(self.actors[k]), self.address, k) + msg = {'status': 'OK', 'data': {k: to_serialize(v) for k, v in data.items()}} nbytes = {k: self.nbytes.get(k) for k in data} @@ -625,6 +668,7 @@ def get_data(self, comm, keys=None, who=None, serializers=None, if self.digests is not None: self.digests['get-data-load-duration'].add(stop - start) start = time() + try: compressed = yield comm.write(msg, serializers=serializers) response = yield comm.read(deserializers=serializers) @@ -878,6 +922,30 @@ def apply_function(function, args, kwargs, execution_state, key, return msg +def apply_function_actor(function, args, kwargs, execution_state, key, + active_threads, active_threads_lock): + """ Run a function, collect information + + Returns + ------- + msg: dictionary with status, result/error, timings, etc.. + """ + ident = get_thread_identity() + + with active_threads_lock: + active_threads[ident] = key + + thread_state.execution_state = execution_state + thread_state.key = key + + result = function(*args, **kwargs) + + with active_threads_lock: + del active_threads[ident] + + return result + + def get_msg_safe_str(msg): """ Make a worker msg, which contains args and kwargs, safe to cast to str: allowing for some arguments to raise exceptions during conversion and @@ -1267,13 +1335,13 @@ def __repr__(self): def add_task(self, key, function=None, args=None, kwargs=None, task=None, who_has=None, nbytes=None, priority=None, duration=None, - resource_restrictions=None, **kwargs2): + resource_restrictions=None, actor=False, **kwargs2): try: if key in self.tasks: state = self.task_state[key] if state in ('memory', 'error'): if state == 'memory': - assert key in self.data + assert key in self.data or key in self.actors logger.debug("Asked to compute pre-existing result: %s: %s", key, state) self.send_task_state_to_scheduler(key) @@ -1298,6 +1366,8 @@ def add_task(self, key, function=None, args=None, kwargs=None, task=None, try: start = time() self.tasks[key] = _deserialize(function, args, kwargs, task) + if actor: + self.actors[key] = None stop = time() if stop - start > 0.010: @@ -1490,7 +1560,7 @@ def transition_waiting_ready(self, key): assert self.task_state[key] == 'waiting' assert key in self.waiting_for_data assert not self.waiting_for_data[key] - assert all(dep in self.data for dep in self.dependencies[key]) + assert all(dep in self.data or dep in self.actors for dep in self.dependencies[key]) assert key not in self.executing assert key not in self.ready @@ -1532,7 +1602,7 @@ def transition_ready_executing(self, key): # assert key not in self.data assert self.task_state[key] in READY assert key not in self.ready - assert all(dep in self.data for dep in self.dependencies[key]) + assert all(dep in self.data or dep in self.actors for dep in self.dependencies[key]) self.executing.add(key) self.loop.add_callback(self.execute, key) @@ -1712,9 +1782,14 @@ def ensure_communicating(self): raise def send_task_state_to_scheduler(self, key): - if key in self.data: - nbytes = self.nbytes[key] or sizeof(self.data[key]) - typ = self.types.get(key) or type(self.data[key]) + if key in self.data or self.actors.get(key): + try: + value = self.data[key] + except KeyError: + value = self.actors[key] + nbytes = self.nbytes[key] or sizeof(value) + typ = self.types.get(key) or type(value) + del value try: typ = dumps_function(typ) except PicklingError: @@ -1747,11 +1822,15 @@ def put_key_in_memory(self, key, value, transition=True): if key in self.data: return - start = time() - self.data[key] = value - stop = time() - if stop - start > 0.020: - self.startstops[key].append(('disk-write', start, stop)) + if key in self.actors: + self.actors[key] = value + + else: + start = time() + self.data[key] = value + stop = time() + if stop - start > 0.020: + self.startstops[key].append(('disk-write', start, stop)) if key not in self.nbytes: self.nbytes[key] = sizeof(value) @@ -2012,6 +2091,10 @@ def release_key(self, key, cause=None, reason=None, report=True): exc_info=True) del self.nbytes[key] del self.types[key] + if key in self.actors and key not in self.dep_state: + del self.actors[key] + del self.nbytes[key] + del self.types[key] if key in self.waiting_for_data: del self.waiting_for_data[key] @@ -2075,6 +2158,9 @@ def release_dep(self, dep, report=False): if dep in self.data: del self.data[dep] del self.types[dep] + if dep in self.actors: + del self.actors[dep] + del self.types[dep] del self.nbytes[dep] if dep in self.in_flight_tasks: @@ -2176,7 +2262,13 @@ def execute(self, key, report=False): function, args, kwargs = self.tasks[key] start = time() - data = {k: self.data[k] for k in self.dependencies[key]} + data = {} + for k in self.dependencies[key]: + try: + data[k] = self.data[k] + except KeyError: + from .actor import Actor # TODO: create local actor + data[k] = Actor(type(self.actors[k]), self.address, k, self) args2 = pack_data(args, data, key_types=(bytes, unicode)) kwargs2 = pack_data(kwargs, data, key_types=(bytes, unicode)) stop = time() @@ -2187,12 +2279,12 @@ def execute(self, key, report=False): logger.debug("Execute key: %s worker: %s", key, self.address) # TODO: comment out? try: - result = yield self.executor_submit(key, apply_function, function, - args2, kwargs2, - self.execution_state, key, - self.active_threads, - self.active_threads_lock, - self.scheduler_delay) + result = yield self.executor_submit(key, apply_function, + args=(function, args2, kwargs2, + self.execution_state, key, + self.active_threads, + self.active_threads_lock, + self.scheduler_delay)) except RuntimeError as e: executor_error = e raise @@ -2353,9 +2445,9 @@ def trigger_profile(self): if frame is not None: key = key_split(active_threads[ident]) profile.process(frame, None, self.profile_recent, - stop='_concurrent_futures_thread.py') + stop='distributed/worker.py') profile.process(frame, None, self.profile_keys[key], - stop='_concurrent_futures_thread.py') + stop='distributed/worker.py') stop = time() if self.digests is not None: self.digests['profile-duration'].add(stop - start) @@ -2438,7 +2530,7 @@ def get_logs(self, comm=None, n=None): return [(msg.levelname, deque_handler.format(msg)) for msg in L] def validate_key_memory(self, key): - assert key in self.data + assert key in self.data or key in self.actors assert key in self.nbytes assert key not in self.waiting_for_data assert key not in self.executing @@ -2450,14 +2542,14 @@ def validate_key_executing(self, key): assert key in self.executing assert key not in self.data assert key not in self.waiting_for_data - assert all(dep in self.data for dep in self.dependencies[key]) + assert all(dep in self.data or dep in self.actors for dep in self.dependencies[key]) def validate_key_ready(self, key): assert key in pluck(1, self.ready) assert key not in self.data assert key not in self.executing assert key not in self.waiting_for_data - assert all(dep in self.data for dep in self.dependencies[key]) + assert all(dep in self.data or dep in self.actors for dep in self.dependencies[key]) def validate_key_waiting(self, key): assert key not in self.data @@ -2495,7 +2587,7 @@ def validate_dep_flight(self, dep): assert dep in self.in_flight_workers[peer] def validate_dep_memory(self, dep): - assert dep in self.data + assert dep in self.data or dep in self.actors assert dep in self.nbytes assert dep in self.types if dep in self.task_state: @@ -2548,7 +2640,7 @@ def validate_state(self): if self.task_state[key] == 'memory': assert isinstance(self.nbytes[key], int) assert key not in self.waiting_for_data - assert key in self.data + assert key in self.data or key in self.actors except Exception as e: logger.exception(e) diff --git a/docs/source/actors.rst b/docs/source/actors.rst new file mode 100644 index 00000000000..109b9d907aa --- /dev/null +++ b/docs/source/actors.rst @@ -0,0 +1,235 @@ +Actors +====== + +.. note:: This is an experimental feature and is subject to change without notice +.. note:: This is an advanced feature and may not be suitable for beginning users. + It is rarely necessary for common workloads. + +Actors enable stateful computations within a Dask workflow. They are useful +for some rare algorithms that require additional performance and are willing to +sacrifice resilience. + +An actor is a pointer to a user-defined-object living on a remote worker. +Anyone with that actor can call methods on that remote object. + +Example +------- + +Here we create a simple ``Counter`` class, instantiate that class on one worker, +and then call methods on that class remotely. + +.. code-block:: python + + class Counter: + """ A simple class to manage an incrementing counter """ + n = 0 + + def __init__(self): + self.n = 0 + + def increment(self): + self.n += 1 + return self.n + + def add(self, x): + self.n += x + return self.n + + from dask.distributed import Client # Start a Dask Client + client = Client() + + future = client.submit(Counter, actor=True) # Create a Counter on a worker + counter = future.result() # Get back a pointer to that object + + counter + # + + future = counter.increment() # Call remote method + future.result() # Get back result + # 1 + + future = counter.add(10) # Call remote method + future.result() # Get back result + # 11 + +Motivation +---------- + +Actors are motivated by some of the challenges of using pure task graphs. + +Normal Dask computations are composed of a graph of functions. +This approach has a few limitations that are good for resilience, but can +negatively affect performance: + +1. **State**: The functions should not mutate their inputs in-place or rely on + global state. They should instead operate in a pure-functional manner, + consuming inputs and producing separate outputs. +2. **Central Overhead**: The execution location and order is determined by the + centralized scheduler. Because the scheduler is involved in every decision + it can sometimes create a central bottleneck. + +Some workloads may need to update state directly, or may involve more tiny +tasks than the scheduler can handle (the scheduler can coordinate about 4000 +tasks per second). + +Actors side-step both of these limitations: + +1. **State**: Actors can hold on to and mutate state. They are allowed to + update their state in-place. +2. **Overhead**: Operations on actors do not inform the central scheduler, and + so do not contribute to the 4000 task/second overhead. They also avoid an + extra network hop and so have lower latencies. + +Create an Actor +--------------- + +You create an actor by submitting a Class to run on a worker using normal Dask +computation functions like ``submit``, ``map``, ``compute``, or ``persist``, +and using the ``actors=`` keyword (or ``actor=`` on ``submit``). + +.. code-block:: python + + future = client.submit(Counter, actors=True) + +You can use all other keywords to these functions like ``workers=``, +``resources=``, and so on to control where this actor ends up. + +This creates a normal Dask future on which you can call ``.result()`` to get +the Actor once it has successfully run on a worker. + +.. code-block:: python + + >>> counter = future.result() + >>> counter + + +A ``Counter`` object has been instantiated on one of the workers, and this +``Actor`` object serves as our proxy to that remote object. It has the same +methods and attributes. + +.. code-block:: python + + >>> dir(counter) + ['add', 'increment', 'n'] + +Call Remote Methods +------------------- + +However accessing an attribute or calling a method will trigger a communication +to the remote worker, run the method on the remote worker in a separate thread +pool, and then communicate the result back to the calling side. For attribute +access these operations block and return when finished, for method calls they +return an ``ActorFuture`` immediately. + +.. code-block:: python + + >>> future = counter.increment() # Immediately returns an ActorFuture + >>> future.result() # Block until finished and result arrives + 1 + +``ActorFuture`` are similar to normal Dask ``Future`` objects, but not as fully +featured. They curently *only* support the ``result`` method and nothing else. +They don't currently work with any other Dask functions that expect futures, +like ``as_completed``, ``wait``, or ``client.gather``. They can't be placed +into additional submit or map calls to form dependencies. They communicate +their results immediately (rather than waiting for result to be called) and +cache the result on the future itself. + +Access Attributes +----------------- + +If you define an attribute at the class level then that attribute will be +accessible to the actor. + +.. code-block:: python + + class Counter: + n = 0 # Recall that we defined our class with `n` as a class variable + + ... + + >>> counter.n # Blocks until finished + 1 + +Attribute access blocks automatically. It's as though you called ``.result()``. + + +Execution on the Worker +----------------------- + +When you call a method on an actor, your arguments get serialized and sent +to the worker that owns the actor's object. If you do this from a worker this +communication is direct. If you do this from a Client then this will be direct +if the Client has direct access to the workers (create a client with +``Client(..., direct_to_workers=True)`` if direct connections are possible) or +by proxying through the scheduler if direct connections from the client to the +workers are not possible. + +The appropriate method of the Actor's object is then called in a separate +thread, the result captured, and then sent back to the calling side. Currently +workers have only a single thread for actors, but this may change in the +future. + +The result is sent back immediately to the calling side, and is not stored on +the worker with the actor. It is cached on the ``ActorFuture`` object. + + +Calling from coroutines and async/await +-------------------------- + +If you use actors within a coroutine or async/await function then actor methods +and attrbute access will return Tornado futures + +.. code-block:: python + + async def f(): + counter = await client.submit(Counter, actor=True) + + await counter.increment() + n = await counter.n + + +Coroutines and async/await on the Actor +--------------------------------------- + +If you define an ``async def`` function on the actor class then that method +will run on the Worker's event loop thread rather than a separate thread. + +.. code-block:: python + + def Waiter(object): + def __init__(self): + self.event = tornado.locks.Event() + + async def set(self): + self.event.set() + + async def wait(self): + await self.event.wait() + + waiter = client.submit(Waiter, actor=True).result() + waiter.wait().result() # waits until set, without consuming a worker thread + + +Performance +----------- + +Worker operations currently have about 1ms of latency, on top of any network +latency that may exist. However other activity in a worker may easily increase +these latencies if enough other activities are present. + + +Limitations +----------- + +Actors offer advanced capabilities, but with some cost: + +1. **No Resilience:** No effort is made to make actor workloads resilient to + worker failure. If the worker dies while holding an actor that actor is + lost forever. +2. **No Diagnostics:** Because the scheduler is not informed about actor + computations no diagnostics are available about these computations. +3. **No Load balancing:** Actors are allocated onto workers evenly, without + serious consideration given to avoiding communication. +4. **Experimental:** Actors are a new feature and subject to change without + warning diff --git a/docs/source/index.rst b/docs/source/index.rst index 2d1b6328aa6..211e94a4c3b 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -104,6 +104,7 @@ Contents :maxdepth: 1 :caption: Additional Features + actors adaptive asynchronous configuration From c5a0b2359cd1e35ae9f42bf35a811dd1f5e1cb0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Est=C3=A8ve?= Date: Wed, 8 Aug 2018 16:27:15 +0200 Subject: [PATCH 0053/1550] Fix tooltip (#2168) --- distributed/bokeh/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/bokeh/scheduler.py b/distributed/bokeh/scheduler.py index cd677e455c2..52e693c75d5 100644 --- a/distributed/bokeh/scheduler.py +++ b/distributed/bokeh/scheduler.py @@ -971,7 +971,7 @@ def __init__(self, scheduler, width=800, **kwargs): point_policy="follow_mouse", tooltips="""
- @host: + @worker: @memory_percent
""" From 8ae684816916e1b0e8131f12d1f54b89f7f4d07f Mon Sep 17 00:00:00 2001 From: Joe Hamman Date: Wed, 8 Aug 2018 16:47:37 -0600 Subject: [PATCH 0054/1550] fix scale / avoid returning coroutines (#2171) --- distributed/deploy/cluster.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/distributed/deploy/cluster.py b/distributed/deploy/cluster.py index 4265a151945..0c647d7adc9 100644 --- a/distributed/deploy/cluster.py +++ b/distributed/deploy/cluster.py @@ -93,12 +93,10 @@ def scale(self, n): if n >= len(self.scheduler.workers): self.scheduler.loop.add_callback(self.scale_up, n) else: - to_close = self.scheduler.retire_workers( - remove=False, - close_workers=True, - n=len(self.scheduler.workers) - n - ) + to_close = self.scheduler.workers_to_close( + n=len(self.scheduler.workers) - n) logger.debug("Closing workers: %s", to_close) + self.scheduler.loop.add_callback(self.scheduler.retire_workers, workers=to_close) self.scheduler.loop.add_callback(self.scale_down, to_close) def _widget_status(self): From 2cfaa2d179c14165eb77cc2abd1469d92acdb6bc Mon Sep 17 00:00:00 2001 From: Yu Feng Date: Wed, 8 Aug 2018 17:12:27 -0700 Subject: [PATCH 0055/1550] Clarify dask-worker --nprocs (#2173) --- distributed/cli/dask_worker.py | 4 ++-- docs/source/worker.rst | 22 +++++++++++++--------- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/distributed/cli/dask_worker.py b/distributed/cli/dask_worker.py index 05531a6f4a1..2c0676361ae 100755 --- a/distributed/cli/dask_worker.py +++ b/distributed/cli/dask_worker.py @@ -62,13 +62,13 @@ @click.option('--nthreads', type=int, default=0, help="Number of threads per process.") @click.option('--nprocs', type=int, default=1, - help="Number of worker processes. Defaults to one.") + help="Number of worker processes to launch. Defaults to one.") @click.option('--name', type=str, default='', help="A unique name for this worker like 'worker-1'. " "If used with --nprocs then the process number " "will be appended like name-0, name-1, name-2, ...") @click.option('--memory-limit', default='auto', - help="Bytes of memory that the worker can use. " + help="Bytes of memory per process that the worker can use. " "This can be an integer (bytes), " "float (fraction of total system memory), " "string (like 5GB or 5000M), " diff --git a/docs/source/worker.rst b/docs/source/worker.rst index deaa0243913..5ac95c197c9 100644 --- a/docs/source/worker.rst +++ b/docs/source/worker.rst @@ -65,15 +65,18 @@ communication costs and generally simplifies deployment. If your computations are mostly Python code and don't release the GIL then it is advisable to run ``dask-worker`` processes with many processes and one -thread per core:: +thread per process:: - $ dask-worker scheduler:8786 --nprocs 8 + $ dask-worker scheduler:8786 --nprocs 8 --nthreads 1 + +This will launch 8 worker processes each of which has its own +ThreadPoolExecutor of size 1. If your computations are external to Python and long-running and don't release the GIL then beware that while the computation is running the worker process will not be able to communicate to other workers or to the scheduler. This situation should be avoided. If you don't link in your own custom C/Fortran -code then this topic probably doesn't apply to you. +code then this topic probably doesn't apply. Command Line tool ----------------- @@ -93,9 +96,9 @@ are the available options:: hopefully be visible from the scheduler network. --nthreads INTEGER Number of threads per process. Defaults to number of cores - --nprocs INTEGER Number of worker processes. Defaults to one. + --nprocs INTEGER Number of worker processes to launch. Defaults to one. --name TEXT Alias - --memory-limit TEXT Number of bytes before spilling data to disk + --memory-limit TEXT Number of bytes (per worker process) before spilling data to disk --no-nanny --help Show this message and exit. @@ -143,12 +146,13 @@ Memory Management Workers are given a target memory limit to stay under with the command line ``--memory-limit`` keyword or the ``memory_limit=`` Python -keyword argument.:: +keyword argument, which sets the memory limit per worker processes launched +by dask-workder :: - $ dask-worker tcp://scheduler:port --memory-limit=auto # total available RAM - $ dask-worker tcp://scheduler:port --memory-limit=4e9 # four gigabytes + $ dask-worker tcp://scheduler:port --memory-limit=auto # total available RAM on the machine + $ dask-worker tcp://scheduler:port --memory-limit=4e9 # four gigabytes per worker process. -Workers use a few different policies to keep memory use beneath this limit: +Workers use a few different heuristics to keep memory use beneath this limit: 1. At 60% of memory load (as estimated by ``sizeof``), spill least recently used data to disk 2. At 70% of memory load, spill least recently used data to disk regardless of From a1d4a9d698d760496e06c0790cf4a617065ccbb3 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 9 Aug 2018 08:42:19 -0400 Subject: [PATCH 0056/1550] Concatenate all bytes of small messages in TCP comms (#2172) Previously we would write lengths to a socket, and then follow up with frames. This causes additional socket.send calls, which can be costly. Now for small messages we just bundle everything together and suffer a memory copy, but avoid the extra socket.send calls. --- distributed/comm/tcp.py | 35 ++++++++++++++++++++--------------- distributed/utils_test.py | 4 ++-- 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/distributed/comm/tcp.py b/distributed/comm/tcp.py index d1dcab7569c..a1785d907eb 100644 --- a/distributed/comm/tcp.py +++ b/distributed/comm/tcp.py @@ -221,21 +221,26 @@ def write(self, msg, serializers=None, on_error='message'): 'recipient': self._peer_addr}) try: - lengths = ([struct.pack('Q', len(frames))] + - [struct.pack('Q', nbytes(frame)) for frame in frames]) - stream.write(b''.join(lengths)) - - for frame in frames: - # Can't wait for the write() Future as it may be lost - # ("If write is called again before that Future has resolved, - # the previous future will be orphaned and will never resolve") - if not self._iostream_allows_memoryview: - frame = ensure_bytes(frame) - future = stream.write(frame) - bytes_since_last_yield += nbytes(frame) - if bytes_since_last_yield > 32e6: - yield future - bytes_since_last_yield = 0 + lengths = [nbytes(frame) for frame in frames] + length_bytes = ([struct.pack('Q', len(frames))] + + [struct.pack('Q', x) for x in lengths]) + if PY3 and sum(lengths) < 2**17: # 128kiB + b = b''.join(length_bytes + frames) # small enough, send in one go + stream.write(b) + else: + stream.write(b''.join(length_bytes)) # avoid large memcpy, send in many + + for frame in frames: + # Can't wait for the write() Future as it may be lost + # ("If write is called again before that Future has resolved, + # the previous future will be orphaned and will never resolve") + if not self._iostream_allows_memoryview: + frame = ensure_bytes(frame) + future = stream.write(frame) + bytes_since_last_yield += nbytes(frame) + if bytes_since_last_yield > 32e6: + yield future + bytes_since_last_yield = 0 except StreamClosedError as e: stream = None convert_stream_closed_error(self, e) diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 10e5329f4d4..289987b2641 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -985,7 +985,7 @@ def assert_can_connect(addr, timeout=None, connection_args=None): within the given *timeout*. """ if timeout is None: - timeout = 0.2 + timeout = 0.5 comm = yield connect(addr, timeout=timeout, connection_args=connection_args) comm.abort() @@ -998,7 +998,7 @@ def assert_cannot_connect(addr, timeout=None, connection_args=None, exception_cl within the given *timeout*. """ if timeout is None: - timeout = 0.2 + timeout = 0.5 with pytest.raises(exception_class): comm = yield connect(addr, timeout=timeout, connection_args=connection_args) From 03341fd025276231423ed9d25fa302101a429b50 Mon Sep 17 00:00:00 2001 From: Jacob Tomlinson Date: Thu, 9 Aug 2018 17:05:09 -0600 Subject: [PATCH 0057/1550] Add dashboard_link property (#2176) --- distributed/deploy/cluster.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/distributed/deploy/cluster.py b/distributed/deploy/cluster.py index 0c647d7adc9..3bc2b2d9124 100644 --- a/distributed/deploy/cluster.py +++ b/distributed/deploy/cluster.py @@ -72,6 +72,13 @@ def adapt(self, **kwargs): def scheduler_address(self): return self.scheduler.address + @property + def dashboard_link(self): + template = dask.config.get('distributed.dashboard.link') + host = self.scheduler.address.split('://')[1].split(':')[0] + port = self.scheduler.services['bokeh'].port + return template.format(host=host, port=port, **os.environ) + def scale(self, n): """ Scale cluster to n workers @@ -140,11 +147,7 @@ def _widget(self): layout = Layout(width='150px') if 'bokeh' in self.scheduler.services: - template = dask.config.get('distributed.dashboard.link') - - host = self.scheduler.address.split('://')[1].split(':')[0] - port = self.scheduler.services['bokeh'].port - link = template.format(host=host, port=port, **os.environ) + link = self.dashboard_link link = '

Dashboard: %s

\n' % (link, link) else: link = '' From cb10d6b10ad7800648b44ab3b194220d1e6fc5f8 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 10 Aug 2018 07:55:21 -0400 Subject: [PATCH 0058/1550] always offload to_frames (#2170) --- distributed/bokeh/tests/test_components.py | 4 ++-- distributed/comm/utils.py | 11 +++++------ distributed/tests/test_asyncprocess.py | 2 +- distributed/utils_test.py | 4 ++++ 4 files changed, 12 insertions(+), 9 deletions(-) diff --git a/distributed/bokeh/tests/test_components.py b/distributed/bokeh/tests/test_components.py index 89606eaf553..741c90c8d49 100644 --- a/distributed/bokeh/tests/test_components.py +++ b/distributed/bokeh/tests/test_components.py @@ -25,7 +25,7 @@ def test_basic(Component): c.update(messages) -@gen_cluster(client=True) +@gen_cluster(client=True, check_new_threads=False) def test_profile_plot(c, s, a, b): p = ProfilePlot() assert len(p.source.data['left']) <= 1 @@ -34,7 +34,7 @@ def test_profile_plot(c, s, a, b): assert len(p.source.data['left']) > 1 -@gen_cluster(client=True) +@gen_cluster(client=True, check_new_threads=False) def test_profile_time_plot(c, s, a, b): from bokeh.io import curdoc sp = ProfileTimePlot(s, doc=curdoc()) diff --git a/distributed/comm/utils.py b/distributed/comm/utils.py index 02677b9faba..d3e758a0741 100644 --- a/distributed/comm/utils.py +++ b/distributed/comm/utils.py @@ -7,9 +7,8 @@ from tornado import gen from .. import protocol -from ..compatibility import finalize -from ..sizeof import sizeof -from ..utils import get_ip, get_ipv6, mp_context, nbytes +from ..compatibility import finalize, PY3 +from ..utils import get_ip, get_ipv6, nbytes logger = logging.getLogger(__name__) @@ -20,7 +19,7 @@ FRAME_OFFLOAD_THRESHOLD = 10 * 1024 ** 2 # 10 MB -_offload_executor = ThreadPoolExecutor(max_workers=min(4, mp_context.cpu_count())) +_offload_executor = ThreadPoolExecutor(max_workers=1) finalize(_offload_executor, _offload_executor.shutdown) @@ -44,9 +43,9 @@ def _to_frames(): logger.exception(e) raise - if sizeof(msg) > FRAME_OFFLOAD_THRESHOLD: + if PY3: res = yield offload(_to_frames) - else: + else: # distributed/deploy/tests/test_adaptive.py::test_get_scale_up_kwargs fails on Py27. Don't know why res = _to_frames() raise gen.Return(res) diff --git a/distributed/tests/test_asyncprocess.py b/distributed/tests/test_asyncprocess.py index 9c8a8da2531..af5a07acea0 100644 --- a/distributed/tests/test_asyncprocess.py +++ b/distributed/tests/test_asyncprocess.py @@ -267,7 +267,7 @@ def test_child_main_thread(): yield proc.join() n_threads = q.get() main_name = q.get() - assert n_threads == 2 + assert n_threads <= 3 assert main_name == "MainThread" q.close() q._reader.close() diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 289987b2641..93ce281bd4f 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -40,6 +40,7 @@ from .client import default_client, _global_clients from .compatibility import PY3, Empty, WINDOWS +from .comm.utils import offload from .config import initialize_logging from .core import connect, rpc, CommClosedError from .metrics import time @@ -63,6 +64,9 @@ if isinstance(logger, logging.Logger)} +offload(lambda: None).result() # create thread during import + + @pytest.fixture(scope='session') def valid_python_script(tmpdir_factory): local_file = tmpdir_factory.mktemp('data').join('file.py') From 88a1c0a69eadd302388dde984a4e6e45851eee22 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sun, 19 Aug 2018 11:52:52 -0400 Subject: [PATCH 0059/1550] Warn if desired port is already in use (#2191) * Warn if desired port is already in use * remove diagnostics_port from LocalCluster * cleanup processes during tests --- distributed/bokeh/core.py | 12 +++++++-- distributed/bokeh/tests/test_worker_bokeh.py | 17 ------------ distributed/cli/dask_worker.py | 4 +-- distributed/deploy/local.py | 6 +++-- distributed/deploy/tests/test_local.py | 27 +++++++++++++++----- distributed/deploy/utils_test.py | 7 ++--- distributed/scheduler.py | 6 +++-- distributed/tests/py3_test_asyncio.py | 2 +- distributed/tests/test_core.py | 2 +- distributed/utils_test.py | 17 +++++++++++- 10 files changed, 62 insertions(+), 38 deletions(-) diff --git a/distributed/bokeh/core.py b/distributed/bokeh/core.py index 1cd8ed58a25..351901a3386 100644 --- a/distributed/bokeh/core.py +++ b/distributed/bokeh/core.py @@ -39,8 +39,16 @@ def listen(self, addr): self.server._tornado.add_handlers(r'.*', handlers) return - except (SystemExit, EnvironmentError): - port = 0 + except (SystemExit, EnvironmentError) as exc: + if port != 0: + if ("already in use" in str(exc) or # Unix/Mac + "Only one usage of" in str(exc)): # Windows + msg = ("Port %d is already in use. " + "Perhaps you already have a cluster running?" + % port) + else: + msg = "Failed to start diagnostics server on port %d. " % port + str(exc) + raise type(exc)(msg) if i == 4: raise diff --git a/distributed/bokeh/tests/test_worker_bokeh.py b/distributed/bokeh/tests/test_worker_bokeh.py index cdb97689778..01242b1c0b3 100644 --- a/distributed/bokeh/tests/test_worker_bokeh.py +++ b/distributed/bokeh/tests/test_worker_bokeh.py @@ -118,20 +118,3 @@ def test_CommunicatingStream(c, s, a, b): len(first(bb.outgoing.data.values()))) assert (len(first(aa.incoming.data.values())) and len(first(bb.incoming.data.values()))) - - -@pytest.mark.skipif(sys.version_info[0] == 2, - reason='https://github.com/bokeh/bokeh/issues/5494') -@gen_cluster(client=True) -def test_port_overlap(c, s, a, b): - # When the given port is unavailable, another one is chosen automatically - sa = BokehWorker(a) - sa.listen(57384) - sb = BokehWorker(b) - sb.listen(57384) - assert sa.port - assert sb.port - assert sa.port != sb.port - - sa.stop() - sb.stop() diff --git a/distributed/cli/dask_worker.py b/distributed/cli/dask_worker.py index 2c0676361ae..645869fc0eb 100755 --- a/distributed/cli/dask_worker.py +++ b/distributed/cli/dask_worker.py @@ -40,8 +40,8 @@ help="Serving computation port, defaults to random") @click.option('--nanny-port', type=int, default=0, help="Serving nanny port, defaults to random") -@click.option('--bokeh-port', type=int, default=8789, - help="Bokeh port, defaults to 8789") +@click.option('--bokeh-port', type=int, default=0, + help="Bokeh port, defaults to random port") @click.option('--bokeh/--no-bokeh', 'bokeh', default=True, show_default=True, required=False, help="Launch Bokeh Web UI") @click.option('--listen-address', type=str, default=None, diff --git a/distributed/deploy/local.py b/distributed/deploy/local.py index be063366b6d..a277f6d53da 100644 --- a/distributed/deploy/local.py +++ b/distributed/deploy/local.py @@ -79,7 +79,7 @@ class LocalCluster(Cluster): def __init__(self, n_workers=None, threads_per_worker=None, processes=True, loop=None, start=None, ip=None, scheduler_port=0, silence_logs=logging.WARN, diagnostics_port=8787, - services={}, worker_services={}, service_kwargs=None, + services=None, worker_services=None, service_kwargs=None, asynchronous=False, security=None, **worker_kwargs): if start is not None: msg = ("The start= parameter is deprecated. " @@ -93,6 +93,8 @@ def __init__(self, n_workers=None, threads_per_worker=None, processes=True, self.silence_logs = silence_logs self._asynchronous = asynchronous self.security = security + services = services or {} + worker_services = worker_services or {} if silence_logs: self._old_logging_level = silence_logging(level=silence_logs) if n_workers is None and threads_per_worker is None: @@ -116,7 +118,7 @@ def __init__(self, n_workers=None, threads_per_worker=None, processes=True, self._loop_runner = LoopRunner(loop=loop, asynchronous=asynchronous) self.loop = self._loop_runner.loop - if diagnostics_port is not None: + if diagnostics_port is not False and diagnostics_port is not None: try: from distributed.bokeh.scheduler import BokehScheduler from distributed.bokeh.worker import BokehWorker diff --git a/distributed/deploy/tests/test_local.py b/distributed/deploy/tests/test_local.py index 7f2f5874d41..25536222ce0 100644 --- a/distributed/deploy/tests/test_local.py +++ b/distributed/deploy/tests/test_local.py @@ -125,6 +125,7 @@ def test_transports(): @pytest.mark.skipif('sys.version_info[0] == 2', reason='') class LocalTest(ClusterTest, unittest.TestCase): Cluster = partial(LocalCluster, silence_logs=False, diagnostics_port=None) + kwargs = {'diagnostics_port': None} @pytest.mark.skipif('sys.version_info[0] == 2', reason='') @@ -142,6 +143,18 @@ def test_Client_solo(loop): assert c.cluster.status == 'closed' +@gen_test() +def test_duplicate_clients(): + c1 = yield Client(processes=False, silence_logs=False, diagnostics_port=9876) + with pytest.warns(Exception) as info: + yield Client(processes=False, silence_logs=False, diagnostics_port=9876) + + assert any(all(word in str(msg.message).lower() + for word in ['9876', 'running', 'already in use']) + for msg in info.list) + yield c1.close() + + def test_Client_kwargs(loop): with Client(loop=loop, processes=False, n_workers=2, silence_logs=False) as c: assert len(c.cluster.workers) == 2 @@ -150,8 +163,8 @@ def test_Client_kwargs(loop): def test_Client_twice(loop): - with Client(loop=loop, silence_logs=False) as c: - with Client(loop=loop, silence_logs=False) as f: + with Client(loop=loop, silence_logs=False, diagnostics_port=None) as c: + with Client(loop=loop, silence_logs=False, diagnostics_port=None) as f: assert c.cluster.scheduler.port != f.cluster.scheduler.port @@ -367,7 +380,7 @@ def test_logging(): def test_ipywidgets(loop): ipywidgets = pytest.importorskip('ipywidgets') with LocalCluster(scheduler_port=0, silence_logs=False, loop=loop, - diagnostics_port=0, processes=False) as cluster: + diagnostics_port=False, processes=False) as cluster: cluster._ipython_display_() box = cluster._cached_widget assert isinstance(box, ipywidgets.Widget) @@ -376,7 +389,7 @@ def test_ipywidgets(loop): def test_scale(loop): """ Directly calling scale both up and down works as expected """ with LocalCluster(scheduler_port=0, silence_logs=False, loop=loop, - diagnostics_port=0, processes=False, n_workers=0) as cluster: + diagnostics_port=False, processes=False, n_workers=0) as cluster: assert not cluster.scheduler.workers cluster.scale(3) @@ -397,7 +410,7 @@ def test_scale(loop): def test_adapt(loop): with LocalCluster(scheduler_port=0, silence_logs=False, loop=loop, - diagnostics_port=0, processes=False, n_workers=0) as cluster: + diagnostics_port=False, processes=False, n_workers=0) as cluster: cluster.adapt(minimum=0, maximum=2, interval='10ms') assert cluster._adaptive.minimum == 0 assert cluster._adaptive.maximum == 2 @@ -423,7 +436,7 @@ def test_adapt(loop): def test_adapt_then_manual(loop): """ We can revert from adaptive, back to manual """ with LocalCluster(scheduler_port=0, silence_logs=False, loop=loop, - diagnostics_port=0, processes=False, n_workers=8) as cluster: + diagnostics_port=False, processes=False, n_workers=8) as cluster: sleep(0.1) cluster.adapt(minimum=0, maximum=4, interval='10ms') @@ -454,7 +467,7 @@ def test_local_tls(loop): from distributed.utils_test import tls_only_security security = tls_only_security() with LocalCluster(scheduler_port=8786, silence_logs=False, security=security, - diagnostics_port=0, ip='tls://0.0.0.0', loop=loop) as c: + diagnostics_port=False, ip='tls://0.0.0.0', loop=loop) as c: sync(loop, assert_can_connect_from_everywhere_4, c.scheduler.port, connection_args=security.get_connection_args('client'), protocol='tls', timeout=3) diff --git a/distributed/deploy/utils_test.py b/distributed/deploy/utils_test.py index 375612edc26..9bc8cacccad 100644 --- a/distributed/deploy/utils_test.py +++ b/distributed/deploy/utils_test.py @@ -3,9 +3,10 @@ class ClusterTest(object): Cluster = None + kwargs = {} def setUp(self): - self.cluster = self.Cluster(2, scheduler_port=0) + self.cluster = self.Cluster(2, scheduler_port=0, **self.kwargs) self.client = Client(self.cluster.scheduler_address) def tearDown(self): @@ -33,10 +34,10 @@ def test_start_worker(self): assert c == a def test_context_manager(self): - with self.Cluster() as c: + with self.Cluster(**self.kwargs) as c: with Client(c) as e: assert e.ncores() def test_no_workers(self): - with self.Cluster(0, scheduler_port=0): + with self.Cluster(0, scheduler_port=0, **self.kwargs): pass diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 034ea3eeadb..e5396486761 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -12,6 +12,7 @@ import pickle import random import six +import warnings import psutil import sortedcontainers @@ -1033,8 +1034,9 @@ def start_services(self, listen_ip): service.listen((listen_ip, port)) self.services[k] = service except Exception as e: - logger.info("Could not launch service: %r", (k, port), - exc_info=True) + warnings.warn("\nCould not launch service '%s' on port %d. " % (k, port) + + "Got the following message:\n\n" + str(e), + stacklevel=3) def stop_services(self): for service in self.services.values(): diff --git a/distributed/tests/py3_test_asyncio.py b/distributed/tests/py3_test_asyncio.py index a57e1bd7f12..cf60b945b80 100644 --- a/distributed/tests/py3_test_asyncio.py +++ b/distributed/tests/py3_test_asyncio.py @@ -49,7 +49,7 @@ async def test_coro_test(): @coro_test async def test_asyncio_start_close(): - async with AioClient(processes=False) as c: + async with AioClient(processes=False, diagnostics_port=False) as c: assert c.status == 'running' # AioClient has installed its AioLoop shim. assert isinstance(IOLoop.current(instance=False), BaseAsyncIOLoop) diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index ce67817e6a6..66bd7d20058 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -401,7 +401,7 @@ def test_identity_inproc(): def test_ports(loop): - port = 9876 + port = 9877 server = Server({}, io_loop=loop) server.listen(port) try: diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 93ce281bd4f..a23f821f929 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -11,6 +11,7 @@ import logging import logging.config import os +import psutil import re import shutil import signal @@ -39,11 +40,12 @@ from tornado.ioloop import IOLoop from .client import default_client, _global_clients -from .compatibility import PY3, Empty, WINDOWS +from .compatibility import PY3, Empty, WINDOWS, PY2 from .comm.utils import offload from .config import initialize_logging from .core import connect, rpc, CommClosedError from .metrics import time +from .process import _cleanup_dangling from .proctitle import enable_proctitle_on_children from .security import Security from .utils import (ignoring, log_errors, mp_context, get_ip, get_ipv6, @@ -129,6 +131,18 @@ def start(): else: is_stopped.wait() del _global_workers[:] + + start = time() + while set(_global_clients): + sleep(0.1) + assert time() < start + 5 + + _cleanup_dangling() + + if PY2: # no forkserver, so no extra procs + for child in psutil.Process().children(recursive=True): + child.terminate() + _global_clients.clear() @@ -845,6 +859,7 @@ def coro(): thread = threading._active[tid] call_stacks = profile.call_stack(sys._current_frames()[tid]) assert False, (thread, call_stacks) + _cleanup_dangling() return result return test_func From dbb529ba6c0acb03fb915eec4c1ff846bb2c8c72 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sun, 19 Aug 2018 18:29:31 -0400 Subject: [PATCH 0060/1550] Add profile page for event loop thread (#2144) This adds a thread to each event loop that periodically polls the state of the event loop and maintains a time series of profile information. This data is served on the profile-server route of the bokeh servers --- .../setup_conda_environment.cmd | 2 +- distributed/actor.py | 2 +- distributed/bokeh/components.py | 134 ++++++++++++++++++ distributed/bokeh/scheduler.py | 16 ++- .../bokeh/tests/test_scheduler_bokeh.py | 13 +- distributed/bokeh/worker.py | 18 ++- distributed/cli/tests/test_dask_scheduler.py | 2 +- distributed/client.py | 5 +- distributed/comm/tests/test_comms.py | 7 +- distributed/core.py | 20 +++ distributed/deploy/local.py | 5 +- distributed/process.py | 5 +- distributed/profile.py | 95 ++++++++++++- distributed/scheduler.py | 10 +- distributed/tests/py3_test_pubsub.py | 4 +- distributed/tests/test_client.py | 6 +- distributed/tests/test_metrics.py | 2 +- distributed/tests/test_profile.py | 33 ++++- distributed/tests/test_queues.py | 8 +- distributed/tests/test_scheduler.py | 12 +- distributed/tests/test_utils.py | 6 +- distributed/utils.py | 5 +- distributed/utils_test.py | 5 +- 23 files changed, 369 insertions(+), 46 deletions(-) diff --git a/continuous_integration/setup_conda_environment.cmd b/continuous_integration/setup_conda_environment.cmd index f03441c336e..d8fc2445d27 100644 --- a/continuous_integration/setup_conda_environment.cmd +++ b/continuous_integration/setup_conda_environment.cmd @@ -37,7 +37,7 @@ call deactivate requests ^ toolz ^ tblib ^ - tornado=4.5 ^ + tornado=5 ^ zict ^ -c conda-forge diff --git a/distributed/actor.py b/distributed/actor.py index b97c79a6041..85f08dd9efd 100644 --- a/distributed/actor.py +++ b/distributed/actor.py @@ -119,7 +119,7 @@ def __getattr__(self, key): attr = getattr(self._cls, key) if self._future and not self._future.status == 'finished': - raise ValueError("Worker holding Actor was lost") + raise ValueError("Worker holding Actor was lost. Status: " + self._future.status) if callable(attr): @functools.wraps(attr) diff --git a/distributed/bokeh/components.py b/distributed/bokeh/components.py index fe82d52ce80..c2f37844e6e 100644 --- a/distributed/bokeh/components.py +++ b/distributed/bokeh/components.py @@ -13,6 +13,7 @@ from bokeh.plotting import figure import dask from tornado import gen +import toolz from ..diagnostics.progress_stream import nbytes_bar from .. import profile @@ -437,6 +438,7 @@ def cb(attr, old, new): self.ts_source = ColumnDataSource({'time': [], 'count': []}) self.ts_plot = figure(title='Activity over time', height=100, x_axis_type='datetime', active_drag='xbox_select', + y_range=[0, 1 / profile_interval], tools='xpan,xwheel_zoom,xbox_select,reset', **kwargs) self.ts_plot.line('time', 'count', source=self.ts_source) @@ -511,3 +513,135 @@ def cb(): self.doc().add_next_tick_callback(lambda: self.update(prof, metadata)) self.server.loop.add_callback(cb) + + +class ProfileServer(DashboardComponent): + """ Time plots of the current resource usage on the cluster + + This is two plots, one for CPU and Memory and another for Network I/O + """ + + def __init__(self, server, doc=None, **kwargs): + if doc is not None: + self.doc = weakref.ref(doc) + self.server = server + self.log = self.server.io_loop.profile + self.start = None + self.stop = None + self.ts = {'count': [], 'time': []} + self.state = profile.get_profile(self.log) + data = profile.plot_data(self.state, profile_interval) + self.states = data.pop('states') + self.source = ColumnDataSource(data=data) + + changing = [False] # avoid repeated changes from within callback + + def cb(attr, old, new): + if changing[0]: + return + with log_errors(): + try: + ind = new['1d']['indices'][0] + except IndexError: + return + data = profile.plot_data(self.states[ind], profile_interval) + del self.states[:] + self.states.extend(data.pop('states')) + changing[0] = True # don't recursively trigger callback + self.source.data.update(data) + self.source.selected = old + changing[0] = False + + self.source.on_change('selected', cb) + + self.profile_plot = figure(tools='tap', height=400, **kwargs) + r = self.profile_plot.quad('left', 'right', 'top', 'bottom', color='color', + line_color='black', source=self.source) + r.selection_glyph = None + r.nonselection_glyph = None + + hover = HoverTool( + point_policy="follow_mouse", + tooltips=""" +
+ Name:  + @name +
+
+ Filename:  + @filename +
+
+ Line number:  + @line_number +
+
+ Line:  + @line +
+
+ Time:  + @time +
+
+ Percentage:  + @percentage +
+ """ + ) + self.profile_plot.add_tools(hover) + + self.profile_plot.xaxis.visible = False + self.profile_plot.yaxis.visible = False + self.profile_plot.grid.visible = False + + self.ts_source = ColumnDataSource({'time': [], 'count': []}) + self.ts_plot = figure(title='Activity over time', height=100, + x_axis_type='datetime', active_drag='xbox_select', + y_range=[0, 1 / profile_interval], + tools='xpan,xwheel_zoom,xbox_select,reset', + **kwargs) + self.ts_plot.line('time', 'count', source=self.ts_source) + self.ts_plot.circle('time', 'count', source=self.ts_source, color=None, + selection_color='orange') + self.ts_plot.yaxis.visible = False + self.ts_plot.grid.visible = False + + def ts_change(attr, old, new): + with log_errors(): + selected = self.ts_source.selected['1d']['indices'] + if selected: + start = self.ts_source.data['time'][min(selected)] / 1000 + stop = self.ts_source.data['time'][max(selected)] / 1000 + self.start, self.stop = min(start, stop), max(start, stop) + else: + self.start = self.stop = None + self.trigger_update() + + self.ts_source.on_change('selected', ts_change) + + self.reset_button = Button(label="Reset", button_type="success") + self.reset_button.on_click(lambda: self.update(self.state)) + + self.update_button = Button(label="Update", button_type="success") + self.update_button.on_click(self.trigger_update) + + self.root = column(row(self.reset_button, self.update_button, + sizing_mode='scale_width'), + self.profile_plot, self.ts_plot, **kwargs) + + def update(self, state): + with log_errors(): + self.state = state + data = profile.plot_data(self.state, profile_interval) + self.states = data.pop('states') + self.source.data.update(data) + + def trigger_update(self): + self.state = profile.get_profile(self.log, start=self.start, stop=self.stop) + data = profile.plot_data(self.state, profile_interval) + self.states = data.pop('states') + self.source.data.update(data) + times = [t * 1000 for t, _ in self.log] + counts = list(toolz.pluck('count', toolz.pluck(1, self.log))) + self.ts_source.data.update({'time': times, 'count': counts}) diff --git a/distributed/bokeh/scheduler.py b/distributed/bokeh/scheduler.py index 52e693c75d5..27c39674871 100644 --- a/distributed/bokeh/scheduler.py +++ b/distributed/bokeh/scheduler.py @@ -30,7 +30,7 @@ np = False from . import components -from .components import DashboardComponent, ProfileTimePlot +from .components import (DashboardComponent, ProfileTimePlot, ProfileServer) from .core import BokehServer from .worker import SystemMonitor, counters_doc from .utils import transpose @@ -1164,6 +1164,18 @@ def profile_doc(scheduler, extra, doc): prof.trigger_update() +def profile_server_doc(scheduler, extra, doc): + with log_errors(): + doc.title = "Dask: Profile of Event Loop" + prof = ProfileServer(scheduler, sizing_mode='scale_width', doc=doc) + doc.add_root(prof.root) + doc.template = template + # doc.template_variables['active_page'] = 'profile' + doc.template_variables.update(extra) + + prof.trigger_update() + + class BokehScheduler(BokehServer): def __init__(self, scheduler, io_loop=None, prefix='', **kwargs): self.scheduler = scheduler @@ -1184,6 +1196,7 @@ def __init__(self, scheduler, io_loop=None, prefix='', **kwargs): tasks = Application(FunctionHandler(partial(tasks_doc, scheduler, self.extra))) status = Application(FunctionHandler(partial(status_doc, scheduler, self.extra))) profile = Application(FunctionHandler(partial(profile_doc, scheduler, self.extra))) + profile_server = Application(FunctionHandler(partial(profile_server_doc, scheduler, self.extra))) graph = Application(FunctionHandler(partial(graph_doc, scheduler, self.extra))) self.apps = { @@ -1195,6 +1208,7 @@ def __init__(self, scheduler, io_loop=None, prefix='', **kwargs): '/tasks': tasks, '/status': status, '/profile': profile, + '/profile-server': profile_server, '/graph': graph, } diff --git a/distributed/bokeh/tests/test_scheduler_bokeh.py b/distributed/bokeh/tests/test_scheduler_bokeh.py index 235d85d0c5c..bcd7eae476b 100644 --- a/distributed/bokeh/tests/test_scheduler_bokeh.py +++ b/distributed/bokeh/tests/test_scheduler_bokeh.py @@ -22,7 +22,7 @@ MemoryUse, CurrentLoad, ProcessingHistogram, NBytesHistogram, WorkerTable, - GraphPlot) + GraphPlot, ProfileServer) from distributed.bokeh import scheduler @@ -389,3 +389,14 @@ def test_GraphPlot_order(c, s, a, b): gp.update() assert gp.node_source.data['state'][gp.layout.index[y.key]] == 'erred' + + +@gen_cluster(client=True, + config={'distributed.worker.profile.interval': '10ms', + 'distributed.worker.profile.cycle': '50ms'}) +def test_profile_server(c, s, a, b): + ptp = ProfileServer(s) + ptp.trigger_update() + yield gen.sleep(0.200) + ptp.trigger_update() + assert 2 < len(ptp.ts_source.data['time']) < 20 diff --git a/distributed/bokeh/worker.py b/distributed/bokeh/worker.py index 7e577979a30..687da327d79 100644 --- a/distributed/bokeh/worker.py +++ b/distributed/bokeh/worker.py @@ -15,7 +15,7 @@ from bokeh.palettes import RdBu from toolz import merge, partition_all -from .components import DashboardComponent, ProfileTimePlot +from .components import DashboardComponent, ProfileTimePlot, ProfileServer from .core import BokehServer from .utils import transpose from ..compatibility import WINDOWS @@ -616,6 +616,18 @@ def profile_doc(server, extra, doc): doc.template_variables.update(extra) +def profile_server_doc(server, extra, doc): + with log_errors(): + doc.title = "Dask: Profile of Event Loop" + prof = ProfileServer(server, sizing_mode='scale_width', doc=doc) + doc.add_root(prof.root) + doc.template = template + # doc.template_variables['active_page'] = '' + doc.template_variables.update(extra) + + prof.trigger_update() + + class BokehWorker(BokehServer): def __init__(self, worker, io_loop=None, prefix='', **kwargs): self.worker = worker @@ -636,12 +648,14 @@ def __init__(self, worker, io_loop=None, prefix='', **kwargs): systemmonitor = Application(FunctionHandler(partial(systemmonitor_doc, worker, extra))) counters = Application(FunctionHandler(partial(counters_doc, worker, extra))) profile = Application(FunctionHandler(partial(profile_doc, worker, extra))) + profile_server = Application(FunctionHandler(partial(profile_server_doc, worker, extra))) self.apps = {'/main': main, '/counters': counters, '/crossfilter': crossfilter, '/system': systemmonitor, - '/profile': profile} + '/profile': profile, + '/profile-server': profile_server} self.loop = io_loop or worker.loop self.server = None diff --git a/distributed/cli/tests/test_dask_scheduler.py b/distributed/cli/tests/test_dask_scheduler.py index 456f44c55c4..ac6934f3dc9 100644 --- a/distributed/cli/tests/test_dask_scheduler.py +++ b/distributed/cli/tests/test_dask_scheduler.py @@ -50,7 +50,7 @@ def test_hostport(loop): def f(): yield [ # The scheduler's main port can't be contacted from the outside - assert_can_connect_locally_4(8978, 2.0), + assert_can_connect_locally_4(8978, 5.0), ] loop.run_sync(f) diff --git a/distributed/client.py b/distributed/client.py index f136b91bcd6..2ee1ffea537 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -1478,7 +1478,10 @@ def wait(k): self._send_to_scheduler({'op': 'report-key', 'key': key}) for key in response['keys']: - self.futures[key].reset() + try: + self.futures[key].reset() + except KeyError: # TODO: verify that this is safe + pass else: break diff --git a/distributed/comm/tests/test_comms.py b/distributed/comm/tests/test_comms.py index fc7a316c575..9eca015ebee 100644 --- a/distributed/comm/tests/test_comms.py +++ b/distributed/comm/tests/test_comms.py @@ -733,8 +733,11 @@ def handle_comm(comm): comm = yield connect(contact_addr) comm.write("foo") - yield gen.sleep(0.01) - assert comm.closed() + + start = time() + while not comm.closed(): + yield gen.sleep(0.01) + assert time() < start + 2 comm.close() comm.close() diff --git a/distributed/core.py b/distributed/core.py index aa8b77984d7..e1543555ce5 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -21,6 +21,7 @@ normalize_address, unparse_host_port, get_address_host_port) from .metrics import time +from . import profile from .system_monitor import SystemMonitor from .utils import (get_traceback, truncate_exception, ignoring, shutting_down, PeriodicCallback, parse_timedelta, has_keyword) @@ -115,6 +116,25 @@ def __init__(self, handlers, stream_handlers=None, connection_limit=512, self.io_loop = io_loop or IOLoop.current() self.loop = self.io_loop + if not hasattr(self.io_loop, 'profile'): + ref = weakref.ref(self.io_loop) + + if hasattr(self.io_loop, 'closing'): + def stop(): + loop = ref() + return loop is None or loop.closing + else: + def stop(): + loop = ref() + return loop is None or loop._closing + + self.io_loop.profile = profile.watch( + omit=('profile.py', 'selectors.py'), + interval=dask.config.get('distributed.worker.profile.interval'), + cycle=dask.config.get('distributed.worker.profile.cycle'), + stop=stop, + ) + # Statistics counters for various events with ignoring(ImportError): from .counter import Digest diff --git a/distributed/deploy/local.py b/distributed/deploy/local.py index a277f6d53da..b4e837a2a47 100644 --- a/distributed/deploy/local.py +++ b/distributed/deploy/local.py @@ -287,7 +287,10 @@ def close(self, timeout=20): else: sleep(0.01) del self.workers[:] - self._loop_runner.run_sync(self._close, callback_timeout=timeout) + try: + self._loop_runner.run_sync(self._close, callback_timeout=timeout) + except RuntimeError: # IOLoop is closed + pass self._loop_runner.stop() finally: self.status = 'closed' diff --git a/distributed/process.py b/distributed/process.py index e3a1e2ecbb4..38e3af62c3b 100644 --- a/distributed/process.py +++ b/distributed/process.py @@ -5,7 +5,6 @@ import logging import os import re -import sys import threading import weakref @@ -34,10 +33,10 @@ def _loop_add_callback(loop, func, *args): def _call_and_set_future(loop, future, func, *args, **kwargs): try: res = func(*args, **kwargs) - except Exception: + except Exception as exc: # Tornado futures are not thread-safe, need to # set_result() / set_exc_info() from the loop's thread - _loop_add_callback(loop, future.set_exc_info, sys.exc_info()) + _loop_add_callback(loop, future.set_exception, exc) else: _loop_add_callback(loop, future.set_result, res) diff --git a/distributed/profile.py b/distributed/profile.py index 46a4e441ebb..71ff6d18205 100644 --- a/distributed/profile.py +++ b/distributed/profile.py @@ -24,12 +24,18 @@ 'children': {...}}} } """ - - -from collections import defaultdict +import bisect +from collections import defaultdict, deque import linecache +import sys +import threading +from time import sleep -from .utils import format_time, color_of +import toolz + +from .metrics import time +from .utils import format_time, color_of, parse_timedelta +from .compatibility import get_thread_identity def identifier(frame): @@ -64,7 +70,7 @@ def info_frame(frame): 'line': line} -def process(frame, child, state, stop=None): +def process(frame, child, state, stop=None, omit=None): """ Add counts from a frame stack onto existing state This recursively adds counts to the existing state dictionary and creates @@ -84,9 +90,14 @@ def process(frame, child, state, stop=None): 'description': 'root', 'children': {'...'}} """ + if omit is not None and any(frame.f_code.co_filename.endswith(o) for o in omit): + return False + prev = frame.f_back if prev is not None and (stop is None or not prev.f_code.co_filename.endswith(stop)): state = process(prev, frame, state, stop=stop) + if state is False: + return False ident = identifier(frame) @@ -214,3 +225,77 @@ def traverse(state, start, stop, height): 'name': names, 'time': times, 'percentage': percentages} + + +def _watch(thread_id, log, interval='20ms', cycle='2s', omit=None, + stop=lambda: False): + interval = parse_timedelta(interval) + cycle = parse_timedelta(cycle) + + recent = create() + last = time() + + while not stop(): + if time() > last + cycle: + log.append((time(), recent)) + recent = create() + last = time() + try: + frame = sys._current_frames()[thread_id] + except KeyError: + return + + process(frame, None, recent, omit=omit) + sleep(interval) + + +def watch(thread_id=None, interval='20ms', cycle='2s', maxlen=1000, omit=None, + stop=lambda: False): + if thread_id is None: + thread_id = get_thread_identity() + + log = deque(maxlen=maxlen) + + thread = threading.Thread(target=_watch, + name='Profile', + kwargs={'thread_id': thread_id, + 'interval': interval, + 'cycle': cycle, + 'log': log, + 'omit': omit, + 'stop': stop}) + thread.daemon = True + thread.start() + + return log + + +def get_profile(history, recent=None, start=None, stop=None, key=None): + now = time() + if start is None: + istart = 0 + else: + istart = bisect.bisect_left(history, (start,)) + + if stop is None: + istop = None + else: + istop = bisect.bisect_right(history, (stop,)) + 1 + if istop >= len(history): + istop = None # include end + + if istart == 0 and istop is None: + history = list(history) + else: + iistop = len(history) if istop is None else istop + history = [history[i] for i in range(istart, iistop)] + + prof = merge(*toolz.pluck(1, history)) + + if not history: + return create() + + if recent: + prof = merge(prof, recent) + + return prof diff --git a/distributed/scheduler.py b/distributed/scheduler.py index e5396486761..f3a18612a25 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1376,7 +1376,7 @@ def update_graph(self, client=None, tasks=None, keys=None, stack.append(dep) for d in done: - del tasks[d] + tasks.pop(d, None) del dependencies[d] # Get or create task states @@ -2804,7 +2804,13 @@ def update_data(self, comm=None, who_has=None, nbytes=None, client=None, def report_on_key(self, key=None, ts=None, client=None): assert (key is None) + (ts is None) == 1, (key, ts) if ts is None: - ts = self.tasks[key] + try: + ts = self.tasks[key] + except KeyError: + self.report({'op': 'cancelled-key', + 'key': key}, + client=client) + return else: key = ts.key if ts.state == 'forgotten': diff --git a/distributed/tests/py3_test_pubsub.py b/distributed/tests/py3_test_pubsub.py index 33d9477e92e..ede8023801b 100644 --- a/distributed/tests/py3_test_pubsub.py +++ b/distributed/tests/py3_test_pubsub.py @@ -13,7 +13,7 @@ async def publish(): i = 0 while True: await gen.sleep(0.01) - pub.put(i) + pub._put(i) i += 1 def f(_): @@ -32,4 +32,4 @@ def f(_): # assert r == [x, x + 1, x + 2, x + 3, x + 4] assert len(r) == 5 - assert all(r[i] < r[i + 1] for i in range(0, 4)) + assert all(r[i] < r[i + 1] for i in range(0, 4)), r diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 0e099ccb678..dcdfafc884d 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -2049,7 +2049,6 @@ def test_waiting_data(c, s, a, b): @gen_cluster() def test_multi_client(s, a, b): c = yield Client((s.ip, s.port), asynchronous=True) - f = yield Client((s.ip, s.port), asynchronous=True) assert set(s.client_comms) == {c.id, f.id} @@ -2080,7 +2079,10 @@ def test_multi_client(s, a, b): yield f.close() - assert not s.tasks + start = time() + while s.tasks: + yield gen.sleep(0.01) + assert time() < start + 2, s.tasks def long_running_client_connection(address): diff --git a/distributed/tests/test_metrics.py b/distributed/tests/test_metrics.py index 290ac8d8a23..84b7c180993 100644 --- a/distributed/tests/test_metrics.py +++ b/distributed/tests/test_metrics.py @@ -35,7 +35,7 @@ def test_process_time(): t.start() t.join() dt = metrics.process_time() - start - assert dt >= 0.08 + assert dt >= 0.05 if PY3: # Sleep time not counted diff --git a/distributed/tests/test_profile.py b/distributed/tests/test_profile.py index 48aa5527823..2101d2a1669 100644 --- a/distributed/tests/test_profile.py +++ b/distributed/tests/test_profile.py @@ -1,11 +1,12 @@ import sys import time from toolz import first -from threading import Thread +import threading -from distributed.profile import (process, merge, create, call_stack, - identifier) from distributed.compatibility import get_thread_identity +from distributed import metrics +from distributed.profile import (process, merge, create, call_stack, + identifier, watch) def test_basic(): @@ -20,7 +21,7 @@ def test_f(): test_g() test_h() - thread = Thread(target=test_f) + thread = threading.Thread(target=test_f) thread.daemon = True thread.start() @@ -113,3 +114,27 @@ def test_identifier(): frame = sys._current_frames()[get_thread_identity()] assert identifier(frame) == identifier(frame) assert identifier(None) == identifier(None) + + +def test_watch(): + start = metrics.time() + + def stop(): + return metrics.time() > start + 0.500 + + start_threads = threading.active_count() + + log = watch(interval='10ms', cycle='50ms', stop=stop) + + start = metrics.time() # wait until thread starts up + while threading.active_count() <= start_threads: + assert metrics.time() < start + 2 + time.sleep(0.01) + + time.sleep(0.5) + assert 1 < len(log) < 10 + + start = metrics.time() + while threading.active_count() > start_threads: + assert metrics.time() < start + 2 + time.sleep(0.01) diff --git a/distributed/tests/test_queues.py b/distributed/tests/test_queues.py index a5a8d63d8fa..913434d6909 100644 --- a/distributed/tests/test_queues.py +++ b/distributed/tests/test_queues.py @@ -262,14 +262,14 @@ def test_timeout(c, s, a, b): start = time() with pytest.raises(gen.TimeoutError): - yield q.get(timeout=0.1) + yield q.get(timeout=0.3) stop = time() - assert 0.1 < stop - start < 2.0 + assert 0.2 < stop - start < 2.0 yield q.put(1) start = time() with pytest.raises(gen.TimeoutError): - yield q.put(2, timeout=0.1) + yield q.put(2, timeout=0.3) stop = time() - assert 0.05 < stop - start < 2.0 + assert 0.1 < stop - start < 2.0 diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 5a80faf8f83..259ca100f19 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -732,26 +732,28 @@ def test_file_descriptors(c, s): assert time() < start + 3 +@slow @nodebug @gen_cluster(client=True) def test_learn_occupancy(c, s, a, b): - futures = c.map(slowinc, range(1000), delay=0.01) + futures = c.map(slowinc, range(1000), delay=0.2) while sum(len(ts.who_has) for ts in s.tasks.values()) < 10: yield gen.sleep(0.01) - assert 1 < s.total_occupancy < 40 + assert 100 < s.total_occupancy < 1000 for w in [a, b]: - assert 1 < s.workers[w.address].occupancy < 20 + assert 50 < s.workers[w.address].occupancy < 700 +@slow @nodebug @gen_cluster(client=True) def test_learn_occupancy_2(c, s, a, b): - future = c.map(slowinc, range(1000), delay=0.1) + future = c.map(slowinc, range(1000), delay=0.2) while not any(ts.who_has for ts in s.tasks.values()): yield gen.sleep(0.01) - assert 50 < s.total_occupancy < 200 + assert 100 < s.total_occupancy < 1000 @gen_cluster(client=True) diff --git a/distributed/tests/test_utils.py b/distributed/tests/test_utils.py index 60b5105a78d..b0f75e37d5c 100644 --- a/distributed/tests/test_utils.py +++ b/distributed/tests/test_utils.py @@ -514,16 +514,16 @@ def test_parse_timedelta(): def test_all_exceptions_logging(): @gen.coroutine def throws(): - raise Exception('foo') + raise Exception('foo1234') with captured_logger('') as sio: try: yield All([throws() for _ in range(5)], - quiet_exceptions=Exception) + quiet_exceptions=Exception) except Exception: pass import gc; gc.collect() yield gen.sleep(0.1) - assert not sio.getvalue() + assert 'foo1234' not in sio.getvalue() diff --git a/distributed/utils.py b/distributed/utils.py index 193005dbc83..1e9b9d69ae0 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -358,12 +358,11 @@ def run_loop(loop=self._loop): finally: done_evt.set() - thread = threading.Thread(target=run_loop, - name="IO loop") + thread = threading.Thread(target=run_loop, name="IO loop") thread.daemon = True thread.start() - loop_evt.wait(timeout=1000) + loop_evt.wait(timeout=10) self._started = True actual_thread = in_thread[0] diff --git a/distributed/utils_test.py b/distributed/utils_test.py index a23f821f929..d65b1abcb80 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -637,7 +637,10 @@ def cluster(nworkers=2, nanny=False, worker_kwargs={}, active_rpc_timeout=1, else: client.close() - assert not ws + start = time() + while list(ws): + sleep(0.01) + assert time() < start + 1, 'Workers still around after one second' @gen.coroutine From 177dfb891089cc872bab34fb8369d75cae13bc0a Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sun, 19 Aug 2018 18:37:43 -0400 Subject: [PATCH 0061/1550] Use dispatch for dask serialization, also add sklearn, pytorch (#2175) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Use dispatch for dask serialization This allows people to register serialization functions using decorator syntax ```python @dask_serialize.register(np.ndarray) def serialize_array(x): ... @dask_deserialize.register(np.ndarray) def serialize_array(header, frames): ... ``` This also means that inheritance turns on by default (which is both good and bad) * add torch serialization ``` In [1]: import torchvision In [2]: from distributed.protocol import serialize, deserialize In [3]: import pickle In [4]: model = torchvision.models.resnet50() In [5]: %time header, frames = serialize(model) CPU times: user 46.9 ms, sys: 0 ns, total: 46.9 ms Wall time: 46.3 ms In [6]: %timeit serialize(model) 19.1 ms ± 1.14 ms per loop (mean ± std. dev. of 7 runs, 100 loops each) In [7]: %timeit deserialize(header, frames) # most of this seems to be torch.Tensor(numpy_array) 64.2 ms ± 1.11 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) In [8]: %time b = pickle.dumps(model) CPU times: user 77.1 ms, sys: 68.2 ms, total: 145 ms Wall time: 142 ms In [9]: %timeit pickle.dumps(model) 108 ms ± 583 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) In [10]: %timeit pickle.loads(b) 111 ms ± 1.9 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) ``` --- .travis.yml | 2 +- distributed/protocol/__init__.py | 37 +++- distributed/protocol/arrow.py | 36 ++-- distributed/protocol/h5py.py | 22 +-- distributed/protocol/keras.py | 12 +- distributed/protocol/netcdf4.py | 32 ++-- distributed/protocol/numpy.py | 7 +- distributed/protocol/serialize.py | 175 ++++++++++++++----- distributed/protocol/sparse.py | 11 +- distributed/protocol/tests/test_arrow.py | 8 - distributed/protocol/tests/test_serialize.py | 34 ++-- distributed/protocol/tests/test_sklearn.py | 19 ++ distributed/protocol/tests/test_torch.py | 33 ++++ distributed/protocol/torch.py | 56 ++++++ distributed/tests/test_client.py | 6 +- distributed/tests/test_nanny.py | 3 +- distributed/utils.py | 5 + distributed/utils_test.py | 2 +- docs/source/serialization.rst | 37 +++- 19 files changed, 364 insertions(+), 173 deletions(-) create mode 100644 distributed/protocol/tests/test_sklearn.py create mode 100644 distributed/protocol/tests/test_torch.py create mode 100644 distributed/protocol/torch.py diff --git a/.travis.yml b/.travis.yml index def0659e6a7..2dd96098c50 100644 --- a/.travis.yml +++ b/.travis.yml @@ -8,7 +8,7 @@ env: matrix: - PYTHON=2.7 TESTS=true PACKAGES="python-blosc futures faulthandler" - PYTHON=3.5.4 TESTS=true COVERAGE=true PACKAGES=python-blosc CRICK=true - - PYTHON=3.6 TESTS=true + - PYTHON=3.6 TESTS=true PACKAGES="scikit-learn" matrix: fast_finish: true diff --git a/distributed/protocol/__init__.py b/distributed/protocol/__init__.py index 01ac7e8464a..bd8f7331c8e 100644 --- a/distributed/protocol/__init__.py +++ b/distributed/protocol/__init__.py @@ -6,39 +6,60 @@ from .core import (dumps, loads, maybe_compress, decompress, msgpack) from .serialize import ( serialize, deserialize, nested_deserialize, Serialize, Serialized, - to_serialize, register_serialization, register_serialization_lazy, + to_serialize, register_serialization, dask_serialize, dask_deserialize, serialize_bytes, deserialize_bytes, serialize_bytelist, - register_serialization_family, + register_serialization_family, register_generic, ) from ..utils import ignoring -@partial(register_serialization_lazy, "numpy") +@dask_serialize.register_lazy("numpy") +@dask_deserialize.register_lazy("numpy") def _register_numpy(): from . import numpy -@partial(register_serialization_lazy, "h5py") +@dask_serialize.register_lazy("h5py") +@dask_deserialize.register_lazy("h5py") def _register_h5py(): from . import h5py -@partial(register_serialization_lazy, "netCDF4") +@dask_serialize.register_lazy("netCDF4") +@dask_deserialize.register_lazy("netCDF4") def _register_netcdf4(): from . import netcdf4 -@partial(register_serialization_lazy, "keras") +@dask_serialize.register_lazy("keras") +@dask_deserialize.register_lazy("keras") def _register_keras(): from . import keras -@partial(register_serialization_lazy, "sparse") +@dask_serialize.register_lazy("sparse") +@dask_deserialize.register_lazy("sparse") def _register_sparse(): from . import sparse -@partial(register_serialization_lazy, "pyarrow") +@dask_serialize.register_lazy("pyarrow") +@dask_deserialize.register_lazy("pyarrow") def _register_arrow(): from . import arrow + + +@dask_serialize.register_lazy("sklearn") +@dask_deserialize.register_lazy("sklearn") +def _register_sklearn(): + import sklearn.base + register_generic(sklearn.base.BaseEstimator) + + +@dask_serialize.register_lazy("torch") +@dask_deserialize.register_lazy("torch") +@dask_serialize.register_lazy("torchvision") +@dask_deserialize.register_lazy("torchvision") +def _register_torch(): + from . import torch diff --git a/distributed/protocol/arrow.py b/distributed/protocol/arrow.py index 87c5d05c99f..c90ba190cfc 100644 --- a/distributed/protocol/arrow.py +++ b/distributed/protocol/arrow.py @@ -1,12 +1,14 @@ from __future__ import print_function, division, absolute_import -from .serialize import register_serialization +from .serialize import dask_serialize, dask_deserialize +import pyarrow + +@dask_serialize.register(pyarrow.RecordBatch) def serialize_batch(batch): - import pyarrow as pa - sink = pa.BufferOutputStream() - writer = pa.RecordBatchStreamWriter(sink, batch.schema) + sink = pyarrow.BufferOutputStream() + writer = pyarrow.RecordBatchStreamWriter(sink, batch.schema) writer.write_batch(batch) writer.close() buf = sink.get_result() @@ -15,17 +17,17 @@ def serialize_batch(batch): return header, frames +@dask_deserialize.register(pyarrow.RecordBatch) def deserialize_batch(header, frames): - import pyarrow as pa blob = frames[0] - reader = pa.RecordBatchStreamReader(pa.BufferReader(blob)) + reader = pyarrow.RecordBatchStreamReader(pyarrow.BufferReader(blob)) return reader.read_next_batch() +@dask_serialize.register(pyarrow.Table) def serialize_table(tbl): - import pyarrow as pa - sink = pa.BufferOutputStream() - writer = pa.RecordBatchStreamWriter(sink, tbl.schema) + sink = pyarrow.BufferOutputStream() + writer = pyarrow.RecordBatchStreamWriter(sink, tbl.schema) writer.write_table(tbl) writer.close() buf = sink.get_result() @@ -34,20 +36,8 @@ def serialize_table(tbl): return header, frames +@dask_deserialize.register(pyarrow.Table) def deserialize_table(header, frames): - import pyarrow as pa blob = frames[0] - reader = pa.RecordBatchStreamReader(pa.BufferReader(blob)) + reader = pyarrow.RecordBatchStreamReader(pyarrow.BufferReader(blob)) return reader.read_all() - - -register_serialization( - 'pyarrow.lib.RecordBatch', - serialize_batch, - deserialize_batch -) -register_serialization( - 'pyarrow.lib.Table', - serialize_table, - deserialize_table -) diff --git a/distributed/protocol/h5py.py b/distributed/protocol/h5py.py index 81a83cffcea..9936920a759 100644 --- a/distributed/protocol/h5py.py +++ b/distributed/protocol/h5py.py @@ -1,39 +1,31 @@ from __future__ import print_function, division, absolute_import -from .serialize import register_serialization +from .serialize import dask_serialize, dask_deserialize +import h5py + +@dask_serialize.register(h5py.File) def serialize_h5py_file(f): if f.mode != 'r': raise ValueError("Can only serialize read-only h5py files") return {'filename': f.filename}, [] +@dask_deserialize.register(h5py.File) def deserialize_h5py_file(header, frames): import h5py return h5py.File(header['filename'], mode='r') -register_serialization('h5py._hl.files.File', - serialize_h5py_file, - deserialize_h5py_file) - - +@dask_serialize.register((h5py.Group, h5py.Dataset)) def serialize_h5py_dataset(x): header, _ = serialize_h5py_file(x.file) header['name'] = x.name return header, [] +@dask_deserialize.register((h5py.Group, h5py.Dataset)) def deserialize_h5py_dataset(header, frames): file = deserialize_h5py_file(header, frames) return file[header['name']] - - -register_serialization('h5py._hl.dataset.Dataset', - serialize_h5py_dataset, - deserialize_h5py_dataset) - -register_serialization('h5py._hl.group.Group', - serialize_h5py_dataset, - deserialize_h5py_dataset) diff --git a/distributed/protocol/keras.py b/distributed/protocol/keras.py index 2217380fc80..a5437f60e18 100644 --- a/distributed/protocol/keras.py +++ b/distributed/protocol/keras.py @@ -1,8 +1,11 @@ from __future__ import print_function, division, absolute_import -from .serialize import register_serialization, serialize, deserialize +from .serialize import dask_serialize, dask_deserialize, serialize, deserialize +import keras + +@dask_serialize.register(keras.Model) def serialize_keras_model(model): import keras if keras.__version__ < '1.2.0': @@ -18,6 +21,7 @@ def serialize_keras_model(model): return header, frames +@dask_deserialize.register(keras.Model) def deserialize_keras_model(header, frames): from keras.models import model_from_config n = 0 @@ -29,9 +33,3 @@ def deserialize_keras_model(header, frames): model = model_from_config(header) model.set_weights(weights) return model - - -for module in ['keras', 'tensorflow.contrib.keras.python.keras']: - for name in ['engine.training.Model', 'models.Model', 'models.Sequential']: - register_serialization('.'.join([module, name]), serialize_keras_model, - deserialize_keras_model) diff --git a/distributed/protocol/netcdf4.py b/distributed/protocol/netcdf4.py index 2154358e866..06711ad03cb 100644 --- a/distributed/protocol/netcdf4.py +++ b/distributed/protocol/netcdf4.py @@ -1,47 +1,39 @@ from __future__ import print_function, division, absolute_import -from .serialize import register_serialization, serialize, deserialize +from .serialize import dask_serialize, dask_deserialize, serialize, deserialize -try: - import netCDF4 - HAS_NETCDF4 = True -except ImportError: - HAS_NETCDF4 = False +import netCDF4 +@dask_serialize.register(netCDF4.Dataset) def serialize_netcdf4_dataset(ds): # assume mode is read-only return {'filename': ds.filepath()}, [] +@dask_deserialize.register(netCDF4.Dataset) def deserialize_netcdf4_dataset(header, frames): - import netCDF4 return netCDF4.Dataset(header['filename'], mode='r') -if HAS_NETCDF4: - register_serialization(netCDF4.Dataset, serialize_netcdf4_dataset, - deserialize_netcdf4_dataset) - - +@dask_serialize.register(netCDF4.Variable) def serialize_netcdf4_variable(x): header, _ = serialize(x.group()) header['parent-type'] = header['type'] + header['parent-type-serialized'] = header['type-serialized'] header['name'] = x.name return header, [] +@dask_deserialize.register(netCDF4.Variable) def deserialize_netcdf4_variable(header, frames): header['type'] = header['parent-type'] + header['type-serialized'] = header['parent-type-serialized'] parent = deserialize(header, frames) return parent.variables[header['name']] -if HAS_NETCDF4: - register_serialization(netCDF4.Variable, serialize_netcdf4_variable, - deserialize_netcdf4_variable) - - +@dask_serialize.register(netCDF4.Group) def serialize_netcdf4_group(g): parent = g while parent.parent: @@ -51,11 +43,7 @@ def serialize_netcdf4_group(g): return header, [] +@dask_deserialize.register(netCDF4.Group) def deserialize_netcdf4_group(header, frames): file = deserialize_netcdf4_dataset(header, frames) return file[header['path']] - - -if HAS_NETCDF4: - register_serialization(netCDF4.Group, serialize_netcdf4_group, - deserialize_netcdf4_group) diff --git a/distributed/protocol/numpy.py b/distributed/protocol/numpy.py index 5998294fd51..d6fc52a4e4d 100644 --- a/distributed/protocol/numpy.py +++ b/distributed/protocol/numpy.py @@ -11,7 +11,7 @@ blosc = False from .utils import frame_split_size, merge_frames -from .serialize import register_serialization +from .serialize import dask_serialize, dask_deserialize from . import pickle from ..utils import log_errors @@ -28,6 +28,7 @@ def itemsize(dt): return result +@dask_serialize.register(np.ndarray) def serialize_numpy_ndarray(x): if x.dtype.hasobject: header = {'pickle': True} @@ -88,6 +89,7 @@ def serialize_numpy_ndarray(x): return header, frames +@dask_deserialize.register(np.ndarray) def deserialize_numpy_ndarray(header, frames): with log_errors(): if len(frames) > 1: @@ -106,6 +108,3 @@ def deserialize_numpy_ndarray(header, frames): strides=header['strides']) return x - - -register_serialization(np.ndarray, serialize_numpy_ndarray, deserialize_numpy_ndarray) diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index cb3b802c504..3f0da622995 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -2,6 +2,7 @@ from functools import partial import traceback +import dask from dask.base import normalize_token try: from cytoolz import valmap, get_in @@ -17,41 +18,35 @@ from .utils import unpack_frames, pack_frames_prelude, frame_split_size -class_serializers = {} - lazy_registrations = {} +dask_serialize = dask.utils.Dispatch('dask_serialize') +dask_deserialize = dask.utils.Dispatch('dask_deserialize') + + def dask_dumps(x, context=None): """Serialise object using the class-based registry""" - typ = typename(type(x)) - if typ in class_serializers: - dumps, loads, has_context = class_serializers[typ] - if has_context: - header, frames = dumps(x, context=context) - else: - header, frames = dumps(x) - header['type'] = typ - header['serializer'] = 'dask' - return header, frames - elif _find_lazy_registration(typ): - return dask_dumps(x) # recurse + type_name = typename(type(x)) + try: + dumps = dask_serialize.dispatch(type(x)) + except TypeError: + raise NotImplementedError(type_name) + if has_keyword(dumps, 'context'): + header, frames = dumps(x, context=context) else: - raise NotImplementedError(typ) - + header, frames = dumps(x) -def dask_loads(header, frames): - typ = header['type'] + header['type'] = type_name + header['type-serialized'] = pickle.dumps(type(x)) + header['serializer'] = 'dask' + return header, frames - if typ not in class_serializers: - _find_lazy_registration(typ) - try: - dumps, loads, _ = class_serializers[typ] - except KeyError: - raise TypeError("Serialization for type %s not found" % typ) - else: - return loads(header, frames) +def dask_loads(header, frames): + typ = pickle.loads(header['type-serialized']) + loads = dask_deserialize.dispatch(typ) + return loads(header, frames) def pickle_dumps(x): @@ -406,20 +401,20 @@ def register_serialization(cls, serialize, deserialize): serialize deserialize """ - if isinstance(cls, type): - name = typename(cls) - elif isinstance(cls, str): - name = cls - class_serializers[name] = (serialize, - deserialize, - has_keyword(serialize, 'context')) + if isinstance(cls, str): + raise TypeError( + "Strings are no longer accepted for type registration. " + "Use dask_serialize.register_lazy instead" + ) + dask_serialize.register(cls)(serialize) + dask_deserialize.register(cls)(deserialize) def register_serialization_lazy(toplevel, func): """Register a registration function to be called if *toplevel* module is ever loaded. """ - lazy_registrations[toplevel] = func + raise Exception("Serialization registration has changed. See documentation") def typename(typ): @@ -434,33 +429,117 @@ def typename(typ): return typ.__module__ + '.' + typ.__name__ -def _find_lazy_registration(typename): - toplevel, _, _ = typename.partition('.') - if toplevel in lazy_registrations: - lazy_registrations.pop(toplevel)() - return True - else: - return False - - @partial(normalize_token.register, Serialized) def normalize_Serialized(o): return [o.header] + o.frames # for dask.base.tokenize # Teach serialize how to handle bytestrings +@dask_serialize.register((bytes, bytearray)) def _serialize_bytes(obj): header = {} # no special metadata frames = [obj] return header, frames +@dask_deserialize.register((bytes, bytearray)) def _deserialize_bytes(header, frames): return frames[0] -# NOTE: using the same exact serialization means a bytes object may be -# deserialized as bytearray or vice-versa... Not sure this is a problem -# in practice. -register_serialization(bytes, _serialize_bytes, _deserialize_bytes) -register_serialization(bytearray, _serialize_bytes, _deserialize_bytes) +######################### +# Descend into __dict__ # +######################### + + +def _is_msgpack_serializable(v): + typ = type(v) + return (typ is str or typ is int or typ is float or + isinstance(v, dict) and all(map(_is_msgpack_serializable, v.values())) + and all(typ is str for x in v.keys()) or + isinstance(v, (list, tuple)) and all(map(_is_msgpack_serializable, v))) + + +def serialize_object_with_dict(est): + header = { + 'serializer': 'dask', + 'type-serialized': pickle.dumps(type(est)), + 'simple': {}, + 'complex': {} + } + frames = [] + + if isinstance(est, dict): + d = est + else: + d = est.__dict__ + + for k, v in d.items(): + if _is_msgpack_serializable(v): + header['simple'][k] = v + else: + if isinstance(v, dict): + h, f = serialize_object_with_dict(v) + else: + h, f = serialize(v) + header['complex'][k] = {'header': h, + 'start': len(frames), + 'stop': len(frames) + len(f)} + frames += f + return header, frames + + +def deserialize_object_with_dict(header, frames): + cls = pickle.loads(header['type-serialized']) + if issubclass(cls, dict): + dd = obj = {} + else: + obj = object.__new__(cls) + dd = obj.__dict__ + dd.update(header['simple']) + for k, d in header['complex'].items(): + h = d['header'] + f = frames[d['start']: d['stop']] + v = deserialize(h, f) + dd[k] = v + + return obj + + +dask_deserialize.register(dict)(deserialize_object_with_dict) + + +def register_generic(cls): + """ Register dask_(de)serialize to traverse through __dict__ + + Normally when registering new classes for Dask's custom serialization you + need to manage headers and frames, which can be tedious. If all you want + to do is traverse through your object and apply serialize to all of your + object's attributes then this function may provide an easier path. + + This registers a class for the custom Dask serialization family. It + serializes it by traversing through its __dict__ of attributes and applying + ``serialize`` and ``deserialize`` recursively. It collects a set of frames + and keeps small attributes in the header. Deserialization reverses this + process. + + This is a good idea if the following hold: + + 1. Most of the bytes of your object are composed of data types that Dask's + custom serializtion already handles well, like Numpy arrays. + 2. Your object doesn't require any special constructor logic, other than + object.__new__(cls) + + Examples + -------- + >>> import sklearn.base + >>> from distributed.protocol import register_generic + >>> register_generic(sklearn.base.BaseEstimator) + + See Also + -------- + dask_serialize + dask_deserialize + """ + dask_serialize.register(cls)(serialize_object_with_dict) + dask_deserialize.register(cls)(deserialize_object_with_dict) diff --git a/distributed/protocol/sparse.py b/distributed/protocol/sparse.py index d8b7a42c2f8..ca0c6f38a79 100644 --- a/distributed/protocol/sparse.py +++ b/distributed/protocol/sparse.py @@ -1,8 +1,11 @@ from __future__ import print_function, division, absolute_import -from .serialize import register_serialization, serialize, deserialize +from .serialize import dask_serialize, dask_deserialize, serialize, deserialize +import sparse + +@dask_serialize.register(sparse.COO) def serialize_sparse(x): coords_header, coords_frames = serialize(x.coords) data_header, data_frames = serialize(x.data) @@ -14,8 +17,8 @@ def serialize_sparse(x): return header, coords_frames + data_frames +@dask_deserialize.register(sparse.COO) def deserialize_sparse(header, frames): - import sparse coords_frames = frames[:header['nframes'][0]] data_frames = frames[header['nframes'][0]:] @@ -26,7 +29,3 @@ def deserialize_sparse(header, frames): shape = header['shape'] return sparse.COO(coords, data, shape=shape) - - -register_serialization('sparse.core.COO', serialize_sparse, deserialize_sparse) # version 0.1 -register_serialization('sparse.coo.COO', serialize_sparse, deserialize_sparse) # version 0.2 diff --git a/distributed/protocol/tests/test_arrow.py b/distributed/protocol/tests/test_arrow.py index 6f014bae323..eca8de9f1a3 100644 --- a/distributed/protocol/tests/test_arrow.py +++ b/distributed/protocol/tests/test_arrow.py @@ -5,7 +5,6 @@ from distributed.utils_test import gen_cluster from distributed.protocol import deserialize, serialize -from distributed.protocol.serialize import class_serializers, typename df = pd.DataFrame({'A': list('abc'), 'B': [1,2,3]}) @@ -22,13 +21,6 @@ def test_roundtrip(obj): assert obj.equals(new_obj) -@pytest.mark.parametrize('obj', [batch, tbl], ids=["RecordBatch", "Table"]) -def test_typename(obj): - # The typename used to register the custom serialization is hardcoded - # ensure that the typename hasn't changed - assert typename(type(obj)) in class_serializers - - def echo(arg): return arg diff --git a/distributed/protocol/tests/test_serialize.py b/distributed/protocol/tests/test_serialize.py index dc7377385ea..4e9062cd044 100644 --- a/distributed/protocol/tests/test_serialize.py +++ b/distributed/protocol/tests/test_serialize.py @@ -311,28 +311,24 @@ def check(dask_worker): def test_context_specific_serialization_class(c, s, a, b): register_serialization(MyObject, my_dumps, my_loads) - try: - # Create the object on A, force communication to B - x = c.submit(MyObject, x=1, y=2, workers=a.address) - y = c.submit(lambda x: x, x, workers=b.address) + # Create the object on A, force communication to B + x = c.submit(MyObject, x=1, y=2, workers=a.address) + y = c.submit(lambda x: x, x, workers=b.address) - yield wait(y) + yield wait(y) - key = y.key + key = y.key - def check(dask_worker): - # Get the context from the object stored on B - my_obj = dask_worker.data[key] - return my_obj.context + def check(dask_worker): + # Get the context from the object stored on B + my_obj = dask_worker.data[key] + return my_obj.context - result = yield c.run(check, workers=[b.address]) - expected = {'sender': a.address, 'recipient': b.address} - assert result[b.address]['sender'] == a.address # see origin worker + result = yield c.run(check, workers=[b.address]) + expected = {'sender': a.address, 'recipient': b.address} + assert result[b.address]['sender'] == a.address # see origin worker - z = yield y # bring object to local process + z = yield y # bring object to local process - assert z.x == 1 and z.y == 2 - assert z.context['sender'] == b.address - finally: - from distributed.protocol.serialize import class_serializers, typename - del class_serializers[typename(MyObject)] + assert z.x == 1 and z.y == 2 + assert z.context['sender'] == b.address diff --git a/distributed/protocol/tests/test_sklearn.py b/distributed/protocol/tests/test_sklearn.py new file mode 100644 index 00000000000..4fa8aeb5369 --- /dev/null +++ b/distributed/protocol/tests/test_sklearn.py @@ -0,0 +1,19 @@ +import pytest +pytest.importorskip('sklearn') + +import sklearn.linear_model + +from distributed.protocol import serialize, deserialize + + +def test_basic(): + est = sklearn.linear_model.LinearRegression() + est.fit([[0, 0], [1, 1], [2, 2]], [0, 1, 2]) + + header, frames = serialize(est) + assert header['serializer'] == 'dask' + + est2 = deserialize(header, frames) + + inp = [[2, 3], [-1, 3]] + assert (est.predict(inp) == est2.predict(inp)).all() diff --git a/distributed/protocol/tests/test_torch.py b/distributed/protocol/tests/test_torch.py new file mode 100644 index 00000000000..d5d93b16f03 --- /dev/null +++ b/distributed/protocol/tests/test_torch.py @@ -0,0 +1,33 @@ +from distributed.protocol import serialize, deserialize +import pytest + +np = pytest.importorskip('numpy') +torch = pytest.importorskip('torch') + + +def test_tensor(): + x = np.arange(10) + t = torch.Tensor(x) + header, frames = serialize(t) + assert header['serializer'] == 'dask' + t2 = deserialize(header, frames) + assert (x == t2.numpy()).all() + + +def test_grad(): + x = np.arange(10) + t = torch.Tensor(x) + t.grad = torch.zeros_like(t) + 1 + + t2 = deserialize(*serialize(t)) + assert (t2.numpy() == x).all() + assert (t2.grad.numpy() == 1).all() + + +def test_resnet(): + torchvision = pytest.importorskip('torchvision') + model = torchvision.models.resnet.resnet18() + + header, frames = serialize(model) + model2 = deserialize(header, frames) + assert str(model) == str(model2) diff --git a/distributed/protocol/torch.py b/distributed/protocol/torch.py new file mode 100644 index 00000000000..c25b1549004 --- /dev/null +++ b/distributed/protocol/torch.py @@ -0,0 +1,56 @@ +from .serialize import (serialize, dask_serialize, dask_deserialize, + register_generic) + +import torch +import numpy as np + + +@dask_serialize.register(torch.Tensor) +def serialize_torch_Tensor(t): + header, frames = serialize(t.numpy()) + if t.grad is not None: + grad_header, grad_frames = serialize(t.grad.numpy()) + header['grad'] = {'header': grad_header, 'start': len(frames)} + frames += grad_frames + header['requires_grad'] = t.requires_grad + header['device'] = t.device.type + return header, frames + + +@dask_deserialize.register(torch.Tensor) +def deserialize_torch_Tensor(header, frames): + if header.get('grad', False): + i = header['grad']['start'] + frames, grad_frames = frames[:i], frames[i:] + grad = dask_deserialize.dispatch(np.ndarray)(header['grad']['header'], grad_frames) + else: + grad = None + + x = dask_deserialize.dispatch(np.ndarray)(header, frames) + if header['device'] == 'cpu': + t = torch.from_numpy(x) + if header['requires_grad']: + t = t.requires_grad_(True) + else: + t = torch.tensor(data=x, + device=header['device'], + requires_grad=header['requires_grad']) + if grad is not None: + t.grad = torch.from_numpy(grad) + return t + + +@dask_serialize.register(torch.nn.Parameter) +def serialize_torch_Parameters(p): + header, frames = serialize(p.detach()) + header['requires_grad'] = p.requires_grad + return header, frames + + +@dask_deserialize.register(torch.nn.Parameter) +def deserialize_torch_Parameters(header, frames): + t = dask_deserialize.dispatch(torch.Tensor)(header, frames) + return torch.nn.Parameter(data=t, requires_grad=header['requires_grad']) + + +register_generic(torch.nn.Module) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index dcdfafc884d..70e93b4534f 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -3860,13 +3860,13 @@ def test_scatter_compute_lose(c, s, a, b): yield a._close() + with pytest.raises(CancelledError): + yield wait(z) + assert x.status == 'cancelled' assert y.status == 'finished' assert z.status == 'cancelled' - with pytest.raises(CancelledError): - yield wait(z) - @gen_cluster(client=True) def test_scatter_compute_store_lose(c, s, a, b): diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index 9a9feabf114..4e192408f7b 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -265,7 +265,8 @@ def test_nanny_timeout(c, s, a): @gen_cluster(ncores=[('127.0.0.1', 1)], client=True, Worker=Nanny, - worker_kwargs={'memory_limit': 1e8}, timeout=20) + worker_kwargs={'memory_limit': 1e8}, timeout=20, + check_new_threads=False) def test_nanny_terminate(c, s, a): from time import sleep diff --git a/distributed/utils.py b/distributed/utils.py index 1e9b9d69ae0..3bebc65b24f 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -548,6 +548,7 @@ def key_split(s): try: from functools import lru_cache except ImportError: + lru_cache = False pass else: key_split = lru_cache(100000)(key_split) @@ -1405,6 +1406,10 @@ def has_keyword(func, keyword): return keyword in inspect.getargspec(func).args +if lru_cache: + has_keyword = lru_cache(1000)(has_keyword) + + # from bokeh.palettes import viridis # palette = viridis(18) palette = ['#440154', '#471669', '#472A79', '#433C84', '#3C4D8A', '#355D8C', diff --git a/distributed/utils_test.py b/distributed/utils_test.py index d65b1abcb80..900abab06ed 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -856,7 +856,7 @@ def coro(): break else: sleep(0.01) - if time() > start + 2: + if time() > start + 5: from distributed import profile tid = bad[0] thread = threading._active[tid] diff --git a/docs/source/serialization.rst b/docs/source/serialization.rst index cdb765c380a..5cfd1515684 100644 --- a/docs/source/serialization.rst +++ b/docs/source/serialization.rst @@ -1,6 +1,8 @@ Serialization ============= +.. currentmodule:: distributed.protocol.serialize + When we communicate data between computers we first convert that data into a sequence of bytes that can be communicated across the network. Choices made in serialization can affect performance and security. @@ -138,6 +140,10 @@ pickle internally in some cases. It should not be considered more secure. Extend ++++++ +.. autosummary:: + dask_serialize + dask_deserialize + As with serialization families in general, the Dask family in particular is *also* extensible. This is a good way to support custom serialization of a single type of object. The method is similar, you create serialize and @@ -150,26 +156,43 @@ register them with Dask. def __init__(self, name): self.name = name + from distributed.protocol import dask_serialize, dask_deserialize + + @dask_serialize.register(Human) def serialize(human: Human) -> Tuple[Dict, List[bytes]]: header = {} frames = [human.name.encode()] return header, frames + @dask_deserialize.register(Human) def deserialize(header: Dict, frames: List[bytes]) -> Human: return Human(frames[0].decode()) - from distributed.protocol.serialize import register_serialization - register_serialization(Human, serialize, deserialize) + +Traverse attributes ++++++++++++++++++++ + +.. autosummary:: + register_generic + +A common case is that your object just wraps Numpy arrays or other objects that +Dask already serializes well. For example, Scikit-Learn estimators mostly +surround Numpy arrays with a bit of extra metadata. In these cases you can +register your class for custom Dask serialization with the +``register_generic`` +function. API --- -.. currentmodule:: distributed.protocol.serialize - -.. autosummary:: register_serialization - serialize +.. autosummary:: serialize deserialize + dask_serialize + dask_deserialize + register_generic -.. autofunction:: register_serialization .. autofunction:: serialize .. autofunction:: deserialize +.. autofunction:: dask_serialize +.. autofunction:: dask_deserialize +.. autofunction:: register_generic From cdae7ca0130563a0a448e23203d763d12dbf10a5 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 20 Aug 2018 07:55:47 -0400 Subject: [PATCH 0062/1550] Handle corner cases with busy signal (#2182) There are a couple cases where in-flight dependencies can come in in atypical ways: 1. A request can be made and in-flight, for some reason the dependency state changes (it gets cancelled and then re-requested) then when the request comes in we try to transition to a memory state, buy may not have an established route to do so 2. A request can be made and in-flight, for some reason the dependency state changes, then the request comes in with a busy signal and we try to transition down to waiting, but may not have an established route to do so This currently doesn't have any tests, and transition_dep_waiting_memory is entirely uncovered. --- distributed/worker.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/distributed/worker.py b/distributed/worker.py index f614c8e0b34..a554307e42b 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1542,6 +1542,8 @@ def transition_dep_waiting_memory(self, dep, value=None): import pdb pdb.set_trace() raise + if value is not no_value and dep not in self.data: + self.put_key_in_memory(dep, value, transition=False) def transition(self, key, finish, **kwargs): start = self.task_state[key] @@ -1890,7 +1892,8 @@ def gather_dep(self, worker, dep, deps, total_nbytes, cause=None): if response['status'] == 'busy': self.log.append(('busy-gather', worker, deps)) for dep in deps: - self.transition_dep(dep, 'waiting') + if self.dep_state[dep] == 'flight': + self.transition_dep(dep, 'waiting') return if cause: From 12b8d22bd029cdcebbd25b7caaac61e2a43ee7d9 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 20 Aug 2018 15:09:58 -0400 Subject: [PATCH 0063/1550] Check self.dependencies when looking at tasks in memory (#2196) Previously during a check to see if tasks and their dependencies were already in memory we assumed that the dependencies were known to the client when it sent in the computation. However if the inputs were futures then this was not the case, and a KeyError was raised because only the scheduler knew where things were. Now we properly check both the input dependencies and self.dependencies and are robust to the information being in either place. Fixes #2187 --- distributed/scheduler.py | 8 ++++++-- distributed/tests/test_scheduler.py | 22 ++++++++++++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index f3a18612a25..afb51ff4312 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1369,7 +1369,11 @@ def update_graph(self, client=None, tasks=None, keys=None, while stack: # remove unnecessary dependencies key = stack.pop() ts = self.tasks[key] - for dep in dependencies[key]: + try: + deps = dependencies[key] + except KeyError: + deps = self.dependencies[key] + for dep in deps: if all(d in done for d in dependents[dep]): if dep in self.tasks: done.add(dep) @@ -1377,7 +1381,7 @@ def update_graph(self, client=None, tasks=None, keys=None, for d in done: tasks.pop(d, None) - del dependencies[d] + dependencies.pop(d, None) # Get or create task states stack = list(keys) diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 259ca100f19..816283537e5 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -1362,3 +1362,25 @@ def test_resources_reset_after_cancelled_task(c, s, w): assert w.available_resources == {'A': 1} yield c.submit(inc, 1, resources={'A': 1}) + + +@gen_cluster(client=True) +def test_gh2187(c, s, a, b): + def foo(): + return 'foo' + + def bar(x): + return x + 'bar' + + def baz(x): + sleep(0.1) + return x + 'baz' + + x = c.submit(foo, key='x') + y = c.submit(bar, x, key='y') + yield y + z = c.submit(baz, y, key='z') + del y + yield gen.sleep(0.1) + f = c.submit(bar, x, key='y') + yield f From 2ec428ae5652f7d068baeb18223fb8a04ab8804e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Est=C3=A8ve?= Date: Tue, 21 Aug 2018 13:52:13 +0200 Subject: [PATCH 0064/1550] Add ability to log additional custom metrics from each worker (#2169) * Add ability to log additional custom metrics from each worker * Refactor ws.info to ws.metrics This removes old state from the info attribute of a WorkerState and renames it to metrics. We also remove the legacy worker_info dictionary from code and tests --- distributed/bokeh/scheduler.py | 45 ++++-- distributed/bokeh/templates/worker-table.html | 9 +- .../bokeh/tests/test_scheduler_bokeh.py | 142 ++++++++++++++++++ distributed/bokeh/tests/test_worker_bokeh.py | 6 +- distributed/deploy/adaptive.py | 4 +- distributed/deploy/local.py | 2 +- distributed/scheduler.py | 58 ++++--- distributed/tests/test_client.py | 4 +- distributed/tests/test_nanny.py | 8 +- distributed/tests/test_scheduler.py | 4 +- distributed/tests/test_worker.py | 11 +- distributed/worker.py | 54 ++++--- docs/source/web.rst | 24 +++ 13 files changed, 295 insertions(+), 76 deletions(-) diff --git a/distributed/bokeh/scheduler.py b/distributed/bokeh/scheduler.py index 27c39674871..a282e012574 100644 --- a/distributed/bokeh/scheduler.py +++ b/distributed/bokeh/scheduler.py @@ -328,13 +328,13 @@ def update(self): else: processing_color.append('blue') - nbytes = [ws.info['memory'] for ws in workers] + nbytes = [ws.metrics['memory'] for ws in workers] nbytes_text = [format_bytes(nb) for nb in nbytes] nbytes_color = [] max_limit = 0 for ws, nb in zip(workers, nbytes): try: - limit = self.scheduler.worker_info[ws.address]['memory_limit'] + limit = self.scheduler.workers[ws.address].memory_limit except KeyError: limit = 16e9 if limit > max_limit: @@ -927,12 +927,17 @@ class WorkerTable(DashboardComponent): This is two plots, a text-based table for each host and a thin horizontal plot laying out hosts by their current memory use. """ + excluded_names = {'executing', 'in_flight', 'in_memory', 'ready', 'time'} def __init__(self, scheduler, width=800, **kwargs): self.scheduler = scheduler self.names = ['worker', 'ncores', 'cpu', 'memory', 'memory_limit', 'memory_percent', 'num_fds', 'read_bytes', 'write_bytes', 'cpu_fraction'] + workers = self.scheduler.workers.values() + self.extra_names = sorted({m for ws in workers + for m in ws.metrics + if m not in self.names} - self.excluded_names) table_names = ['worker', 'ncores', 'cpu', 'memory', 'memory_limit', 'memory_percent', 'num_fds', 'read_bytes', @@ -967,6 +972,17 @@ def __init__(self, scheduler, width=800, **kwargs): if name in formatters: table.columns[table_names.index(name)].formatter = formatters[name] + extra_names = ['worker'] + self.extra_names + extra_columns = {name: TableColumn(field=name, + title=name.replace('_percent', '%')) + for name in extra_names} + + extra_table = DataTable( + source=self.source, + columns=[extra_columns[n] for n in extra_names], + reorderable=True, sortable=True, width=width, **dt_kwargs + ) + hover = HoverTool( point_policy="follow_mouse", tooltips=""" @@ -1015,20 +1031,25 @@ def __init__(self, scheduler, width=800, **kwargs): else: sizing_mode = {} - self.root = column(cpu_plot, mem_plot, table, id='bk-worker-table', **sizing_mode) + components = [cpu_plot, mem_plot, table] + if self.extra_names: + components.append(extra_table) + + self.root = column(*components, id='bk-worker-table', **sizing_mode) def update(self): - data = {name: [] for name in self.names} - for worker, info in sorted(self.scheduler.worker_info.items()): - for name in self.names: - data[name].append(info.get(name, None)) - data['worker'][-1] = worker - if info['memory_limit']: - data['memory_percent'][-1] = info['memory'] / info['memory_limit'] + data = {name: [] for name in self.names + self.extra_names} + for addr, ws in sorted(self.scheduler.workers.items()): + for name in self.names + self.extra_names: + data[name].append(ws.metrics.get(name, None)) + data['worker'][-1] = ws.address + if ws.memory_limit: + data['memory_percent'][-1] = ws.metrics['memory'] / ws.memory_limit else: data['memory_percent'][-1] = '' - data['cpu'][-1] = info['cpu'] / 100.0 - data['cpu_fraction'][-1] = info['cpu'] / 100.0 / info['ncores'] + data['memory_limit'][-1] = ws.memory_limit + data['cpu'][-1] = ws.metrics['cpu'] / 100.0 + data['cpu_fraction'][-1] = ws.metrics['cpu'] / 100.0 / ws.ncores self.source.data.update(data) diff --git a/distributed/bokeh/templates/worker-table.html b/distributed/bokeh/templates/worker-table.html index 1d1768cec11..90b59c08c54 100644 --- a/distributed/bokeh/templates/worker-table.html +++ b/distributed/bokeh/templates/worker-table.html @@ -11,17 +11,16 @@ Logs {% for ws in worker_list %} - {% set wi = worker_info[ws.address] %} {{ws.address}} {{ ws.ncores }} - {{ format_bytes(wi['memory_limit']) }} - + {{ format_bytes(ws.memory_limit) }} + {{ format_time(ws.occupancy) }} {{ len(ws.processing) }} {{ len(ws.has_what) }} - {% if 'bokeh' in wi['services'] %} - bokeh + {% if 'bokeh' in ws.services %} + bokeh {% else %} {% end %} diff --git a/distributed/bokeh/tests/test_scheduler_bokeh.py b/distributed/bokeh/tests/test_scheduler_bokeh.py index bcd7eae476b..262a5e8a67f 100644 --- a/distributed/bokeh/tests/test_scheduler_bokeh.py +++ b/distributed/bokeh/tests/test_scheduler_bokeh.py @@ -286,6 +286,148 @@ def test_WorkerTable(c, s, a, b): assert all(len(v) == 2 for v in wt.source.data.values()) +@gen_cluster(client=True) +def test_WorkerTable_custom_metrics(c, s, a, b): + def metric_port(worker): + return worker.port + + def metric_address(worker): + return worker.address + + metrics = {'metric_port': metric_port, + 'metric_address': metric_address} + + for w in [a, b]: + for name, func in metrics.items(): + w.metrics[name] = func + + while not all('metric_port' in s.workers[w.address].metrics for w in [a, b]): + yield gen.sleep(0.01) + + for w in [a, b]: + assert s.workers[w.address].metrics['metric_port'] == w.port + assert s.workers[w.address].metrics['metric_address'] == w.address + + wt = WorkerTable(s) + wt.update() + data = wt.source.data + + for name in metrics: + assert name in data + + assert all(data.values()) + assert all(len(v) == 2 for v in data.values()) + my_index = data['worker'].index(a.address), data['worker'].index(b.address) + assert [data['metric_port'][i] for i in my_index] == [a.port, b.port] + assert [data['metric_address'][i] for i in my_index] == [a.address, b.address] + + +@gen_cluster(client=True) +def test_WorkerTable_different_metrics(c, s, a, b): + def metric_port(worker): + return worker.port + + a.metrics['metric_a'] = metric_port + b.metrics['metric_b'] = metric_port + + while not ('metric_a' in s.workers[a.address].metrics and + 'metric_b' in s.workers[b.address].metrics): + yield gen.sleep(0.01) + + assert s.workers[a.address].metrics['metric_a'] == a.port + assert s.workers[b.address].metrics['metric_b'] == b.port + + wt = WorkerTable(s) + wt.update() + data = wt.source.data + + assert 'metric_a' in data + assert 'metric_b' in data + assert all(data.values()) + assert all(len(v) == 2 for v in data.values()) + my_index = data['worker'].index(a.address), data['worker'].index(b.address) + assert [data['metric_a'][i] for i in my_index] == [a.port, None] + assert [data['metric_b'][i] for i in my_index] == [None, b.port] + + +@gen_cluster(client=True) +def test_WorkerTable_metrics_with_different_metric_2(c, s, a, b): + def metric_port(worker): + return worker.port + + a.metrics['metric_a'] = metric_port + + while 'metric_a' not in s.workers[a.address].metrics: + yield gen.sleep(0.01) + + wt = WorkerTable(s) + wt.update() + data = wt.source.data + + assert 'metric_a' in data + assert all(data.values()) + assert all(len(v) == 2 for v in data.values()) + my_index = data['worker'].index(a.address), data['worker'].index(b.address) + assert [data['metric_a'][i] for i in my_index] == [a.port, None] + + +@gen_cluster(client=True, worker_kwargs={'metrics': {'my_port': lambda w: w.port}}) +def test_WorkerTable_add_and_remove_metrics(c, s, a, b): + def metric_port(worker): + return worker.port + + a.metrics['metric_a'] = metric_port + a.metrics['metric_b'] = metric_port + + while not ('metric_a' in s.workers[a.address].metrics and + 'metric_b' in s.workers[b.address].metrics): + yield gen.sleep(0.01) + + assert s.workers[a.address].metrics['metric_a'] == a.port + assert s.workers[b.address].metrics['metric_b'] == b.port + + wt = WorkerTable(s) + wt.update() + assert 'metric_a' in wt.source.data + assert 'metric_b' in wt.source.data + + # Remove 'metric_b' from worker b + del b.metrics['metric_b'] + + while 'metric_b' in s.workers[b.address].metrics: + yield gen.sleep(0.01) + + wt = WorkerTable(s) + wt.update() + assert 'metric_a' in wt.source.data + + del a.metrics['metric_a'] + + while 'metric_a' in s.workers[a.address].metrics: + yield gen.sleep(0.01) + + wt = WorkerTable(s) + wt.update() + assert 'metric_a' not in wt.source.data + + +@gen_cluster(client=True) +def test_WorkerTable_custom_metric_overlap_with_core_metric(c, s, a, b): + def metric(worker): + return -999 + + a.metrics['executing'] = metric + a.metrics['cpu'] = metric + a.metrics['metric'] = metric + + while 'metric' not in s.workers[a.address].metrics: + yield gen.sleep(0.01) + + assert s.workers[a.address].metrics['executing'] != -999 + assert s.workers[a.address].metrics['cpu'] != -999 + assert s.workers[a.address].metrics['metric'] == -999 + + @gen_cluster(client=True) def test_GraphPlot(c, s, a, b): gp = GraphPlot(s) diff --git a/distributed/bokeh/tests/test_worker_bokeh.py b/distributed/bokeh/tests/test_worker_bokeh.py index 01242b1c0b3..e991e53e369 100644 --- a/distributed/bokeh/tests/test_worker_bokeh.py +++ b/distributed/bokeh/tests/test_worker_bokeh.py @@ -23,8 +23,8 @@ @gen_cluster(client=True, worker_kwargs={'services': {('bokeh', 0): BokehWorker}}) def test_simple(c, s, a, b): - assert s.worker_info[a.address]['services'] == {'bokeh': a.services['bokeh'].port} - assert s.worker_info[b.address]['services'] == {'bokeh': b.services['bokeh'].port} + assert s.workers[a.address].services == {'bokeh': a.services['bokeh'].port} + assert s.workers[b.address].services == {'bokeh': b.services['bokeh'].port} future = c.submit(sleep, 1) yield gen.sleep(0.1) @@ -39,7 +39,7 @@ def test_simple(c, s, a, b): @gen_cluster(client=True, worker_kwargs={'services': {('bokeh', 0): (BokehWorker, {})}}) def test_services_kwargs(c, s, a, b): - assert s.worker_info[a.address]['services'] == {'bokeh': a.services['bokeh'].port} + assert s.workers[a.address].services == {'bokeh': a.services['bokeh'].port} assert isinstance(a.services['bokeh'], BokehWorker) diff --git a/distributed/deploy/adaptive.py b/distributed/deploy/adaptive.py index 014373ac7a8..62d308c6e22 100644 --- a/distributed/deploy/adaptive.py +++ b/distributed/deploy/adaptive.py @@ -144,8 +144,8 @@ def needs_memory(self): Returns ``True`` if the required bytes in distributed memory is some factor larger than the actual distributed memory available. """ - limit_bytes = {w: self.scheduler.worker_info[w]['memory_limit'] - for w in self.scheduler.worker_info} + limit_bytes = {addr: ws.memory_limit + for addr, ws in self.scheduler.workers.items()} worker_bytes = [ws.nbytes for ws in self.scheduler.workers.values()] limit = sum(limit_bytes.values()) diff --git a/distributed/deploy/local.py b/distributed/deploy/local.py index b4e837a2a47..d3feedb578f 100644 --- a/distributed/deploy/local.py +++ b/distributed/deploy/local.py @@ -209,7 +209,7 @@ def _start_worker(self, death_timeout=60, **kwargs): self.workers.append(w) - while w.status != 'closed' and w.worker_address not in self.scheduler.worker_info: + while w.status != 'closed' and w.worker_address not in self.scheduler.workers: yield gen.sleep(0.01) if w.status == 'closed' and self.scheduler.status == 'running': diff --git a/distributed/scheduler.py b/distributed/scheduler.py index afb51ff4312..9f2d961b2ff 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -196,15 +196,19 @@ class WorkerState(object): 'address', 'has_what', 'info', + 'local_directory', 'memory_limit', + 'metrics', 'name', 'nbytes', 'ncores', 'occupancy', + 'pid', 'processing', 'resources', 'time_delay', 'used_resources', + 'services', 'status', 'last_seen', 'actors', @@ -218,17 +222,14 @@ def __init__(self, worker, ncores, memory_limit, name=None): self.nbytes = 0 self.ncores = ncores self.occupancy = 0 + self.pid = 0 self.processing = dict() self.resources = {} self.used_resources = {} self.last_seen = 0 + self.services = {} self.actors = set() - - self.info = {'name': name, - 'memory_limit': memory_limit, - 'host': self.host, - 'resources': self.resources, - 'ncores': self.ncores} # for backwards compatibility + self.metrics = {} @property def host(self): @@ -241,6 +242,21 @@ def __repr__(self): def __str__(self): return self.address + def identity(self): + return { + 'type': 'Worker', + 'id': self.name, + 'host': self.host, + 'resources': self.resources, + 'local_directory': self.local_directory, + 'name': self.name, + 'ncores': self.ncores, + 'memory_limit': self.memory_limit, + 'last_seen': self.last_seen, + 'services': self.services, + 'metrics': self.metrics + } + class TaskState(object): """ @@ -856,7 +872,7 @@ def __init__( ('worker_resources', 'resources', None), ('used_resources', 'used_resources', None), ('occupancy', 'occupancy', None), - ('worker_info', 'info', None), + ('worker_info', 'metrics', None), ('processing', 'processing', _legacy_task_key_dict), ('has_what', 'has_what', _legacy_task_key_set)]: func = operator.attrgetter(new_attr) @@ -996,7 +1012,8 @@ def identity(self, comm=None): 'id': str(self.id), 'address': self.address, 'services': {key: v.port for (key, v) in self.services.items()}, - 'workers': dict(self.worker_info)} + 'workers': {worker.address: worker.identity() + for worker in self.workers.values()}} return d def get_worker_service_addr(self, worker, service_name): @@ -1005,11 +1022,11 @@ def get_worker_service_addr(self, worker, service_name): Returns None if the service doesn't exist. """ ws = self.workers[worker] - port = ws.info['services'].get(service_name) + port = ws.services.get(service_name) if port is None: return None else: - return ws.info['host'], port + return ws.host, port def start_services(self, listen_ip): for k, v in self.service_specs.items(): @@ -1190,14 +1207,14 @@ def _setup_logging(self): @gen.coroutine def heartbeat_worker(self, comm=None, address=None, resolve_address=True, - now=None, resources=None, host_info=None, **info): + now=None, resources=None, host_info=None, metrics=None): address = self.coerce_address(address, resolve_address) address = normalize_address(address) host = get_address_host(address) local_now = time() now = now or time() - info = info or {} + metrics = metrics or {} host_info = host_info or {} self.host_info[host]['last-seen'] = local_now @@ -1208,8 +1225,8 @@ def heartbeat_worker(self, comm=None, address=None, resolve_address=True, ws.last_seen = time() - if info: - ws.info.update(info) + if metrics: + ws.metrics = metrics if host_info: self.host_info[host].update(host_info) @@ -1220,7 +1237,7 @@ def heartbeat_worker(self, comm=None, address=None, resolve_address=True, if resources: self.add_resources(worker=address, resources=resources) - self.log_event(address, merge({'action': 'heartbeat'}, info)) + self.log_event(address, merge({'action': 'heartbeat'}, metrics)) return {'status': 'OK', 'time': time(), @@ -1229,7 +1246,8 @@ def heartbeat_worker(self, comm=None, address=None, resolve_address=True, @gen.coroutine def add_worker(self, comm=None, address=None, keys=(), ncores=None, name=None, resolve_address=True, nbytes=None, now=None, - resources=None, host_info=None, memory_limit=None, **info): + resources=None, host_info=None, memory_limit=None, + metrics=None, pid=0, services=None, local_directory=None): """ Add a new worker to the cluster """ with log_errors(): address = self.coerce_address(address, resolve_address) @@ -1260,11 +1278,15 @@ def add_worker(self, comm=None, address=None, keys=(), ncores=None, self.total_ncores += ncores self.aliases[name] = address ws.name = name + ws.pid = pid + ws.services = services + ws.local_directory = local_directory response = self.heartbeat_worker(address=address, resolve_address=resolve_address, now=now, resources=resources, - host_info=host_info, **info) + host_info=host_info, + metrics=metrics) # Do not need to adjust self.total_occupancy as self.occupancy[ws] cannot exist before this. self.check_idle_saturated(ws) @@ -2732,7 +2754,7 @@ def retire_workers(self, comm=None, workers=None, remove=True, else: raise gen.Return([]) - worker_keys = {ws.address: ws.info for ws in workers} + worker_keys = {ws.address: ws.identity() for ws in workers} if close_workers and worker_keys: yield [self.close_worker(worker=w, safe=True) for w in worker_keys] diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 70e93b4534f..af64e27dca5 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -897,7 +897,7 @@ def test_remove_worker(c, s, a, b): yield b._close() - assert b.address not in s.worker_info + assert b.address not in s.workers result = yield c.gather(L) assert result == list(map(inc, range(20))) @@ -4076,7 +4076,7 @@ def test_retire_workers_2(c, s, a, b): assert s.who_has == {x.key: {b.address}} assert s.has_what == {b.address: {x.key}} - assert a.address not in s.worker_info + assert a.address not in s.workers @gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 10) diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index 4e192408f7b..03f29f9b884 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -30,22 +30,22 @@ def test_nanny(s): with rpc(n.address) as nn: assert n.is_alive() assert s.ncores[n.worker_address] == 2 - assert s.worker_info[n.worker_address]['services']['nanny'] > 1024 + assert s.workers[n.worker_address].services['nanny'] > 1024 yield nn.kill() assert not n.is_alive() assert n.worker_address not in s.ncores - assert n.worker_address not in s.worker_info + assert n.worker_address not in s.workers yield nn.kill() assert not n.is_alive() assert n.worker_address not in s.ncores - assert n.worker_address not in s.worker_info + assert n.worker_address not in s.workers yield nn.instantiate() assert n.is_alive() assert s.ncores[n.worker_address] == 2 - assert s.worker_info[n.worker_address]['services']['nanny'] > 1024 + assert s.workers[n.worker_address].services['nanny'] > 1024 yield nn.terminate() assert not n.is_alive() diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 816283537e5..542e4d91cd2 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -442,7 +442,7 @@ def test_worker_name(): s.start(0) w = Worker(s.ip, s.port, name='alice') yield w._start() - assert s.worker_info[w.address]['name'] == 'alice' + assert s.workers[w.address].name == 'alice' assert s.aliases['alice'] == w.address with pytest.raises(ValueError): @@ -575,7 +575,7 @@ def test_scheduler_sees_memory_limits(s): w = Worker(s.ip, s.port, ncores=3, memory_limit=12345) yield w._start(0) - assert s.worker_info[w.address]['memory_limit'] == 12345 + assert s.workers[w.address].memory_limit == 12345 yield w._close() diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 4ebfd9869c9..d8a3d5f2371 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -740,8 +740,7 @@ def test_worker_dir(worker): with tmpfile() as fn: @gen_cluster(client=True, worker_kwargs={'local_dir': fn}) def test_worker_dir(c, s, a, b): - directories = [info['local_directory'] - for info in s.worker_info.values()] + directories = [w.local_directory for w in s.workers.values()] assert all(d.startswith(fn) for d in directories) assert len(set(directories)) == 2 # distinct @@ -817,7 +816,7 @@ def __sizeof__(self): @gen_cluster() def test_pid(s, a, b): - assert s.worker_info[a.address]['pid'] == os.getpid() + assert s.workers[a.address].pid == os.getpid() @gen_cluster(client=True) @@ -1198,3 +1197,9 @@ def test_avoid_oversubscription(c, s, *workers): # Some other workers did some work assert len([w for w in workers if len(w.outgoing_transfer_log) > 0]) >= 3 + + +@gen_cluster(client=True, worker_kwargs={'metrics': {'my_port': lambda w: w.port}}) +def test_custom_metrics(c, s, a, b): + assert s.workers[a.address].metrics['my_port'] == a.port + assert s.workers[b.address].metrics['my_port'] == b.port diff --git a/distributed/worker.py b/distributed/worker.py index a554307e42b..eae46e8ef43 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -17,9 +17,9 @@ from dask.core import istask from dask.compatibility import apply try: - from cytoolz import pluck, partial + from cytoolz import pluck, partial, merge except ImportError: - from toolz import pluck, partial + from toolz import pluck, partial, merge from tornado.gen import Return from tornado import gen from tornado.ioloop import IOLoop @@ -89,7 +89,7 @@ def __init__(self, scheduler_ip=None, scheduler_port=None, executor=None, resources=None, silence_logs=None, death_timeout=None, preload=(), preload_argv=[], security=None, contact_address=None, memory_monitor_interval='200ms', - extensions=None, **kwargs): + extensions=None, metrics=None, **kwargs): self._setup_logging() @@ -179,6 +179,7 @@ def __init__(self, scheduler_ip=None, scheduler_port=None, self.services = {} self.service_ports = service_ports or {} self.service_specs = services or {} + self.metrics = metrics or {} handlers = { 'gather': self.gather, @@ -261,16 +262,12 @@ def heartbeat(self): start = time() response = yield self.scheduler.heartbeat_worker( address=self.contact_address, - name=self.name, now=time(), - memory_limit=self.memory_limit, - executing=len(self.executing), - in_memory=len(self.data), - ready=len(self.ready), - in_flight=len(self.in_flight_tasks), - **self.monitor.recent()) + metrics=self.get_metrics() + ) end = time() middle = (start + end) / 2 + if response['status'] == 'missing': yield self._register_with_scheduler() return @@ -283,6 +280,15 @@ def heartbeat(self): else: logger.debug("Heartbeat skipped: channel busy") + def get_metrics(self): + core = dict(executing=len(self.executing), + in_memory=len(self.data), + ready=len(self.ready), + in_flight=len(self.in_flight_tasks)) + custom = {k: metric(self) for k, metric in self.metrics.items()} + + return merge(custom, self.monitor.recent(), core) + @gen.coroutine def _register_with_scheduler(self): self.periodic_callbacks['heartbeat'].stop() @@ -301,20 +307,20 @@ def _register_with_scheduler(self): comm = yield connect(self.scheduler.address, connection_args=self.connection_args) yield comm.write(dict(op='register-worker', - ncores=self.ncores, - address=self.contact_address, - keys=list(self.data), - name=self.name, - nbytes=self.nbytes, - now=time(), - services=self.service_ports, - memory_limit=self.memory_limit, - local_directory=self.local_dir, - resources=self.total_resources, - pid=os.getpid(), - reply=False, - **self.monitor.recent()), - serializers=['msgpack']) + reply=False, + address=self.contact_address, + keys=list(self.data), + ncores=self.ncores, + name=self.name, + nbytes=self.nbytes, + now=time(), + resources=self.total_resources, + memory_limit=self.memory_limit, + local_directory=self.local_dir, + services=self.service_ports, + pid=os.getpid(), + metrics=self.get_metrics()), + serializers=['msgpack']) future = comm.read(deserializers=['msgpack']) if self.death_timeout: diff = self.death_timeout - (time() - start) diff --git a/docs/source/web.rst b/docs/source/web.rst index 97dca02cb29..7d81fe7d935 100644 --- a/docs/source/web.rst +++ b/docs/source/web.rst @@ -161,6 +161,30 @@ available in the ``workers/`` page. .. image:: https://raw.githubusercontent.com/dask/dask-org/master/images/bokeh-resources.gif :alt: Resources view of Dask web interface +Per-worker resources +~~~~~~~~~~~~~~~~~~~~ + +The ``workers/`` page shows per-worker resources, the main ones being CPU and +memory use. Custom metrics can be registered and displayed in this page. Here +is an example showing how to display GPU utilization and GPU memory use: + +.. code-block:: python + + import subprocess + + def nvidia_data(name): + def dask_function(dask_worker): + cmd = 'nvidia-smi --query-gpu={} --format=csv,noheader'.format(name) + result = subprocess.check_output(cmd.split()) + return result.strip().decode() + return dask_function + + def register_metrics(dask_worker): + for name in ['utilization.gpu', 'utilization.memory']: + dask_worker.metrics[name] = nvidia_data(name) + + client.run(register_metrics) + Connecting to Web Interface --------------------------- From 1b3bf8c47e26b0d4756ea233afd1640068906e9f Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 22 Aug 2018 09:37:52 -0400 Subject: [PATCH 0065/1550] Clean up metrics tests (#2205) We were getting intermittent testing failures with the WorkerTable custom metrics tests. This makes two changes that will hopefully resolve the situation: 1. It explicitly calls heartbeat to accelerate the tests 2. We make a copy of the metrics dict in order to avoid sharing the same metrics across multiple workers in tests --- .../bokeh/tests/test_scheduler_bokeh.py | 31 +++++-------------- distributed/worker.py | 2 +- 2 files changed, 9 insertions(+), 24 deletions(-) diff --git a/distributed/bokeh/tests/test_scheduler_bokeh.py b/distributed/bokeh/tests/test_scheduler_bokeh.py index 262a5e8a67f..1c07b437f21 100644 --- a/distributed/bokeh/tests/test_scheduler_bokeh.py +++ b/distributed/bokeh/tests/test_scheduler_bokeh.py @@ -301,8 +301,7 @@ def metric_address(worker): for name, func in metrics.items(): w.metrics[name] = func - while not all('metric_port' in s.workers[w.address].metrics for w in [a, b]): - yield gen.sleep(0.01) + yield [a.heartbeat(), b.heartbeat()] for w in [a, b]: assert s.workers[w.address].metrics['metric_port'] == w.port @@ -329,10 +328,7 @@ def metric_port(worker): a.metrics['metric_a'] = metric_port b.metrics['metric_b'] = metric_port - - while not ('metric_a' in s.workers[a.address].metrics and - 'metric_b' in s.workers[b.address].metrics): - yield gen.sleep(0.01) + yield [a.heartbeat(), b.heartbeat()] assert s.workers[a.address].metrics['metric_a'] == a.port assert s.workers[b.address].metrics['metric_b'] == b.port @@ -356,9 +352,7 @@ def metric_port(worker): return worker.port a.metrics['metric_a'] = metric_port - - while 'metric_a' not in s.workers[a.address].metrics: - yield gen.sleep(0.01) + yield [a.heartbeat(), b.heartbeat()] wt = WorkerTable(s) wt.update() @@ -377,11 +371,8 @@ def metric_port(worker): return worker.port a.metrics['metric_a'] = metric_port - a.metrics['metric_b'] = metric_port - - while not ('metric_a' in s.workers[a.address].metrics and - 'metric_b' in s.workers[b.address].metrics): - yield gen.sleep(0.01) + b.metrics['metric_b'] = metric_port + yield [a.heartbeat(), b.heartbeat()] assert s.workers[a.address].metrics['metric_a'] == a.port assert s.workers[b.address].metrics['metric_b'] == b.port @@ -393,18 +384,14 @@ def metric_port(worker): # Remove 'metric_b' from worker b del b.metrics['metric_b'] - - while 'metric_b' in s.workers[b.address].metrics: - yield gen.sleep(0.01) + yield [a.heartbeat(), b.heartbeat()] wt = WorkerTable(s) wt.update() assert 'metric_a' in wt.source.data del a.metrics['metric_a'] - - while 'metric_a' in s.workers[a.address].metrics: - yield gen.sleep(0.01) + yield [a.heartbeat(), b.heartbeat()] wt = WorkerTable(s) wt.update() @@ -419,9 +406,7 @@ def metric(worker): a.metrics['executing'] = metric a.metrics['cpu'] = metric a.metrics['metric'] = metric - - while 'metric' not in s.workers[a.address].metrics: - yield gen.sleep(0.01) + yield [a.heartbeat(), b.heartbeat()] assert s.workers[a.address].metrics['executing'] != -999 assert s.workers[a.address].metrics['cpu'] != -999 diff --git a/distributed/worker.py b/distributed/worker.py index eae46e8ef43..bd9f6040e5f 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -179,7 +179,7 @@ def __init__(self, scheduler_ip=None, scheduler_port=None, self.services = {} self.service_ports = service_ports or {} self.service_specs = services or {} - self.metrics = metrics or {} + self.metrics = dict(metrics) if metrics else {} handlers = { 'gather': self.gather, From 5b7658c1fe05b9dad405e2c56a9587727b9513c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Est=C3=A8ve?= Date: Wed, 22 Aug 2018 15:38:22 +0200 Subject: [PATCH 0066/1550] Fix formatting when port is a tuple (#2204) --- distributed/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 9f2d961b2ff..86c75d1872d 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1051,7 +1051,7 @@ def start_services(self, listen_ip): service.listen((listen_ip, port)) self.services[k] = service except Exception as e: - warnings.warn("\nCould not launch service '%s' on port %d. " % (k, port) + + warnings.warn("\nCould not launch service '%s' on port %s. " % (k, port) + "Got the following message:\n\n" + str(e), stacklevel=3) From 58de8b6d21d285591a5a2a63eb218df843a70d07 Mon Sep 17 00:00:00 2001 From: Mike DePalatis Date: Sat, 25 Aug 2018 16:01:05 -0400 Subject: [PATCH 0067/1550] Describe what ZeroMQ is (#2211) --- docs/source/related-work.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/related-work.rst b/docs/source/related-work.rst index 413f8cd86b9..e2b13e458d2 100644 --- a/docs/source/related-work.rst +++ b/docs/source/related-work.rst @@ -61,7 +61,7 @@ Direct Communication * MPI4Py_: Wraps the Message Passing Interface popular in high performance computing. -* PyZMQ_: Wraps ZeroMQ, the gentleman's socket. +* PyZMQ_: Wraps ZeroMQ, the high-performance asynchronous messaging library. Venerable ~~~~~~~~~ From f701a538eb8c24b5e6234db32b90de1de4193a31 Mon Sep 17 00:00:00 2001 From: Anderson Banihirwe Date: Tue, 28 Aug 2018 09:25:48 -0400 Subject: [PATCH 0068/1550] Tiny typo fix (#2214) --- docs/source/work-stealing.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/work-stealing.rst b/docs/source/work-stealing.rst index acb2d699423..cf5a4bc48c1 100644 --- a/docs/source/work-stealing.rst +++ b/docs/source/work-stealing.rst @@ -118,7 +118,7 @@ sends a request to the busy worker. The worker inspects its current state of the task and sends a response to the scheduler: 1. If the task is not yet running, then the worker cancels the task and - informs the scheduler that it can reroute the ask elsewhere. + informs the scheduler that it can reroute the task elsewhere. 2. If the task is already running or complete then the worker tells the scheduler that it should not replicate the task elsewhere. From de340007509b019b3a94b67d2379985f47bf4a69 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 28 Aug 2018 10:22:35 -0400 Subject: [PATCH 0069/1550] Add Python 3.7 to travis.yml (#2203) * Add Python 3.7 to travis.yml * bump pytest to 3.7 * skip test that uses yield/StopIteration This is no longer valid for Python 3.7 The solution of using yield then return works just fine, but raises a SyntaxError in Python 2 * add compression libraries --- .travis.yml | 7 ++++--- continuous_integration/travis/install.sh | 3 +-- distributed/tests/test_client.py | 2 ++ 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/.travis.yml b/.travis.yml index 2dd96098c50..5963855c59b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -6,9 +6,10 @@ dist: trusty env: matrix: - - PYTHON=2.7 TESTS=true PACKAGES="python-blosc futures faulthandler" - - PYTHON=3.5.4 TESTS=true COVERAGE=true PACKAGES=python-blosc CRICK=true - - PYTHON=3.6 TESTS=true PACKAGES="scikit-learn" + - PYTHON=2.7 TESTS=true PACKAGES="python-blosc futures faulthandler lz4" + - PYTHON=3.5.4 TESTS=true COVERAGE=true PACKAGES="python-blosc lz4" CRICK=true + - PYTHON=3.6 TESTS=true PACKAGES="scikit-learn lz4" + - PYTHON=3.7 TESTS=true PACKAGES="scikit-learn python-snappy python-blosc" matrix: fast_finish: true diff --git a/continuous_integration/travis/install.sh b/continuous_integration/travis/install.sh index 72f13e20b07..4954f582f1c 100644 --- a/continuous_integration/travis/install.sh +++ b/continuous_integration/travis/install.sh @@ -39,12 +39,11 @@ conda install -q \ ipywidgets \ joblib \ jupyter_client \ - lz4 \ mock \ netcdf4 \ paramiko \ psutil \ - pytest=3.1 \ + pytest=3.7 \ pytest-timeout \ python=$PYTHON \ requests \ diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index af64e27dca5..6ed9e0faf8c 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -2347,6 +2347,8 @@ def test_map_queue(c, s, a, b): assert result == (1 + 1) * 2 +@pytest.mark.skipif(sys.version_info >= (3, 7), + reason="replace StopIteration with return") @gen_cluster(client=True) def test_map_iterator_with_return(c, s, a, b): def g(): From a1aeff71096a08906ce073c710f80a3a9656c8c1 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 28 Aug 2018 13:27:31 -0400 Subject: [PATCH 0070/1550] Downgrade exception to warning when reusing port (#2199) * Downgrade exception to warning when reusing port * change warning message --- distributed/bokeh/core.py | 7 +++++-- distributed/deploy/tests/test_local.py | 5 ++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/distributed/bokeh/core.py b/distributed/bokeh/core.py index 351901a3386..f7409c043a0 100644 --- a/distributed/bokeh/core.py +++ b/distributed/bokeh/core.py @@ -2,6 +2,7 @@ from distutils.version import LooseVersion import os +import warnings import bokeh from bokeh.server.server import Server @@ -44,11 +45,13 @@ def listen(self, addr): if ("already in use" in str(exc) or # Unix/Mac "Only one usage of" in str(exc)): # Windows msg = ("Port %d is already in use. " - "Perhaps you already have a cluster running?" + "\nPerhaps you already have a cluster running?" + "\nHosting the diagnostics dashboard on a random port instead." % port) else: msg = "Failed to start diagnostics server on port %d. " % port + str(exc) - raise type(exc)(msg) + warnings.warn('\n' + msg) + port = 0 if i == 4: raise diff --git a/distributed/deploy/tests/test_local.py b/distributed/deploy/tests/test_local.py index 25536222ce0..b2c226c565c 100644 --- a/distributed/deploy/tests/test_local.py +++ b/distributed/deploy/tests/test_local.py @@ -147,7 +147,10 @@ def test_Client_solo(loop): def test_duplicate_clients(): c1 = yield Client(processes=False, silence_logs=False, diagnostics_port=9876) with pytest.warns(Exception) as info: - yield Client(processes=False, silence_logs=False, diagnostics_port=9876) + c2 = yield Client(processes=False, silence_logs=False, diagnostics_port=9876) + + assert 'bokeh' in c1.cluster.scheduler.services + assert 'bokeh' in c2.cluster.scheduler.services assert any(all(word in str(msg.message).lower() for word in ['9876', 'running', 'already in use']) From 15743f6505e12785f5f151789f04fa11f9ab4589 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 28 Aug 2018 15:48:52 -0400 Subject: [PATCH 0071/1550] Add plot= keyword to get_task_stream (#2198) * Add plot= keyword to get_task_stream ```python In [1]: from dask.distributed import Client, get_task_stream In [2]: import dask.array as da In [3]: client = Client() In [4]: with get_task_stream(plot='save', filename='foo.html') as ts: ...: da.random.random((10000, 10000), chunks=(1000, 1000)).std().compute() In [5]: !ls foo.html foo.html In [6]: type(ts.figure) Out[6]: bokeh.plotting.figure.Figure ``` * add default filename --- distributed/bokeh/components.py | 85 +++++++------ distributed/client.py | 65 +++++++++- distributed/diagnostics/task_stream.py | 114 ++++++++++-------- .../diagnostics/tests/test_task_stream.py | 30 +++++ 4 files changed, 199 insertions(+), 95 deletions(-) diff --git a/distributed/bokeh/components.py b/distributed/bokeh/components.py index c2f37844e6e..f3d1ca4c3d4 100644 --- a/distributed/bokeh/components.py +++ b/distributed/bokeh/components.py @@ -59,12 +59,48 @@ def __init__(self, n_rectangles=1000, clear_interval='20s', **kwargs): """ kwargs are applied to the bokeh.models.plots.Plot constructor """ - clear_interval = parse_timedelta(clear_interval, default='ms') self.n_rectangles = n_rectangles + clear_interval = parse_timedelta(clear_interval, default='ms') self.clear_interval = clear_interval self.last = 0 - self.source = ColumnDataSource(data=dict( + self.source, self.root = task_stream_figure(clear_interval, **kwargs) + + # Required for update callback + self.task_stream_index = [0] + + def update(self, messages): + with log_errors(): + index = messages['task-events']['index'] + rectangles = messages['task-events']['rectangles'] + + if not index or index[-1] == self.task_stream_index[0]: + return + + ind = bisect(index, self.task_stream_index[0]) + rectangles = {k: [v[i] for i in range(ind, len(index))] + for k, v in rectangles.items()} + self.task_stream_index[0] = index[-1] + + # If there has been a significant delay then clear old rectangles + if rectangles['start']: + m = min(map(add, rectangles['start'], rectangles['duration'])) + if m > self.last: + self.last, last = m, self.last + if m > last + self.clear_interval: + self.source.data.update(rectangles) + return + + self.source.stream(rectangles, self.n_rectangles) + + +def task_stream_figure(clear_interval='20s', **kwargs): + """ + kwargs are applied to the bokeh.models.plots.Plot constructor + """ + clear_interval = parse_timedelta(clear_interval, default='ms') + + source = ColumnDataSource(data=dict( start=[time() - clear_interval], duration=[0.1], key=['start'], name=['start'], color=['white'], duration_text=['100 ms'], worker=['foo'], y=[0], worker_thread=[1], alpha=[0.0]) @@ -73,21 +109,21 @@ def __init__(self, n_rectangles=1000, clear_interval='20s', **kwargs): x_range = DataRange1d(range_padding=0) y_range = DataRange1d(range_padding=0) - self.root = figure( + root = figure( title="Task Stream", id='bk-task-stream-plot', x_range=x_range, y_range=y_range, toolbar_location="above", x_axis_type='datetime', min_border_right=35, tools='', **kwargs) - self.root.yaxis.axis_label = 'Worker Core' + root.yaxis.axis_label = 'Worker Core' - rect = self.root.rect(source=self.source, x="start", y="y", + rect = root.rect(source=source, x="start", y="y", width="duration", height=0.4, fill_color="color", line_color="color", line_alpha=0.6, fill_alpha="alpha", line_width=3) rect.nonselection_glyph = None - self.root.yaxis.major_label_text_alpha = 0 - self.root.yaxis.minor_tick_line_alpha = 0 - self.root.xgrid.visible = False + root.yaxis.major_label_text_alpha = 0 + root.yaxis.minor_tick_line_alpha = 0 + root.xgrid.visible = False hover = HoverTool( point_policy="follow_mouse", @@ -101,7 +137,7 @@ def __init__(self, n_rectangles=1000, clear_interval='20s', **kwargs): tap = TapTool(callback=OpenURL(url='/profile?key=@name')) - self.root.add_tools( + root.add_tools( hover, tap, BoxZoomTool(), ResetTool(), @@ -110,35 +146,10 @@ def __init__(self, n_rectangles=1000, clear_interval='20s', **kwargs): ) if ExportTool: export = ExportTool() - export.register_plot(self.root) - self.root.add_tools(export) - - # Required for update callback - self.task_stream_index = [0] - - def update(self, messages): - with log_errors(): - index = messages['task-events']['index'] - rectangles = messages['task-events']['rectangles'] - - if not index or index[-1] == self.task_stream_index[0]: - return + export.register_plot(root) + root.add_tools(export) - ind = bisect(index, self.task_stream_index[0]) - rectangles = {k: [v[i] for i in range(ind, len(index))] - for k, v in rectangles.items()} - self.task_stream_index[0] = index[-1] - - # If there has been a significant delay then clear old rectangles - if rectangles['start']: - m = min(map(add, rectangles['start'], rectangles['duration'])) - if m > self.last: - self.last, last = m, self.last - if m > last + self.clear_interval: - self.source.data.update(rectangles) - return - - self.source.stream(rectangles, self.n_rectangles) + return source, root class MemoryUsage(DashboardComponent): diff --git a/distributed/client.py b/distributed/client.py index 2ee1ffea537..7e97df7dbdf 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -3337,7 +3337,8 @@ def get_restrictions(cls, collections, workers, allow_other_workers): def collections_to_dsk(collections, *args, **kwargs): return collections_to_dsk(collections, *args, **kwargs) - def get_task_stream(self, start=None, stop=None, count=None): + def get_task_stream(self, start=None, stop=None, count=None, plot=False, + filename='task-stream.html'): """ Get task stream data from scheduler This collects the data present in the diagnostic "Task Stream" plot on @@ -3360,6 +3361,11 @@ def get_task_stream(self, start=None, stop=None, count=None): count: int The number of desired records, ignored if both start and stop are specified + plot: boolean, str + If true then also return a Bokeh figure + If plot == 'save' then save the figure to a file + filename: str (optional) + The filename to save to if you set ``plot='save'`` Examples -------- @@ -3371,6 +3377,11 @@ def get_task_stream(self, start=None, stop=None, count=None): 'thread': ..., ...}] + Pass the ``plot=True`` or ``plot='save'`` keywords to get back a Bokeh + figure + + >>> data, figure = client.get_task_stream(plot='save', filename='myfile.html') + Alternatively consider the context manager >>> from dask.distributed import get_task_stream @@ -3385,10 +3396,28 @@ def get_task_stream(self, start=None, stop=None, count=None): See Also -------- - get_task_stream: a dontext manager version of this method + get_task_stream: a context manager version of this method """ - return self.sync(self.scheduler.get_task_stream, start=start, + return self.sync(self._get_task_stream, start=start, stop=stop, + count=count, plot=plot, filename=filename) + + @gen.coroutine + def _get_task_stream(self, start=None, stop=None, count=None, plot=False, + filename='task-stream.html'): + msgs = yield self.scheduler.get_task_stream(start=start, stop=stop, count=count) + if plot: + from .diagnostics.task_stream import rectangles + rects = rectangles(msgs) + from .bokeh.components import task_stream_figure + source, figure = task_stream_figure(sizing_mode='stretch_both') + source.data.update(rects) + if plot == 'save': + from bokeh.plotting import save + save(figure, title='Dask Task Stream', filename=filename) + raise gen.Return((msgs, figure)) + else: + raise gen.Return(msgs) class Executor(Client): @@ -3768,6 +3797,14 @@ class get_task_stream(object): This must be used as a context manager. + Parameters + ---------- + plot: boolean, str + If true then also return a Bokeh figure + If plot == 'save' then save the figure to a file + filename: str (optional) + The filename to save to if you set ``plot='save'`` + Examples -------- >>> with get_task_stream() as ts: @@ -3775,12 +3812,22 @@ class get_task_stream(object): >>> ts.data [...] + Get back a Bokeh figure and optionally save to a file + + >>> with get_task_stream(plot='save', filename='myfile.html') as ts: + ... x.compute() + >>> ts.figure + + See Also -------- Client.get_task_stream: Function version of this context manager """ - def __init__(self, client=None): + def __init__(self, client=None, plot=False, filename='task-stream.html'): self.data = [] + self._plot = plot + self._filename = filename + self.figure = None self.client = client or default_client() self.client.get_task_stream(start=0, stop=0) # ensure plugin @@ -3789,7 +3836,10 @@ def __enter__(self): return self def __exit__(self, typ, value, traceback): - L = self.client.get_task_stream(start=self.start) + L = self.client.get_task_stream(start=self.start, plot=self._plot, + filename=self._filename) + if self._plot: + L, self.figure = L self.data.extend(L) @gen.coroutine @@ -3798,7 +3848,10 @@ def __aenter__(self): @gen.coroutine def __aexit__(self, typ, value, traceback): - L = yield self.client.get_task_stream(start=self.start) + L = yield self.client.get_task_stream(start=self.start, plot=self._plot, + filename=self._filename) + if self._plot: + L, self.figure = L self.data.extend(L) diff --git a/distributed/diagnostics/task_stream.py b/distributed/diagnostics/task_stream.py index 7cabcf96311..f3fd169d8d6 100644 --- a/distributed/diagnostics/task_stream.py +++ b/distributed/diagnostics/task_stream.py @@ -71,64 +71,74 @@ def bisect(target, left, right): return [self.buffer[i] for i in range(start, stop)] def rectangles(self, istart, istop=None, workers=None, start_boundary=0): - L_start = [] - L_duration = [] - L_duration_text = [] - L_key = [] - L_name = [] - L_color = [] - L_alpha = [] - L_worker = [] - L_worker_thread = [] - L_y = [] - + msgs = [] diff = self.index - len(self.buffer) if istop is None: istop = len(self.buffer) for i in range((istart or 0) - diff, istop - diff if istop else istop): msg = self.buffer[i] - key = msg['key'] - name = key_split(key) - startstops = msg.get('startstops', []) - try: - worker_thread = '%s-%d' % (msg['worker'], msg['thread']) - except Exception: + msgs.append(msg) + + return rectangles(msgs, workers=workers, start_boundary=start_boundary) + + +def rectangles(msgs, workers=None, start_boundary=0): + workers = workers or {} + + L_start = [] + L_duration = [] + L_duration_text = [] + L_key = [] + L_name = [] + L_color = [] + L_alpha = [] + L_worker = [] + L_worker_thread = [] + L_y = [] + + for msg in msgs: + key = msg['key'] + name = key_split(key) + startstops = msg.get('startstops', []) + try: + worker_thread = '%s-%d' % (msg['worker'], msg['thread']) + except Exception: + continue + logger.warning("Message contained bad information: %s", msg, + exc_info=True) + worker_thread = '' + + if worker_thread not in workers: + workers[worker_thread] = len(workers) / 2 + + for action, start, stop in startstops: + if start < start_boundary: continue - logger.warning("Message contained bad information: %s", msg, - exc_info=True) - worker_thread = '' - - if worker_thread not in workers: - workers[worker_thread] = len(workers) / 2 - - for action, start, stop in startstops: - if start < start_boundary: - continue - color = colors[action] - if type(color) is not str: - color = color(msg) - - L_start.append((start + stop) / 2 * 1000) - L_duration.append(1000 * (stop - start)) - L_duration_text.append(format_time(stop - start)) - L_key.append(key) - L_name.append(prefix[action] + name) - L_color.append(color) - L_alpha.append(alphas[action]) - L_worker.append(msg['worker']) - L_worker_thread.append(worker_thread) - L_y.append(workers[worker_thread]) - - return {'start': L_start, - 'duration': L_duration, - 'duration_text': L_duration_text, - 'key': L_key, - 'name': L_name, - 'color': L_color, - 'alpha': L_alpha, - 'worker': L_worker, - 'worker_thread': L_worker_thread, - 'y': L_y} + color = colors[action] + if type(color) is not str: + color = color(msg) + + L_start.append((start + stop) / 2 * 1000) + L_duration.append(1000 * (stop - start)) + L_duration_text.append(format_time(stop - start)) + L_key.append(key) + L_name.append(prefix[action] + name) + L_color.append(color) + L_alpha.append(alphas[action]) + L_worker.append(msg['worker']) + L_worker_thread.append(worker_thread) + L_y.append(workers[worker_thread]) + + return {'start': L_start, + 'duration': L_duration, + 'duration_text': L_duration_text, + 'key': L_key, + 'name': L_name, + 'color': L_color, + 'alpha': L_alpha, + 'worker': L_worker, + 'worker_thread': L_worker_thread, + 'y': L_y} def color_of_message(msg): diff --git a/distributed/diagnostics/tests/test_task_stream.py b/distributed/diagnostics/tests/test_task_stream.py index eccb0a9db8e..f4354b74e6b 100644 --- a/distributed/diagnostics/tests/test_task_stream.py +++ b/distributed/diagnostics/tests/test_task_stream.py @@ -1,7 +1,9 @@ from __future__ import absolute_import, division, print_function +import os from time import sleep +import pytest from toolz import frequencies from distributed import Client, get_task_stream @@ -102,3 +104,31 @@ def test_client_sync(loop): wait(futures) assert len(ts.data) == 10 + + +@gen_cluster(client=True) +def test_get_task_stream_plot(c, s, a, b): + bokeh = pytest.importorskip('bokeh') + yield c.get_task_stream() + + futures = c.map(slowinc, range(10), delay=0.1) + yield wait(futures) + + data, figure = yield c.get_task_stream(plot=True) + assert isinstance(figure, bokeh.plotting.Figure) + + +def test_get_task_stream_save(loop, tmpdir): + bokeh = pytest.importorskip('bokeh') + tmpdir = str(tmpdir) + fn = os.path.join(tmpdir, 'foo.html') + with cluster() as (s, [a, b]): + with Client(s['address'], loop=loop) as c: + with get_task_stream(plot='save', filename=fn) as ts: + wait(c.map(inc, range(10))) + with open(fn) as f: + data = f.read() + assert 'inc' in data + assert 'bokeh' in data + + assert isinstance(ts.figure, bokeh.plotting.Figure) From 1b9f9f4a6ec0d101869d9a92983ad1c7fe7c0241 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 29 Aug 2018 09:28:36 -0400 Subject: [PATCH 0072/1550] Add support for optional versions in Client.get_versions (#2216) * add dask-ml to get_versions * Add support for optional packages in get_versions --- distributed/client.py | 14 +++++++++---- distributed/node.py | 4 ++-- distributed/tests/test_client.py | 4 ++++ distributed/versions.py | 36 +++++++++++++++++++++++++++----- 4 files changed, 47 insertions(+), 11 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 7e97df7dbdf..6461065dfe9 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -3039,7 +3039,7 @@ def set_metadata(self, key, value): key = (key,) return self.sync(self.scheduler.set_metadata, keys=key, value=value) - def get_versions(self, check=False): + def get_versions(self, check=False, packages=[]): """ Return version info for the scheduler, all workers and myself Parameters @@ -3047,18 +3047,24 @@ def get_versions(self, check=False): check : boolean, default False raise ValueError if all required & optional packages do not match + packages : List[str] + Extra package names to check Examples -------- >>> c.get_versions() # doctest: +SKIP + + >>> c.get_versions(packages=['sklearn', 'geopandas']) # doctest: +SKIP """ - client = get_versions() + client = get_versions(packages=packages) try: - scheduler = sync(self.loop, self.scheduler.versions) + scheduler = sync(self.loop, self.scheduler.versions, + packages=packages) except KeyError: scheduler = None - workers = sync(self.loop, self.scheduler.broadcast, msg={'op': 'versions'}) + workers = sync(self.loop, self.scheduler.broadcast, + msg={'op': 'versions', 'packages': packages}) result = {'scheduler': scheduler, 'workers': workers, 'client': client} if check: diff --git a/distributed/node.py b/distributed/node.py index 8373c07709c..654d67f376c 100644 --- a/distributed/node.py +++ b/distributed/node.py @@ -46,5 +46,5 @@ def __init__(self, handlers=None, stream_handlers=None, connection_limit=connection_limit, deserialize=deserialize, io_loop=self.io_loop) - def versions(self, comm=None): - return get_versions() + def versions(self, comm=None, packages=None): + return get_versions(packages=packages) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 6ed9e0faf8c..7f09a2efce7 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -3814,6 +3814,10 @@ def test_get_versions(loop): # smoke test for versions # that this does not raise + v = c.get_versions(packages=['requests']) + import requests + assert dict(v['client']['packages']['optional'])['requests'] == requests.__version__ + def test_threaded_get_within_distributed(loop): with cluster() as (s, [a, b]): diff --git a/distributed/versions.py b/distributed/versions.py index 9f95591c880..fa7bbc0835a 100644 --- a/distributed/versions.py +++ b/distributed/versions.py @@ -9,6 +9,8 @@ import locale import importlib +from .utils import ignoring + required_packages = [('dask', lambda p: p.__version__), ('distributed', lambda p: p.__version__), @@ -21,16 +23,20 @@ ('pandas', lambda p: p.__version__), ('bokeh', lambda p: p.__version__), ('lz4', lambda p: p.__version__), + ('dask_ml', lambda p: p.__version__), ('blosc', lambda p: p.__version__)] -def get_versions(): - """ Return basic information on our software installation, - and out installed versions of packages. """ +def get_versions(packages=None): + """ + Return basic information on our software installation, and out installed versions of packages. + """ + if packages is None: + packages = [] d = {'host': get_system_info(), 'packages': {'required': get_package_info(required_packages), - 'optional': get_package_info(optional_packages)} + 'optional': get_package_info(optional_packages + list(packages))} } return d @@ -53,11 +59,31 @@ def get_system_info(): return host +def version_of_package(pkg): + """ Try a variety of common ways to get the version of a package """ + with ignoring(AttributeError): + return pkg.__version__ + with ignoring(AttributeError): + return str(pkg.version) + with ignoring(AttributeError): + return '.'.join(map(str, pkg.version_info)) + return None + + def get_package_info(pkgs): """ get package versions for the passed required & optional packages """ pversions = [] - for (modname, ver_f) in pkgs: + for pkg in pkgs: + if isinstance(pkg, (tuple, list)): + modname, ver_f = pkg + else: + modname = pkg + ver_f = version_of_package + + if ver_f is None: + ver_f = version_of_package + try: mod = importlib.import_module(modname) ver = ver_f(mod) From ee86eef4dfa292edafe193248f789dc30c7da414 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 29 Aug 2018 11:39:41 -0400 Subject: [PATCH 0073/1550] Add routes for solo bokeh figures in dashboard (#2185) This adds new routes for individual plots. This is to help some JLab extension work. --- distributed/bokeh/scheduler.py | 70 +++++++++++++++++++ distributed/bokeh/scheduler_html.py | 11 +++ .../bokeh/tests/test_scheduler_bokeh.py | 15 ++-- .../bokeh/tests/test_scheduler_bokeh_html.py | 1 + 4 files changed, 93 insertions(+), 4 deletions(-) diff --git a/distributed/bokeh/scheduler.py b/distributed/bokeh/scheduler.py index a282e012574..508d35eb037 100644 --- a/distributed/bokeh/scheduler.py +++ b/distributed/bokeh/scheduler.py @@ -1173,6 +1173,58 @@ def status_doc(scheduler, extra, doc): doc.template_variables.update(extra) +def individual_task_stream_doc(scheduler, extra, doc): + task_stream = TaskStream(scheduler, n_rectangles=1000, + clear_interval='10s', sizing_mode='stretch_both') + task_stream.update() + doc.add_periodic_callback(task_stream.update, 100) + doc.add_root(task_stream.root) + + +def individual_load_doc(scheduler, extra, doc): + current_load = CurrentLoad(scheduler, height=160, sizing_mode='stretch_both') + current_load.update() + doc.add_periodic_callback(current_load.update, 100) + doc.add_root(current_load.root) + + +def individual_progress_doc(scheduler, extra, doc): + task_progress = TaskProgress(scheduler, height=160, sizing_mode='stretch_both') + task_progress.update() + doc.add_periodic_callback(task_progress.update, 100) + doc.add_root(task_progress.root) + + +def individual_graph_doc(scheduler, extra, doc): + with log_errors(): + graph = GraphPlot(scheduler, sizing_mode='stretch_both') + graph.update() + doc.add_periodic_callback(graph.update, 200) + doc.add_root(graph.root) + + +def individual_profile_doc(scheduler, extra, doc): + with log_errors(): + prof = ProfileTimePlot(scheduler, sizing_mode='scale_width', doc=doc) + doc.add_root(prof.root) + prof.trigger_update() + + +def individual_profile_server_doc(scheduler, extra, doc): + with log_errors(): + prof = ProfileServer(scheduler, sizing_mode='scale_width', doc=doc) + doc.add_root(prof.root) + prof.trigger_update() + + +def individual_workers_doc(scheduler, extra, doc): + with log_errors(): + table = WorkerTable(scheduler) + table.update() + doc.add_periodic_callback(table.update, 500) + doc.add_root(table.root) + + def profile_doc(scheduler, extra, doc): with log_errors(): doc.title = "Dask: Profile" @@ -1220,6 +1272,16 @@ def __init__(self, scheduler, io_loop=None, prefix='', **kwargs): profile_server = Application(FunctionHandler(partial(profile_server_doc, scheduler, self.extra))) graph = Application(FunctionHandler(partial(graph_doc, scheduler, self.extra))) + individual_task_stream = Application(FunctionHandler(partial( + individual_task_stream_doc, scheduler, self.extra))) + individual_progress = Application(FunctionHandler(partial(individual_progress_doc, scheduler, self.extra))) + individual_graph = Application(FunctionHandler(partial(individual_graph_doc, scheduler, self.extra))) + individual_profile = Application(FunctionHandler(partial(individual_profile_doc, scheduler, self.extra))) + individual_profile_server = Application(FunctionHandler(partial( + individual_profile_server_doc, scheduler, self.extra))) + individual_load = Application(FunctionHandler(partial(individual_load_doc, scheduler, self.extra))) + individual_workers = Application(FunctionHandler(partial(individual_workers_doc, scheduler, self.extra))) + self.apps = { '/system': systemmonitor, '/stealing': stealing, @@ -1231,6 +1293,14 @@ def __init__(self, scheduler, io_loop=None, prefix='', **kwargs): '/profile': profile, '/profile-server': profile_server, '/graph': graph, + + '/individual-task-stream': individual_task_stream, + '/individual-progress': individual_progress, + '/individual-graph': individual_graph, + '/individual-profile': individual_profile, + '/individual-profile-server': individual_profile_server, + '/individual-load': individual_load, + '/individual-workers': individual_workers, } self.loop = io_loop or scheduler.loop diff --git a/distributed/bokeh/scheduler_html.py b/distributed/bokeh/scheduler_html.py index bab032c7be7..086155be957 100644 --- a/distributed/bokeh/scheduler_html.py +++ b/distributed/bokeh/scheduler_html.py @@ -152,6 +152,16 @@ def get(self): self.render('json-index.html', routes=r, title='Index of JSON routes', **self.extra) +class IndividualPlots(RequestHandler): + def get(self): + bokeh_server = self.server.services['bokeh'] + result = {uri.strip('/').replace('-', ' ').title(): uri + for uri in bokeh_server.apps + if uri.lstrip('/').startswith('individual-') + and not uri.endswith('.json')} + self.write(result) + + routes = [ (r'info/main/workers.html', Workers), (r'info/worker/(.*).html', Worker), @@ -163,6 +173,7 @@ def get(self): (r'json/counts.json', CountsJSON), (r'json/identity.json', IdentityJSON), (r'json/index.html', IndexJSON), + (r'individual-plots.json', IndividualPlots), ] diff --git a/distributed/bokeh/tests/test_scheduler_bokeh.py b/distributed/bokeh/tests/test_scheduler_bokeh.py index 1c07b437f21..ef8599bb6bd 100644 --- a/distributed/bokeh/tests/test_scheduler_bokeh.py +++ b/distributed/bokeh/tests/test_scheduler_bokeh.py @@ -1,10 +1,11 @@ from __future__ import print_function, division, absolute_import +import json +import sys from time import sleep import pytest pytest.importorskip('bokeh') -import sys from toolz import first from tornado import gen from tornado.httpclient import AsyncHTTPClient @@ -35,17 +36,23 @@ scheduler_kwargs={'services': {('bokeh', 0): BokehScheduler}}) def test_simple(c, s, a, b): assert isinstance(s.services['bokeh'], BokehScheduler) + port = s.services['bokeh'].port future = c.submit(sleep, 1) yield gen.sleep(0.1) http_client = AsyncHTTPClient() for suffix in ['system', 'counters', 'workers', 'status', 'tasks', - 'stealing', 'graph']: - response = yield http_client.fetch('http://localhost:%d/%s' - % (s.services['bokeh'].port, suffix)) + 'stealing', 'graph', 'individual-task-stream', 'individual-progress', + 'individual-graph', 'individual-load', + 'individual-profile']: + response = yield http_client.fetch('http://localhost:%d/%s' % (port, suffix)) assert 'bokeh' in response.body.decode().lower() + response = yield http_client.fetch('http://localhost:%d/individual-plots.json' % port) + response = json.loads(response.body.decode()) + assert response + @gen_cluster(client=True, worker_kwargs=dict(services={'bokeh': BokehWorker})) def test_basic(c, s, a, b): diff --git a/distributed/bokeh/tests/test_scheduler_bokeh_html.py b/distributed/bokeh/tests/test_scheduler_bokeh_html.py index 89fe9c47ca6..52eb65c803f 100644 --- a/distributed/bokeh/tests/test_scheduler_bokeh_html.py +++ b/distributed/bokeh/tests/test_scheduler_bokeh_html.py @@ -30,6 +30,7 @@ def test_connect(c, s, a, b): 'json/counts.json', 'json/identity.json', 'json/index.html', + 'individual-plots.json', ]: response = yield http_client.fetch('http://localhost:%d/%s' % (s.services['bokeh'].port, suffix)) From 32c79ecba86ca1673a4d40a395d429d92e643202 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 29 Aug 2018 12:09:14 -0400 Subject: [PATCH 0074/1550] Fix intermittent failure for test_dont_steal_long_running_tasks (#2218) --- distributed/tests/test_steal.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 01193b61e20..35315d531ca 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -587,10 +587,12 @@ def long(delay): yield gen.sleep(0.2) - assert sum(1 for k in s.processing[b.address] if k.startswith('long')) <= nb - yield wait(long_tasks) + for t in long_tasks: + assert (sum(log[1] == 'executing' for log in a.story(t)) + + sum(log[1] == 'executing' for log in b.story(t))) <= 1 + @gen_cluster(client=True, ncores=[('127.0.0.1', 5)] * 2) def test_cleanup_repeated_tasks(c, s, a, b): From b391a4f09a860fa5e4584ff1b7a008179e141225 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 29 Aug 2018 12:12:11 -0400 Subject: [PATCH 0075/1550] Be resilient to missing dep after busy signal (#2217) --- distributed/worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/worker.py b/distributed/worker.py index bd9f6040e5f..66535e4517d 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1898,7 +1898,7 @@ def gather_dep(self, worker, dep, deps, total_nbytes, cause=None): if response['status'] == 'busy': self.log.append(('busy-gather', worker, deps)) for dep in deps: - if self.dep_state[dep] == 'flight': + if self.dep_state.get(dep, None) == 'flight': self.transition_dep(dep, 'waiting') return From 182a60ef404e7e73a1860f641737e65c9751fd78 Mon Sep 17 00:00:00 2001 From: Derek Ludwig Date: Wed, 29 Aug 2018 13:29:39 -0700 Subject: [PATCH 0076/1550] Use CSS Grid to layout status page on the dashboard (#2213) This uses CSS Grid and the new Bokeh templates to layout the status page in a way that is more responsive to wide and narrow screens. This removes the div around single-plot pages, which resolves a resizing issue that arose in Bokeh 0.13.0 This sets the minimum bokeh version to 0.13 Fixes https://github.com/dask/distributed/issues/2194 --- distributed/bokeh/components.py | 34 +++-- distributed/bokeh/core.py | 6 +- distributed/bokeh/scheduler.py | 137 ++++++++--------- distributed/bokeh/static/dask_horizontal.svg | 28 ---- distributed/bokeh/static/dask_icon_no_pad.svg | 35 +++++ distributed/bokeh/template.html | 123 --------------- distributed/bokeh/templates/base.html | 141 ++++++++++++++++++ distributed/bokeh/templates/simple.html | 6 + distributed/bokeh/templates/status.html | 79 ++++++++++ .../bokeh/tests/test_scheduler_bokeh.py | 3 +- distributed/bokeh/worker.py | 20 +-- 11 files changed, 364 insertions(+), 248 deletions(-) delete mode 100644 distributed/bokeh/static/dask_horizontal.svg create mode 100644 distributed/bokeh/static/dask_icon_no_pad.svg delete mode 100644 distributed/bokeh/template.html create mode 100644 distributed/bokeh/templates/base.html create mode 100644 distributed/bokeh/templates/simple.html create mode 100644 distributed/bokeh/templates/status.html diff --git a/distributed/bokeh/components.py b/distributed/bokeh/components.py index f3d1ca4c3d4..5626977483f 100644 --- a/distributed/bokeh/components.py +++ b/distributed/bokeh/components.py @@ -110,19 +110,35 @@ def task_stream_figure(clear_interval='20s', **kwargs): y_range = DataRange1d(range_padding=0) root = figure( - title="Task Stream", id='bk-task-stream-plot', - x_range=x_range, y_range=y_range, toolbar_location="above", - x_axis_type='datetime', min_border_right=35, tools='', **kwargs) - root.yaxis.axis_label = 'Worker Core' - - rect = root.rect(source=source, x="start", y="y", - width="duration", height=0.4, fill_color="color", - line_color="color", line_alpha=0.6, fill_alpha="alpha", - line_width=3) + name='task_stream', + title="Task Stream", + id='bk-task-stream-plot', + x_range=x_range, + y_range=y_range, + toolbar_location="above", + x_axis_type='datetime', + min_border_right=35, + tools='', + **kwargs + ) + + rect = root.rect( + source=source, + x="start", + y="y", + width="duration", + height=0.4, + fill_color="color", + line_color="color", + line_alpha=0.6, + fill_alpha="alpha", + line_width=3 + ) rect.nonselection_glyph = None root.yaxis.major_label_text_alpha = 0 root.yaxis.minor_tick_line_alpha = 0 + root.yaxis.major_tick_line_alpha = 0 root.xgrid.visible = False hover = HoverTool( diff --git a/distributed/bokeh/core.py b/distributed/bokeh/core.py index f7409c043a0..9e8540f0037 100644 --- a/distributed/bokeh/core.py +++ b/distributed/bokeh/core.py @@ -9,8 +9,10 @@ from tornado import web -if LooseVersion(bokeh.__version__) < LooseVersion('0.12.6'): - raise ImportError("Dask needs bokeh >= 0.12.6") +if LooseVersion(bokeh.__version__) < LooseVersion('0.13.0'): + warnings.warn("\nDask needs bokeh >= 0.13.0 for the dashboard." + "\nContinuing without the dashboard.") + raise ImportError("Dask needs bokeh >= 0.13.0") class BokehServer(object): diff --git a/distributed/bokeh/scheduler.py b/distributed/bokeh/scheduler.py index 508d35eb037..dbef4064905 100644 --- a/distributed/bokeh/scheduler.py +++ b/distributed/bokeh/scheduler.py @@ -10,8 +10,6 @@ import os import bokeh -from bokeh.application import Application -from bokeh.application.handlers.function import FunctionHandler from bokeh.layouts import column, row from bokeh.models import (ColumnDataSource, DataRange1d, HoverTool, ResetTool, PanTool, WheelZoomTool, TapTool, OpenURL, Range1d, Plot, Quad, @@ -51,12 +49,8 @@ PROFILING = False -import jinja2 - -with open(os.path.join(os.path.dirname(__file__), 'template.html')) as f: - template_source = f.read() - -template = jinja2.Template(template_source) +from jinja2 import Environment, FileSystemLoader +env = Environment(loader=FileSystemLoader(os.path.join(os.path.dirname(__file__), 'templates'))) template_variables = {'pages': ['status', 'workers', 'tasks', 'system', 'profile', 'graph']} @@ -183,6 +177,7 @@ def __init__(self, scheduler, **kwargs): self.root = figure(title='Tasks Processing', id='bk-nprocessing-histogram-plot', + name='processing_hist', **kwargs) self.root.xaxis.minor_tick_line_alpha = 0 @@ -215,6 +210,7 @@ def __init__(self, scheduler, **kwargs): 'top': [0, 0]}) self.root = figure(title='Bytes Stored', + name='nbytes_hist', id='bk-nbytes-histogram-plot', **kwargs) self.root.xaxis[0].formatter = NumeralTickFormatter(format='0.0 b') @@ -258,7 +254,7 @@ def __init__(self, scheduler, width=600, **kwargs): 'bokeh_address': ['', '']}) processing = figure(title='Tasks Processing', tools='', id='bk-nprocessing-plot', - width=int(width / 2), **kwargs) + name='processing_hist', width=int(width / 2), **kwargs) rect = processing.rect(source=self.source, x='nprocessing-half', y='y', width='nprocessing', height=1, @@ -268,7 +264,7 @@ def __init__(self, scheduler, width=600, **kwargs): nbytes = figure(title='Bytes stored', tools='', id='bk-nbytes-worker-plot', width=int(width / 2), - **kwargs) + name='nbytes_hist', **kwargs) rect = nbytes.rect(source=self.source, x='nbytes-half', y='y', width='nbytes', height=1, @@ -306,7 +302,6 @@ def __init__(self, scheduler, width=600, **kwargs): self.nbytes_figure = nbytes processing.y_range = nbytes.y_range - self.root = row(nbytes, processing, sizing_mode='scale_width') def update(self): with log_errors(): @@ -763,7 +758,7 @@ def __init__(self, scheduler, **kwargs): y_range = Range1d(-8, 0) self.root = figure( - id='bk-task-progress-plot', title='Progress', + id='bk-task-progress-plot', title='Progress', name='task_progress', x_range=x_range, y_range=y_range, toolbar_location=None, **kwargs ) self.root.line( # just to define early ranges @@ -1061,7 +1056,7 @@ def systemmonitor_doc(scheduler, extra, doc): doc.add_periodic_callback(sysmon.update, 500) doc.add_root(column(sysmon.root, sizing_mode='scale_width')) - doc.template = template + doc.template = env.get_template('simple.html') doc.template_variables['active_page'] = 'system' doc.template_variables.update(extra) @@ -1081,7 +1076,7 @@ def stealing_doc(scheduler, extra, doc): stealing_events.root, sizing_mode='scale_width')) - doc.template = template + doc.template = env.get_template('simple.html') doc.template_variables['active_page'] = 'stealing' doc.template_variables.update(extra) @@ -1093,7 +1088,7 @@ def events_doc(scheduler, extra, doc): doc.add_periodic_callback(events.update, 500) doc.title = "Dask: Scheduler Events" doc.add_root(column(events.root, sizing_mode='scale_width')) - doc.template = template + doc.template = env.get_template('simple.html') doc.template_variables['active_page'] = 'events' doc.template_variables.update(extra) @@ -1105,7 +1100,7 @@ def workers_doc(scheduler, extra, doc): doc.add_periodic_callback(table.update, 500) doc.title = "Dask: Workers" doc.add_root(table.root) - doc.template = template + doc.template = env.get_template('simple.html') doc.template_variables['active_page'] = 'workers' doc.template_variables.update(extra) @@ -1118,7 +1113,7 @@ def tasks_doc(scheduler, extra, doc): doc.add_periodic_callback(ts.update, 5000) doc.title = "Dask: Task Stream" doc.add_root(ts.root) - doc.template = template + doc.template = env.get_template('simple.html') doc.template_variables['active_page'] = 'tasks' doc.template_variables.update(extra) @@ -1131,7 +1126,7 @@ def graph_doc(scheduler, extra, doc): doc.add_periodic_callback(graph.update, 200) doc.add_root(graph.root) - doc.template = template + doc.template = env.get_template('simple.html') doc.template_variables['active_page'] = 'graph' doc.template_variables.update(extra) @@ -1139,36 +1134,38 @@ def graph_doc(scheduler, extra, doc): def status_doc(scheduler, extra, doc): with log_errors(): task_stream = TaskStream(scheduler, n_rectangles=1000, - clear_interval='10s', height=350) + clear_interval='10s', sizing_mode='stretch_both') task_stream.update() doc.add_periodic_callback(task_stream.update, 100) - task_progress = TaskProgress(scheduler, height=160) + task_progress = TaskProgress(scheduler, sizing_mode='stretch_both') task_progress.update() doc.add_periodic_callback(task_progress.update, 100) if len(scheduler.workers) < 50: - current_load = CurrentLoad(scheduler, height=160) + current_load = CurrentLoad(scheduler, sizing_mode='stretch_both') current_load.update() doc.add_periodic_callback(current_load.update, 100) - current_load_fig = current_load.root + doc.add_root(current_load.nbytes_figure) + doc.add_root(current_load.processing_figure) else: - nbytes_hist = NBytesHistogram(scheduler, width=300, height=160) + nbytes_hist = NBytesHistogram(scheduler, sizing_mode='stretch_both') nbytes_hist.update() - processing_hist = ProcessingHistogram(scheduler, width=300, - height=160) + processing_hist = ProcessingHistogram(scheduler, sizing_mode='stretch_both') processing_hist.update() doc.add_periodic_callback(nbytes_hist.update, 100) doc.add_periodic_callback(processing_hist.update, 100) current_load_fig = row(nbytes_hist.root, processing_hist.root, - sizing_mode='scale_width') + sizing_mode='stretch_both') + + doc.add_root(nbytes_hist.root) + doc.add_root(processing_hist.root) doc.title = "Dask: Status" - doc.add_root(column(current_load_fig, - task_stream.root, - task_progress.root, - sizing_mode='scale_width')) - doc.template = template + doc.add_root(task_progress.root) + doc.add_root(task_stream.root) + + doc.template = env.get_template('status.html') doc.template_variables['active_page'] = 'status' doc.template_variables.update(extra) @@ -1181,11 +1178,18 @@ def individual_task_stream_doc(scheduler, extra, doc): doc.add_root(task_stream.root) -def individual_load_doc(scheduler, extra, doc): - current_load = CurrentLoad(scheduler, height=160, sizing_mode='stretch_both') +def individual_nbytes_doc(scheduler, extra, doc): + current_load = CurrentLoad(scheduler, sizing_mode='stretch_both') current_load.update() doc.add_periodic_callback(current_load.update, 100) - doc.add_root(current_load.root) + doc.add_root(current_load.nbytes_figure) + + +def individual_nprocessing_doc(scheduler, extra, doc): + current_load = CurrentLoad(scheduler, sizing_mode='stretch_both') + current_load.update() + doc.add_periodic_callback(current_load.update, 100) + doc.add_root(current_load.processing_figure) def individual_progress_doc(scheduler, extra, doc): @@ -1230,7 +1234,7 @@ def profile_doc(scheduler, extra, doc): doc.title = "Dask: Profile" prof = ProfileTimePlot(scheduler, sizing_mode='scale_width', doc=doc) doc.add_root(prof.root) - doc.template = template + doc.template = env.get_template('simple.html') doc.template_variables['active_page'] = 'profile' doc.template_variables.update(extra) @@ -1242,7 +1246,7 @@ def profile_server_doc(scheduler, extra, doc): doc.title = "Dask: Profile of Event Loop" prof = ProfileServer(scheduler, sizing_mode='scale_width', doc=doc) doc.add_root(prof.root) - doc.template = template + doc.template = env.get_template('simple.html') # doc.template_variables['active_page'] = 'profile' doc.template_variables.update(extra) @@ -1261,48 +1265,31 @@ def __init__(self, scheduler, io_loop=None, prefix='', **kwargs): self.server_kwargs = kwargs self.server_kwargs['prefix'] = prefix or None - systemmonitor = Application(FunctionHandler(partial(systemmonitor_doc, scheduler, self.extra))) - workers = Application(FunctionHandler(partial(workers_doc, scheduler, self.extra))) - stealing = Application(FunctionHandler(partial(stealing_doc, scheduler, self.extra))) - counters = Application(FunctionHandler(partial(counters_doc, scheduler, self.extra))) - events = Application(FunctionHandler(partial(events_doc, scheduler, self.extra))) - tasks = Application(FunctionHandler(partial(tasks_doc, scheduler, self.extra))) - status = Application(FunctionHandler(partial(status_doc, scheduler, self.extra))) - profile = Application(FunctionHandler(partial(profile_doc, scheduler, self.extra))) - profile_server = Application(FunctionHandler(partial(profile_server_doc, scheduler, self.extra))) - graph = Application(FunctionHandler(partial(graph_doc, scheduler, self.extra))) - - individual_task_stream = Application(FunctionHandler(partial( - individual_task_stream_doc, scheduler, self.extra))) - individual_progress = Application(FunctionHandler(partial(individual_progress_doc, scheduler, self.extra))) - individual_graph = Application(FunctionHandler(partial(individual_graph_doc, scheduler, self.extra))) - individual_profile = Application(FunctionHandler(partial(individual_profile_doc, scheduler, self.extra))) - individual_profile_server = Application(FunctionHandler(partial( - individual_profile_server_doc, scheduler, self.extra))) - individual_load = Application(FunctionHandler(partial(individual_load_doc, scheduler, self.extra))) - individual_workers = Application(FunctionHandler(partial(individual_workers_doc, scheduler, self.extra))) - self.apps = { - '/system': systemmonitor, - '/stealing': stealing, - '/workers': workers, - '/events': events, - '/counters': counters, - '/tasks': tasks, - '/status': status, - '/profile': profile, - '/profile-server': profile_server, - '/graph': graph, - - '/individual-task-stream': individual_task_stream, - '/individual-progress': individual_progress, - '/individual-graph': individual_graph, - '/individual-profile': individual_profile, - '/individual-profile-server': individual_profile_server, - '/individual-load': individual_load, - '/individual-workers': individual_workers, + '/system': systemmonitor_doc, + '/stealing': stealing_doc, + '/workers': workers_doc, + '/events': events_doc, + '/counters': counters_doc, + '/tasks': tasks_doc, + '/status': status_doc, + '/profile': profile_doc, + '/profile-server': profile_server_doc, + '/graph': graph_doc, + + '/individual-task-stream': individual_task_stream_doc, + '/individual-progress': individual_progress_doc, + '/individual-graph': individual_graph_doc, + '/individual-profile': individual_profile_doc, + '/individual-profile-server': individual_profile_server_doc, + '/individual-nbytes': individual_nbytes_doc, + '/individual-nprocessing': individual_nprocessing_doc, + '/individual-workers': individual_workers_doc, } + self.apps = {k: partial(v, scheduler, self.extra) + for k, v in self.apps.items()} + self.loop = io_loop or scheduler.loop self.server = None diff --git a/distributed/bokeh/static/dask_horizontal.svg b/distributed/bokeh/static/dask_horizontal.svg deleted file mode 100644 index bfce8ca6b67..00000000000 --- a/distributed/bokeh/static/dask_horizontal.svg +++ /dev/null @@ -1,28 +0,0 @@ - - - - - - - - - - - - - - - Dask - - - - - - - - - diff --git a/distributed/bokeh/static/dask_icon_no_pad.svg b/distributed/bokeh/static/dask_icon_no_pad.svg new file mode 100644 index 00000000000..8999ed4a720 --- /dev/null +++ b/distributed/bokeh/static/dask_icon_no_pad.svg @@ -0,0 +1,35 @@ + + + + + + + + + + + + + + + Dask + + + + + diff --git a/distributed/bokeh/template.html b/distributed/bokeh/template.html deleted file mode 100644 index e0b34e00448..00000000000 --- a/distributed/bokeh/template.html +++ /dev/null @@ -1,123 +0,0 @@ - - - - - Dask - Status dashboard - {{ bokeh_css }} - {{ bokeh_js }} - - - - -
- {{ plot_div }} -
- {{ plot_script }} - - diff --git a/distributed/bokeh/templates/base.html b/distributed/bokeh/templates/base.html new file mode 100644 index 00000000000..ad9ffd152b4 --- /dev/null +++ b/distributed/bokeh/templates/base.html @@ -0,0 +1,141 @@ + + + + + Dask Diagnostic UI + + + {% block resources %} + {% block js_resources %} + {{ bokeh_css | indent(8) if bokeh_css }} + {% endblock %} + {% block css_resources%} + {{ bokeh_js | indent(8) if bokeh_js }} + {% endblock %} + {% endblock %} + + + +
+ {% block content %} + {% endblock %} +
+ + diff --git a/distributed/bokeh/templates/simple.html b/distributed/bokeh/templates/simple.html new file mode 100644 index 00000000000..6f982b44f1c --- /dev/null +++ b/distributed/bokeh/templates/simple.html @@ -0,0 +1,6 @@ +{% extends "base.html" %} + +{% block content %} +{{ plot_div }} +{{ plot_script }} +{% endblock %} diff --git a/distributed/bokeh/templates/status.html b/distributed/bokeh/templates/status.html new file mode 100644 index 00000000000..face484386f --- /dev/null +++ b/distributed/bokeh/templates/status.html @@ -0,0 +1,79 @@ +{% extends "base.html" %} + +{% block content %} +{% from macros import embed %} + +
+ +
+ {{ embed(roots.nbytes_hist) }} +
+ +
+ {{ embed(roots.processing_hist) }} +
+ +
+ {{ embed(roots.task_stream) }} +
+ +
+ {{ embed(roots.task_progress) }} +
+ +
+{{ plot_script }} + +{% endblock %} diff --git a/distributed/bokeh/tests/test_scheduler_bokeh.py b/distributed/bokeh/tests/test_scheduler_bokeh.py index ef8599bb6bd..bf0eca90221 100644 --- a/distributed/bokeh/tests/test_scheduler_bokeh.py +++ b/distributed/bokeh/tests/test_scheduler_bokeh.py @@ -44,7 +44,8 @@ def test_simple(c, s, a, b): http_client = AsyncHTTPClient() for suffix in ['system', 'counters', 'workers', 'status', 'tasks', 'stealing', 'graph', 'individual-task-stream', 'individual-progress', - 'individual-graph', 'individual-load', + 'individual-graph', 'individual-nbytes', + 'individual-nprocessing', 'individual-profile']: response = yield http_client.fetch('http://localhost:%d/%s' % (port, suffix)) assert 'bokeh' in response.body.decode().lower() diff --git a/distributed/bokeh/worker.py b/distributed/bokeh/worker.py index 687da327d79..257eab7284c 100644 --- a/distributed/bokeh/worker.py +++ b/distributed/bokeh/worker.py @@ -26,12 +26,12 @@ logger = logging.getLogger(__name__) -import jinja2 - -with open(os.path.join(os.path.dirname(__file__), 'template.html')) as f: +with open(os.path.join(os.path.dirname(__file__), 'templates', 'base.html')) as f: template_source = f.read() -template = jinja2.Template(template_source) +from jinja2 import Environment, FileSystemLoader +env = Environment(loader=FileSystemLoader(os.path.join(os.path.dirname(__file__), 'templates'))) + template_variables = {'pages': ['main', 'system', 'profile', 'crossfilter']} @@ -560,7 +560,7 @@ def main_doc(worker, extra, doc): communicating_ts.root, communicating_stream.root, sizing_mode='scale_width')) - doc.template = template + doc.template = env.get_template('simple.html') doc.template_variables['active_page'] = 'main' doc.template_variables.update(extra) @@ -575,7 +575,7 @@ def crossfilter_doc(worker, extra, doc): doc.add_periodic_callback(crossfilter.update, 500) doc.add_root(column(statetable.root, crossfilter.root)) - doc.template = template + doc.template = env.get_template('simple.html') doc.template_variables['active_page'] = 'crossfilter' doc.template_variables.update(extra) @@ -587,7 +587,7 @@ def systemmonitor_doc(worker, extra, doc): doc.add_periodic_callback(sysmon.update, 500) doc.add_root(sysmon.root) - doc.template = template + doc.template = env.get_template('simple.html') doc.template_variables['active_page'] = 'system' doc.template_variables.update(extra) @@ -599,7 +599,7 @@ def counters_doc(server, extra, doc): doc.add_periodic_callback(counter.update, 500) doc.add_root(counter.root) - doc.template = template + doc.template = env.get_template('simple.html') doc.template_variables['active_page'] = 'counters' doc.template_variables.update(extra) @@ -611,7 +611,7 @@ def profile_doc(server, extra, doc): profile.trigger_update() doc.add_root(profile.root) - doc.template = template + doc.template = env.get_template('simple.html') doc.template_variables['active_page'] = 'profile' doc.template_variables.update(extra) @@ -621,7 +621,7 @@ def profile_server_doc(server, extra, doc): doc.title = "Dask: Profile of Event Loop" prof = ProfileServer(server, sizing_mode='scale_width', doc=doc) doc.add_root(prof.root) - doc.template = template + doc.template = env.get_template('simple.html') # doc.template_variables['active_page'] = '' doc.template_variables.update(extra) From 5c023259ee9f91d2ba00f5448c8a73f72f91b250 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 29 Aug 2018 21:08:25 -0400 Subject: [PATCH 0077/1550] Fix deserialization of queues on main ioloop thread (#2221) * Fix deserialization of queues on main ioloop thread Previously we had a difficult time determining that we were on the IOLoop thread and should act asynchronously. This adds a new thread local, `on_event_loop_thread` to verify this explicitly Fixes https://github.com/dask/distributed/issues/2220 * clear thread_state.on_event_loop_thread in gen_cluster --- distributed/queues.py | 6 +++--- distributed/tests/test_queues.py | 16 ++++++++++++++++ distributed/utils_test.py | 5 ++++- distributed/variable.py | 2 +- distributed/worker.py | 1 + 5 files changed, 25 insertions(+), 5 deletions(-) diff --git a/distributed/queues.py b/distributed/queues.py index 803985a30be..fda6daaae5c 100644 --- a/distributed/queues.py +++ b/distributed/queues.py @@ -9,7 +9,7 @@ import tornado.queues from .client import Future, _get_global_client, Client -from .utils import tokey, sync +from .utils import tokey, sync, thread_state from .worker import get_client logger = logging.getLogger(__name__) @@ -165,7 +165,7 @@ class Queue(object): def __init__(self, name=None, client=None, maxsize=0): self.client = client or _get_global_client() self.name = name or 'queue-' + uuid.uuid4().hex - if self.client.asynchronous: + if self.client.asynchronous or getattr(thread_state, 'on_event_loop_thread', False): self._started = self.client.scheduler.queue_create(name=self.name, maxsize=maxsize) else: @@ -258,7 +258,7 @@ def __setstate__(self, state): name, address = state try: client = get_client(address) - assert client.address == address + assert client.scheduler.address == address except (AttributeError, AssertionError): client = Client(address, set_as_default=False) self.__init__(name=name, client=client) diff --git a/distributed/tests/test_queues.py b/distributed/tests/test_queues.py index 913434d6909..faa8707cdf0 100644 --- a/distributed/tests/test_queues.py +++ b/distributed/tests/test_queues.py @@ -273,3 +273,19 @@ def test_timeout(c, s, a, b): yield q.put(2, timeout=0.3) stop = time() assert 0.1 < stop - start < 2.0 + + +@gen_cluster(client=True) +def test_2220(c, s, a, b): + q = Queue() + + def put(): + q.put(55) + + def get(): + print(q.get()) + + fut = c.submit(put) + res = c.submit(get) + + yield [res, fut] diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 900abab06ed..0898be211b2 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -49,7 +49,8 @@ from .proctitle import enable_proctitle_on_children from .security import Security from .utils import (ignoring, log_errors, mp_context, get_ip, get_ipv6, - DequeHandler, reset_logger_locks, sync, iscoroutinefunction) + DequeHandler, reset_logger_locks, sync, + iscoroutinefunction, thread_state) from .worker import Worker, TOTAL_MEMORY, _global_workers try: @@ -863,6 +864,8 @@ def coro(): call_stacks = profile.call_stack(sys._current_frames()[tid]) assert False, (thread, call_stacks) _cleanup_dangling() + with ignoring(AttributeError): + del thread_state.on_event_loop_thread return result return test_func diff --git a/distributed/variable.py b/distributed/variable.py index b21a047ce8c..5d905358a9e 100644 --- a/distributed/variable.py +++ b/distributed/variable.py @@ -207,7 +207,7 @@ def __setstate__(self, state): name, address = state try: client = get_client(address) - assert client.address == address + assert client.scheduler.address == address except (AttributeError, AssertionError): client = Client(address, set_as_default=False) self.__init__(name=name, client=client) diff --git a/distributed/worker.py b/distributed/worker.py index 66535e4517d..45ef87d739c 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -402,6 +402,7 @@ def _start(self, addr_or_port=0): assert self.status is None enable_gc_diagnosis() + thread_state.on_event_loop_thread = True # XXX Factor this out if not addr_or_port: From ab83d8566cbab8b745ae6c62c45300c0927debde Mon Sep 17 00:00:00 2001 From: Guillaume EB Date: Thu, 30 Aug 2018 09:10:07 -0700 Subject: [PATCH 0078/1550] Add a worker initialization function (#2201) --- distributed/client.py | 31 ++++++++++++++ distributed/scheduler.py | 16 ++++++- distributed/tests/test_worker.py | 73 ++++++++++++++++++++++++++++++++ distributed/worker.py | 14 +++++- 4 files changed, 131 insertions(+), 3 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 6461065dfe9..079b3994087 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -3425,6 +3425,37 @@ def _get_task_stream(self, start=None, stop=None, count=None, plot=False, else: raise gen.Return(msgs) + @gen.coroutine + def _register_worker_callbacks(self, setup=None): + responses = yield self.scheduler.register_worker_callbacks(setup=dumps(setup)) + results = {} + for key, resp in responses.items(): + if resp['status'] == 'OK': + results[key] = resp['result'] + elif resp['status'] == 'error': + six.reraise(*clean_exception(**resp)) + raise gen.Return(results) + + def register_worker_callbacks(self, setup=None): + """ + Registers a setup callback function for all current and future workers. + + This registers a new setup function for workers in this cluster. The + function will run immediately on all currently connected workers. It + will also be run upon connection by any workers that are added in the + future. Multiple setup functions can be registered - these will be + called in the order they were added. + + If the function takes an input argument named ``dask_worker`` then + that variable will be populated with the worker itself. + + Parameters + ---------- + setup : callable(dask_worker: Worker) -> None + Function to register and run on all workers + """ + return self.sync(self._register_worker_callbacks, setup=setup) + class Executor(Client): """ Deprecated: see Client """ diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 86c75d1872d..dd72a03ca7c 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -899,6 +899,7 @@ def __init__( self.plugins = [] self.transition_log = deque(maxlen=dask.config.get('distributed.scheduler.transition-log-length')) self.log = deque(maxlen=dask.config.get('distributed.scheduler.transition-log-length')) + self.worker_setups = [] worker_handlers = { 'task-finished': self.handle_task_finished, @@ -956,6 +957,7 @@ def __init__( 'heartbeat_worker': self.heartbeat_worker, 'get_task_status': self.get_task_status, 'get_task_stream': self.get_task_stream, + 'register_worker_callbacks': self.register_worker_callbacks } self._transitions = { @@ -1330,7 +1332,8 @@ def add_worker(self, comm=None, address=None, keys=(), ncores=None, yield comm.write({'status': 'OK', 'time': time(), - 'heartbeat-interval': heartbeat_interval(len(self.workers))}) + 'heartbeat-interval': heartbeat_interval(len(self.workers)), + 'worker-setups': self.worker_setups}) yield self.handle_worker(comm=comm, worker=address) def update_graph(self, client=None, tasks=None, keys=None, @@ -3035,6 +3038,17 @@ def get_task_stream(self, comm=None, start=None, stop=None, count=None): ts = [p for p in self.plugins if isinstance(p, TaskStreamPlugin)][0] return ts.collect(start=start, stop=stop, count=count) + @gen.coroutine + def register_worker_callbacks(self, comm, setup=None): + """ Registers a setup function, and call it on every worker """ + if setup is None: + raise gen.Return({}) + + self.worker_setups.append(setup) + + responses = yield self.broadcast(msg=dict(op='run', function=setup)) + raise gen.Return(responses) + ##################### # State Transitions # ##################### diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index d8a3d5f2371..f85638d3656 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1203,3 +1203,76 @@ def test_avoid_oversubscription(c, s, *workers): def test_custom_metrics(c, s, a, b): assert s.workers[a.address].metrics['my_port'] == a.port assert s.workers[b.address].metrics['my_port'] == b.port + + +@gen_cluster(client=True) +def test_register_worker_callbacks(c, s, a, b): + #preload function to run + def mystartup(dask_worker): + dask_worker.init_variable = 1 + + def mystartup2(): + import os + os.environ['MY_ENV_VALUE'] = 'WORKER_ENV_VALUE' + return "Env set." + + #Check that preload function has been run + def test_import(dask_worker): + return hasattr(dask_worker, 'init_variable') + # and dask_worker.init_variable == 1 + + def test_startup2(): + import os + return os.getenv('MY_ENV_VALUE', None) == 'WORKER_ENV_VALUE' + + # Nothing has been run yet + assert len(s.worker_setups) == 0 + result = yield c.run(test_import) + assert list(result.values()) == [False] * 2 + result = yield c.run(test_startup2) + assert list(result.values()) == [False] * 2 + + # Start a worker and check that startup is not run + worker = Worker(s.address, loop=s.loop) + yield worker._start() + result = yield c.run(test_import, workers=[worker.address]) + assert list(result.values()) == [False] + yield worker._close() + + # Add a preload function + response = yield c.register_worker_callbacks(setup=mystartup) + assert len(response) == 2 + assert len(s.worker_setups) == 1 + + # Check it has been ran on existing worker + result = yield c.run(test_import) + assert list(result.values()) == [True] * 2 + + # Start a worker and check it is ran on it + worker = Worker(s.address, loop=s.loop) + yield worker._start() + result = yield c.run(test_import, workers=[worker.address]) + assert list(result.values()) == [True] + yield worker._close() + + # Register another preload function + response = yield c.register_worker_callbacks(setup=mystartup2) + assert len(response) == 2 + assert len(s.worker_setups) == 2 + + # Check it has been run + result = yield c.run(test_startup2) + assert list(result.values()) == [True] * 2 + + # Start a worker and check it is ran on it + worker = Worker(s.address, loop=s.loop) + yield worker._start() + result = yield c.run(test_import, workers=[worker.address]) + assert list(result.values()) == [True] + result = yield c.run(test_startup2, workers=[worker.address]) + assert list(result.values()) == [True] + yield worker._close() + + # Final exception test + with pytest.raises(ZeroDivisionError): + yield c.register_worker_callbacks(setup=lambda: 1 / 0) diff --git a/distributed/worker.py b/distributed/worker.py index 45ef87d739c..c0ce3259276 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -342,6 +342,15 @@ def _register_with_scheduler(self): raise ValueError("Unexpected response from register: %r" % (response,)) else: + # Retrieve eventual init functions and run them + for function_bytes in response['worker-setups']: + setup_function = pickle.loads(function_bytes) + if has_arg(setup_function, 'dask_worker'): + result = setup_function(dask_worker=self) + else: + result = setup_function() + logger.info('Init function %s ran: output=%s' % (setup_function, result)) + logger.info(' Registered to: %26s', self.scheduler.address) logger.info('-' * 49) @@ -567,10 +576,11 @@ def executor_submit(self, key, function, args=(), kwargs=None, # logger.info("Finish job %d, %s", i, key) raise gen.Return(result) - def run(self, comm, function, args=(), kwargs={}): + def run(self, comm, function, args=(), kwargs=None): + kwargs = kwargs or {} return run(self, comm, function=function, args=args, kwargs=kwargs) - def run_coroutine(self, comm, function, args=(), kwargs={}, wait=True): + def run_coroutine(self, comm, function, args=(), kwargs=None, wait=True): return run(self, comm, function=function, args=args, kwargs=kwargs, is_coro=True, wait=wait) From 6f583f440e8a6c1635779d0618fcc60f568bdec4 Mon Sep 17 00:00:00 2001 From: Luke Canavan Date: Thu, 30 Aug 2018 11:27:00 -0500 Subject: [PATCH 0079/1550] Canavandl/collapse navbar (#2223) * Move CSS into css file and add collapsing navbar * Move navbar JS into js file * Add bokeh theme to make status doc backgrounds clear * Move active page highlight logic to JS * Move status css to standalone file * use self-closing tag * [skip ci] add css to MANIFEST.in * add bokeh theme to all plot --- MANIFEST.in | 1 + distributed/bokeh/scheduler.py | 30 ++-- distributed/bokeh/static/css/base.css | 108 +++++++++++++ distributed/bokeh/static/css/status.css | 50 ++++++ .../dask-logo.svg} | 0 distributed/bokeh/static/images/fa-bars.svg | 1 + distributed/bokeh/templates/base.html | 151 +++++------------- distributed/bokeh/templates/status.html | 57 +------ distributed/bokeh/theme.yaml | 5 + distributed/bokeh/worker.py | 9 ++ 10 files changed, 235 insertions(+), 177 deletions(-) create mode 100644 distributed/bokeh/static/css/base.css create mode 100644 distributed/bokeh/static/css/status.css rename distributed/bokeh/static/{dask_icon_no_pad.svg => images/dask-logo.svg} (100%) create mode 100644 distributed/bokeh/static/images/fa-bars.svg create mode 100644 distributed/bokeh/theme.yaml diff --git a/MANIFEST.in b/MANIFEST.in index ac05efbf329..a6c03274f24 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -2,6 +2,7 @@ recursive-include distributed *.py recursive-include distributed *.js recursive-include distributed *.coffee recursive-include distributed *.html +recursive-include distributed *.css recursive-include distributed *.svg recursive-include distributed *.yaml recursive-include docs *.rst diff --git a/distributed/bokeh/scheduler.py b/distributed/bokeh/scheduler.py index dbef4064905..0c34f68d651 100644 --- a/distributed/bokeh/scheduler.py +++ b/distributed/bokeh/scheduler.py @@ -18,6 +18,7 @@ from bokeh.models.widgets import DataTable, TableColumn from bokeh.plotting import figure from bokeh.palettes import Viridis11 +from bokeh.themes import Theme from bokeh.transform import factor_cmap from bokeh.io import curdoc from toolz import pipe, merge @@ -54,6 +55,7 @@ template_variables = {'pages': ['status', 'workers', 'tasks', 'system', 'profile', 'graph']} +BOKEH_THEME = Theme(os.path.join(os.path.dirname(__file__), 'theme.yaml')) nan = float('nan') @@ -1057,8 +1059,8 @@ def systemmonitor_doc(scheduler, extra, doc): doc.add_root(column(sysmon.root, sizing_mode='scale_width')) doc.template = env.get_template('simple.html') - doc.template_variables['active_page'] = 'system' doc.template_variables.update(extra) + doc.theme = BOKEH_THEME def stealing_doc(scheduler, extra, doc): @@ -1077,8 +1079,8 @@ def stealing_doc(scheduler, extra, doc): sizing_mode='scale_width')) doc.template = env.get_template('simple.html') - doc.template_variables['active_page'] = 'stealing' doc.template_variables.update(extra) + doc.theme = BOKEH_THEME def events_doc(scheduler, extra, doc): @@ -1089,8 +1091,8 @@ def events_doc(scheduler, extra, doc): doc.title = "Dask: Scheduler Events" doc.add_root(column(events.root, sizing_mode='scale_width')) doc.template = env.get_template('simple.html') - doc.template_variables['active_page'] = 'events' doc.template_variables.update(extra) + doc.theme = BOKEH_THEME def workers_doc(scheduler, extra, doc): @@ -1101,8 +1103,8 @@ def workers_doc(scheduler, extra, doc): doc.title = "Dask: Workers" doc.add_root(table.root) doc.template = env.get_template('simple.html') - doc.template_variables['active_page'] = 'workers' doc.template_variables.update(extra) + doc.theme = BOKEH_THEME def tasks_doc(scheduler, extra, doc): @@ -1114,8 +1116,8 @@ def tasks_doc(scheduler, extra, doc): doc.title = "Dask: Task Stream" doc.add_root(ts.root) doc.template = env.get_template('simple.html') - doc.template_variables['active_page'] = 'tasks' doc.template_variables.update(extra) + doc.theme = BOKEH_THEME def graph_doc(scheduler, extra, doc): @@ -1127,8 +1129,8 @@ def graph_doc(scheduler, extra, doc): doc.add_root(graph.root) doc.template = env.get_template('simple.html') - doc.template_variables['active_page'] = 'graph' doc.template_variables.update(extra) + doc.theme = BOKEH_THEME def status_doc(scheduler, extra, doc): @@ -1164,10 +1166,10 @@ def status_doc(scheduler, extra, doc): doc.title = "Dask: Status" doc.add_root(task_progress.root) doc.add_root(task_stream.root) - + doc.theme = BOKEH_THEME doc.template = env.get_template('status.html') - doc.template_variables['active_page'] = 'status' doc.template_variables.update(extra) + doc.theme = BOKEH_THEME def individual_task_stream_doc(scheduler, extra, doc): @@ -1176,6 +1178,7 @@ def individual_task_stream_doc(scheduler, extra, doc): task_stream.update() doc.add_periodic_callback(task_stream.update, 100) doc.add_root(task_stream.root) + doc.theme = BOKEH_THEME def individual_nbytes_doc(scheduler, extra, doc): @@ -1183,6 +1186,7 @@ def individual_nbytes_doc(scheduler, extra, doc): current_load.update() doc.add_periodic_callback(current_load.update, 100) doc.add_root(current_load.nbytes_figure) + doc.theme = BOKEH_THEME def individual_nprocessing_doc(scheduler, extra, doc): @@ -1190,6 +1194,7 @@ def individual_nprocessing_doc(scheduler, extra, doc): current_load.update() doc.add_periodic_callback(current_load.update, 100) doc.add_root(current_load.processing_figure) + doc.theme = BOKEH_THEME def individual_progress_doc(scheduler, extra, doc): @@ -1197,6 +1202,7 @@ def individual_progress_doc(scheduler, extra, doc): task_progress.update() doc.add_periodic_callback(task_progress.update, 100) doc.add_root(task_progress.root) + doc.theme = BOKEH_THEME def individual_graph_doc(scheduler, extra, doc): @@ -1205,6 +1211,7 @@ def individual_graph_doc(scheduler, extra, doc): graph.update() doc.add_periodic_callback(graph.update, 200) doc.add_root(graph.root) + doc.theme = BOKEH_THEME def individual_profile_doc(scheduler, extra, doc): @@ -1212,6 +1219,7 @@ def individual_profile_doc(scheduler, extra, doc): prof = ProfileTimePlot(scheduler, sizing_mode='scale_width', doc=doc) doc.add_root(prof.root) prof.trigger_update() + doc.theme = BOKEH_THEME def individual_profile_server_doc(scheduler, extra, doc): @@ -1219,6 +1227,7 @@ def individual_profile_server_doc(scheduler, extra, doc): prof = ProfileServer(scheduler, sizing_mode='scale_width', doc=doc) doc.add_root(prof.root) prof.trigger_update() + doc.theme = BOKEH_THEME def individual_workers_doc(scheduler, extra, doc): @@ -1227,6 +1236,7 @@ def individual_workers_doc(scheduler, extra, doc): table.update() doc.add_periodic_callback(table.update, 500) doc.add_root(table.root) + doc.theme = BOKEH_THEME def profile_doc(scheduler, extra, doc): @@ -1235,8 +1245,8 @@ def profile_doc(scheduler, extra, doc): prof = ProfileTimePlot(scheduler, sizing_mode='scale_width', doc=doc) doc.add_root(prof.root) doc.template = env.get_template('simple.html') - doc.template_variables['active_page'] = 'profile' doc.template_variables.update(extra) + doc.theme = BOKEH_THEME prof.trigger_update() @@ -1247,8 +1257,8 @@ def profile_server_doc(scheduler, extra, doc): prof = ProfileServer(scheduler, sizing_mode='scale_width', doc=doc) doc.add_root(prof.root) doc.template = env.get_template('simple.html') - # doc.template_variables['active_page'] = 'profile' doc.template_variables.update(extra) + doc.theme = BOKEH_THEME prof.trigger_update() diff --git a/distributed/bokeh/static/css/base.css b/distributed/bokeh/static/css/base.css new file mode 100644 index 00000000000..4731f93973b --- /dev/null +++ b/distributed/bokeh/static/css/base.css @@ -0,0 +1,108 @@ +html { + width: 100%; + height: 100%; + background: #FAFAFA; +} + +body { + height: 100%; + width: 100%; + margin: 0; + box-sizing: border-box; + font-family: Helvetica, Arial, sans-serif; + padding: 0px 10px; + padding-top: 3rem; + padding-bottom: 1rem; +} + +.navbar { + position: fixed; + top: 0; + left: 0; + right: 0; + z-index: 1; +} + +.navbar ul { + list-style-type: none; + margin: 0; + padding: 0; + overflow: hidden; + background-color: #EEE; +} + +.navbar li { + float: left; + font-size: 17px; + transition: .3s background-color; +} + +.navbar li.active { + background-color: rgba(234, 170, 109, 0.7); +} + +.navbar li a { + display: block; + color: black; + padding: 11px 16px; + text-decoration: none; +} + +.navbar li:hover { + background-color: #eaaa6d; +} + +#dask-logo img { + height: 28px; + padding: 5px 15px; +} + +#dask-logo a { + padding: 0px; +} + +#navbar-toggle-icon { + float: right; +} + +#navbar-toggle-icon a { + display: none; +} + +#navbar-toggle-icon img { + height: 22px; +} + +@media screen and (max-width: 650px) { + .navbar li:not(#dask-logo):not(#navbar-toggle-icon) a { + display: none; + } + #navbar-toggle-icon a { + display: block; + } +} + +@media screen and (max-width: 650px) { + .navbar.responsive li:not(#navbar-toggle-icon) { + float: none; + } + .navbar.responsive li:not(#dask-logo):not(#navbar-toggle-icon) a { + display: block; + text-align: left; + } + .navbar.responsive #navbar-toggle-icon a { + position: absolute; + right: 0; + top: 0; + } +} + +.bk-root .bk-toolbar-box .bk-toolbar-right { + top: 4px; + right: 4px; +} + +.content { + width: 100%; + height: 100%; +} diff --git a/distributed/bokeh/static/css/status.css b/distributed/bokeh/static/css/status.css new file mode 100644 index 00000000000..9de0d01b353 --- /dev/null +++ b/distributed/bokeh/static/css/status.css @@ -0,0 +1,50 @@ +#status-fluid { + display: grid; + height: 100%; +} + +@media (min-width: 0px) { + #status-fluid { + grid-template-columns: 1fr 1fr; + grid-template-rows: 1fr 3fr 1fr; + } + #status-history { + grid-column: 1; + grid-row: 1; + } + #status-processing { + grid-column: 2; + grid-row: 1; + } + #status-tasks { + grid-column: 1 / span 2; + grid-row: 2; + } + #status-progress { + grid-column: 1 / span 2; + grid-row: 3; + } +} + +@media (min-width: 992px) { + #status-fluid { + grid-template-columns: 1fr 3fr; + grid-template-rows: 1fr 1fr 1fr 1fr 1fr 1fr; + } + #status-history { + grid-column: 1; + grid-row: 1 / span 3; + } + #status-processing { + grid-column: 1; + grid-row: 4 / span 3; + } + #status-tasks { + grid-column: 2; + grid-row: 1 / span 4; + } + #status-progress { + grid-column: 2; + grid-row: 5 / span 2; + } +} diff --git a/distributed/bokeh/static/dask_icon_no_pad.svg b/distributed/bokeh/static/images/dask-logo.svg similarity index 100% rename from distributed/bokeh/static/dask_icon_no_pad.svg rename to distributed/bokeh/static/images/dask-logo.svg diff --git a/distributed/bokeh/static/images/fa-bars.svg b/distributed/bokeh/static/images/fa-bars.svg new file mode 100644 index 00000000000..06e78c1c3a5 --- /dev/null +++ b/distributed/bokeh/static/images/fa-bars.svg @@ -0,0 +1 @@ + diff --git a/distributed/bokeh/templates/base.html b/distributed/bokeh/templates/base.html index ad9ffd152b4..ac8a855d1fd 100644 --- a/distributed/bokeh/templates/base.html +++ b/distributed/bokeh/templates/base.html @@ -4,138 +4,61 @@ Dask Diagnostic UI - + {% block resources %} - {% block js_resources %} + {% block css_resources %} {{ bokeh_css | indent(8) if bokeh_css }} {% endblock %} - {% block css_resources%} + {% block js_resources%} {{ bokeh_js | indent(8) if bokeh_js }} {% endblock %} + {% block extra_resources %} + {% endblock %} {% endblock %} - - """ + """, ) self.root.add_tools(hover) @@ -925,8 +1145,9 @@ def update(self): with log_errors(): nb = nbytes_bar(self.plugin.nbytes) update(self.source, nb) - self.root.title.text = \ - "Memory Use: %0.2f MB" % (sum(self.plugin.nbytes.values()) / 1e6) + self.root.title.text = "Memory Use: %0.2f MB" % ( + sum(self.plugin.nbytes.values()) / 1e6 + ) class WorkerTable(DashboardComponent): @@ -935,60 +1156,90 @@ class WorkerTable(DashboardComponent): This is two plots, a text-based table for each host and a thin horizontal plot laying out hosts by their current memory use. """ - excluded_names = {'executing', 'in_flight', 'in_memory', 'ready', 'time'} + + excluded_names = {"executing", "in_flight", "in_memory", "ready", "time"} def __init__(self, scheduler, width=800, **kwargs): self.scheduler = scheduler - self.names = ['worker', 'ncores', 'cpu', 'memory', 'memory_limit', - 'memory_percent', 'num_fds', 'read_bytes', 'write_bytes', - 'cpu_fraction'] + self.names = [ + "worker", + "ncores", + "cpu", + "memory", + "memory_limit", + "memory_percent", + "num_fds", + "read_bytes", + "write_bytes", + "cpu_fraction", + ] workers = self.scheduler.workers.values() - self.extra_names = sorted({m for ws in workers - for m in ws.metrics - if m not in self.names} - self.excluded_names) + self.extra_names = sorted( + {m for ws in workers for m in ws.metrics if m not in self.names} + - self.excluded_names + ) - table_names = ['worker', 'ncores', 'cpu', 'memory', 'memory_limit', - 'memory_percent', 'num_fds', 'read_bytes', - 'write_bytes'] + table_names = [ + "worker", + "ncores", + "cpu", + "memory", + "memory_limit", + "memory_percent", + "num_fds", + "read_bytes", + "write_bytes", + ] self.source = ColumnDataSource({k: [] for k in self.names}) - columns = {name: TableColumn(field=name, - title=name.replace('_percent', ' %')) - for name in table_names} - - formatters = {'cpu': NumberFormatter(format='0.0 %'), - 'memory_percent': NumberFormatter(format='0.0 %'), - 'memory': NumberFormatter(format='0 b'), - 'memory_limit': NumberFormatter(format='0 b'), - 'read_bytes': NumberFormatter(format='0 b'), - 'write_bytes': NumberFormatter(format='0 b'), - 'num_fds': NumberFormatter(format='0'), - 'ncores': NumberFormatter(format='0')} - - if BOKEH_VERSION < '0.12.15': - dt_kwargs = {'row_headers': False} + columns = { + name: TableColumn(field=name, title=name.replace("_percent", " %")) + for name in table_names + } + + formatters = { + "cpu": NumberFormatter(format="0.0 %"), + "memory_percent": NumberFormatter(format="0.0 %"), + "memory": NumberFormatter(format="0 b"), + "memory_limit": NumberFormatter(format="0 b"), + "read_bytes": NumberFormatter(format="0 b"), + "write_bytes": NumberFormatter(format="0 b"), + "num_fds": NumberFormatter(format="0"), + "ncores": NumberFormatter(format="0"), + } + + if BOKEH_VERSION < "0.12.15": + dt_kwargs = {"row_headers": False} else: - dt_kwargs = {'index_position': None} + dt_kwargs = {"index_position": None} table = DataTable( - source=self.source, columns=[columns[n] for n in table_names], - reorderable=True, sortable=True, width=width, **dt_kwargs + source=self.source, + columns=[columns[n] for n in table_names], + reorderable=True, + sortable=True, + width=width, + **dt_kwargs ) for name in table_names: if name in formatters: table.columns[table_names.index(name)].formatter = formatters[name] - extra_names = ['worker'] + self.extra_names - extra_columns = {name: TableColumn(field=name, - title=name.replace('_percent', '%')) - for name in extra_names} + extra_names = ["worker"] + self.extra_names + extra_columns = { + name: TableColumn(field=name, title=name.replace("_percent", "%")) + for name in extra_names + } extra_table = DataTable( source=self.source, columns=[extra_columns[n] for n in extra_names], - reorderable=True, sortable=True, width=width, **dt_kwargs + reorderable=True, + sortable=True, + width=width, + **dt_kwargs ) hover = HoverTool( @@ -998,14 +1249,22 @@ def __init__(self, scheduler, width=800, **kwargs): @worker: @memory_percent - """ + """, ) - mem_plot = figure(title='Memory Use (%)', toolbar_location=None, - x_range=(0, 1), y_range=(-0.1, 0.1), height=60, - width=width, tools='', **kwargs) - mem_plot.circle(source=self.source, x='memory_percent', y=0, - size=10, fill_alpha=0.5) + mem_plot = figure( + title="Memory Use (%)", + toolbar_location=None, + x_range=(0, 1), + y_range=(-0.1, 0.1), + height=60, + width=width, + tools="", + **kwargs + ) + mem_plot.circle( + source=self.source, x="memory_percent", y=0, size=10, fill_alpha=0.5 + ) mem_plot.ygrid.visible = False mem_plot.yaxis.minor_tick_line_alpha = 0 mem_plot.xaxis.visible = False @@ -1019,14 +1278,22 @@ def __init__(self, scheduler, width=800, **kwargs): @worker: @cpu - """ + """, ) - cpu_plot = figure(title='CPU Use (%)', toolbar_location=None, - x_range=(0, 1), y_range=(-0.1, 0.1), height=60, - width=width, tools='', **kwargs) - cpu_plot.circle(source=self.source, x='cpu_fraction', y=0, - size=10, fill_alpha=0.5) + cpu_plot = figure( + title="CPU Use (%)", + toolbar_location=None, + x_range=(0, 1), + y_range=(-0.1, 0.1), + height=60, + width=width, + tools="", + **kwargs + ) + cpu_plot.circle( + source=self.source, x="cpu_fraction", y=0, size=10, fill_alpha=0.5 + ) cpu_plot.ygrid.visible = False cpu_plot.yaxis.minor_tick_line_alpha = 0 cpu_plot.xaxis.visible = False @@ -1034,8 +1301,8 @@ def __init__(self, scheduler, width=800, **kwargs): cpu_plot.add_tools(hover, BoxSelectTool()) self.cpu_plot = cpu_plot - if 'sizing_mode' in kwargs: - sizing_mode = {'sizing_mode': kwargs['sizing_mode']} + if "sizing_mode" in kwargs: + sizing_mode = {"sizing_mode": kwargs["sizing_mode"]} else: sizing_mode = {} @@ -1043,7 +1310,7 @@ def __init__(self, scheduler, width=800, **kwargs): if self.extra_names: components.append(extra_table) - self.root = column(*components, id='bk-worker-table', **sizing_mode) + self.root = column(*components, id="bk-worker-table", **sizing_mode) @without_property_validation def update(self): @@ -1051,60 +1318,65 @@ def update(self): for addr, ws in sorted(self.scheduler.workers.items()): for name in self.names + self.extra_names: data[name].append(ws.metrics.get(name, None)) - data['worker'][-1] = ws.address + data["worker"][-1] = ws.address if ws.memory_limit: - data['memory_percent'][-1] = ws.metrics['memory'] / ws.memory_limit + data["memory_percent"][-1] = ws.metrics["memory"] / ws.memory_limit else: - data['memory_percent'][-1] = '' - data['memory_limit'][-1] = ws.memory_limit - data['cpu'][-1] = ws.metrics['cpu'] / 100.0 - data['cpu_fraction'][-1] = ws.metrics['cpu'] / 100.0 / ws.ncores - data['ncores'][-1] = ws.ncores + data["memory_percent"][-1] = "" + data["memory_limit"][-1] = ws.memory_limit + data["cpu"][-1] = ws.metrics["cpu"] / 100.0 + data["cpu_fraction"][-1] = ws.metrics["cpu"] / 100.0 / ws.ncores + data["ncores"][-1] = ws.ncores self.source.data.update(data) def systemmonitor_doc(scheduler, extra, doc): with log_errors(): - sysmon = SystemMonitor(scheduler, sizing_mode='stretch_both') + sysmon = SystemMonitor(scheduler, sizing_mode="stretch_both") doc.title = "Dask: Scheduler System Monitor" add_periodic_callback(doc, sysmon, 500) for subdoc in sysmon.root.children: doc.add_root(subdoc) - doc.template = env.get_template('system.html') + doc.template = env.get_template("system.html") doc.template_variables.update(extra) doc.theme = BOKEH_THEME def stealing_doc(scheduler, extra, doc): with log_errors(): - occupancy = Occupancy(scheduler, height=200, sizing_mode='scale_width') - stealing_ts = StealingTimeSeries(scheduler, sizing_mode='scale_width') - stealing_events = StealingEvents(scheduler, sizing_mode='scale_width') + occupancy = Occupancy(scheduler, height=200, sizing_mode="scale_width") + stealing_ts = StealingTimeSeries(scheduler, sizing_mode="scale_width") + stealing_events = StealingEvents(scheduler, sizing_mode="scale_width") stealing_events.root.x_range = stealing_ts.root.x_range doc.title = "Dask: Work Stealing" add_periodic_callback(doc, occupancy, 500) add_periodic_callback(doc, stealing_ts, 500) add_periodic_callback(doc, stealing_events, 500) - doc.add_root(column(occupancy.root, stealing_ts.root, - stealing_events.root, - sizing_mode='scale_width')) + doc.add_root( + column( + occupancy.root, + stealing_ts.root, + stealing_events.root, + sizing_mode="scale_width", + ) + ) - doc.template = env.get_template('simple.html') + doc.template = env.get_template("simple.html") doc.template_variables.update(extra) doc.theme = BOKEH_THEME def events_doc(scheduler, extra, doc): with log_errors(): - events = Events(scheduler, 'all', height=250) + events = Events(scheduler, "all", height=250) events.update() add_periodic_callback(doc, events, 500) doc.title = "Dask: Scheduler Events" - doc.add_root(column(events.root, sizing_mode='scale_width')) - doc.template = env.get_template('simple.html') + doc.add_root(column(events.root, sizing_mode="scale_width")) + doc.template = env.get_template("simple.html") doc.template_variables.update(extra) doc.theme = BOKEH_THEME @@ -1116,63 +1388,72 @@ def workers_doc(scheduler, extra, doc): add_periodic_callback(doc, table, 500) doc.title = "Dask: Workers" doc.add_root(table.root) - doc.template = env.get_template('simple.html') + doc.template = env.get_template("simple.html") doc.template_variables.update(extra) doc.theme = BOKEH_THEME def tasks_doc(scheduler, extra, doc): with log_errors(): - ts = TaskStream(scheduler, n_rectangles=100000, clear_interval='60s', - sizing_mode='stretch_both') + ts = TaskStream( + scheduler, + n_rectangles=100000, + clear_interval="60s", + sizing_mode="stretch_both", + ) ts.update() add_periodic_callback(doc, ts, 5000) doc.title = "Dask: Task Stream" doc.add_root(ts.root) - doc.template = env.get_template('simple.html') + doc.template = env.get_template("simple.html") doc.template_variables.update(extra) doc.theme = BOKEH_THEME def graph_doc(scheduler, extra, doc): with log_errors(): - graph = GraphPlot(scheduler, sizing_mode='stretch_both') + graph = GraphPlot(scheduler, sizing_mode="stretch_both") doc.title = "Dask: Task Graph" graph.update() add_periodic_callback(doc, graph, 200) doc.add_root(graph.root) - doc.template = env.get_template('simple.html') + doc.template = env.get_template("simple.html") doc.template_variables.update(extra) doc.theme = BOKEH_THEME def status_doc(scheduler, extra, doc): with log_errors(): - task_stream = TaskStream(scheduler, n_rectangles=1000, - clear_interval='10s', sizing_mode='stretch_both') + task_stream = TaskStream( + scheduler, + n_rectangles=1000, + clear_interval="10s", + sizing_mode="stretch_both", + ) task_stream.update() add_periodic_callback(doc, task_stream, 100) - task_progress = TaskProgress(scheduler, sizing_mode='stretch_both') + task_progress = TaskProgress(scheduler, sizing_mode="stretch_both") task_progress.update() add_periodic_callback(doc, task_progress, 100) if len(scheduler.workers) < 50: - current_load = CurrentLoad(scheduler, sizing_mode='stretch_both') + current_load = CurrentLoad(scheduler, sizing_mode="stretch_both") current_load.update() add_periodic_callback(doc, current_load, 100) doc.add_root(current_load.nbytes_figure) doc.add_root(current_load.processing_figure) else: - nbytes_hist = NBytesHistogram(scheduler, sizing_mode='stretch_both') + nbytes_hist = NBytesHistogram(scheduler, sizing_mode="stretch_both") nbytes_hist.update() - processing_hist = ProcessingHistogram(scheduler, sizing_mode='stretch_both') + processing_hist = ProcessingHistogram(scheduler, sizing_mode="stretch_both") processing_hist.update() add_periodic_callback(doc, nbytes_hist, 100) add_periodic_callback(doc, processing_hist, 100) - current_load_fig = row(nbytes_hist.root, processing_hist.root, - sizing_mode='stretch_both') + current_load_fig = row( + nbytes_hist.root, processing_hist.root, sizing_mode="stretch_both" + ) doc.add_root(nbytes_hist.root) doc.add_root(processing_hist.root) @@ -1181,14 +1462,15 @@ def status_doc(scheduler, extra, doc): doc.add_root(task_progress.root) doc.add_root(task_stream.root) doc.theme = BOKEH_THEME - doc.template = env.get_template('status.html') + doc.template = env.get_template("status.html") doc.template_variables.update(extra) doc.theme = BOKEH_THEME def individual_task_stream_doc(scheduler, extra, doc): - task_stream = TaskStream(scheduler, n_rectangles=1000, - clear_interval='10s', sizing_mode='stretch_both') + task_stream = TaskStream( + scheduler, n_rectangles=1000, clear_interval="10s", sizing_mode="stretch_both" + ) task_stream.update() add_periodic_callback(doc, task_stream, 100) doc.add_root(task_stream.root) @@ -1196,7 +1478,7 @@ def individual_task_stream_doc(scheduler, extra, doc): def individual_nbytes_doc(scheduler, extra, doc): - current_load = CurrentLoad(scheduler, sizing_mode='stretch_both') + current_load = CurrentLoad(scheduler, sizing_mode="stretch_both") current_load.update() add_periodic_callback(doc, current_load, 100) doc.add_root(current_load.nbytes_figure) @@ -1204,7 +1486,7 @@ def individual_nbytes_doc(scheduler, extra, doc): def individual_nprocessing_doc(scheduler, extra, doc): - current_load = CurrentLoad(scheduler, sizing_mode='stretch_both') + current_load = CurrentLoad(scheduler, sizing_mode="stretch_both") current_load.update() add_periodic_callback(doc, current_load, 100) doc.add_root(current_load.processing_figure) @@ -1212,7 +1494,7 @@ def individual_nprocessing_doc(scheduler, extra, doc): def individual_progress_doc(scheduler, extra, doc): - task_progress = TaskProgress(scheduler, height=160, sizing_mode='stretch_both') + task_progress = TaskProgress(scheduler, height=160, sizing_mode="stretch_both") task_progress.update() add_periodic_callback(doc, task_progress, 100) doc.add_root(task_progress.root) @@ -1221,7 +1503,7 @@ def individual_progress_doc(scheduler, extra, doc): def individual_graph_doc(scheduler, extra, doc): with log_errors(): - graph = GraphPlot(scheduler, sizing_mode='stretch_both') + graph = GraphPlot(scheduler, sizing_mode="stretch_both") graph.update() add_periodic_callback(doc, graph, 200) @@ -1231,7 +1513,7 @@ def individual_graph_doc(scheduler, extra, doc): def individual_profile_doc(scheduler, extra, doc): with log_errors(): - prof = ProfileTimePlot(scheduler, sizing_mode='scale_width', doc=doc) + prof = ProfileTimePlot(scheduler, sizing_mode="scale_width", doc=doc) doc.add_root(prof.root) prof.trigger_update() doc.theme = BOKEH_THEME @@ -1239,7 +1521,7 @@ def individual_profile_doc(scheduler, extra, doc): def individual_profile_server_doc(scheduler, extra, doc): with log_errors(): - prof = ProfileServer(scheduler, sizing_mode='scale_width', doc=doc) + prof = ProfileServer(scheduler, sizing_mode="scale_width", doc=doc) doc.add_root(prof.root) prof.trigger_update() doc.theme = BOKEH_THEME @@ -1257,9 +1539,9 @@ def individual_workers_doc(scheduler, extra, doc): def profile_doc(scheduler, extra, doc): with log_errors(): doc.title = "Dask: Profile" - prof = ProfileTimePlot(scheduler, sizing_mode='scale_width', doc=doc) + prof = ProfileTimePlot(scheduler, sizing_mode="scale_width", doc=doc) doc.add_root(prof.root) - doc.template = env.get_template('simple.html') + doc.template = env.get_template("simple.html") doc.template_variables.update(extra) doc.theme = BOKEH_THEME @@ -1269,9 +1551,9 @@ def profile_doc(scheduler, extra, doc): def profile_server_doc(scheduler, extra, doc): with log_errors(): doc.title = "Dask: Profile of Event Loop" - prof = ProfileServer(scheduler, sizing_mode='scale_width', doc=doc) + prof = ProfileServer(scheduler, sizing_mode="scale_width", doc=doc) doc.add_root(prof.root) - doc.template = env.get_template('simple.html') + doc.template = env.get_template("simple.html") doc.template_variables.update(extra) doc.theme = BOKEH_THEME @@ -1279,48 +1561,46 @@ def profile_server_doc(scheduler, extra, doc): class BokehScheduler(BokehServer): - def __init__(self, scheduler, io_loop=None, prefix='', **kwargs): + def __init__(self, scheduler, io_loop=None, prefix="", **kwargs): self.scheduler = scheduler - prefix = prefix or '' - prefix = prefix.rstrip('/') - if prefix and not prefix.startswith('/'): - prefix = '/' + prefix + prefix = prefix or "" + prefix = prefix.rstrip("/") + if prefix and not prefix.startswith("/"): + prefix = "/" + prefix self.prefix = prefix self.server_kwargs = kwargs - self.server_kwargs['prefix'] = prefix or None + self.server_kwargs["prefix"] = prefix or None self.apps = { - '/system': systemmonitor_doc, - '/stealing': stealing_doc, - '/workers': workers_doc, - '/events': events_doc, - '/counters': counters_doc, - '/tasks': tasks_doc, - '/status': status_doc, - '/profile': profile_doc, - '/profile-server': profile_server_doc, - '/graph': graph_doc, - - '/individual-task-stream': individual_task_stream_doc, - '/individual-progress': individual_progress_doc, - '/individual-graph': individual_graph_doc, - '/individual-profile': individual_profile_doc, - '/individual-profile-server': individual_profile_server_doc, - '/individual-nbytes': individual_nbytes_doc, - '/individual-nprocessing': individual_nprocessing_doc, - '/individual-workers': individual_workers_doc, + "/system": systemmonitor_doc, + "/stealing": stealing_doc, + "/workers": workers_doc, + "/events": events_doc, + "/counters": counters_doc, + "/tasks": tasks_doc, + "/status": status_doc, + "/profile": profile_doc, + "/profile-server": profile_server_doc, + "/graph": graph_doc, + "/individual-task-stream": individual_task_stream_doc, + "/individual-progress": individual_progress_doc, + "/individual-graph": individual_graph_doc, + "/individual-profile": individual_profile_doc, + "/individual-profile-server": individual_profile_server_doc, + "/individual-nbytes": individual_nbytes_doc, + "/individual-nprocessing": individual_nprocessing_doc, + "/individual-workers": individual_workers_doc, } - self.apps = {k: partial(v, scheduler, self.extra) - for k, v in self.apps.items()} + self.apps = {k: partial(v, scheduler, self.extra) for k, v in self.apps.items()} self.loop = io_loop or scheduler.loop self.server = None @property def extra(self): - return merge({'prefix': self.prefix}, template_variables) + return merge({"prefix": self.prefix}, template_variables) @property def my_server(self): @@ -1330,7 +1610,14 @@ def listen(self, *args, **kwargs): super(BokehScheduler, self).listen(*args, **kwargs) from .scheduler_html import routes - handlers = [(self.prefix + '/' + url, cls, {'server': self.my_server, 'extra': self.extra}) - for url, cls in routes] - self.server._tornado.add_handlers(r'.*', handlers) + handlers = [ + ( + self.prefix + "/" + url, + cls, + {"server": self.my_server, "extra": self.extra}, + ) + for url, cls in routes + ] + + self.server._tornado.add_handlers(r".*", handlers) diff --git a/distributed/bokeh/scheduler_html.py b/distributed/bokeh/scheduler_html.py index e8050e4a9fa..d1ba2646ed6 100644 --- a/distributed/bokeh/scheduler_html.py +++ b/distributed/bokeh/scheduler_html.py @@ -10,7 +10,9 @@ dirname = os.path.dirname(__file__) -ns = {func.__name__: func for func in [format_bytes, format_time, datetime.fromtimestamp]} +ns = { + func.__name__: func for func in [format_bytes, format_time, datetime.fromtimestamp] +} class RequestHandler(web.RequestHandler): @@ -19,43 +21,50 @@ def initialize(self, server=None, extra=None): self.extra = extra or {} def get_template_path(self): - return os.path.join(dirname, 'templates') + return os.path.join(dirname, "templates") class Workers(RequestHandler): def get(self): with log_errors(): - self.render('workers.html', - title='Workers', - scheduler=self.server, - **toolz.merge(self.server.__dict__, ns, self.extra)) + self.render( + "workers.html", + title="Workers", + scheduler=self.server, + **toolz.merge(self.server.__dict__, ns, self.extra) + ) class Worker(RequestHandler): def get(self, worker): worker = escape.url_unescape(worker) with log_errors(): - self.render('worker.html', - title='Worker: ' + worker, Worker=worker, - **toolz.merge(self.server.__dict__, ns, self.extra)) + self.render( + "worker.html", + title="Worker: " + worker, + Worker=worker, + **toolz.merge(self.server.__dict__, ns, self.extra) + ) class Task(RequestHandler): def get(self, task): task = escape.url_unescape(task) with log_errors(): - self.render('task.html', - title='Task: ' + task, - Task=task, - server=self.server, - **toolz.merge(self.server.__dict__, ns, self.extra)) + self.render( + "task.html", + title="Task: " + task, + Task=task, + server=self.server, + **toolz.merge(self.server.__dict__, ns, self.extra) + ) class Logs(RequestHandler): def get(self): with log_errors(): logs = self.server.get_logs() - self.render('logs.html', title="Logs", logs=logs, **self.extra) + self.render("logs.html", title="Logs", logs=logs, **self.extra) class WorkerLogs(RequestHandler): @@ -65,8 +74,7 @@ def get(self, worker): worker = escape.url_unescape(worker) logs = yield self.server.get_worker_logs(workers=[worker]) logs = logs[worker] - self.render('logs.html', title="Logs: " + worker, logs=logs, - **self.extra) + self.render("logs.html", title="Logs: " + worker, logs=logs, **self.extra) class WorkerCallStacks(RequestHandler): @@ -76,8 +84,12 @@ def get(self, worker): worker = escape.url_unescape(worker) keys = self.server.processing[worker] call_stack = yield self.server.get_call_stack(keys=keys) - self.render('call-stack.html', title="Call Stacks: " + worker, - call_stack=call_stack, **self.extra) + self.render( + "call-stack.html", + title="Call Stacks: " + worker, + call_stack=call_stack, + **self.extra + ) class TaskCallStack(RequestHandler): @@ -87,11 +99,17 @@ def get(self, key): key = escape.url_unescape(key) call_stack = yield self.server.get_call_stack(keys=[key]) if not call_stack: - self.write("

Task not actively running. " - "It may be finished or not yet started

") + self.write( + "

Task not actively running. " + "It may be finished or not yet started

" + ) else: - self.render('call-stack.html', title="Call Stack: " + key, - call_stack=call_stack, **self.extra) + self.render( + "call-stack.html", + title="Call Stack: " + key, + call_stack=call_stack, + **self.extra + ) class CountsJSON(RequestHandler): @@ -109,7 +127,7 @@ def get(self): for ts in scheduler.tasks.values(): if ts.exception_blame is not None: erred += 1 - elif ts.state == 'released': + elif ts.state == "released": released += 1 if ts.waiting_on: waiting += 1 @@ -122,21 +140,21 @@ def get(self): processing += len(ws.processing) response = { - 'bytes': nbytes, - 'clients': len(scheduler.clients), - 'cores': ncores, - 'erred': erred, - 'hosts': len(scheduler.host_info), - 'idle': len(scheduler.idle), - 'memory': memory, - 'processing': processing, - 'released': released, - 'saturated': len(scheduler.saturated), - 'tasks': len(scheduler.tasks), - 'unrunnable': len(scheduler.unrunnable), - 'waiting': waiting, - 'waiting_data': waiting_data, - 'workers': len(scheduler.workers), + "bytes": nbytes, + "clients": len(scheduler.clients), + "cores": ncores, + "erred": erred, + "hosts": len(scheduler.host_info), + "idle": len(scheduler.idle), + "memory": memory, + "processing": processing, + "released": released, + "saturated": len(scheduler.saturated), + "tasks": len(scheduler.tasks), + "unrunnable": len(scheduler.unrunnable), + "waiting": waiting, + "waiting_data": waiting_data, + "workers": len(scheduler.workers), } self.write(response) @@ -149,17 +167,20 @@ def get(self): class IndexJSON(RequestHandler): def get(self): with log_errors(): - r = [url for url, _ in routes if url.endswith('.json')] - self.render('json-index.html', routes=r, title='Index of JSON routes', **self.extra) + r = [url for url, _ in routes if url.endswith(".json")] + self.render( + "json-index.html", routes=r, title="Index of JSON routes", **self.extra + ) class IndividualPlots(RequestHandler): def get(self): - bokeh_server = self.server.services['bokeh'] - result = {uri.strip('/').replace('-', ' ').title(): uri - for uri in bokeh_server.apps - if uri.lstrip('/').startswith('individual-') - and not uri.endswith('.json')} + bokeh_server = self.server.services["bokeh"] + result = { + uri.strip("/").replace("-", " ").title(): uri + for uri in bokeh_server.apps + if uri.lstrip("/").startswith("individual-") and not uri.endswith(".json") + } self.write(result) @@ -170,13 +191,13 @@ def __init__(self, server, prometheus_client): def collect(self): yield self.prometheus_client.core.GaugeMetricFamily( - 'dask_scheduler_workers', - 'Number of workers.', + "dask_scheduler_workers", + "Number of workers.", value=len(self.server.workers), ) yield self.prometheus_client.core.GaugeMetricFamily( - 'dask_scheduler_clients', - 'Number of clients.', + "dask_scheduler_clients", + "Number of clients.", value=len(self.server.clients), ) @@ -186,6 +207,7 @@ class PrometheusHandler(RequestHandler): def __init__(self, *args, **kwargs): import prometheus_client # keep out of global namespace + self.prometheus_client = prometheus_client super(PrometheusHandler, self).__init__(*args, **kwargs) @@ -197,41 +219,38 @@ def _init(self): return self.prometheus_client.REGISTRY.register( - _PrometheusCollector( - self.server, - self.prometheus_client, - ) + _PrometheusCollector(self.server, self.prometheus_client) ) PrometheusHandler._initialized = True def get(self): self.write(self.prometheus_client.generate_latest()) - self.set_header('Content-Type', 'text/plain; version=0.0.4') + self.set_header("Content-Type", "text/plain; version=0.0.4") class HealthHandler(RequestHandler): def get(self): - self.write('ok') - self.set_header('Content-Type', 'text/plain') + self.write("ok") + self.set_header("Content-Type", "text/plain") routes = [ - (r'info/main/workers.html', Workers), - (r'info/worker/(.*).html', Worker), - (r'info/task/(.*).html', Task), - (r'info/main/logs.html', Logs), - (r'info/call-stacks/(.*).html', WorkerCallStacks), - (r'info/call-stack/(.*).html', TaskCallStack), - (r'info/logs/(.*).html', WorkerLogs), - (r'json/counts.json', CountsJSON), - (r'json/identity.json', IdentityJSON), - (r'json/index.html', IndexJSON), - (r'individual-plots.json', IndividualPlots), - (r'metrics', PrometheusHandler), - (r'health', HealthHandler), + (r"info/main/workers.html", Workers), + (r"info/worker/(.*).html", Worker), + (r"info/task/(.*).html", Task), + (r"info/main/logs.html", Logs), + (r"info/call-stacks/(.*).html", WorkerCallStacks), + (r"info/call-stack/(.*).html", TaskCallStack), + (r"info/logs/(.*).html", WorkerLogs), + (r"json/counts.json", CountsJSON), + (r"json/identity.json", IdentityJSON), + (r"json/index.html", IndexJSON), + (r"individual-plots.json", IndividualPlots), + (r"metrics", PrometheusHandler), + (r"health", HealthHandler), ] def get_handlers(server): - return [(url, cls, {'server': server}) for url, cls in routes] + return [(url, cls, {"server": server}) for url, cls in routes] diff --git a/distributed/bokeh/tests/test_components.py b/distributed/bokeh/tests/test_components.py index 741c90c8d49..4f4df92f6cd 100644 --- a/distributed/bokeh/tests/test_components.py +++ b/distributed/bokeh/tests/test_components.py @@ -1,7 +1,8 @@ from __future__ import print_function, division, absolute_import import pytest -pytest.importorskip('bokeh') + +pytest.importorskip("bokeh") from bokeh.models import ColumnDataSource, Model from tornado import gen @@ -10,14 +11,15 @@ from distributed.utils_test import slowinc, gen_cluster from distributed.bokeh.components import ( - TaskStream, MemoryUsage, - Processing, ProfilePlot, ProfileTimePlot + TaskStream, + MemoryUsage, + Processing, + ProfilePlot, + ProfileTimePlot, ) -@pytest.mark.parametrize('Component', [TaskStream, - MemoryUsage, - Processing]) +@pytest.mark.parametrize("Component", [TaskStream, MemoryUsage, Processing]) def test_basic(Component): c = Component() assert isinstance(c.source, ColumnDataSource) @@ -28,23 +30,24 @@ def test_basic(Component): @gen_cluster(client=True, check_new_threads=False) def test_profile_plot(c, s, a, b): p = ProfilePlot() - assert len(p.source.data['left']) <= 1 + assert len(p.source.data["left"]) <= 1 yield c.map(slowinc, range(10), delay=0.05) p.update(a.profile_recent) - assert len(p.source.data['left']) > 1 + assert len(p.source.data["left"]) > 1 @gen_cluster(client=True, check_new_threads=False) def test_profile_time_plot(c, s, a, b): from bokeh.io import curdoc + sp = ProfileTimePlot(s, doc=curdoc()) sp.trigger_update() ap = ProfileTimePlot(a, doc=curdoc()) ap.trigger_update() - assert len(sp.source.data['left']) <= 1 - assert len(ap.source.data['left']) <= 1 + assert len(sp.source.data["left"]) <= 1 + assert len(ap.source.data["left"]) <= 1 yield c.map(slowinc, range(10), delay=0.05) ap.trigger_update() diff --git a/distributed/bokeh/tests/test_scheduler_bokeh.py b/distributed/bokeh/tests/test_scheduler_bokeh.py index acbf54bc102..380dff104e2 100644 --- a/distributed/bokeh/tests/test_scheduler_bokeh.py +++ b/distributed/bokeh/tests/test_scheduler_bokeh.py @@ -6,7 +6,8 @@ from time import sleep import pytest -pytest.importorskip('bokeh') + +pytest.importorskip("bokeh") from toolz import first from tornado import gen from tornado.httpclient import AsyncHTTPClient @@ -17,48 +18,69 @@ from distributed.metrics import time from distributed.utils_test import gen_cluster, inc, dec, slowinc, div from distributed.bokeh.worker import Counters, BokehWorker -from distributed.bokeh.scheduler import (BokehScheduler, SystemMonitor, - Occupancy, StealingTimeSeries, - StealingEvents, Events, - TaskStream, TaskProgress, - MemoryUse, CurrentLoad, - ProcessingHistogram, - NBytesHistogram, WorkerTable, - GraphPlot, ProfileServer) +from distributed.bokeh.scheduler import ( + BokehScheduler, + SystemMonitor, + Occupancy, + StealingTimeSeries, + StealingEvents, + Events, + TaskStream, + TaskProgress, + MemoryUse, + CurrentLoad, + ProcessingHistogram, + NBytesHistogram, + WorkerTable, + GraphPlot, + ProfileServer, +) from distributed.bokeh import scheduler scheduler.PROFILING = False -@pytest.mark.skipif(sys.version_info[0] == 2, - reason='https://github.com/bokeh/bokeh/issues/5494') -@gen_cluster(client=True, - scheduler_kwargs={'services': {('bokeh', 0): BokehScheduler}}) +@pytest.mark.skipif( + sys.version_info[0] == 2, reason="https://github.com/bokeh/bokeh/issues/5494" +) +@gen_cluster(client=True, scheduler_kwargs={"services": {("bokeh", 0): BokehScheduler}}) def test_simple(c, s, a, b): - assert isinstance(s.services['bokeh'], BokehScheduler) - port = s.services['bokeh'].port + assert isinstance(s.services["bokeh"], BokehScheduler) + port = s.services["bokeh"].port future = c.submit(sleep, 1) yield gen.sleep(0.1) http_client = AsyncHTTPClient() - for suffix in ['system', 'counters', 'workers', 'status', 'tasks', - 'stealing', 'graph', 'individual-task-stream', 'individual-progress', - 'individual-graph', 'individual-nbytes', - 'individual-nprocessing', - 'individual-profile']: - response = yield http_client.fetch('http://localhost:%d/%s' % (port, suffix)) + for suffix in [ + "system", + "counters", + "workers", + "status", + "tasks", + "stealing", + "graph", + "individual-task-stream", + "individual-progress", + "individual-graph", + "individual-nbytes", + "individual-nprocessing", + "individual-profile", + ]: + response = yield http_client.fetch("http://localhost:%d/%s" % (port, suffix)) body = response.body.decode() - assert 'bokeh' in body.lower() + assert "bokeh" in body.lower() assert not re.search("href=./", body) # no absolute links - response = yield http_client.fetch('http://localhost:%d/individual-plots.json' % port) + response = yield http_client.fetch( + "http://localhost:%d/individual-plots.json" % port + ) response = json.loads(response.body.decode()) assert response -@gen_cluster(client=True, worker_kwargs=dict(services={'bokeh': BokehWorker})) +@gen_cluster(client=True, worker_kwargs=dict(services={"bokeh": BokehWorker})) def test_basic(c, s, a, b): for component in [SystemMonitor, Occupancy, StealingTimeSeries]: ss = component(s) @@ -67,14 +89,13 @@ def test_basic(c, s, a, b): data = ss.source.data assert len(first(data.values())) if component is Occupancy: - assert all(addr.startswith('127.0.0.1:') - for addr in data['bokeh_address']) + assert all(addr.startswith("127.0.0.1:") for addr in data["bokeh_address"]) @gen_cluster(client=True) def test_counters(c, s, a, b): - pytest.importorskip('crick') - while 'tick-duration' not in s.digests: + pytest.importorskip("crick") + while "tick-duration" not in s.digests: yield gen.sleep(0.01) ss = Counters(s) @@ -83,7 +104,7 @@ def test_counters(c, s, a, b): ss.update() start = time() - while not len(ss.digest_sources['tick-duration'][0].data['x']): + while not len(ss.digest_sources["tick-duration"][0].data["x"]): yield gen.sleep(1) assert time() < start + 5 @@ -92,8 +113,9 @@ def test_counters(c, s, a, b): def test_stealing_events(c, s, a, b): se = StealingEvents(s) - futures = c.map(slowinc, range(100), delay=0.1, workers=a.address, - allow_other_workers=True) + futures = c.map( + slowinc, range(100), delay=0.1, workers=a.address, allow_other_workers=True + ) while not b.task_state: # will steal soon yield gen.sleep(0.01) @@ -105,17 +127,18 @@ def test_stealing_events(c, s, a, b): @gen_cluster(client=True) def test_events(c, s, a, b): - e = Events(s, 'all') + e = Events(s, "all") - futures = c.map(slowinc, range(100), delay=0.1, workers=a.address, - allow_other_workers=True) + futures = c.map( + slowinc, range(100), delay=0.1, workers=a.address, allow_other_workers=True + ) while not b.task_state: yield gen.sleep(0.01) e.update() d = dict(e.source.data) - assert sum(a == 'add-worker' for a in d['action']) == 2 + assert sum(a == "add-worker" for a in d["action"]) == 2 @gen_cluster(client=True) @@ -130,7 +153,7 @@ def test_task_stream(c, s, a, b): d = dict(ts.source.data) assert all(len(L) == 10 for L in d.values()) - assert min(d['start']) == 0 # zero based + assert min(d["start"]) == 0 # zero based ts.update() d = dict(ts.source.data) @@ -151,7 +174,7 @@ def test_task_stream_n_rectangles(c, s, a, b): yield wait(futures) ts.update() - assert len(ts.source.data['start']) == 10 + assert len(ts.source.data["start"]) == 10 @gen_cluster(client=True) @@ -177,16 +200,16 @@ def test_task_stream_clear_interval(c, s, a, b): ts.update() assert len(set(map(len, ts.source.data.values()))) == 1 - assert ts.source.data['name'].count('inc') == 10 - assert ts.source.data['name'].count('dec') == 10 + assert ts.source.data["name"].count("inc") == 10 + assert ts.source.data["name"].count("dec") == 10 yield gen.sleep(0.300) yield wait(c.map(inc, range(10, 20))) ts.update() assert len(set(map(len, ts.source.data.values()))) == 1 - assert ts.source.data['name'].count('inc') == 10 - assert ts.source.data['name'].count('dec') == 0 + assert ts.source.data["name"].count("inc") == 10 + assert ts.source.data["name"].count("dec") == 0 @gen_cluster(client=True) @@ -199,7 +222,7 @@ def test_TaskProgress(c, s, a, b): tp.update() d = dict(tp.source.data) assert all(len(L) == 1 for L in d.values()) - assert d['name'] == ['slowinc'] + assert d["name"] == ["slowinc"] futures2 = c.map(dec, range(5)) yield wait(futures2) @@ -207,7 +230,7 @@ def test_TaskProgress(c, s, a, b): tp.update() d = dict(tp.source.data) assert all(len(L) == 2 for L in d.values()) - assert d['name'] == ['slowinc', 'dec'] + assert d["name"] == ["slowinc", "dec"] del futures, futures2 @@ -215,7 +238,7 @@ def test_TaskProgress(c, s, a, b): yield gen.sleep(0.01) tp.update() - assert not tp.source.data['all'] + assert not tp.source.data["all"] @gen_cluster(client=True) @@ -223,7 +246,7 @@ def test_TaskProgress_empty(c, s, a, b): tp = TaskProgress(s) tp.update() - futures = [c.submit(inc, i, key='f-' + 'a' * i) for i in range(20)] + futures = [c.submit(inc, i, key="f-" + "a" * i) for i in range(20)] yield wait(futures) tp.update() @@ -245,7 +268,7 @@ def test_MemoryUse(c, s, a, b): mu.update() d = dict(mu.source.data) assert all(len(L) == 1 for L in d.values()) - assert d['name'] == ['slowinc'] + assert d["name"] == ["slowinc"] @gen_cluster(client=True) @@ -259,34 +282,34 @@ def test_CurrentLoad(c, s, a, b): d = dict(cl.source.data) assert all(len(L) == 2 for L in d.values()) - assert all(d['nbytes']) + assert all(d["nbytes"]) @gen_cluster(client=True) def test_ProcessingHistogram(c, s, a, b): ph = ProcessingHistogram(s) ph.update() - assert (ph.source.data['top'] != 0).sum() == 1 + assert (ph.source.data["top"] != 0).sum() == 1 futures = c.map(slowinc, range(10), delay=0.050) while not s.tasks: yield gen.sleep(0.01) ph.update() - assert ph.source.data['right'][-1] > 2 + assert ph.source.data["right"][-1] > 2 @gen_cluster(client=True) def test_NBytesHistogram(c, s, a, b): nh = NBytesHistogram(s) nh.update() - assert (nh.source.data['top'] != 0).sum() == 1 + assert (nh.source.data["top"] != 0).sum() == 1 futures = c.map(inc, range(10)) yield wait(futures) nh.update() - assert nh.source.data['right'][-1] > 5 * 20 + assert nh.source.data["right"][-1] > 5 * 20 @gen_cluster(client=True) @@ -296,7 +319,7 @@ def test_WorkerTable(c, s, a, b): assert all(wt.source.data.values()) assert all(len(v) == 2 for v in wt.source.data.values()) - ncores = wt.source.data['ncores'] + ncores = wt.source.data["ncores"] assert all(ncores) @@ -308,8 +331,7 @@ def metric_port(worker): def metric_address(worker): return worker.address - metrics = {'metric_port': metric_port, - 'metric_address': metric_address} + metrics = {"metric_port": metric_port, "metric_address": metric_address} for w in [a, b]: for name, func in metrics.items(): @@ -318,8 +340,8 @@ def metric_address(worker): yield [a.heartbeat(), b.heartbeat()] for w in [a, b]: - assert s.workers[w.address].metrics['metric_port'] == w.port - assert s.workers[w.address].metrics['metric_address'] == w.address + assert s.workers[w.address].metrics["metric_port"] == w.port + assert s.workers[w.address].metrics["metric_address"] == w.address wt = WorkerTable(s) wt.update() @@ -330,9 +352,9 @@ def metric_address(worker): assert all(data.values()) assert all(len(v) == 2 for v in data.values()) - my_index = data['worker'].index(a.address), data['worker'].index(b.address) - assert [data['metric_port'][i] for i in my_index] == [a.port, b.port] - assert [data['metric_address'][i] for i in my_index] == [a.address, b.address] + my_index = data["worker"].index(a.address), data["worker"].index(b.address) + assert [data["metric_port"][i] for i in my_index] == [a.port, b.port] + assert [data["metric_address"][i] for i in my_index] == [a.address, b.address] @gen_cluster(client=True) @@ -340,24 +362,24 @@ def test_WorkerTable_different_metrics(c, s, a, b): def metric_port(worker): return worker.port - a.metrics['metric_a'] = metric_port - b.metrics['metric_b'] = metric_port + a.metrics["metric_a"] = metric_port + b.metrics["metric_b"] = metric_port yield [a.heartbeat(), b.heartbeat()] - assert s.workers[a.address].metrics['metric_a'] == a.port - assert s.workers[b.address].metrics['metric_b'] == b.port + assert s.workers[a.address].metrics["metric_a"] == a.port + assert s.workers[b.address].metrics["metric_b"] == b.port wt = WorkerTable(s) wt.update() data = wt.source.data - assert 'metric_a' in data - assert 'metric_b' in data + assert "metric_a" in data + assert "metric_b" in data assert all(data.values()) assert all(len(v) == 2 for v in data.values()) - my_index = data['worker'].index(a.address), data['worker'].index(b.address) - assert [data['metric_a'][i] for i in my_index] == [a.port, None] - assert [data['metric_b'][i] for i in my_index] == [None, b.port] + my_index = data["worker"].index(a.address), data["worker"].index(b.address) + assert [data["metric_a"][i] for i in my_index] == [a.port, None] + assert [data["metric_b"][i] for i in my_index] == [None, b.port] @gen_cluster(client=True) @@ -365,51 +387,51 @@ def test_WorkerTable_metrics_with_different_metric_2(c, s, a, b): def metric_port(worker): return worker.port - a.metrics['metric_a'] = metric_port + a.metrics["metric_a"] = metric_port yield [a.heartbeat(), b.heartbeat()] wt = WorkerTable(s) wt.update() data = wt.source.data - assert 'metric_a' in data + assert "metric_a" in data assert all(data.values()) assert all(len(v) == 2 for v in data.values()) - my_index = data['worker'].index(a.address), data['worker'].index(b.address) - assert [data['metric_a'][i] for i in my_index] == [a.port, None] + my_index = data["worker"].index(a.address), data["worker"].index(b.address) + assert [data["metric_a"][i] for i in my_index] == [a.port, None] -@gen_cluster(client=True, worker_kwargs={'metrics': {'my_port': lambda w: w.port}}) +@gen_cluster(client=True, worker_kwargs={"metrics": {"my_port": lambda w: w.port}}) def test_WorkerTable_add_and_remove_metrics(c, s, a, b): def metric_port(worker): return worker.port - a.metrics['metric_a'] = metric_port - b.metrics['metric_b'] = metric_port + a.metrics["metric_a"] = metric_port + b.metrics["metric_b"] = metric_port yield [a.heartbeat(), b.heartbeat()] - assert s.workers[a.address].metrics['metric_a'] == a.port - assert s.workers[b.address].metrics['metric_b'] == b.port + assert s.workers[a.address].metrics["metric_a"] == a.port + assert s.workers[b.address].metrics["metric_b"] == b.port wt = WorkerTable(s) wt.update() - assert 'metric_a' in wt.source.data - assert 'metric_b' in wt.source.data + assert "metric_a" in wt.source.data + assert "metric_b" in wt.source.data # Remove 'metric_b' from worker b - del b.metrics['metric_b'] + del b.metrics["metric_b"] yield [a.heartbeat(), b.heartbeat()] wt = WorkerTable(s) wt.update() - assert 'metric_a' in wt.source.data + assert "metric_a" in wt.source.data - del a.metrics['metric_a'] + del a.metrics["metric_a"] yield [a.heartbeat(), b.heartbeat()] wt = WorkerTable(s) wt.update() - assert 'metric_a' not in wt.source.data + assert "metric_a" not in wt.source.data @gen_cluster(client=True) @@ -417,14 +439,14 @@ def test_WorkerTable_custom_metric_overlap_with_core_metric(c, s, a, b): def metric(worker): return -999 - a.metrics['executing'] = metric - a.metrics['cpu'] = metric - a.metrics['metric'] = metric + a.metrics["executing"] = metric + a.metrics["cpu"] = metric + a.metrics["metric"] = metric yield [a.heartbeat(), b.heartbeat()] - assert s.workers[a.address].metrics['executing'] != -999 - assert s.workers[a.address].metrics['cpu'] != -999 - assert s.workers[a.address].metrics['metric'] == -999 + assert s.workers[a.address].metrics["executing"] != -999 + assert s.workers[a.address].metrics["cpu"] != -999 + assert s.workers[a.address].metrics["metric"] == -999 @gen_cluster(client=True) @@ -438,7 +460,7 @@ def test_GraphPlot(c, s, a, b): assert set(map(len, gp.node_source.data.values())) == {6} assert set(map(len, gp.edge_source.data.values())) == {5} - da = pytest.importorskip('dask.array') + da = pytest.importorskip("dask.array") x = da.random.random((20, 20), chunks=(10, 10)).persist() y = (x + x.T) - x.mean(axis=0) y = y.persist() @@ -459,12 +481,12 @@ def test_GraphPlot(c, s, a, b): while key in s.tasks: yield gen.sleep(0.01) - assert 'memory' in gp.node_source.data['state'] + assert "memory" in gp.node_source.data["state"] gp.update() gp.update() - assert not all(x == 'False' for x in gp.edge_source.data['visible']) + assert not all(x == "False" for x in gp.edge_source.data["visible"]) @gen_cluster(client=True) @@ -493,31 +515,33 @@ def test_GraphPlot_clear(c, s, a, b): @gen_cluster(client=True, timeout=30) def test_GraphPlot_complex(c, s, a, b): - da = pytest.importorskip('dask.array') + da = pytest.importorskip("dask.array") gp = GraphPlot(s) x = da.random.random((2000, 2000), chunks=(1000, 1000)) y = ((x + x.T) - x.mean(axis=0)).persist() yield wait(y) gp.update() - assert len(gp.layout.index) == len(gp.node_source.data['x']) + assert len(gp.layout.index) == len(gp.node_source.data["x"]) assert len(gp.layout.index) == len(s.tasks) z = (x - y).sum().persist() yield wait(z) gp.update() - assert len(gp.layout.index) == len(gp.node_source.data['x']) + assert len(gp.layout.index) == len(gp.node_source.data["x"]) assert len(gp.layout.index) == len(s.tasks) del z yield gen.sleep(0.2) gp.update() - assert len(gp.layout.index) == sum(v == 'True' for v in gp.node_source.data['visible']) + assert len(gp.layout.index) == sum( + v == "True" for v in gp.node_source.data["visible"] + ) assert len(gp.layout.index) == len(s.tasks) - assert max(gp.layout.index.values()) < len(gp.node_source.data['visible']) - assert gp.layout.next_index == len(gp.node_source.data['visible']) + assert max(gp.layout.index.values()) < len(gp.node_source.data["visible"]) + assert gp.layout.next_index == len(gp.node_source.data["visible"]) gp.update() assert set(gp.layout.index.values()) == set(range(len(gp.layout.index))) - visible = gp.node_source.data['visible'] + visible = gp.node_source.data["visible"] keys = list(map(tokey, flatten(y.__dask_keys__()))) - assert all(visible[gp.layout.index[key]] == 'True' for key in keys) + assert all(visible[gp.layout.index[key]] == "True" for key in keys) @gen_cluster(client=True) @@ -529,24 +553,29 @@ def test_GraphPlot_order(c, s, a, b): gp = GraphPlot(s) gp.update() - assert gp.node_source.data['state'][gp.layout.index[y.key]] == 'erred' + assert gp.node_source.data["state"][gp.layout.index[y.key]] == "erred" -@gen_cluster(client=True, - config={'distributed.worker.profile.interval': '10ms', - 'distributed.worker.profile.cycle': '50ms'}) +@gen_cluster( + client=True, + config={ + "distributed.worker.profile.interval": "10ms", + "distributed.worker.profile.cycle": "50ms", + }, +) def test_profile_server(c, s, a, b): ptp = ProfileServer(s) ptp.trigger_update() yield gen.sleep(0.200) ptp.trigger_update() - assert 2 < len(ptp.ts_source.data['time']) < 20 + assert 2 < len(ptp.ts_source.data["time"]) < 20 -@gen_cluster(client=True, - scheduler_kwargs={'services': {('bokeh', 0): BokehScheduler}}) +@gen_cluster(client=True, scheduler_kwargs={"services": {("bokeh", 0): BokehScheduler}}) def test_root_redirect(c, s, a, b): http_client = AsyncHTTPClient() - response = yield http_client.fetch('http://localhost:%d/' % s.services['bokeh'].port) + response = yield http_client.fetch( + "http://localhost:%d/" % s.services["bokeh"].port + ) assert response.code == 200 assert "/status" in response.effective_url diff --git a/distributed/bokeh/tests/test_scheduler_bokeh_html.py b/distributed/bokeh/tests/test_scheduler_bokeh_html.py index 96dc71c6e67..d5ca1ee7f05 100644 --- a/distributed/bokeh/tests/test_scheduler_bokeh_html.py +++ b/distributed/bokeh/tests/test_scheduler_bokeh_html.py @@ -5,7 +5,8 @@ import xml.etree.ElementTree import pytest -pytest.importorskip('bokeh') + +pytest.importorskip("bokeh") from tornado.escape import url_escape from tornado.httpclient import AsyncHTTPClient @@ -14,59 +15,62 @@ from distributed.bokeh.scheduler import BokehScheduler -@gen_cluster(client=True, - scheduler_kwargs={'services': {('bokeh', 0): BokehScheduler}}) +@gen_cluster(client=True, scheduler_kwargs={"services": {("bokeh", 0): BokehScheduler}}) def test_connect(c, s, a, b): future = c.submit(lambda x: x + 1, 1) x = c.submit(slowinc, 1, delay=1, retries=5) yield future http_client = AsyncHTTPClient() - for suffix in ['info/main/workers.html', - 'info/worker/' + url_escape(a.address) + '.html', - 'info/task/' + url_escape(future.key) + '.html', - 'info/main/logs.html', - 'info/logs/' + url_escape(a.address) + '.html', - 'info/call-stack/' + url_escape(x.key) + '.html', - 'info/call-stacks/' + url_escape(a.address) + '.html', - 'json/counts.json', - 'json/identity.json', - 'json/index.html', - 'individual-plots.json', - ]: - response = yield http_client.fetch('http://localhost:%d/%s' - % (s.services['bokeh'].port, suffix)) + for suffix in [ + "info/main/workers.html", + "info/worker/" + url_escape(a.address) + ".html", + "info/task/" + url_escape(future.key) + ".html", + "info/main/logs.html", + "info/logs/" + url_escape(a.address) + ".html", + "info/call-stack/" + url_escape(x.key) + ".html", + "info/call-stacks/" + url_escape(a.address) + ".html", + "json/counts.json", + "json/identity.json", + "json/index.html", + "individual-plots.json", + ]: + response = yield http_client.fetch( + "http://localhost:%d/%s" % (s.services["bokeh"].port, suffix) + ) assert response.code == 200 body = response.body.decode() - if suffix.endswith('.json'): + if suffix.endswith(".json"): json.loads(body) else: assert xml.etree.ElementTree.fromstring(body) is not None assert not re.search("href=./", body) # no absolute links -@gen_cluster(client=True, - scheduler_kwargs={'services': {('bokeh', 0): (BokehScheduler, - {'prefix': '/foo'})}}) +@gen_cluster( + client=True, + scheduler_kwargs={"services": {("bokeh", 0): (BokehScheduler, {"prefix": "/foo"})}}, +) def test_prefix(c, s, a, b): http_client = AsyncHTTPClient() - for suffix in ['foo/info/main/workers.html', - 'foo/json/index.html', - 'foo/system']: - response = yield http_client.fetch('http://localhost:%d/%s' - % (s.services['bokeh'].port, suffix)) + for suffix in ["foo/info/main/workers.html", "foo/json/index.html", "foo/system"]: + response = yield http_client.fetch( + "http://localhost:%d/%s" % (s.services["bokeh"].port, suffix) + ) assert response.code == 200 body = response.body.decode() - if suffix.endswith('.json'): + if suffix.endswith(".json"): json.loads(body) else: assert xml.etree.ElementTree.fromstring(body) is not None -@gen_cluster(client=True, - check_new_threads=False, - scheduler_kwargs={'services': {('bokeh', 0): BokehScheduler}}) +@gen_cluster( + client=True, + check_new_threads=False, + scheduler_kwargs={"services": {("bokeh", 0): BokehScheduler}}, +) def test_prometheus(c, s, a, b): - pytest.importorskip('prometheus_client') + pytest.importorskip("prometheus_client") from prometheus_client.parser import text_string_to_metric_families http_client = AsyncHTTPClient() @@ -74,29 +78,30 @@ def test_prometheus(c, s, a, b): # request data twice since there once was a case where metrics got registered multiple times resulting in # prometheus_client errors for _ in range(2): - response = yield http_client.fetch('http://localhost:%d/metrics' - % s.services['bokeh'].port) + response = yield http_client.fetch( + "http://localhost:%d/metrics" % s.services["bokeh"].port + ) assert response.code == 200 - assert response.headers['Content-Type'] == 'text/plain; version=0.0.4' + assert response.headers["Content-Type"] == "text/plain; version=0.0.4" - txt = response.body.decode('utf8') - families = { - familiy.name - for familiy in text_string_to_metric_families(txt) - } - assert 'dask_scheduler_workers' in families + txt = response.body.decode("utf8") + families = {familiy.name for familiy in text_string_to_metric_families(txt)} + assert "dask_scheduler_workers" in families -@gen_cluster(client=True, +@gen_cluster( + client=True, check_new_threads=False, - scheduler_kwargs={'services': {('bokeh', 0): BokehScheduler}}) + scheduler_kwargs={"services": {("bokeh", 0): BokehScheduler}}, +) def test_health(c, s, a, b): http_client = AsyncHTTPClient() - response = yield http_client.fetch('http://localhost:%d/health' - % s.services['bokeh'].port) + response = yield http_client.fetch( + "http://localhost:%d/health" % s.services["bokeh"].port + ) assert response.code == 200 - assert response.headers['Content-Type'] == 'text/plain' + assert response.headers["Content-Type"] == "text/plain" - txt = response.body.decode('utf8') - assert txt == 'ok' + txt = response.body.decode("utf8") + assert txt == "ok" diff --git a/distributed/bokeh/tests/test_worker_bokeh.py b/distributed/bokeh/tests/test_worker_bokeh.py index 32c14d4fa50..03a7ed3861b 100644 --- a/distributed/bokeh/tests/test_worker_bokeh.py +++ b/distributed/bokeh/tests/test_worker_bokeh.py @@ -4,7 +4,8 @@ from time import sleep import pytest -pytest.importorskip('bokeh') + +pytest.importorskip("bokeh") import sys from toolz import first from tornado import gen @@ -13,40 +14,52 @@ from distributed.client import wait from distributed.metrics import time from distributed.utils_test import gen_cluster, inc, dec -from distributed.bokeh.worker import (BokehWorker, StateTable, CrossFilter, - CommunicatingStream, ExecutingTimeSeries, CommunicatingTimeSeries, - SystemMonitor, Counters) - - -@pytest.mark.skipif(sys.version_info[0] == 2, - reason='https://github.com/bokeh/bokeh/issues/5494') -@gen_cluster(client=True, - worker_kwargs={'services': {('bokeh', 0): BokehWorker}}) +from distributed.bokeh.worker import ( + BokehWorker, + StateTable, + CrossFilter, + CommunicatingStream, + ExecutingTimeSeries, + CommunicatingTimeSeries, + SystemMonitor, + Counters, +) + + +@pytest.mark.skipif( + sys.version_info[0] == 2, reason="https://github.com/bokeh/bokeh/issues/5494" +) +@gen_cluster(client=True, worker_kwargs={"services": {("bokeh", 0): BokehWorker}}) def test_simple(c, s, a, b): - assert s.workers[a.address].services == {'bokeh': a.services['bokeh'].port} - assert s.workers[b.address].services == {'bokeh': b.services['bokeh'].port} + assert s.workers[a.address].services == {"bokeh": a.services["bokeh"].port} + assert s.workers[b.address].services == {"bokeh": b.services["bokeh"].port} future = c.submit(sleep, 1) yield gen.sleep(0.1) http_client = AsyncHTTPClient() - for suffix in ['main', 'crossfilter', 'system']: - response = yield http_client.fetch('http://localhost:%d/%s' - % (a.services['bokeh'].port, suffix)) - assert 'bokeh' in response.body.decode().lower() + for suffix in ["main", "crossfilter", "system"]: + response = yield http_client.fetch( + "http://localhost:%d/%s" % (a.services["bokeh"].port, suffix) + ) + assert "bokeh" in response.body.decode().lower() -@gen_cluster(client=True, - worker_kwargs={'services': {('bokeh', 0): (BokehWorker, {})}}) +@gen_cluster(client=True, worker_kwargs={"services": {("bokeh", 0): (BokehWorker, {})}}) def test_services_kwargs(c, s, a, b): - assert s.workers[a.address].services == {'bokeh': a.services['bokeh'].port} - assert isinstance(a.services['bokeh'], BokehWorker) + assert s.workers[a.address].services == {"bokeh": a.services["bokeh"].port} + assert isinstance(a.services["bokeh"], BokehWorker) @gen_cluster(client=True) def test_basic(c, s, a, b): - for component in [StateTable, ExecutingTimeSeries, - CommunicatingTimeSeries, CrossFilter, SystemMonitor]: + for component in [ + StateTable, + ExecutingTimeSeries, + CommunicatingTimeSeries, + CrossFilter, + SystemMonitor, + ]: aa = component(a) bb = component(b) @@ -64,14 +77,15 @@ def slowall(*args): aa.update() bb.update() - assert (len(first(aa.source.data.values())) and - len(first(bb.source.data.values()))) + assert len(first(aa.source.data.values())) and len( + first(bb.source.data.values()) + ) @gen_cluster(client=True) def test_counters(c, s, a, b): - pytest.importorskip('crick') - while 'tick-duration' not in a.digests: + pytest.importorskip("crick") + while "tick-duration" not in a.digests: yield gen.sleep(0.01) aa = Counters(a) @@ -80,18 +94,18 @@ def test_counters(c, s, a, b): aa.update() start = time() - while not len(aa.digest_sources['tick-duration'][0].data['x']): + while not len(aa.digest_sources["tick-duration"][0].data["x"]): yield gen.sleep(1) assert time() < start + 5 - a.digests['foo'].add(1) - a.digests['foo'].add(2) - aa.add_digest_figure('foo') + a.digests["foo"].add(1) + a.digests["foo"].add(2) + aa.add_digest_figure("foo") - a.counters['bar'].add(1) - a.counters['bar'].add(2) - a.counters['bar'].add(2) - aa.add_counter_figure('bar') + a.counters["bar"].add(1) + a.counters["bar"].add(2) + a.counters["bar"].add(2) + aa.add_counter_figure("bar") for x in [aa.counter_sources.values(), aa.digest_sources.values()]: for y in x: @@ -114,21 +128,26 @@ def test_CommunicatingStream(c, s, a, b): aa.update() bb.update() - assert (len(first(aa.outgoing.data.values())) and - len(first(bb.outgoing.data.values()))) - assert (len(first(aa.incoming.data.values())) and - len(first(bb.incoming.data.values()))) + assert len(first(aa.outgoing.data.values())) and len( + first(bb.outgoing.data.values()) + ) + assert len(first(aa.incoming.data.values())) and len( + first(bb.incoming.data.values()) + ) -@gen_cluster(client=True, - check_new_threads=False, - worker_kwargs={'services': {('bokeh', 0): BokehWorker}}) +@gen_cluster( + client=True, + check_new_threads=False, + worker_kwargs={"services": {("bokeh", 0): BokehWorker}}, +) def test_prometheus(c, s, a, b): - pytest.importorskip('prometheus_client') - assert s.workers[a.address].services == {'bokeh': a.services['bokeh'].port} + pytest.importorskip("prometheus_client") + assert s.workers[a.address].services == {"bokeh": a.services["bokeh"].port} http_client = AsyncHTTPClient() - for suffix in ['metrics']: - response = yield http_client.fetch('http://localhost:%d/%s' - % (a.services['bokeh'].port, suffix)) + for suffix in ["metrics"]: + response = yield http_client.fetch( + "http://localhost:%d/%s" % (a.services["bokeh"].port, suffix) + ) assert response.code == 200 diff --git a/distributed/bokeh/tests/test_worker_bokeh_html.py b/distributed/bokeh/tests/test_worker_bokeh_html.py index 80819972050..d59fec8d2d8 100644 --- a/distributed/bokeh/tests/test_worker_bokeh_html.py +++ b/distributed/bokeh/tests/test_worker_bokeh_html.py @@ -1,15 +1,15 @@ import pytest -pytest.importorskip('bokeh') + +pytest.importorskip("bokeh") from tornado.httpclient import AsyncHTTPClient from distributed.utils_test import gen_cluster from distributed.bokeh.worker import BokehWorker -@gen_cluster(client=True, - worker_kwargs={'services': {('bokeh', 0): BokehWorker}}) +@gen_cluster(client=True, worker_kwargs={"services": {("bokeh", 0): BokehWorker}}) def test_prometheus(c, s, a, b): - pytest.importorskip('prometheus_client') + pytest.importorskip("prometheus_client") from prometheus_client.parser import text_string_to_metric_families http_client = AsyncHTTPClient() @@ -17,28 +17,26 @@ def test_prometheus(c, s, a, b): # request data twice since there once was a case where metrics got registered multiple times resulting in # prometheus_client errors for _ in range(2): - response = yield http_client.fetch('http://localhost:%d/metrics' - % a.services['bokeh'].port) + response = yield http_client.fetch( + "http://localhost:%d/metrics" % a.services["bokeh"].port + ) assert response.code == 200 - assert response.headers['Content-Type'] == 'text/plain; version=0.0.4' + assert response.headers["Content-Type"] == "text/plain; version=0.0.4" - txt = response.body.decode('utf8') - families = { - familiy.name - for familiy in text_string_to_metric_families(txt) - } + txt = response.body.decode("utf8") + families = {familiy.name for familiy in text_string_to_metric_families(txt)} assert len(families) > 0 -@gen_cluster(client=True, - worker_kwargs={'services': {('bokeh', 0): BokehWorker}}) +@gen_cluster(client=True, worker_kwargs={"services": {("bokeh", 0): BokehWorker}}) def test_health(c, s, a, b): http_client = AsyncHTTPClient() - response = yield http_client.fetch('http://localhost:%d/health' - % a.services['bokeh'].port) + response = yield http_client.fetch( + "http://localhost:%d/health" % a.services["bokeh"].port + ) assert response.code == 200 - assert response.headers['Content-Type'] == 'text/plain' + assert response.headers["Content-Type"] == "text/plain" - txt = response.body.decode('utf8') - assert txt == 'ok' + txt = response.body.decode("utf8") + assert txt == "ok" diff --git a/distributed/bokeh/utils.py b/distributed/bokeh/utils.py index 3bfada9402e..516ca5bfb88 100644 --- a/distributed/bokeh/utils.py +++ b/distributed/bokeh/utils.py @@ -10,11 +10,12 @@ BOKEH_VERSION = LooseVersion(bokeh.__version__) -if BOKEH_VERSION >= '1.0.0' and not PY2: +if BOKEH_VERSION >= "1.0.0" and not PY2: # This decorator is only available in bokeh >= 1.0.0, and doesn't work for # callbacks in Python 2, since the signature introspection won't line up. from bokeh.core.properties import without_property_validation else: + def without_property_validation(f): return f diff --git a/distributed/bokeh/worker.py b/distributed/bokeh/worker.py index ef6c27e0404..c7ced4d90fc 100644 --- a/distributed/bokeh/worker.py +++ b/distributed/bokeh/worker.py @@ -6,9 +6,17 @@ import os from bokeh.layouts import row, column, widgetbox -from bokeh.models import (ColumnDataSource, DataRange1d, HoverTool, - BoxZoomTool, ResetTool, PanTool, WheelZoomTool, NumeralTickFormatter, - Select) +from bokeh.models import ( + ColumnDataSource, + DataRange1d, + HoverTool, + BoxZoomTool, + ResetTool, + PanTool, + WheelZoomTool, + NumeralTickFormatter, + Select, +) from bokeh.models.widgets import DataTable, TableColumn from bokeh.plotting import figure @@ -16,27 +24,34 @@ from bokeh.themes import Theme from toolz import merge, partition_all -from .components import (DashboardComponent, ProfileTimePlot, ProfileServer, - add_periodic_callback) +from .components import ( + DashboardComponent, + ProfileTimePlot, + ProfileServer, + add_periodic_callback, +) from .core import BokehServer from .utils import transpose, without_property_validation from ..compatibility import WINDOWS from ..diagnostics.progress_stream import color_of from ..metrics import time -from ..utils import (log_errors, key_split, format_bytes, format_time) +from ..utils import log_errors, key_split, format_bytes, format_time logger = logging.getLogger(__name__) -with open(os.path.join(os.path.dirname(__file__), 'templates', 'base.html')) as f: +with open(os.path.join(os.path.dirname(__file__), "templates", "base.html")) as f: template_source = f.read() from jinja2 import Environment, FileSystemLoader -env = Environment(loader=FileSystemLoader(os.path.join(os.path.dirname(__file__), 'templates'))) -BOKEH_THEME = Theme(os.path.join(os.path.dirname(__file__), 'theme.yaml')) +env = Environment( + loader=FileSystemLoader(os.path.join(os.path.dirname(__file__), "templates")) +) -template_variables = {'pages': ['main', 'system', 'profile', 'crossfilter']} +BOKEH_THEME = Theme(os.path.join(os.path.dirname(__file__), "theme.yaml")) + +template_variables = {"pages": ["main", "system", "profile", "crossfilter"]} class StateTable(DashboardComponent): @@ -45,15 +60,13 @@ class StateTable(DashboardComponent): def __init__(self, worker): self.worker = worker - names = ['Stored', 'Executing', 'Ready', 'Waiting', 'Connections', 'Serving'] + names = ["Stored", "Executing", "Ready", "Waiting", "Connections", "Serving"] self.source = ColumnDataSource({name: [] for name in names}) - columns = {name: TableColumn(field=name, title=name) - for name in names} + columns = {name: TableColumn(field=name, title=name) for name in names} table = DataTable( - source=self.source, columns=[columns[n] for n in names], - height=70, + source=self.source, columns=[columns[n] for n in names], height=70 ) self.root = table @@ -61,12 +74,14 @@ def __init__(self, worker): def update(self): with log_errors(): w = self.worker - d = {'Stored': [len(w.data)], - 'Executing': ['%d / %d' % (len(w.executing), w.ncores)], - 'Ready': [len(w.ready)], - 'Waiting': [len(w.waiting_for_data)], - 'Connections': [len(w.in_flight_workers)], - 'Serving': [len(w._comms)]} + d = { + "Stored": [len(w.data)], + "Executing": ["%d / %d" % (len(w.executing), w.ncores)], + "Ready": [len(w.ready)], + "Waiting": [len(w.waiting_for_data)], + "Connections": [len(w.in_flight_workers)], + "Serving": [len(w._comms)], + } self.source.data.update(d) @@ -74,8 +89,18 @@ class CommunicatingStream(DashboardComponent): def __init__(self, worker, height=300, **kwargs): with log_errors(): self.worker = worker - names = ['start', 'stop', 'middle', 'duration', 'who', 'y', - 'hover', 'alpha', 'bandwidth', 'total'] + names = [ + "start", + "stop", + "middle", + "duration", + "who", + "y", + "hover", + "alpha", + "bandwidth", + "total", + ] self.incoming = ColumnDataSource({name: [] for name in names}) self.outgoing = ColumnDataSource({name: [] for name in names}) @@ -83,24 +108,41 @@ def __init__(self, worker, height=300, **kwargs): x_range = DataRange1d(range_padding=0) y_range = DataRange1d(range_padding=0) - fig = figure(title='Peer Communications', - x_axis_type='datetime', x_range=x_range, y_range=y_range, - height=height, tools='', **kwargs) - - fig.rect(source=self.incoming, x='middle', y='y', width='duration', - height=0.9, color='red', alpha='alpha') - fig.rect(source=self.outgoing, x='middle', y='y', width='duration', - height=0.9, color='blue', alpha='alpha') + fig = figure( + title="Peer Communications", + x_axis_type="datetime", + x_range=x_range, + y_range=y_range, + height=height, + tools="", + **kwargs + ) - hover = HoverTool( - point_policy="follow_mouse", - tooltips="""@hover""" + fig.rect( + source=self.incoming, + x="middle", + y="y", + width="duration", + height=0.9, + color="red", + alpha="alpha", + ) + fig.rect( + source=self.outgoing, + x="middle", + y="y", + width="duration", + height=0.9, + color="blue", + alpha="alpha", ) + + hover = HoverTool(point_policy="follow_mouse", tooltips="""@hover""") fig.add_tools( hover, ResetTool(), PanTool(dimensions="width"), - WheelZoomTool(dimensions="width") + WheelZoomTool(dimensions="width"), ) self.root = fig @@ -122,35 +164,40 @@ def update(self): incoming = [incoming[-i].copy() for i in range(1, n + 1)] self.last_incoming = self.worker.incoming_count - for [msgs, source] in [[incoming, self.incoming], - [outgoing, self.outgoing]]: + for [msgs, source] in [ + [incoming, self.incoming], + [outgoing, self.outgoing], + ]: for msg in msgs: - if 'compressed' in msg: - del msg['compressed'] - del msg['keys'] + if "compressed" in msg: + del msg["compressed"] + del msg["keys"] - bandwidth = msg['total'] / (msg['duration'] or 0.5) + bandwidth = msg["total"] / (msg["duration"] or 0.5) bw = max(min(bandwidth / 500e6, 1), 0.3) - msg['alpha'] = bw + msg["alpha"] = bw try: - msg['y'] = self.who[msg['who']] + msg["y"] = self.who[msg["who"]] except KeyError: - self.who[msg['who']] = len(self.who) - msg['y'] = self.who[msg['who']] + self.who[msg["who"]] = len(self.who) + msg["y"] = self.who[msg["who"]] - msg['hover'] = '%s / %s = %s/s' % ( - format_bytes(msg['total']), - format_time(msg['duration']), - format_bytes(msg['total'] / msg['duration'])) + msg["hover"] = "%s / %s = %s/s" % ( + format_bytes(msg["total"]), + format_time(msg["duration"]), + format_bytes(msg["total"] / msg["duration"]), + ) - for k in ['middle', 'duration', 'start', 'stop']: + for k in ["middle", "duration", "start", "stop"]: msg[k] = msg[k] * 1000 if msgs: msgs = transpose(msgs) - if (len(source.data['stop']) and - min(msgs['start']) > source.data['stop'][-1] + 10000): + if ( + len(source.data["stop"]) + and min(msgs["start"]) > source.data["stop"][-1] + 10000 + ): source.data.update(msgs) else: source.stream(msgs, rollover=10000) @@ -159,21 +206,24 @@ def update(self): class CommunicatingTimeSeries(DashboardComponent): def __init__(self, worker, **kwargs): self.worker = worker - self.source = ColumnDataSource({'x': [], 'in': [], 'out': []}) - - x_range = DataRange1d(follow='end', follow_interval=20000, range_padding=0) - - fig = figure(title="Communication History", - x_axis_type='datetime', - y_range=[-0.1, worker.total_out_connections + 0.5], - height=150, tools='', x_range=x_range, **kwargs) - fig.line(source=self.source, x='x', y='in', color='red') - fig.line(source=self.source, x='x', y='out', color='blue') + self.source = ColumnDataSource({"x": [], "in": [], "out": []}) + + x_range = DataRange1d(follow="end", follow_interval=20000, range_padding=0) + + fig = figure( + title="Communication History", + x_axis_type="datetime", + y_range=[-0.1, worker.total_out_connections + 0.5], + height=150, + tools="", + x_range=x_range, + **kwargs + ) + fig.line(source=self.source, x="x", y="in", color="red") + fig.line(source=self.source, x="x", y="out", color="blue") fig.add_tools( - ResetTool(), - PanTool(dimensions="width"), - WheelZoomTool(dimensions="width") + ResetTool(), PanTool(dimensions="width"), WheelZoomTool(dimensions="width") ) self.root = fig @@ -181,28 +231,36 @@ def __init__(self, worker, **kwargs): @without_property_validation def update(self): with log_errors(): - self.source.stream({'x': [time() * 1000], - 'out': [len(self.worker._comms)], - 'in': [len(self.worker.in_flight_workers)]}, - 10000) + self.source.stream( + { + "x": [time() * 1000], + "out": [len(self.worker._comms)], + "in": [len(self.worker.in_flight_workers)], + }, + 10000, + ) class ExecutingTimeSeries(DashboardComponent): def __init__(self, worker, **kwargs): self.worker = worker - self.source = ColumnDataSource({'x': [], 'y': []}) - - x_range = DataRange1d(follow='end', follow_interval=20000, range_padding=0) - - fig = figure(title="Executing History", - x_axis_type='datetime', y_range=[-0.1, worker.ncores + 0.1], - height=150, tools='', x_range=x_range, **kwargs) - fig.line(source=self.source, x='x', y='y') + self.source = ColumnDataSource({"x": [], "y": []}) + + x_range = DataRange1d(follow="end", follow_interval=20000, range_padding=0) + + fig = figure( + title="Executing History", + x_axis_type="datetime", + y_range=[-0.1, worker.ncores + 0.1], + height=150, + tools="", + x_range=x_range, + **kwargs + ) + fig.line(source=self.source, x="x", y="y") fig.add_tools( - ResetTool(), - PanTool(dimensions="width"), - WheelZoomTool(dimensions="width") + ResetTool(), PanTool(dimensions="width"), WheelZoomTool(dimensions="width") ) self.root = fig @@ -210,8 +268,9 @@ def __init__(self, worker, **kwargs): @without_property_validation def update(self): with log_errors(): - self.source.stream({'x': [time() * 1000], - 'y': [len(self.worker.executing)]}, 1000) + self.source.stream( + {"x": [time() * 1000], "y": [len(self.worker.executing)]}, 1000 + ) class CrossFilter(DashboardComponent): @@ -219,52 +278,55 @@ def __init__(self, worker, **kwargs): with log_errors(): self.worker = worker - quantities = ['nbytes', 'duration', 'bandwidth', 'count', - 'start', 'stop'] - colors = ['inout-color', 'type-color', 'key-color'] + quantities = ["nbytes", "duration", "bandwidth", "count", "start", "stop"] + colors = ["inout-color", "type-color", "key-color"] # self.source = ColumnDataSource({name: [] for name in names}) - self.source = ColumnDataSource({ - 'nbytes': [1, 2], - 'duration': [0.01, 0.02], - 'bandwidth': [0.01, 0.02], - 'count': [1, 2], - 'type': ['int', 'str'], - 'inout-color': ['blue', 'red'], - 'type-color': ['blue', 'red'], - 'key': ['add', 'inc'], - 'start': [1, 2], - 'stop': [1, 2] - }) - - self.x = Select(title='X-Axis', value='nbytes', options=quantities) - self.x.on_change('value', self.update_figure) - - self.y = Select(title='Y-Axis', value='bandwidth', options=quantities) - self.y.on_change('value', self.update_figure) - - self.size = Select(title='Size', value='None', - options=['None'] + quantities) - self.size.on_change('value', self.update_figure) - - self.color = Select(title='Color', value='inout-color', - options=['black'] + colors) - self.color.on_change('value', self.update_figure) - - if 'sizing_mode' in kwargs: - kw = {'sizing_mode': kwargs['sizing_mode']} + self.source = ColumnDataSource( + { + "nbytes": [1, 2], + "duration": [0.01, 0.02], + "bandwidth": [0.01, 0.02], + "count": [1, 2], + "type": ["int", "str"], + "inout-color": ["blue", "red"], + "type-color": ["blue", "red"], + "key": ["add", "inc"], + "start": [1, 2], + "stop": [1, 2], + } + ) + + self.x = Select(title="X-Axis", value="nbytes", options=quantities) + self.x.on_change("value", self.update_figure) + + self.y = Select(title="Y-Axis", value="bandwidth", options=quantities) + self.y.on_change("value", self.update_figure) + + self.size = Select( + title="Size", value="None", options=["None"] + quantities + ) + self.size.on_change("value", self.update_figure) + + self.color = Select( + title="Color", value="inout-color", options=["black"] + colors + ) + self.color.on_change("value", self.update_figure) + + if "sizing_mode" in kwargs: + kw = {"sizing_mode": kwargs["sizing_mode"]} else: kw = {} - self.control = widgetbox([self.x, self.y, self.size, self.color], - width=200, **kw) + self.control = widgetbox( + [self.x, self.y, self.size, self.color], width=200, **kw + ) self.last_outgoing = 0 self.last_incoming = 0 self.kwargs = kwargs - self.layout = row(self.control, self.create_figure(**self.kwargs), - **kw) + self.layout = row(self.control, self.create_figure(**self.kwargs), **kw) self.root = self.layout @@ -286,36 +348,44 @@ def update(self): out = [] for msg in incoming: - if msg['keys']: + if msg["keys"]: d = self.process_msg(msg) - d['inout-color'] = 'red' + d["inout-color"] = "red" out.append(d) for msg in outgoing: - if msg['keys']: + if msg["keys"]: d = self.process_msg(msg) - d['inout-color'] = 'blue' + d["inout-color"] = "blue" out.append(d) if out: out = transpose(out) - if (len(self.source.data['stop']) and - min(out['start']) > self.source.data['stop'][-1] + 10): + if ( + len(self.source.data["stop"]) + and min(out["start"]) > self.source.data["stop"][-1] + 10 + ): self.source.data.update(out) else: self.source.stream(out, rollover=1000) def create_figure(self, **kwargs): with log_errors(): - fig = figure(title='', tools='', **kwargs) + fig = figure(title="", tools="", **kwargs) size = self.size.value - if size == 'None': + if size == "None": size = 1 - fig.circle(source=self.source, x=self.x.value, y=self.y.value, - color=self.color.value, size=10, alpha=0.5, - hover_alpha=1) + fig.circle( + source=self.source, + x=self.x.value, + y=self.y.value, + color=self.color.value, + size=10, + alpha=0.5, + hover_alpha=1, + ) fig.xaxis.axis_label = self.x.value fig.yaxis.axis_label = self.y.value @@ -336,22 +406,24 @@ def update_figure(self, attr, old, new): def process_msg(self, msg): try: + def func(k): - return msg['keys'].get(k, 0) - main_key = max(msg['keys'], key=func) + return msg["keys"].get(k, 0) + + main_key = max(msg["keys"], key=func) typ = self.worker.types.get(main_key, object).__name__ keyname = key_split(main_key) d = { - 'nbytes': msg['total'], - 'duration': msg['duration'], - 'bandwidth': msg['bandwidth'], - 'count': len(msg['keys']), - 'type': typ, - 'type-color': color_of(typ), - 'key': keyname, - 'key-color': color_of(keyname), - 'start': msg['start'], - 'stop': msg['stop'] + "nbytes": msg["total"], + "duration": msg["duration"], + "bandwidth": msg["bandwidth"], + "count": len(msg["keys"]), + "type": typ, + "type-color": color_of(typ), + "key": keyname, + "key-color": color_of(keyname), + "start": msg["start"], + "stop": msg["stop"], } return d except Exception as e: @@ -368,44 +440,63 @@ def __init__(self, worker, height=150, **kwargs): self.source = ColumnDataSource({name: [] for name in names}) self.source.data.update(self.get_data()) - x_range = DataRange1d(follow='end', follow_interval=20000, - range_padding=0) - - tools = 'reset,xpan,xwheel_zoom' - - self.cpu = figure(title="CPU", x_axis_type='datetime', - height=height, tools=tools, x_range=x_range, **kwargs) - self.cpu.line(source=self.source, x='time', y='cpu') - self.cpu.yaxis.axis_label = 'Percentage' - self.mem = figure(title="Memory", x_axis_type='datetime', - height=height, tools=tools, x_range=x_range, **kwargs) - self.mem.line(source=self.source, x='time', y='memory') - self.mem.yaxis.axis_label = 'Bytes' - self.bandwidth = figure(title='Bandwidth', x_axis_type='datetime', - height=height, - x_range=x_range, tools=tools, **kwargs) - self.bandwidth.line(source=self.source, x='time', y='read_bytes', - color='red') - self.bandwidth.line(source=self.source, x='time', y='write_bytes', - color='blue') - self.bandwidth.yaxis.axis_label = 'Bytes / second' + x_range = DataRange1d(follow="end", follow_interval=20000, range_padding=0) + + tools = "reset,xpan,xwheel_zoom" + + self.cpu = figure( + title="CPU", + x_axis_type="datetime", + height=height, + tools=tools, + x_range=x_range, + **kwargs + ) + self.cpu.line(source=self.source, x="time", y="cpu") + self.cpu.yaxis.axis_label = "Percentage" + self.mem = figure( + title="Memory", + x_axis_type="datetime", + height=height, + tools=tools, + x_range=x_range, + **kwargs + ) + self.mem.line(source=self.source, x="time", y="memory") + self.mem.yaxis.axis_label = "Bytes" + self.bandwidth = figure( + title="Bandwidth", + x_axis_type="datetime", + height=height, + x_range=x_range, + tools=tools, + **kwargs + ) + self.bandwidth.line(source=self.source, x="time", y="read_bytes", color="red") + self.bandwidth.line(source=self.source, x="time", y="write_bytes", color="blue") + self.bandwidth.yaxis.axis_label = "Bytes / second" # self.cpu.yaxis[0].formatter = NumeralTickFormatter(format='0%') - self.bandwidth.yaxis[0].formatter = NumeralTickFormatter(format='0.0b') - self.mem.yaxis[0].formatter = NumeralTickFormatter(format='0.0b') + self.bandwidth.yaxis[0].formatter = NumeralTickFormatter(format="0.0b") + self.mem.yaxis[0].formatter = NumeralTickFormatter(format="0.0b") plots = [self.cpu, self.mem, self.bandwidth] if not WINDOWS: - self.num_fds = figure(title='Number of File Descriptors', - x_axis_type='datetime', height=height, - x_range=x_range, tools=tools, **kwargs) + self.num_fds = figure( + title="Number of File Descriptors", + x_axis_type="datetime", + height=height, + x_range=x_range, + tools=tools, + **kwargs + ) - self.num_fds.line(source=self.source, x='time', y='num_fds') + self.num_fds.line(source=self.source, x="time", y="num_fds") plots.append(self.num_fds) - if 'sizing_mode' in kwargs: - kw = {'sizing_mode': kwargs['sizing_mode']} + if "sizing_mode" in kwargs: + kw = {"sizing_mode": kwargs["sizing_mode"]} else: kw = {} @@ -420,7 +511,7 @@ def __init__(self, worker, height=150, **kwargs): def get_data(self): d = self.worker.monitor.range_query(start=self.last) - d['time'] = [x * 1000 for x in d['time']] + d["time"] = [x * 1000 for x in d["time"]] self.last = self.worker.monitor.count return d @@ -431,7 +522,7 @@ def update(self): class Counters(DashboardComponent): - def __init__(self, server, sizing_mode='stretch_both', **kwargs): + def __init__(self, server, sizing_mode="stretch_both", **kwargs): self.server = server self.counter_figures = {} self.counter_sources = {} @@ -451,31 +542,40 @@ def __init__(self, server, sizing_mode='stretch_both', **kwargs): if len(figures) <= 5: self.root = column(figures, sizing_mode=sizing_mode) else: - self.root = column(*[row(*pair, sizing_mode=sizing_mode) - for pair in partition_all(2, figures)], - sizing_mode=sizing_mode) + self.root = column( + *[ + row(*pair, sizing_mode=sizing_mode) + for pair in partition_all(2, figures) + ], + sizing_mode=sizing_mode + ) def add_digest_figure(self, name): with log_errors(): n = len(self.server.digests[name].intervals) - sources = {i: ColumnDataSource({'x': [], 'y': []}) - for i in range(n)} + sources = {i: ColumnDataSource({"x": [], "y": []}) for i in range(n)} kwargs = {} - if name.endswith('duration'): - kwargs['x_axis_type'] = 'datetime' + if name.endswith("duration"): + kwargs["x_axis_type"] = "datetime" - fig = figure(title=name, tools='', height=150, - sizing_mode=self.sizing_mode, **kwargs) + fig = figure( + title=name, tools="", height=150, sizing_mode=self.sizing_mode, **kwargs + ) fig.yaxis.visible = False fig.ygrid.visible = False - if name.endswith('bandwidth') or name.endswith('bytes'): - fig.xaxis[0].formatter = NumeralTickFormatter(format='0.0b') + if name.endswith("bandwidth") or name.endswith("bytes"): + fig.xaxis[0].formatter = NumeralTickFormatter(format="0.0b") for i in range(n): alpha = 0.3 + 0.3 * (n - i) / n - fig.line(source=sources[i], x='x', y='y', - alpha=alpha, color=RdBu[max(n, 3)][-i]) + fig.line( + source=sources[i], + x="x", + y="y", + alpha=alpha, + color=RdBu[max(n, 3)][-i], + ) fig.xaxis.major_label_orientation = math.pi / 12 fig.toolbar.logo = None @@ -486,22 +586,33 @@ def add_digest_figure(self, name): def add_counter_figure(self, name): with log_errors(): n = len(self.server.counters[name].intervals) - sources = {i: ColumnDataSource({'x': [], 'y': [], - 'y-center': [], 'counts': []}) - for i in range(n)} + sources = { + i: ColumnDataSource({"x": [], "y": [], "y-center": [], "counts": []}) + for i in range(n) + } - fig = figure(title=name, tools='', height=150, - sizing_mode=self.sizing_mode, - x_range=sorted(map(str, self.server.counters[name].components[0]))) + fig = figure( + title=name, + tools="", + height=150, + sizing_mode=self.sizing_mode, + x_range=sorted(map(str, self.server.counters[name].components[0])), + ) fig.ygrid.visible = False for i in range(n): width = 0.5 + 0.4 * i / n - fig.rect(source=sources[i], x='x', y='y-center', width=width, - height='y', alpha=0.3, color=RdBu[max(n, 3)][-i]) + fig.rect( + source=sources[i], + x="x", + y="y-center", + width=width, + height="y", + alpha=0.3, + color=RdBu[max(n, 3)][-i], + ) hover = HoverTool( - point_policy="follow_mouse", - tooltips="""@x : @counts""" + point_policy="follow_mouse", tooltips="""@x : @counts""" ) fig.add_tools(hover) fig.xaxis.major_label_orientation = math.pi / 12 @@ -522,10 +633,10 @@ def update(self): if d.size(): ys, xs = d.histogram(100) xs = xs[1:] - if name.endswith('duration'): + if name.endswith("duration"): xs *= 1000 - self.digest_sources[name][i].data.update({'x': xs, 'y': ys}) - fig.title.text = '%s: %d' % (name, digest.size()) + self.digest_sources[name][i].data.update({"x": xs, "y": ys}) + fig.title.text = "%s: %d" % (name, digest.size()) for name, fig in self.counter_figures.items(): counter = self.server.counters[name] @@ -538,10 +649,9 @@ def update(self): ys = [factor * c for c in counts] y_centers = [y / 2 for y in ys] xs = list(map(str, xs)) - d = {'x': xs, 'y': ys, 'y-center': y_centers, - 'counts': counts} + d = {"x": xs, "y": ys, "y-center": y_centers, "counts": counts} self.counter_sources[name][i].data.update(d) - fig.title.text = '%s: %d' % (name, counter.size()) + fig.title.text = "%s: %d" % (name, counter.size()) fig.x_range.factors = list(map(str, xs)) @@ -552,11 +662,9 @@ def update(self): def main_doc(worker, extra, doc): with log_errors(): statetable = StateTable(worker) - executing_ts = ExecutingTimeSeries(worker, sizing_mode='scale_width') - communicating_ts = CommunicatingTimeSeries(worker, - sizing_mode='scale_width') - communicating_stream = CommunicatingStream(worker, - sizing_mode='scale_width') + executing_ts = ExecutingTimeSeries(worker, sizing_mode="scale_width") + communicating_ts = CommunicatingTimeSeries(worker, sizing_mode="scale_width") + communicating_stream = CommunicatingStream(worker, sizing_mode="scale_width") xr = executing_ts.root.x_range communicating_ts.root.x_range = xr @@ -567,13 +675,17 @@ def main_doc(worker, extra, doc): add_periodic_callback(doc, executing_ts, 200) add_periodic_callback(doc, communicating_ts, 200) add_periodic_callback(doc, communicating_stream, 200) - doc.add_root(column(statetable.root, - executing_ts.root, - communicating_ts.root, - communicating_stream.root, - sizing_mode='scale_width')) - doc.template = env.get_template('simple.html') - doc.template_variables['active_page'] = 'main' + doc.add_root( + column( + statetable.root, + executing_ts.root, + communicating_ts.root, + communicating_stream.root, + sizing_mode="scale_width", + ) + ) + doc.template = env.get_template("simple.html") + doc.template_variables["active_page"] = "main" doc.template_variables.update(extra) doc.theme = BOKEH_THEME @@ -588,21 +700,21 @@ def crossfilter_doc(worker, extra, doc): add_periodic_callback(doc, crossfilter, 500) doc.add_root(column(statetable.root, crossfilter.root)) - doc.template = env.get_template('simple.html') - doc.template_variables['active_page'] = 'crossfilter' + doc.template = env.get_template("simple.html") + doc.template_variables["active_page"] = "crossfilter" doc.template_variables.update(extra) doc.theme = BOKEH_THEME def systemmonitor_doc(worker, extra, doc): with log_errors(): - sysmon = SystemMonitor(worker, sizing_mode='scale_width') + sysmon = SystemMonitor(worker, sizing_mode="scale_width") doc.title = "Dask Worker Monitor" add_periodic_callback(doc, sysmon, 500) doc.add_root(sysmon.root) - doc.template = env.get_template('simple.html') - doc.template_variables['active_page'] = 'system' + doc.template = env.get_template("simple.html") + doc.template_variables["active_page"] = "system" doc.template_variables.update(extra) doc.theme = BOKEH_THEME @@ -610,12 +722,12 @@ def systemmonitor_doc(worker, extra, doc): def counters_doc(server, extra, doc): with log_errors(): doc.title = "Dask Worker Counters" - counter = Counters(server, sizing_mode='stretch_both') + counter = Counters(server, sizing_mode="stretch_both") add_periodic_callback(doc, counter, 500) doc.add_root(counter.root) - doc.template = env.get_template('simple.html') - doc.template_variables['active_page'] = 'counters' + doc.template = env.get_template("simple.html") + doc.template_variables["active_page"] = "counters" doc.template_variables.update(extra) doc.theme = BOKEH_THEME @@ -623,12 +735,12 @@ def counters_doc(server, extra, doc): def profile_doc(server, extra, doc): with log_errors(): doc.title = "Dask Worker Profile" - profile = ProfileTimePlot(server, sizing_mode='scale_width') + profile = ProfileTimePlot(server, sizing_mode="scale_width") profile.trigger_update() doc.add_root(profile.root) - doc.template = env.get_template('simple.html') - doc.template_variables['active_page'] = 'profile' + doc.template = env.get_template("simple.html") + doc.template_variables["active_page"] = "profile" doc.template_variables.update(extra) doc.theme = BOKEH_THEME @@ -636,9 +748,9 @@ def profile_doc(server, extra, doc): def profile_server_doc(server, extra, doc): with log_errors(): doc.title = "Dask: Profile of Event Loop" - prof = ProfileServer(server, sizing_mode='scale_width', doc=doc) + prof = ProfileServer(server, sizing_mode="scale_width", doc=doc) doc.add_root(prof.root) - doc.template = env.get_template('simple.html') + doc.template = env.get_template("simple.html") # doc.template_variables['active_page'] = '' doc.template_variables.update(extra) doc.theme = BOKEH_THEME @@ -647,40 +759,48 @@ def profile_server_doc(server, extra, doc): class BokehWorker(BokehServer): - def __init__(self, worker, io_loop=None, prefix='', **kwargs): + def __init__(self, worker, io_loop=None, prefix="", **kwargs): self.worker = worker self.server_kwargs = kwargs - self.server_kwargs['prefix'] = prefix or None - prefix = prefix or '' - prefix = prefix.rstrip('/') - if prefix and not prefix.startswith('/'): - prefix = '/' + prefix + self.server_kwargs["prefix"] = prefix or None + prefix = prefix or "" + prefix = prefix.rstrip("/") + if prefix and not prefix.startswith("/"): + prefix = "/" + prefix self.prefix = prefix - extra = {'prefix': prefix} + extra = {"prefix": prefix} extra.update(template_variables) main = Application(FunctionHandler(partial(main_doc, worker, extra))) - crossfilter = Application(FunctionHandler(partial(crossfilter_doc, worker, extra))) - systemmonitor = Application(FunctionHandler(partial(systemmonitor_doc, worker, extra))) + crossfilter = Application( + FunctionHandler(partial(crossfilter_doc, worker, extra)) + ) + systemmonitor = Application( + FunctionHandler(partial(systemmonitor_doc, worker, extra)) + ) counters = Application(FunctionHandler(partial(counters_doc, worker, extra))) profile = Application(FunctionHandler(partial(profile_doc, worker, extra))) - profile_server = Application(FunctionHandler(partial(profile_server_doc, worker, extra))) + profile_server = Application( + FunctionHandler(partial(profile_server_doc, worker, extra)) + ) - self.apps = {'/main': main, - '/counters': counters, - '/crossfilter': crossfilter, - '/system': systemmonitor, - '/profile': profile, - '/profile-server': profile_server} + self.apps = { + "/main": main, + "/counters": counters, + "/crossfilter": crossfilter, + "/system": systemmonitor, + "/profile": profile, + "/profile-server": profile_server, + } self.loop = io_loop or worker.loop self.server = None @property def extra(self): - return merge({'prefix': self.prefix}, template_variables) + return merge({"prefix": self.prefix}, template_variables) @property def my_server(self): @@ -690,7 +810,14 @@ def listen(self, *args, **kwargs): super(BokehWorker, self).listen(*args, **kwargs) from .worker_html import routes - handlers = [(self.prefix + '/' + url, cls, {'server': self.my_server, 'extra': self.extra}) - for url, cls in routes] - self.server._tornado.add_handlers(r'.*', handlers) + handlers = [ + ( + self.prefix + "/" + url, + cls, + {"server": self.my_server, "extra": self.extra}, + ) + for url, cls in routes + ] + + self.server._tornado.add_handlers(r".*", handlers) diff --git a/distributed/bokeh/worker_html.py b/distributed/bokeh/worker_html.py index 5a956231c6a..3ddf9490c4d 100644 --- a/distributed/bokeh/worker_html.py +++ b/distributed/bokeh/worker_html.py @@ -11,7 +11,7 @@ def initialize(self, server=None, extra=None): self.extra = extra or {} def get_template_path(self): - return os.path.join(dirname, 'templates') + return os.path.join(dirname, "templates") class _PrometheusCollector(object): @@ -39,6 +39,7 @@ class PrometheusHandler(RequestHandler): def __init__(self, *args, **kwargs): import prometheus_client # keep out of global namespace + self.prometheus_client = prometheus_client super(PrometheusHandler, self).__init__(*args, **kwargs) @@ -50,30 +51,24 @@ def _init(self): return self.prometheus_client.REGISTRY.register( - _PrometheusCollector( - self.server, - self.prometheus_client, - ) + _PrometheusCollector(self.server, self.prometheus_client) ) PrometheusHandler._initialized = True def get(self): self.write(self.prometheus_client.generate_latest()) - self.set_header('Content-Type', 'text/plain; version=0.0.4') + self.set_header("Content-Type", "text/plain; version=0.0.4") class HealthHandler(RequestHandler): def get(self): - self.write('ok') - self.set_header('Content-Type', 'text/plain') + self.write("ok") + self.set_header("Content-Type", "text/plain") -routes = [ - (r'metrics', PrometheusHandler), - (r'health', HealthHandler), -] +routes = [(r"metrics", PrometheusHandler), (r"health", HealthHandler)] def get_handlers(server): - return [(url, cls, {'server': server}) for url, cls in routes] + return [(url, cls, {"server": server}) for url, cls in routes] diff --git a/distributed/cfexecutor.py b/distributed/cfexecutor.py index b22f36783c4..eb7bbf05646 100644 --- a/distributed/cfexecutor.py +++ b/distributed/cfexecutor.py @@ -21,9 +21,9 @@ def _cascade_future(future, cf_future): """ result = yield future._result(raiseit=False) status = future.status - if status == 'finished': + if status == "finished": cf_future.set_result(result) - elif status == 'cancelled': + elif status == "cancelled": cf_future.cancel() # Necessary for wait() and as_completed() to wake up cf_future.set_running_or_notify_cancel() @@ -48,13 +48,17 @@ class ClientExecutor(cf.Executor): A concurrent.futures Executor that executes tasks on a dask.distributed Client. """ - _allowed_kwargs = frozenset(['pure', 'workers', 'resources', 'allow_other_workers', 'retries']) + _allowed_kwargs = frozenset( + ["pure", "workers", "resources", "allow_other_workers", "retries"] + ) def __init__(self, client, **kwargs): sk = set(kwargs) if not sk <= self._allowed_kwargs: - raise TypeError("unsupported arguments to ClientExecutor: %s" - % sorted(sk - self._allowed_kwargs)) + raise TypeError( + "unsupported arguments to ClientExecutor: %s" + % sorted(sk - self._allowed_kwargs) + ) self._client = client self._futures = weakref.WeakSet() self._shutdown = False @@ -68,7 +72,7 @@ def _wrap_future(self, future): # Support cancelling task through .cancel() on c.f.Future def cf_callback(cf_future): - if cf_future.cancelled() and future.status != 'cancelled': + if cf_future.cancelled() and future.status != "cancelled": future.cancel() cf_future.add_done_callback(cf_callback) @@ -87,7 +91,7 @@ def submit(self, fn, *args, **kwargs): A Future representing the given call. """ if self._shutdown: - raise RuntimeError('cannot schedule new futures after shutdown') + raise RuntimeError("cannot schedule new futures after shutdown") future = self._client.submit(fn, *args, **merge(self._kwargs, kwargs)) self._futures.add(future) return self._wrap_future(future) @@ -115,14 +119,13 @@ def map(self, fn, *iterables, **kwargs): before the given timeout. Exception: If ``fn(*args)`` raises for any values. """ - timeout = kwargs.pop('timeout', None) + timeout = kwargs.pop("timeout", None) if timeout is not None: end_time = timeout + time() - if 'chunksize' in kwargs: - del kwargs['chunksize'] + if "chunksize" in kwargs: + del kwargs["chunksize"] if kwargs: - raise TypeError("unexpected arguments to map(): %s" - % sorted(kwargs)) + raise TypeError("unexpected arguments to map(): %s" % sorted(kwargs)) fs = self._client.map(fn, *iterables, **self._kwargs) diff --git a/distributed/cli/dask_mpi.py b/distributed/cli/dask_mpi.py index e26e0bc91e1..ef7dd0c59fa 100644 --- a/distributed/cli/dask_mpi.py +++ b/distributed/cli/dask_mpi.py @@ -18,32 +18,62 @@ @click.command() -@click.option('--scheduler-file', type=str, default='scheduler.json', - help='Filename to JSON encoded scheduler information. ') -@click.option('--interface', type=str, default=None, - help="Network interface like 'eth0' or 'ib0'") -@click.option('--nthreads', type=int, default=0, - help="Number of threads per worker.") -@click.option('--memory-limit', default='auto', - help="Number of bytes before spilling data to disk. " - "This can be an integer (nbytes) " - "float (fraction of total memory) " - "or 'auto'") -@click.option('--local-directory', default='', type=str, - help="Directory to place worker files") -@click.option('--scheduler/--no-scheduler', default=True, - help=("Whether or not to include a scheduler. " - "Use --no-scheduler to increase an existing dask cluster")) -@click.option('--nanny/--no-nanny', default=True, - help="Start workers in nanny process for management") -@click.option('--bokeh-port', type=int, default=8787, - help="Bokeh port for visual diagnostics") -@click.option('--bokeh-worker-port', type=int, default=8789, - help="Worker's Bokeh port for visual diagnostics") -@click.option('--bokeh-prefix', type=str, default=None, - help="Prefix for the bokeh app") -def main(scheduler_file, interface, nthreads, local_directory, memory_limit, - scheduler, bokeh_port, bokeh_prefix, nanny, bokeh_worker_port): +@click.option( + "--scheduler-file", + type=str, + default="scheduler.json", + help="Filename to JSON encoded scheduler information. ", +) +@click.option( + "--interface", type=str, default=None, help="Network interface like 'eth0' or 'ib0'" +) +@click.option("--nthreads", type=int, default=0, help="Number of threads per worker.") +@click.option( + "--memory-limit", + default="auto", + help="Number of bytes before spilling data to disk. " + "This can be an integer (nbytes) " + "float (fraction of total memory) " + "or 'auto'", +) +@click.option( + "--local-directory", default="", type=str, help="Directory to place worker files" +) +@click.option( + "--scheduler/--no-scheduler", + default=True, + help=( + "Whether or not to include a scheduler. " + "Use --no-scheduler to increase an existing dask cluster" + ), +) +@click.option( + "--nanny/--no-nanny", + default=True, + help="Start workers in nanny process for management", +) +@click.option( + "--bokeh-port", type=int, default=8787, help="Bokeh port for visual diagnostics" +) +@click.option( + "--bokeh-worker-port", + type=int, + default=8789, + help="Worker's Bokeh port for visual diagnostics", +) +@click.option("--bokeh-prefix", type=str, default=None, help="Prefix for the bokeh app") +def main( + scheduler_file, + interface, + nthreads, + local_directory, + memory_limit, + scheduler, + bokeh_port, + bokeh_prefix, + nanny, + bokeh_worker_port, +): if interface: host = get_ip_interface(interface) else: @@ -55,11 +85,12 @@ def main(scheduler_file, interface, nthreads, local_directory, memory_limit, except ImportError: services = {} else: - services = {('bokeh', bokeh_port): partial(BokehScheduler, - prefix=bokeh_prefix)} - scheduler = Scheduler(scheduler_file=scheduler_file, - loop=loop, - services=services) + services = { + ("bokeh", bokeh_port): partial(BokehScheduler, prefix=bokeh_prefix) + } + scheduler = Scheduler( + scheduler_file=scheduler_file, loop=loop, services=services + ) addr = uri_from_host_port(host, None, 8786) scheduler.start(addr) try: @@ -69,19 +100,21 @@ def main(scheduler_file, interface, nthreads, local_directory, memory_limit, scheduler.stop() else: W = Nanny if nanny else Worker - worker = W(scheduler_file=scheduler_file, - loop=loop, - name=rank if scheduler else None, - ncores=nthreads, - local_dir=local_directory, - services={('bokeh', bokeh_worker_port): BokehWorker}, - memory_limit=memory_limit) + worker = W( + scheduler_file=scheduler_file, + loop=loop, + name=rank if scheduler else None, + ncores=nthreads, + local_dir=local_directory, + services={("bokeh", bokeh_worker_port): BokehWorker}, + memory_limit=memory_limit, + ) addr = uri_from_host_port(host, None, 0) @gen.coroutine def run(): yield worker._start(addr) - while worker.status != 'closed': + while worker.status != "closed": yield gen.sleep(0.2) try: @@ -99,12 +132,14 @@ def close(): def go(): check_python_3() - warn("The dask-mpi command line utility in the `distributed` " - "package is deprecated. " - "Please install the `dask-mpi` package instead. " - "More information is available at https://mpi.dask.org") + warn( + "The dask-mpi command line utility in the `distributed` " + "package is deprecated. " + "Please install the `dask-mpi` package instead. " + "More information is available at https://mpi.dask.org" + ) main() -if __name__ == '__main__': +if __name__ == "__main__": go() diff --git a/distributed/cli/dask_remote.py b/distributed/cli/dask_remote.py index 2d94d0e1142..933d8d318b0 100644 --- a/distributed/cli/dask_remote.py +++ b/distributed/cli/dask_remote.py @@ -6,9 +6,8 @@ @click.command() -@click.option('--host', type=str, default=None, - help="IP or hostname of this server") -@click.option('--port', type=int, default=8788, help="Remote Client Port") +@click.option("--host", type=str, default=None, help="IP or hostname of this server") +@click.option("--port", type=int, default=8788, help="Remote Client Port") def main(host, port): _remote(host, port) @@ -19,5 +18,5 @@ def go(): main() -if __name__ == '__main__': +if __name__ == "__main__": go() diff --git a/distributed/cli/dask_scheduler.py b/distributed/cli/dask_scheduler.py index f48e8b4e26d..0e8415ac132 100755 --- a/distributed/cli/dask_scheduler.py +++ b/distributed/cli/dask_scheduler.py @@ -16,61 +16,127 @@ from distributed import Scheduler from distributed.security import Security from distributed.utils import get_ip_interface -from distributed.cli.utils import (check_python_3, install_signal_handlers, - uri_from_host_port) +from distributed.cli.utils import ( + check_python_3, + install_signal_handlers, + uri_from_host_port, +) from distributed.preloading import preload_modules, validate_preload_argv -from distributed.proctitle import (enable_proctitle_on_children, - enable_proctitle_on_current) +from distributed.proctitle import ( + enable_proctitle_on_children, + enable_proctitle_on_current, +) -logger = logging.getLogger('distributed.scheduler') +logger = logging.getLogger("distributed.scheduler") pem_file_option_type = click.Path(exists=True, resolve_path=True) @click.command(context_settings=dict(ignore_unknown_options=True)) -@click.option('--host', type=str, default='', - help="URI, IP or hostname of this server") -@click.option('--port', type=int, default=None, help="Serving port") -@click.option('--interface', type=str, default=None, - help="Preferred network interface like 'eth0' or 'ib0'") -@click.option('--tls-ca-file', type=pem_file_option_type, default=None, - help="CA cert(s) file for TLS (in PEM format)") -@click.option('--tls-cert', type=pem_file_option_type, default=None, - help="certificate file for TLS (in PEM format)") -@click.option('--tls-key', type=pem_file_option_type, default=None, - help="private key file for TLS (in PEM format)") +@click.option("--host", type=str, default="", help="URI, IP or hostname of this server") +@click.option("--port", type=int, default=None, help="Serving port") +@click.option( + "--interface", + type=str, + default=None, + help="Preferred network interface like 'eth0' or 'ib0'", +) +@click.option( + "--tls-ca-file", + type=pem_file_option_type, + default=None, + help="CA cert(s) file for TLS (in PEM format)", +) +@click.option( + "--tls-cert", + type=pem_file_option_type, + default=None, + help="certificate file for TLS (in PEM format)", +) +@click.option( + "--tls-key", + type=pem_file_option_type, + default=None, + help="private key file for TLS (in PEM format)", +) # XXX default port (or URI) values should be centralized somewhere -@click.option('--bokeh-port', type=int, default=None, - help="Deprecated. See --dashboard-address") -@click.option('--dashboard-address', type=str, default=':8787', - help="Address on which to listen for diagnostics dashboard") -@click.option('--bokeh/--no-bokeh', '_bokeh', default=True, show_default=True, - required=False, help="Launch Bokeh Web UI") -@click.option('--show/--no-show', default=False, help="Show web UI") -@click.option('--bokeh-whitelist', default=None, multiple=True, - help="IP addresses to whitelist for bokeh.") -@click.option('--bokeh-prefix', type=str, default=None, - help="Prefix for the bokeh app") -@click.option('--use-xheaders', type=bool, default=False, show_default=True, - help="User xheaders in bokeh app for ssl termination in header") -@click.option('--pid-file', type=str, default='', - help="File to write the process PID") -@click.option('--scheduler-file', type=str, default='', - help="File to write connection information. " - "This may be a good way to share connection information if your " - "cluster is on a shared network file system.") -@click.option('--local-directory', default='', type=str, - help="Directory to place scheduler files") -@click.option('--preload', type=str, multiple=True, is_eager=True, default='', - help='Module that should be loaded by the scheduler process ' - 'like "foo.bar" or "/path/to/foo.py".') -@click.argument('preload_argv', nargs=-1, - type=click.UNPROCESSED, callback=validate_preload_argv) -def main(host, port, bokeh_port, show, _bokeh, bokeh_whitelist, bokeh_prefix, - use_xheaders, pid_file, scheduler_file, interface, - local_directory, preload, preload_argv, tls_ca_file, tls_cert, tls_key, - dashboard_address): +@click.option( + "--bokeh-port", type=int, default=None, help="Deprecated. See --dashboard-address" +) +@click.option( + "--dashboard-address", + type=str, + default=":8787", + help="Address on which to listen for diagnostics dashboard", +) +@click.option( + "--bokeh/--no-bokeh", + "_bokeh", + default=True, + show_default=True, + required=False, + help="Launch Bokeh Web UI", +) +@click.option("--show/--no-show", default=False, help="Show web UI") +@click.option( + "--bokeh-whitelist", + default=None, + multiple=True, + help="IP addresses to whitelist for bokeh.", +) +@click.option("--bokeh-prefix", type=str, default=None, help="Prefix for the bokeh app") +@click.option( + "--use-xheaders", + type=bool, + default=False, + show_default=True, + help="User xheaders in bokeh app for ssl termination in header", +) +@click.option("--pid-file", type=str, default="", help="File to write the process PID") +@click.option( + "--scheduler-file", + type=str, + default="", + help="File to write connection information. " + "This may be a good way to share connection information if your " + "cluster is on a shared network file system.", +) +@click.option( + "--local-directory", default="", type=str, help="Directory to place scheduler files" +) +@click.option( + "--preload", + type=str, + multiple=True, + is_eager=True, + default="", + help="Module that should be loaded by the scheduler process " + 'like "foo.bar" or "/path/to/foo.py".', +) +@click.argument( + "preload_argv", nargs=-1, type=click.UNPROCESSED, callback=validate_preload_argv +) +def main( + host, + port, + bokeh_port, + show, + _bokeh, + bokeh_whitelist, + bokeh_prefix, + use_xheaders, + pid_file, + scheduler_file, + interface, + local_directory, + preload, + preload_argv, + tls_ca_file, + tls_cert, + tls_key, + dashboard_address, +): enable_proctitle_on_current() enable_proctitle_on_children() @@ -82,21 +148,21 @@ def main(host, port, bokeh_port, show, _bokeh, bokeh_whitelist, bokeh_prefix, ) dashboard_address = bokeh_port - sec = Security(tls_ca_file=tls_ca_file, - tls_scheduler_cert=tls_cert, - tls_scheduler_key=tls_key, - ) + sec = Security( + tls_ca_file=tls_ca_file, tls_scheduler_cert=tls_cert, tls_scheduler_key=tls_key + ) if not host and (tls_ca_file or tls_cert or tls_key): - host = 'tls://' + host = "tls://" if pid_file: - with open(pid_file, 'w') as f: + with open(pid_file, "w") as f: f.write(str(os.getpid())) def del_pid_file(): if os.path.exists(pid_file): os.remove(pid_file) + atexit.register(del_pid_file) local_directory_created = False @@ -105,13 +171,14 @@ def del_pid_file(): os.mkdir(local_directory) local_directory_created = True else: - local_directory = tempfile.mkdtemp(prefix='scheduler-') + local_directory = tempfile.mkdtemp(prefix="scheduler-") local_directory_created = True if local_directory not in sys.path: sys.path.insert(0, local_directory) - if sys.platform.startswith('linux'): - import resource # module fails importing on Windows + if sys.platform.startswith("linux"): + import resource # module fails importing on Windows + soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) limit = max(soft, hard // 2) resource.setrlimit(resource.RLIMIT_NOFILE, (limit, hard)) @@ -125,32 +192,37 @@ def del_pid_file(): addr = uri_from_host_port(host, port, 8786) loop = IOLoop.current() - logger.info('-' * 47) + logger.info("-" * 47) services = {} if _bokeh: try: from distributed.bokeh.scheduler import BokehScheduler - services[('bokeh', dashboard_address)] = (BokehScheduler, - {'prefix': bokeh_prefix}) + + services[("bokeh", dashboard_address)] = ( + BokehScheduler, + {"prefix": bokeh_prefix}, + ) except ImportError as error: - if str(error).startswith('No module named'): - logger.info('Web dashboard not loaded. Unable to import bokeh') + if str(error).startswith("No module named"): + logger.info("Web dashboard not loaded. Unable to import bokeh") else: - logger.info('Unable to import bokeh: %s' % str(error)) + logger.info("Unable to import bokeh: %s" % str(error)) - scheduler = Scheduler(loop=loop, services=services, - scheduler_file=scheduler_file, - security=sec) + scheduler = Scheduler( + loop=loop, services=services, scheduler_file=scheduler_file, security=sec + ) scheduler.start(addr) if not preload: - preload = dask.config.get('distributed.scheduler.preload') + preload = dask.config.get("distributed.scheduler.preload") if not preload_argv: - preload_argv = dask.config.get('distributed.scheduler.preload-argv') - preload_modules(preload, parameter=scheduler, file_dir=local_directory, argv=preload_argv) + preload_argv = dask.config.get("distributed.scheduler.preload-argv") + preload_modules( + preload, parameter=scheduler, file_dir=local_directory, argv=preload_argv + ) - logger.info('Local Directory: %26s', local_directory) - logger.info('-' * 47) + logger.info("Local Directory: %26s", local_directory) + logger.info("-" * 47) install_signal_handlers(loop) @@ -170,5 +242,5 @@ def go(): main() -if __name__ == '__main__': +if __name__ == "__main__": go() diff --git a/distributed/cli/dask_ssh.py b/distributed/cli/dask_ssh.py index acb87d21642..df2b1c6fe94 100755 --- a/distributed/cli/dask_ssh.py +++ b/distributed/cli/dask_ssh.py @@ -6,53 +6,120 @@ from distributed.cli.utils import check_python_3 -@click.command(help="""Launch a distributed cluster over SSH. A 'dask-scheduler' process will run on the +@click.command( + help="""Launch a distributed cluster over SSH. A 'dask-scheduler' process will run on the first host specified in [HOSTNAMES] or in the hostfile (unless --scheduler is specified explicitly). One or more 'dask-worker' processes will be run each host in [HOSTNAMES] or in the hostfile. Use command line flags to adjust how many dask-worker process are run on - each host (--nprocs) and how many cpus are used by each dask-worker process (--nthreads).""") -@click.option('--scheduler', default=None, type=str, - help="Specify scheduler node. Defaults to first address.") -@click.option('--scheduler-port', default=8786, type=int, - help="Specify scheduler port number. Defaults to port 8786.") -@click.option('--nthreads', default=0, type=int, - help=("Number of threads per worker process. " - "Defaults to number of cores divided by the number of " - "processes per host.")) -@click.option('--nprocs', default=1, type=int, - help="Number of worker processes per host. Defaults to one.") -@click.argument('hostnames', nargs=-1, type=str) -@click.option('--hostfile', default=None, type=click.Path(exists=True), - help="Textfile with hostnames/IP addresses") -@click.option('--ssh-username', default=None, type=str, - help="Username to use when establishing SSH connections.") -@click.option('--ssh-port', default=22, type=int, - help="Port to use for SSH connections.") -@click.option('--ssh-private-key', default=None, type=str, - help="Private key file to use for SSH connections.") -@click.option('--nohost', is_flag=True, - help="Do not pass the hostname to the worker.") -@click.option('--log-directory', default=None, type=click.Path(exists=True), - help=("Directory to use on all cluster nodes for the output of " - "dask-scheduler and dask-worker commands.")) -@click.option('--remote-python', default=None, type=str, - help="Path to Python on remote nodes.") -@click.option('--memory-limit', default='auto', - help="Bytes of memory that the worker can use. " - "This can be an integer (bytes), " - "float (fraction of total system memory), " - "string (like 5GB or 5000M), " - "'auto', or zero for no memory management") -@click.option('--worker-port', type=int, default=0, - help="Serving computation port, defaults to random") -@click.option('--nanny-port', type=int, default=0, - help="Serving nanny port, defaults to random") -@click.option('--remote-dask-worker', default=None, type=str, - help="Worker to run. Defaults to distributed.cli.dask_worker") + each host (--nprocs) and how many cpus are used by each dask-worker process (--nthreads).""" +) +@click.option( + "--scheduler", + default=None, + type=str, + help="Specify scheduler node. Defaults to first address.", +) +@click.option( + "--scheduler-port", + default=8786, + type=int, + help="Specify scheduler port number. Defaults to port 8786.", +) +@click.option( + "--nthreads", + default=0, + type=int, + help=( + "Number of threads per worker process. " + "Defaults to number of cores divided by the number of " + "processes per host." + ), +) +@click.option( + "--nprocs", + default=1, + type=int, + help="Number of worker processes per host. Defaults to one.", +) +@click.argument("hostnames", nargs=-1, type=str) +@click.option( + "--hostfile", + default=None, + type=click.Path(exists=True), + help="Textfile with hostnames/IP addresses", +) +@click.option( + "--ssh-username", + default=None, + type=str, + help="Username to use when establishing SSH connections.", +) +@click.option( + "--ssh-port", default=22, type=int, help="Port to use for SSH connections." +) +@click.option( + "--ssh-private-key", + default=None, + type=str, + help="Private key file to use for SSH connections.", +) +@click.option("--nohost", is_flag=True, help="Do not pass the hostname to the worker.") +@click.option( + "--log-directory", + default=None, + type=click.Path(exists=True), + help=( + "Directory to use on all cluster nodes for the output of " + "dask-scheduler and dask-worker commands." + ), +) +@click.option( + "--remote-python", default=None, type=str, help="Path to Python on remote nodes." +) +@click.option( + "--memory-limit", + default="auto", + help="Bytes of memory that the worker can use. " + "This can be an integer (bytes), " + "float (fraction of total system memory), " + "string (like 5GB or 5000M), " + "'auto', or zero for no memory management", +) +@click.option( + "--worker-port", + type=int, + default=0, + help="Serving computation port, defaults to random", +) +@click.option( + "--nanny-port", type=int, default=0, help="Serving nanny port, defaults to random" +) +@click.option( + "--remote-dask-worker", + default=None, + type=str, + help="Worker to run. Defaults to distributed.cli.dask_worker", +) @click.pass_context -def main(ctx, scheduler, scheduler_port, hostnames, hostfile, nthreads, nprocs, - ssh_username, ssh_port, ssh_private_key, nohost, log_directory, remote_python, - memory_limit, worker_port, nanny_port, remote_dask_worker): +def main( + ctx, + scheduler, + scheduler_port, + hostnames, + hostfile, + nthreads, + nprocs, + ssh_username, + ssh_port, + ssh_private_key, + nohost, + log_directory, + remote_python, + memory_limit, + worker_port, + nanny_port, + remote_dask_worker, +): try: hostnames = list(hostnames) if hostfile: @@ -67,18 +134,37 @@ def main(ctx, scheduler, scheduler_port, hostnames, hostfile, nthreads, nprocs, print(ctx.get_help()) exit(1) - c = SSHCluster(scheduler, scheduler_port, hostnames, nthreads, nprocs, - ssh_username, ssh_port, ssh_private_key, nohost, log_directory, remote_python, - memory_limit, worker_port, nanny_port, remote_dask_worker) + c = SSHCluster( + scheduler, + scheduler_port, + hostnames, + nthreads, + nprocs, + ssh_username, + ssh_port, + ssh_private_key, + nohost, + log_directory, + remote_python, + memory_limit, + worker_port, + nanny_port, + remote_dask_worker, + ) import distributed - print('\n---------------------------------------------------------------') - print(' Dask.distributed v{version}\n'.format(version=distributed.__version__)) - print('Worker nodes:'.format(n=len(hostnames))) + + print("\n---------------------------------------------------------------") + print( + " Dask.distributed v{version}\n".format( + version=distributed.__version__ + ) + ) + print("Worker nodes:".format(n=len(hostnames))) for i, host in enumerate(hostnames): - print(' {num}: {host}'.format(num=i, host=host)) - print('\nscheduler node: {addr}:{port}'.format(addr=scheduler, port=scheduler_port)) - print('---------------------------------------------------------------\n\n') + print(" {num}: {host}".format(num=i, host=host)) + print("\nscheduler node: {addr}:{port}".format(addr=scheduler, port=scheduler_port)) + print("---------------------------------------------------------------\n\n") # Monitor the output of remote processes. This blocks until the user issues a KeyboardInterrupt. c.monitor_remote_processes() @@ -94,5 +180,5 @@ def go(): main() -if __name__ == '__main__': +if __name__ == "__main__": go() diff --git a/distributed/cli/dask_submit.py b/distributed/cli/dask_submit.py index 1bb507dd9fe..1ef759407c6 100644 --- a/distributed/cli/dask_submit.py +++ b/distributed/cli/dask_submit.py @@ -7,8 +7,8 @@ @click.command() -@click.argument('remote_client_address', type=str, required=True) -@click.argument('filepath', type=str, required=True) +@click.argument("remote_client_address", type=str, required=True) +@click.argument("filepath", type=str, required=True) def main(remote_client_address, filepath): @gen.coroutine def f(): @@ -27,5 +27,5 @@ def go(): main() -if __name__ == '__main__': +if __name__ == "__main__": go() diff --git a/distributed/cli/dask_worker.py b/distributed/cli/dask_worker.py index ebc0fb441b6..0eb5a7973fb 100755 --- a/distributed/cli/dask_worker.py +++ b/distributed/cli/dask_worker.py @@ -12,100 +12,201 @@ from distributed.utils import get_ip_interface, parse_timedelta from distributed.worker import _ncores from distributed.security import Security -from distributed.cli.utils import (check_python_3, uri_from_host_port, - install_signal_handlers) +from distributed.cli.utils import ( + check_python_3, + uri_from_host_port, + install_signal_handlers, +) from distributed.comm import get_address_host_port from distributed.preloading import validate_preload_argv -from distributed.proctitle import (enable_proctitle_on_children, - enable_proctitle_on_current) +from distributed.proctitle import ( + enable_proctitle_on_children, + enable_proctitle_on_current, +) from toolz import valmap from tornado.ioloop import IOLoop, TimeoutError from tornado import gen -logger = logging.getLogger('distributed.dask_worker') +logger = logging.getLogger("distributed.dask_worker") pem_file_option_type = click.Path(exists=True, resolve_path=True) @click.command(context_settings=dict(ignore_unknown_options=True)) -@click.argument('scheduler', type=str, required=False) -@click.option('--tls-ca-file', type=pem_file_option_type, default=None, - help="CA cert(s) file for TLS (in PEM format)") -@click.option('--tls-cert', type=pem_file_option_type, default=None, - help="certificate file for TLS (in PEM format)") -@click.option('--tls-key', type=pem_file_option_type, default=None, - help="private key file for TLS (in PEM format)") -@click.option('--worker-port', type=int, default=0, - help="Serving computation port, defaults to random") -@click.option('--nanny-port', type=int, default=0, - help="Serving nanny port, defaults to random") -@click.option('--bokeh-port', type=int, default=None, - help="Deprecated. See --dashboard-address") -@click.option('--dashboard-address', type=str, default=':0', - help="Address on which to listen for diagnostics dashboard") -@click.option('--bokeh/--no-bokeh', 'bokeh', default=True, show_default=True, - required=False, help="Launch Bokeh Web UI") -@click.option('--listen-address', type=str, default=None, - help="The address to which the worker binds. " - "Example: tcp://0.0.0.0:9000") -@click.option('--contact-address', type=str, default=None, - help="The address the worker advertises to the scheduler for " - "communication with it and other workers. " - "Example: tcp://127.0.0.1:9000") -@click.option('--host', type=str, default=None, - help="Serving host. Should be an ip address that is" - " visible to the scheduler and other workers. " - "See --listen-address and --contact-address if you " - "need different listen and contact addresses. " - "See --interface.") -@click.option('--interface', type=str, default=None, - help="Network interface like 'eth0' or 'ib0'") -@click.option('--nthreads', type=int, default=0, - help="Number of threads per process.") -@click.option('--nprocs', type=int, default=1, - help="Number of worker processes to launch. Defaults to one.") -@click.option('--name', type=str, default='', - help="A unique name for this worker like 'worker-1'. " - "If used with --nprocs then the process number " - "will be appended like name-0, name-1, name-2, ...") -@click.option('--memory-limit', default='auto', - help="Bytes of memory per process that the worker can use. " - "This can be an integer (bytes), " - "float (fraction of total system memory), " - "string (like 5GB or 5000M), " - "'auto', or zero for no memory management") -@click.option('--reconnect/--no-reconnect', default=True, - help="Reconnect to scheduler if disconnected") -@click.option('--nanny/--no-nanny', default=True, - help="Start workers in nanny process for management") -@click.option('--pid-file', type=str, default='', - help="File to write the process PID") -@click.option('--local-directory', default='', type=str, - help="Directory to place worker files") -@click.option('--resources', type=str, default='', - help='Resources for task constraints like "GPU=2 MEM=10e9". ' - 'Resources are applied separately to each worker process ' - "(only relevant when starting multiple worker processes with '--nprocs').") -@click.option('--scheduler-file', type=str, default='', - help='Filename to JSON encoded scheduler information. ' - 'Use with dask-scheduler --scheduler-file') -@click.option('--death-timeout', type=str, default=None, - help="Seconds to wait for a scheduler before closing") -@click.option('--bokeh-prefix', type=str, default=None, - help="Prefix for the bokeh app") -@click.option('--preload', type=str, multiple=True, is_eager=True, - help='Module that should be loaded by each worker process ' - 'like "foo.bar" or "/path/to/foo.py"') -@click.argument('preload_argv', nargs=-1, - type=click.UNPROCESSED, callback=validate_preload_argv) -def main(scheduler, host, worker_port, listen_address, contact_address, - nanny_port, nthreads, nprocs, nanny, name, - memory_limit, pid_file, reconnect, resources, bokeh, - bokeh_port, local_directory, scheduler_file, interface, - death_timeout, preload, preload_argv, bokeh_prefix, tls_ca_file, - tls_cert, tls_key, dashboard_address): +@click.argument("scheduler", type=str, required=False) +@click.option( + "--tls-ca-file", + type=pem_file_option_type, + default=None, + help="CA cert(s) file for TLS (in PEM format)", +) +@click.option( + "--tls-cert", + type=pem_file_option_type, + default=None, + help="certificate file for TLS (in PEM format)", +) +@click.option( + "--tls-key", + type=pem_file_option_type, + default=None, + help="private key file for TLS (in PEM format)", +) +@click.option( + "--worker-port", + type=int, + default=0, + help="Serving computation port, defaults to random", +) +@click.option( + "--nanny-port", type=int, default=0, help="Serving nanny port, defaults to random" +) +@click.option( + "--bokeh-port", type=int, default=None, help="Deprecated. See --dashboard-address" +) +@click.option( + "--dashboard-address", + type=str, + default=":0", + help="Address on which to listen for diagnostics dashboard", +) +@click.option( + "--bokeh/--no-bokeh", + "bokeh", + default=True, + show_default=True, + required=False, + help="Launch Bokeh Web UI", +) +@click.option( + "--listen-address", + type=str, + default=None, + help="The address to which the worker binds. " "Example: tcp://0.0.0.0:9000", +) +@click.option( + "--contact-address", + type=str, + default=None, + help="The address the worker advertises to the scheduler for " + "communication with it and other workers. " + "Example: tcp://127.0.0.1:9000", +) +@click.option( + "--host", + type=str, + default=None, + help="Serving host. Should be an ip address that is" + " visible to the scheduler and other workers. " + "See --listen-address and --contact-address if you " + "need different listen and contact addresses. " + "See --interface.", +) +@click.option( + "--interface", type=str, default=None, help="Network interface like 'eth0' or 'ib0'" +) +@click.option("--nthreads", type=int, default=0, help="Number of threads per process.") +@click.option( + "--nprocs", + type=int, + default=1, + help="Number of worker processes to launch. Defaults to one.", +) +@click.option( + "--name", + type=str, + default="", + help="A unique name for this worker like 'worker-1'. " + "If used with --nprocs then the process number " + "will be appended like name-0, name-1, name-2, ...", +) +@click.option( + "--memory-limit", + default="auto", + help="Bytes of memory per process that the worker can use. " + "This can be an integer (bytes), " + "float (fraction of total system memory), " + "string (like 5GB or 5000M), " + "'auto', or zero for no memory management", +) +@click.option( + "--reconnect/--no-reconnect", + default=True, + help="Reconnect to scheduler if disconnected", +) +@click.option( + "--nanny/--no-nanny", + default=True, + help="Start workers in nanny process for management", +) +@click.option("--pid-file", type=str, default="", help="File to write the process PID") +@click.option( + "--local-directory", default="", type=str, help="Directory to place worker files" +) +@click.option( + "--resources", + type=str, + default="", + help='Resources for task constraints like "GPU=2 MEM=10e9". ' + "Resources are applied separately to each worker process " + "(only relevant when starting multiple worker processes with '--nprocs').", +) +@click.option( + "--scheduler-file", + type=str, + default="", + help="Filename to JSON encoded scheduler information. " + "Use with dask-scheduler --scheduler-file", +) +@click.option( + "--death-timeout", + type=str, + default=None, + help="Seconds to wait for a scheduler before closing", +) +@click.option("--bokeh-prefix", type=str, default=None, help="Prefix for the bokeh app") +@click.option( + "--preload", + type=str, + multiple=True, + is_eager=True, + help="Module that should be loaded by each worker process " + 'like "foo.bar" or "/path/to/foo.py"', +) +@click.argument( + "preload_argv", nargs=-1, type=click.UNPROCESSED, callback=validate_preload_argv +) +def main( + scheduler, + host, + worker_port, + listen_address, + contact_address, + nanny_port, + nthreads, + nprocs, + nanny, + name, + memory_limit, + pid_file, + reconnect, + resources, + bokeh, + bokeh_port, + local_directory, + scheduler_file, + interface, + death_timeout, + preload, + preload_argv, + bokeh_prefix, + tls_ca_file, + tls_cert, + tls_key, + dashboard_address, +): enable_proctitle_on_current() enable_proctitle_on_children() @@ -116,32 +217,41 @@ def main(scheduler, host, worker_port, listen_address, contact_address, ) dashboard_address = bokeh_port - sec = Security(tls_ca_file=tls_ca_file, - tls_worker_cert=tls_cert, - tls_worker_key=tls_key, - ) + sec = Security( + tls_ca_file=tls_ca_file, tls_worker_cert=tls_cert, tls_worker_key=tls_key + ) if nprocs > 1 and worker_port != 0: - logger.error("Failed to launch worker. You cannot use the --port argument when nprocs > 1.") + logger.error( + "Failed to launch worker. You cannot use the --port argument when nprocs > 1." + ) exit(1) if nprocs > 1 and not nanny: - logger.error("Failed to launch worker. You cannot use the --no-nanny argument when nprocs > 1.") + logger.error( + "Failed to launch worker. You cannot use the --no-nanny argument when nprocs > 1." + ) exit(1) if contact_address and not listen_address: - logger.error("Failed to launch worker. " - "Must specify --listen-address when --contact-address is given") + logger.error( + "Failed to launch worker. " + "Must specify --listen-address when --contact-address is given" + ) exit(1) if nprocs > 1 and listen_address: - logger.error("Failed to launch worker. " - "You cannot specify --listen-address when nprocs > 1.") + logger.error( + "Failed to launch worker. " + "You cannot specify --listen-address when nprocs > 1." + ) exit(1) if (worker_port or host) and listen_address: - logger.error("Failed to launch worker. " - "You cannot specify --listen-address when --worker-port or --host is given.") + logger.error( + "Failed to launch worker. " + "You cannot specify --listen-address when --worker-port or --host is given." + ) exit(1) try: @@ -167,12 +277,13 @@ def main(scheduler, host, worker_port, listen_address, contact_address, nthreads = _ncores // nprocs if pid_file: - with open(pid_file, 'w') as f: + with open(pid_file, "w") as f: f.write(str(os.getpid())) def del_pid_file(): if os.path.exists(pid_file): os.remove(pid_file) + atexit.register(del_pid_file) services = {} @@ -184,14 +295,14 @@ def del_pid_file(): pass else: if bokeh_prefix: - result = (BokehWorker, {'prefix': bokeh_prefix}) + result = (BokehWorker, {"prefix": bokeh_prefix}) else: result = BokehWorker - services[('bokeh', dashboard_address)] = result + services[("bokeh", dashboard_address)] = result if resources: - resources = resources.replace(',', ' ').split() - resources = dict(pair.split('=') for pair in resources) + resources = resources.replace(",", " ").split() + resources = dict(pair.split("=") for pair in resources) resources = valmap(float, resources) else: resources = None @@ -199,17 +310,19 @@ def del_pid_file(): loop = IOLoop.current() if nanny: - kwargs = {'worker_port': worker_port, 'listen_address': listen_address} + kwargs = {"worker_port": worker_port, "listen_address": listen_address} t = Nanny else: kwargs = {} if nanny_port: - kwargs['service_ports'] = {'nanny': nanny_port} + kwargs["service_ports"] = {"nanny": nanny_port} t = Worker - if not scheduler and not scheduler_file and 'scheduler-address' not in config: - raise ValueError("Need to provide scheduler address like\n" - "dask-worker SCHEDULER_ADDRESS:8786") + if not scheduler and not scheduler_file and "scheduler-address" not in config: + raise ValueError( + "Need to provide scheduler address like\n" + "dask-worker SCHEDULER_ADDRESS:8786" + ) if interface: if host: @@ -224,17 +337,29 @@ def del_pid_file(): addr = None if death_timeout is not None: - death_timeout = parse_timedelta(death_timeout, 's') - - nannies = [t(scheduler, scheduler_file=scheduler_file, ncores=nthreads, - services=services, loop=loop, resources=resources, - memory_limit=memory_limit, reconnect=reconnect, - local_dir=local_directory, death_timeout=death_timeout, - preload=preload, preload_argv=preload_argv, - security=sec, contact_address=contact_address, - name=name if nprocs == 1 or not name else name + '-' + str(i), - **kwargs) - for i in range(nprocs)] + death_timeout = parse_timedelta(death_timeout, "s") + + nannies = [ + t( + scheduler, + scheduler_file=scheduler_file, + ncores=nthreads, + services=services, + loop=loop, + resources=resources, + memory_limit=memory_limit, + reconnect=reconnect, + local_dir=local_directory, + death_timeout=death_timeout, + preload=preload, + preload_argv=preload_argv, + security=sec, + contact_address=contact_address, + name=name if nprocs == 1 or not name else name + "-" + str(i), + **kwargs + ) + for i in range(nprocs) + ] @gen.coroutine def close_all(): @@ -249,7 +374,7 @@ def on_signal(signum): @gen.coroutine def run(): yield [n._start(addr) for n in nannies] - while all(n.status != 'closed' for n in nannies): + while all(n.status != "closed" for n in nannies): yield gen.sleep(0.2) install_signal_handlers(loop, cleanup=on_signal) @@ -267,5 +392,5 @@ def go(): main() -if __name__ == '__main__': +if __name__ == "__main__": go() diff --git a/distributed/cli/tests/test_cli_utils.py b/distributed/cli/tests/test_cli_utils.py index fcc3507b028..4f07f699de5 100644 --- a/distributed/cli/tests/test_cli_utils.py +++ b/distributed/cli/tests/test_cli_utils.py @@ -1,7 +1,8 @@ from __future__ import print_function, division, absolute_import import pytest -pytest.importorskip('requests') + +pytest.importorskip("requests") from distributed.cli.utils import uri_from_host_port from distributed.utils import get_ip @@ -13,41 +14,41 @@ def test_uri_from_host_port(): f = uri_from_host_port - assert f('', 456, None) == 'tcp://:456' - assert f('', 456, 123) == 'tcp://:456' - assert f('', None, 123) == 'tcp://:123' - assert f('', None, 0) == 'tcp://' - assert f('', 0, 123) == 'tcp://' + assert f("", 456, None) == "tcp://:456" + assert f("", 456, 123) == "tcp://:456" + assert f("", None, 123) == "tcp://:123" + assert f("", None, 0) == "tcp://" + assert f("", 0, 123) == "tcp://" - assert f('localhost', 456, None) == 'tcp://localhost:456' - assert f('localhost', 456, 123) == 'tcp://localhost:456' - assert f('localhost', None, 123) == 'tcp://localhost:123' - assert f('localhost', None, 0) == 'tcp://localhost' + assert f("localhost", 456, None) == "tcp://localhost:456" + assert f("localhost", 456, 123) == "tcp://localhost:456" + assert f("localhost", None, 123) == "tcp://localhost:123" + assert f("localhost", None, 0) == "tcp://localhost" - assert f('192.168.1.2', 456, None) == 'tcp://192.168.1.2:456' - assert f('192.168.1.2', 456, 123) == 'tcp://192.168.1.2:456' - assert f('192.168.1.2', None, 123) == 'tcp://192.168.1.2:123' - assert f('192.168.1.2', None, 0) == 'tcp://192.168.1.2' + assert f("192.168.1.2", 456, None) == "tcp://192.168.1.2:456" + assert f("192.168.1.2", 456, 123) == "tcp://192.168.1.2:456" + assert f("192.168.1.2", None, 123) == "tcp://192.168.1.2:123" + assert f("192.168.1.2", None, 0) == "tcp://192.168.1.2" - assert f('tcp://192.168.1.2', 456, None) == 'tcp://192.168.1.2:456' - assert f('tcp://192.168.1.2', 456, 123) == 'tcp://192.168.1.2:456' - assert f('tcp://192.168.1.2', None, 123) == 'tcp://192.168.1.2:123' - assert f('tcp://192.168.1.2', None, 0) == 'tcp://192.168.1.2' + assert f("tcp://192.168.1.2", 456, None) == "tcp://192.168.1.2:456" + assert f("tcp://192.168.1.2", 456, 123) == "tcp://192.168.1.2:456" + assert f("tcp://192.168.1.2", None, 123) == "tcp://192.168.1.2:123" + assert f("tcp://192.168.1.2", None, 0) == "tcp://192.168.1.2" - assert f('tcp://192.168.1.2:456', None, None) == 'tcp://192.168.1.2:456' - assert f('tcp://192.168.1.2:456', 0, 0) == 'tcp://192.168.1.2:456' - assert f('tcp://192.168.1.2:456', 0, 123) == 'tcp://192.168.1.2:456' - assert f('tcp://192.168.1.2:456', 456, 123) == 'tcp://192.168.1.2:456' + assert f("tcp://192.168.1.2:456", None, None) == "tcp://192.168.1.2:456" + assert f("tcp://192.168.1.2:456", 0, 0) == "tcp://192.168.1.2:456" + assert f("tcp://192.168.1.2:456", 0, 123) == "tcp://192.168.1.2:456" + assert f("tcp://192.168.1.2:456", 456, 123) == "tcp://192.168.1.2:456" with pytest.raises(ValueError): # Two incompatible port values - f('tcp://192.168.1.2:456', 123, None) + f("tcp://192.168.1.2:456", 123, None) - assert f('tls://192.168.1.2:456', None, None) == 'tls://192.168.1.2:456' - assert f('tls://192.168.1.2:456', 0, 0) == 'tls://192.168.1.2:456' - assert f('tls://192.168.1.2:456', 0, 123) == 'tls://192.168.1.2:456' - assert f('tls://192.168.1.2:456', 456, 123) == 'tls://192.168.1.2:456' + assert f("tls://192.168.1.2:456", None, None) == "tls://192.168.1.2:456" + assert f("tls://192.168.1.2:456", 0, 0) == "tls://192.168.1.2:456" + assert f("tls://192.168.1.2:456", 0, 123) == "tls://192.168.1.2:456" + assert f("tls://192.168.1.2:456", 456, 123) == "tls://192.168.1.2:456" - assert f('tcp://[::1]:456', None, None) == 'tcp://[::1]:456' + assert f("tcp://[::1]:456", None, None) == "tcp://[::1]:456" - assert f('tls://[::1]:456', None, None) == 'tls://[::1]:456' + assert f("tls://[::1]:456", None, None) == "tls://[::1]:456" diff --git a/distributed/cli/tests/test_dask_mpi.py b/distributed/cli/tests/test_dask_mpi.py index 36298ca5a83..8bc8dddca2e 100644 --- a/distributed/cli/tests/test_dask_mpi.py +++ b/distributed/cli/tests/test_dask_mpi.py @@ -4,7 +4,8 @@ from time import sleep import pytest -pytest.importorskip('mpi4py') + +pytest.importorskip("mpi4py") import requests @@ -15,15 +16,17 @@ from distributed.utils_test import loop # noqa: F401 -@pytest.mark.parametrize('nanny', ['--nanny', '--no-nanny']) +@pytest.mark.parametrize("nanny", ["--nanny", "--no-nanny"]) def test_basic(loop, nanny): with tmpfile() as fn: - with popen(['mpirun', '--np', '4', 'dask-mpi', '--scheduler-file', fn, nanny], - stdin=subprocess.DEVNULL): + with popen( + ["mpirun", "--np", "4", "dask-mpi", "--scheduler-file", fn, nanny], + stdin=subprocess.DEVNULL, + ): with Client(scheduler_file=fn) as c: start = time() - while len(c.scheduler_info()['workers']) != 3: + while len(c.scheduler_info()["workers"]) != 3: assert time() < start + 10 sleep(0.2) @@ -32,36 +35,59 @@ def test_basic(loop, nanny): def test_no_scheduler(loop): with tmpfile() as fn: - with popen(['mpirun', '--np', '2', 'dask-mpi', '--scheduler-file', fn], - stdin=subprocess.DEVNULL): + with popen( + ["mpirun", "--np", "2", "dask-mpi", "--scheduler-file", fn], + stdin=subprocess.DEVNULL, + ): with Client(scheduler_file=fn) as c: start = time() - while len(c.scheduler_info()['workers']) != 1: + while len(c.scheduler_info()["workers"]) != 1: assert time() < start + 10 sleep(0.2) assert c.submit(lambda x: x + 1, 10).result() == 11 - with popen(['mpirun', '--np', '1', 'dask-mpi', - '--scheduler-file', fn, '--no-scheduler']): + with popen( + [ + "mpirun", + "--np", + "1", + "dask-mpi", + "--scheduler-file", + fn, + "--no-scheduler", + ] + ): start = time() - while len(c.scheduler_info()['workers']) != 2: + while len(c.scheduler_info()["workers"]) != 2: assert time() < start + 10 sleep(0.2) def test_bokeh(loop): with tmpfile() as fn: - with popen(['mpirun', '--np', '2', 'dask-mpi', '--scheduler-file', fn, - '--bokeh-port', '59583', '--bokeh-worker-port', '59584'], - stdin=subprocess.DEVNULL): + with popen( + [ + "mpirun", + "--np", + "2", + "dask-mpi", + "--scheduler-file", + fn, + "--bokeh-port", + "59583", + "--bokeh-worker-port", + "59584", + ], + stdin=subprocess.DEVNULL, + ): for port in [59853, 59584]: start = time() while True: try: - response = requests.get('http://localhost:%d/status/' % port) + response = requests.get("http://localhost:%d/status/" % port) assert response.ok break except Exception: @@ -69,4 +95,4 @@ def test_bokeh(loop): assert time() < start + 20 with pytest.raises(Exception): - requests.get('http://localhost:59583/status/') + requests.get("http://localhost:59583/status/") diff --git a/distributed/cli/tests/test_dask_remote.py b/distributed/cli/tests/test_dask_remote.py index 66ac4006963..04d04d62ecf 100644 --- a/distributed/cli/tests/test_dask_remote.py +++ b/distributed/cli/tests/test_dask_remote.py @@ -4,6 +4,6 @@ def test_dask_remote(): runner = CliRunner() - result = runner.invoke(main, ['--help']) - assert '--host TEXT IP or hostname of this server' in result.output + result = runner.invoke(main, ["--help"]) + assert "--host TEXT IP or hostname of this server" in result.output assert result.exit_code == 0 diff --git a/distributed/cli/tests/test_dask_scheduler.py b/distributed/cli/tests/test_dask_scheduler.py index 7564c3f05bd..26fe607b901 100644 --- a/distributed/cli/tests/test_dask_scheduler.py +++ b/distributed/cli/tests/test_dask_scheduler.py @@ -1,7 +1,8 @@ from __future__ import print_function, division, absolute_import import pytest -pytest.importorskip('requests') + +pytest.importorskip("requests") import os import requests @@ -15,67 +16,67 @@ from distributed import Scheduler, Client from distributed.utils import get_ip, get_ip_interface, tmpfile -from distributed.utils_test import (popen, - assert_can_connect_from_everywhere_4_6, - assert_can_connect_locally_4, - ) +from distributed.utils_test import ( + popen, + assert_can_connect_from_everywhere_4_6, + assert_can_connect_locally_4, +) from distributed.utils_test import loop # noqa: F401 from distributed.metrics import time def test_defaults(loop): - with popen(['dask-scheduler', '--no-bokeh']) as proc: + with popen(["dask-scheduler", "--no-bokeh"]) as proc: @gen.coroutine def f(): # Default behaviour is to listen on all addresses - yield [ - assert_can_connect_from_everywhere_4_6(8786, 5.0), # main port - ] + yield [assert_can_connect_from_everywhere_4_6(8786, 5.0)] # main port - with Client('127.0.0.1:%d' % Scheduler.default_port, loop=loop) as c: + with Client("127.0.0.1:%d" % Scheduler.default_port, loop=loop) as c: c.sync(f) with pytest.raises(Exception): - requests.get('http://127.0.0.1:8787/status/') + requests.get("http://127.0.0.1:8787/status/") with pytest.raises(Exception): - response = requests.get('http://127.0.0.1:9786/info.json') + response = requests.get("http://127.0.0.1:9786/info.json") def test_hostport(loop): - with popen(['dask-scheduler', '--no-bokeh', '--host', '127.0.0.1:8978']): + with popen(["dask-scheduler", "--no-bokeh", "--host", "127.0.0.1:8978"]): + @gen.coroutine def f(): yield [ # The scheduler's main port can't be contacted from the outside - assert_can_connect_locally_4(8978, 5.0), + assert_can_connect_locally_4(8978, 5.0) ] - with Client('127.0.0.1:8978', loop=loop) as c: + with Client("127.0.0.1:8978", loop=loop) as c: assert len(c.ncores()) == 0 c.sync(f) def test_no_bokeh(loop): - pytest.importorskip('bokeh') - with popen(['dask-scheduler', '--no-bokeh']) as proc: - with Client('127.0.0.1:%d' % Scheduler.default_port, loop=loop) as c: + pytest.importorskip("bokeh") + with popen(["dask-scheduler", "--no-bokeh"]) as proc: + with Client("127.0.0.1:%d" % Scheduler.default_port, loop=loop) as c: for i in range(3): line = proc.stderr.readline() - assert b'bokeh' not in line.lower() + assert b"bokeh" not in line.lower() with pytest.raises(Exception): - requests.get('http://127.0.0.1:8787/status/') + requests.get("http://127.0.0.1:8787/status/") def test_bokeh(loop): - pytest.importorskip('bokeh') + pytest.importorskip("bokeh") - with popen(['dask-scheduler']) as proc: - with Client('127.0.0.1:%d' % Scheduler.default_port, loop=loop) as c: + with popen(["dask-scheduler"]) as proc: + with Client("127.0.0.1:%d" % Scheduler.default_port, loop=loop) as c: pass - names = ['localhost', '127.0.0.1', get_ip()] - if 'linux' in sys.platform: + names = ["localhost", "127.0.0.1", get_ip()] + if "linux" in sys.platform: names.append(socket.gethostname()) start = time() @@ -83,58 +84,66 @@ def test_bokeh(loop): try: # All addresses should respond for name in names: - uri = 'http://%s:8787/status/' % name + uri = "http://%s:8787/status/" % name response = requests.get(uri) assert response.ok break except Exception as f: - print('got error on %r: %s' % (uri, f)) + print("got error on %r: %s" % (uri, f)) sleep(0.1) assert time() < start + 10 with pytest.raises(Exception): - requests.get('http://127.0.0.1:8787/status/') + requests.get("http://127.0.0.1:8787/status/") def test_bokeh_non_standard_ports(loop): - pytest.importorskip('bokeh') + pytest.importorskip("bokeh") - with popen(['dask-scheduler', - '--port', '3448', - '--dashboard-address', ':4832']) as proc: - with Client('127.0.0.1:3448', loop=loop) as c: + with popen( + ["dask-scheduler", "--port", "3448", "--dashboard-address", ":4832"] + ) as proc: + with Client("127.0.0.1:3448", loop=loop) as c: pass start = time() while True: try: - response = requests.get('http://localhost:4832/status/') + response = requests.get("http://localhost:4832/status/") assert response.ok break except Exception: sleep(0.1) assert time() < start + 20 with pytest.raises(Exception): - requests.get('http://localhost:4832/status/') + requests.get("http://localhost:4832/status/") -@pytest.mark.skipif(not sys.platform.startswith('linux'), - reason="Need 127.0.0.2 to mean localhost") +@pytest.mark.skipif( + not sys.platform.startswith("linux"), reason="Need 127.0.0.2 to mean localhost" +) def test_bokeh_whitelist(loop): - pytest.importorskip('bokeh') + pytest.importorskip("bokeh") with pytest.raises(Exception): - requests.get('http://localhost:8787/status/').ok - - with popen(['dask-scheduler', '--bokeh-whitelist', '127.0.0.2:8787', - '--bokeh-whitelist', '127.0.0.3:8787']) as proc: - with Client('127.0.0.1:%d' % Scheduler.default_port, loop=loop) as c: + requests.get("http://localhost:8787/status/").ok + + with popen( + [ + "dask-scheduler", + "--bokeh-whitelist", + "127.0.0.2:8787", + "--bokeh-whitelist", + "127.0.0.3:8787", + ] + ) as proc: + with Client("127.0.0.1:%d" % Scheduler.default_port, loop=loop) as c: pass start = time() while True: try: - for name in ['127.0.0.2', '127.0.0.3']: - response = requests.get('http://%s:8787/status/' % name) + for name in ["127.0.0.2", "127.0.0.3"]: + response = requests.get("http://%s:8787/status/" % name) assert response.ok break except Exception as f: @@ -144,10 +153,10 @@ def test_bokeh_whitelist(loop): def test_multiple_workers(loop): - with popen(['dask-scheduler', '--no-bokeh']) as s: - with popen(['dask-worker', 'localhost:8786', '--no-bokeh']) as a: - with popen(['dask-worker', 'localhost:8786', '--no-bokeh']) as b: - with Client('127.0.0.1:%d' % Scheduler.default_port, loop=loop) as c: + with popen(["dask-scheduler", "--no-bokeh"]) as s: + with popen(["dask-worker", "localhost:8786", "--no-bokeh"]) as a: + with popen(["dask-worker", "localhost:8786", "--no-bokeh"]) as b: + with Client("127.0.0.1:%d" % Scheduler.default_port, loop=loop) as c: start = time() while len(c.ncores()) < 2: sleep(0.1) @@ -155,7 +164,7 @@ def test_multiple_workers(loop): def test_interface(loop): - psutil = pytest.importorskip('psutil') + psutil = pytest.importorskip("psutil") if_names = sorted(psutil.net_if_addrs()) for if_name in if_names: try: @@ -163,23 +172,26 @@ def test_interface(loop): except ValueError: pass else: - if ipv4_addr == '127.0.0.1': + if ipv4_addr == "127.0.0.1": break else: - pytest.skip("Could not find loopback interface. " - "Available interfaces are: %s." % (if_names,)) - - with popen(['dask-scheduler', '--no-bokeh', '--interface', if_name]) as s: - with popen(['dask-worker', '127.0.0.1:8786', '--no-bokeh', '--interface', if_name]) as a: - with Client('tcp://127.0.0.1:%d' % Scheduler.default_port, loop=loop) as c: + pytest.skip( + "Could not find loopback interface. " + "Available interfaces are: %s." % (if_names,) + ) + + with popen(["dask-scheduler", "--no-bokeh", "--interface", if_name]) as s: + with popen( + ["dask-worker", "127.0.0.1:8786", "--no-bokeh", "--interface", if_name] + ) as a: + with Client("tcp://127.0.0.1:%d" % Scheduler.default_port, loop=loop) as c: start = time() while not len(c.ncores()): sleep(0.1) assert time() - start < 5 info = c.scheduler_info() - assert 'tcp://127.0.0.1' in info['address'] - assert all('127.0.0.1' == d['host'] - for d in info['workers'].values()) + assert "tcp://127.0.0.1" in info["address"] + assert all("127.0.0.1" == d["host"] for d in info["workers"].values()) def test_pid_file(loop): @@ -197,7 +209,7 @@ def check_pidfile(proc, pidfile): with open(pidfile) as f: text = f.read() pid = int(text) - if sys.platform.startswith('win'): + if sys.platform.startswith("win"): # On Windows, `dask-XXX` invokes the dask-XXX.exe # shim, but the PID is written out by the child Python process assert pid @@ -205,35 +217,37 @@ def check_pidfile(proc, pidfile): assert proc.pid == pid with tmpfile() as s: - with popen(['dask-scheduler', '--pid-file', s, '--no-bokeh']) as sched: + with popen(["dask-scheduler", "--pid-file", s, "--no-bokeh"]) as sched: check_pidfile(sched, s) with tmpfile() as w: - with popen(['dask-worker', '127.0.0.1:8786', '--pid-file', w, - '--no-bokeh']) as worker: + with popen( + ["dask-worker", "127.0.0.1:8786", "--pid-file", w, "--no-bokeh"] + ) as worker: check_pidfile(worker, w) def test_scheduler_port_zero(loop): with tmpfile() as fn: - with popen(['dask-scheduler', '--no-bokeh', '--scheduler-file', fn, - '--port', '0']) as sched: + with popen( + ["dask-scheduler", "--no-bokeh", "--scheduler-file", fn, "--port", "0"] + ) as sched: with Client(scheduler_file=fn, loop=loop) as c: assert c.scheduler.port assert c.scheduler.port != 8786 def test_bokeh_port_zero(loop): - pytest.importorskip('bokeh') + pytest.importorskip("bokeh") with tmpfile() as fn: - with popen(['dask-scheduler', '--dashboard-address', ':0']) as proc: + with popen(["dask-scheduler", "--dashboard-address", ":0"]) as proc: count = 0 while count < 1: line = proc.stderr.readline() - if b'bokeh' in line.lower() or b'web' in line.lower(): + if b"bokeh" in line.lower() or b"web" in line.lower(): sleep(0.01) count += 1 - assert b':0' not in line + assert b":0" not in line PRELOAD_TEXT = """ @@ -248,49 +262,53 @@ def get_scheduler_address(): def test_preload_file(loop): - def check_scheduler(): import scheduler_info + return scheduler_info.get_scheduler_address() tmpdir = tempfile.mkdtemp() try: - path = os.path.join(tmpdir, 'scheduler_info.py') - with open(path, 'w') as f: + path = os.path.join(tmpdir, "scheduler_info.py") + with open(path, "w") as f: f.write(PRELOAD_TEXT) with tmpfile() as fn: - with popen(['dask-scheduler', '--scheduler-file', fn, - '--preload', path]): + with popen(["dask-scheduler", "--scheduler-file", fn, "--preload", path]): with Client(scheduler_file=fn, loop=loop) as c: - assert c.run_on_scheduler(check_scheduler) == \ - c.scheduler.address + assert c.run_on_scheduler(check_scheduler) == c.scheduler.address finally: shutil.rmtree(tmpdir) def test_preload_module(loop): - def check_scheduler(): import scheduler_info + return scheduler_info.get_scheduler_address() tmpdir = tempfile.mkdtemp() try: - path = os.path.join(tmpdir, 'scheduler_info.py') - with open(path, 'w') as f: + path = os.path.join(tmpdir, "scheduler_info.py") + with open(path, "w") as f: f.write(PRELOAD_TEXT) env = os.environ.copy() - if 'PYTHONPATH' in env: - env['PYTHONPATH'] = tmpdir + ':' + env['PYTHONPATH'] + if "PYTHONPATH" in env: + env["PYTHONPATH"] = tmpdir + ":" + env["PYTHONPATH"] else: - env['PYTHONPATH'] = tmpdir + env["PYTHONPATH"] = tmpdir with tmpfile() as fn: - with popen(['dask-scheduler', '--scheduler-file', fn, - '--preload', 'scheduler_info'], - env=env): + with popen( + [ + "dask-scheduler", + "--scheduler-file", + fn, + "--preload", + "scheduler_info", + ], + env=env, + ): with Client(scheduler_file=fn, loop=loop) as c: - assert c.run_on_scheduler(check_scheduler) == \ - c.scheduler.address + assert c.run_on_scheduler(check_scheduler) == c.scheduler.address finally: shutil.rmtree(tmpdir) @@ -310,47 +328,57 @@ def get_passthrough(): def test_preload_command(loop): - def check_passthrough(): import passthrough_info + return passthrough_info.get_passthrough() tmpdir = tempfile.mkdtemp() try: - path = os.path.join(tmpdir, 'passthrough_info.py') - with open(path, 'w') as f: + path = os.path.join(tmpdir, "passthrough_info.py") + with open(path, "w") as f: f.write(PRELOAD_COMMAND_TEXT) with tmpfile() as fn: print(fn) - with popen(['dask-scheduler', '--scheduler-file', fn, - '--preload', path, "--passthrough", "foobar"]): + with popen( + [ + "dask-scheduler", + "--scheduler-file", + fn, + "--preload", + path, + "--passthrough", + "foobar", + ] + ): with Client(scheduler_file=fn, loop=loop) as c: - assert c.run_on_scheduler(check_passthrough) == \ - "foobar" + assert c.run_on_scheduler(check_passthrough) == "foobar" finally: shutil.rmtree(tmpdir) def test_preload_command_default(loop): - def check_passthrough(): import passthrough_info + return passthrough_info.get_passthrough() tmpdir = tempfile.mkdtemp() try: - path = os.path.join(tmpdir, 'passthrough_info.py') - with open(path, 'w') as f: + path = os.path.join(tmpdir, "passthrough_info.py") + with open(path, "w") as f: f.write(PRELOAD_COMMAND_TEXT) with tmpfile() as fn2: print(fn2) - with popen(['dask-scheduler', '--scheduler-file', fn2, - '--preload', path], stdout=sys.stdout, stderr=sys.stderr): + with popen( + ["dask-scheduler", "--scheduler-file", fn2, "--preload", path], + stdout=sys.stdout, + stderr=sys.stderr, + ): with Client(scheduler_file=fn2, loop=loop) as c: - assert c.run_on_scheduler(check_passthrough) == \ - "default" + assert c.run_on_scheduler(check_passthrough) == "default" finally: shutil.rmtree(tmpdir) diff --git a/distributed/cli/tests/test_dask_submit.py b/distributed/cli/tests/test_dask_submit.py index 9f3aef0ed34..83c7c1067fa 100644 --- a/distributed/cli/tests/test_dask_submit.py +++ b/distributed/cli/tests/test_dask_submit.py @@ -4,6 +4,6 @@ def test_submit_runs_as_a_cli(): runner = CliRunner() - result = runner.invoke(main, ['--help']) + result = runner.invoke(main, ["--help"]) assert result.exit_code == 0 - assert 'Usage: main [OPTIONS] REMOTE_CLIENT_ADDRESS FILEPATH' in result.output + assert "Usage: main [OPTIONS] REMOTE_CLIENT_ADDRESS FILEPATH" in result.output diff --git a/distributed/cli/tests/test_dask_worker.py b/distributed/cli/tests/test_dask_worker.py index 5797905c7cd..72084e53141 100644 --- a/distributed/cli/tests/test_dask_worker.py +++ b/distributed/cli/tests/test_dask_worker.py @@ -1,7 +1,8 @@ from __future__ import print_function, division, absolute_import import pytest -pytest.importorskip('requests') + +pytest.importorskip("requests") import requests import sys @@ -11,56 +12,73 @@ from distributed import Client from distributed.metrics import time from distributed.utils import sync, tmpfile -from distributed.utils_test import (popen, slow, terminate_process, - wait_for_port) +from distributed.utils_test import popen, slow, terminate_process, wait_for_port from distributed.utils_test import loop # noqa: F401 def test_nanny_worker_ports(loop): - with popen(['dask-scheduler', '--port', '9359', '--no-bokeh']) as sched: - with popen(['dask-worker', '127.0.0.1:9359', '--host', '127.0.0.1', - '--worker-port', '9684', '--nanny-port', '5273', - '--no-bokeh']) as worker: - with Client('127.0.0.1:9359', loop=loop) as c: + with popen(["dask-scheduler", "--port", "9359", "--no-bokeh"]) as sched: + with popen( + [ + "dask-worker", + "127.0.0.1:9359", + "--host", + "127.0.0.1", + "--worker-port", + "9684", + "--nanny-port", + "5273", + "--no-bokeh", + ] + ) as worker: + with Client("127.0.0.1:9359", loop=loop) as c: start = time() while True: d = sync(c.loop, c.scheduler.identity) - if d['workers']: + if d["workers"]: break else: assert time() - start < 5 sleep(0.1) - assert d['workers']['tcp://127.0.0.1:9684']['services']['nanny'] == 5273 + assert d["workers"]["tcp://127.0.0.1:9684"]["services"]["nanny"] == 5273 def test_memory_limit(loop): - with popen(['dask-scheduler', '--no-bokeh']) as sched: - with popen(['dask-worker', '127.0.0.1:8786', '--memory-limit', '2e3MB', - '--no-bokeh']) as worker: - with Client('127.0.0.1:8786', loop=loop) as c: + with popen(["dask-scheduler", "--no-bokeh"]) as sched: + with popen( + ["dask-worker", "127.0.0.1:8786", "--memory-limit", "2e3MB", "--no-bokeh"] + ) as worker: + with Client("127.0.0.1:8786", loop=loop) as c: while not c.ncores(): sleep(0.1) info = c.scheduler_info() - d = first(info['workers'].values()) - assert isinstance(d['memory_limit'], int) - assert d['memory_limit'] == 2e9 + d = first(info["workers"].values()) + assert isinstance(d["memory_limit"], int) + assert d["memory_limit"] == 2e9 def test_no_nanny(loop): - with popen(['dask-scheduler', '--no-bokeh']) as sched: - with popen(['dask-worker', '127.0.0.1:8786', '--no-nanny', - '--no-bokeh']) as worker: - assert any(b'Registered' in worker.stderr.readline() - for i in range(15)) + with popen(["dask-scheduler", "--no-bokeh"]) as sched: + with popen( + ["dask-worker", "127.0.0.1:8786", "--no-nanny", "--no-bokeh"] + ) as worker: + assert any(b"Registered" in worker.stderr.readline() for i in range(15)) @slow -@pytest.mark.parametrize('nanny', ['--nanny', '--no-nanny']) +@pytest.mark.parametrize("nanny", ["--nanny", "--no-nanny"]) def test_no_reconnect(nanny, loop): - with popen(['dask-scheduler', '--no-bokeh']) as sched: - wait_for_port(('127.0.0.1', 8786)) - with popen(['dask-worker', 'tcp://127.0.0.1:8786', '--no-reconnect', nanny, - '--no-bokeh']) as worker: + with popen(["dask-scheduler", "--no-bokeh"]) as sched: + wait_for_port(("127.0.0.1", 8786)) + with popen( + [ + "dask-worker", + "tcp://127.0.0.1:8786", + "--no-reconnect", + nanny, + "--no-bokeh", + ] + ) as worker: sleep(2) terminate_process(sched) start = time() @@ -70,88 +88,116 @@ def test_no_reconnect(nanny, loop): def test_resources(loop): - with popen(['dask-scheduler', '--no-bokeh']) as sched: - with popen(['dask-worker', 'tcp://127.0.0.1:8786', '--no-bokeh', - '--resources', 'A=1 B=2,C=3']) as worker: - with Client('127.0.0.1:8786', loop=loop) as c: - while not c.scheduler_info()['workers']: + with popen(["dask-scheduler", "--no-bokeh"]) as sched: + with popen( + [ + "dask-worker", + "tcp://127.0.0.1:8786", + "--no-bokeh", + "--resources", + "A=1 B=2,C=3", + ] + ) as worker: + with Client("127.0.0.1:8786", loop=loop) as c: + while not c.scheduler_info()["workers"]: sleep(0.1) info = c.scheduler_info() - worker = list(info['workers'].values())[0] - assert worker['resources'] == {'A': 1, 'B': 2, 'C': 3} + worker = list(info["workers"].values())[0] + assert worker["resources"] == {"A": 1, "B": 2, "C": 3} -@pytest.mark.parametrize('nanny', ['--nanny', '--no-nanny']) +@pytest.mark.parametrize("nanny", ["--nanny", "--no-nanny"]) def test_local_directory(loop, nanny): with tmpfile() as fn: - with popen(['dask-scheduler', '--no-bokeh']) as sched: - with popen(['dask-worker', '127.0.0.1:8786', nanny, - '--no-bokeh', '--local-directory', fn]) as worker: - with Client('127.0.0.1:8786', loop=loop, timeout=10) as c: + with popen(["dask-scheduler", "--no-bokeh"]) as sched: + with popen( + [ + "dask-worker", + "127.0.0.1:8786", + nanny, + "--no-bokeh", + "--local-directory", + fn, + ] + ) as worker: + with Client("127.0.0.1:8786", loop=loop, timeout=10) as c: start = time() - while not c.scheduler_info()['workers']: + while not c.scheduler_info()["workers"]: sleep(0.1) assert time() < start + 8 info = c.scheduler_info() - worker = list(info['workers'].values())[0] - assert worker['local_directory'].startswith(fn) + worker = list(info["workers"].values())[0] + assert worker["local_directory"].startswith(fn) -@pytest.mark.parametrize('nanny', ['--nanny', '--no-nanny']) +@pytest.mark.parametrize("nanny", ["--nanny", "--no-nanny"]) def test_scheduler_file(loop, nanny): with tmpfile() as fn: - with popen(['dask-scheduler', '--no-bokeh', '--scheduler-file', fn]) as sched: - with popen(['dask-worker', '--scheduler-file', fn, nanny, '--no-bokeh']): + with popen(["dask-scheduler", "--no-bokeh", "--scheduler-file", fn]) as sched: + with popen(["dask-worker", "--scheduler-file", fn, nanny, "--no-bokeh"]): with Client(scheduler_file=fn, loop=loop) as c: start = time() - while not c.scheduler_info()['workers']: + while not c.scheduler_info()["workers"]: sleep(0.1) assert time() < start + 10 def test_nprocs_requires_nanny(loop): - with popen(['dask-scheduler', '--no-bokeh']) as sched: - with popen(['dask-worker', '127.0.0.1:8786', '--nprocs=2', - '--no-nanny']) as worker: - assert any(b'Failed to launch worker' in worker.stderr.readline() - for i in range(15)) + with popen(["dask-scheduler", "--no-bokeh"]) as sched: + with popen( + ["dask-worker", "127.0.0.1:8786", "--nprocs=2", "--no-nanny"] + ) as worker: + assert any( + b"Failed to launch worker" in worker.stderr.readline() + for i in range(15) + ) def test_nprocs_expands_name(loop): - with popen(['dask-scheduler', '--no-bokeh']) as sched: - with popen(['dask-worker', '127.0.0.1:8786', '--nprocs', '2', - '--name', 'foo']) as worker: - with popen(['dask-worker', '127.0.0.1:8786', '--nprocs', '2']) as worker: - with Client('tcp://127.0.0.1:8786', loop=loop) as c: + with popen(["dask-scheduler", "--no-bokeh"]) as sched: + with popen( + ["dask-worker", "127.0.0.1:8786", "--nprocs", "2", "--name", "foo"] + ) as worker: + with popen(["dask-worker", "127.0.0.1:8786", "--nprocs", "2"]) as worker: + with Client("tcp://127.0.0.1:8786", loop=loop) as c: start = time() - while len(c.scheduler_info()['workers']) < 4: + while len(c.scheduler_info()["workers"]) < 4: sleep(0.2) assert time() < start + 10 info = c.scheduler_info() - names = [d['name'] for d in info['workers'].values()] - foos = [n for n in names if n.startswith('foo')] + names = [d["name"] for d in info["workers"].values()] + foos = [n for n in names if n.startswith("foo")] assert len(foos) == 2 assert len(set(names)) == 4 -@pytest.mark.skipif(not sys.platform.startswith('linux'), - reason="Need 127.0.0.2 to mean localhost") -@pytest.mark.parametrize('nanny', ['--nanny', '--no-nanny']) -@pytest.mark.parametrize('listen_address', [ - 'tcp://0.0.0.0:39837', - 'tcp://127.0.0.2:39837']) +@pytest.mark.skipif( + not sys.platform.startswith("linux"), reason="Need 127.0.0.2 to mean localhost" +) +@pytest.mark.parametrize("nanny", ["--nanny", "--no-nanny"]) +@pytest.mark.parametrize( + "listen_address", ["tcp://0.0.0.0:39837", "tcp://127.0.0.2:39837"] +) def test_contact_listen_address(loop, nanny, listen_address): - with popen(['dask-scheduler', '--no-bokeh']) as sched: - with popen(['dask-worker', '127.0.0.1:8786', - nanny, '--no-bokeh', - '--contact-address', 'tcp://127.0.0.2:39837', - '--listen-address', listen_address]) as worker: - with Client('127.0.0.1:8786') as client: + with popen(["dask-scheduler", "--no-bokeh"]) as sched: + with popen( + [ + "dask-worker", + "127.0.0.1:8786", + nanny, + "--no-bokeh", + "--contact-address", + "tcp://127.0.0.2:39837", + "--listen-address", + listen_address, + ] + ) as worker: + with Client("127.0.0.1:8786") as client: while not client.ncores(): sleep(0.1) info = client.scheduler_info() - assert 'tcp://127.0.0.2:39837' in info['workers'] + assert "tcp://127.0.0.2:39837" in info["workers"] # roundtrip works assert client.submit(lambda x: x + 1, 10).result() == 11 @@ -159,19 +205,20 @@ def test_contact_listen_address(loop, nanny, listen_address): def func(dask_worker): return dask_worker.listener.listen_address - assert client.run(func) == {'tcp://127.0.0.2:39837': listen_address} + assert client.run(func) == {"tcp://127.0.0.2:39837": listen_address} -@pytest.mark.skipif(not sys.platform.startswith('linux'), - reason="Need 127.0.0.2 to mean localhost") -@pytest.mark.parametrize('nanny', ['--nanny', '--no-nanny']) -@pytest.mark.parametrize('host', ['127.0.0.2', '0.0.0.0']) +@pytest.mark.skipif( + not sys.platform.startswith("linux"), reason="Need 127.0.0.2 to mean localhost" +) +@pytest.mark.parametrize("nanny", ["--nanny", "--no-nanny"]) +@pytest.mark.parametrize("host", ["127.0.0.2", "0.0.0.0"]) def test_respect_host_listen_address(loop, nanny, host): - with popen(['dask-scheduler', '--no-bokeh']) as sched: - with popen(['dask-worker', '127.0.0.1:8786', - nanny, '--no-bokeh', - '--host', host]) as worker: - with Client('127.0.0.1:8786') as client: + with popen(["dask-scheduler", "--no-bokeh"]) as sched: + with popen( + ["dask-worker", "127.0.0.1:8786", nanny, "--no-bokeh", "--host", host] + ) as worker: + with Client("127.0.0.1:8786") as client: while not client.ncores(): sleep(0.1) info = client.scheduler_info() @@ -187,22 +234,23 @@ def func(dask_worker): def test_bokeh_non_standard_ports(loop): - pytest.importorskip('bokeh') + pytest.importorskip("bokeh") - with popen(['dask-scheduler', '--port', '3449', '--no-bokeh']): - with popen(['dask-worker', 'tcp://127.0.0.1:3449', - '--dashboard-address', ':4833']) as proc: - with Client('127.0.0.1:3449', loop=loop) as c: + with popen(["dask-scheduler", "--port", "3449", "--no-bokeh"]): + with popen( + ["dask-worker", "tcp://127.0.0.1:3449", "--dashboard-address", ":4833"] + ) as proc: + with Client("127.0.0.1:3449", loop=loop) as c: pass start = time() while True: try: - response = requests.get('http://127.0.0.1:4833/main') + response = requests.get("http://127.0.0.1:4833/main") assert response.ok break except Exception: sleep(0.5) assert time() < start + 20 with pytest.raises(Exception): - requests.get('http://localhost:4833/status/') + requests.get("http://localhost:4833/status/") diff --git a/distributed/cli/tests/test_tls_cli.py b/distributed/cli/tests/test_tls_cli.py index 017a982d39e..d983039c962 100644 --- a/distributed/cli/tests/test_tls_cli.py +++ b/distributed/cli/tests/test_tls_cli.py @@ -4,20 +4,25 @@ from distributed import Client -from distributed.utils_test import (popen, get_cert, new_config_file, - tls_security, tls_only_config) +from distributed.utils_test import ( + popen, + get_cert, + new_config_file, + tls_security, + tls_only_config, +) from distributed.utils_test import loop # noqa: F401 from distributed.metrics import time -ca_file = get_cert('tls-ca-cert.pem') -cert = get_cert('tls-cert.pem') -key = get_cert('tls-key.pem') -keycert = get_cert('tls-key-cert.pem') +ca_file = get_cert("tls-ca-cert.pem") +cert = get_cert("tls-cert.pem") +key = get_cert("tls-key.pem") +keycert = get_cert("tls-key-cert.pem") -tls_args = ['--tls-ca-file', ca_file, '--tls-cert', keycert] -tls_args_2 = ['--tls-ca-file', ca_file, '--tls-cert', cert, '--tls-key', key] +tls_args = ["--tls-ca-file", ca_file, "--tls-cert", keycert] +tls_args_2 = ["--tls-ca-file", ca_file, "--tls-cert", cert, "--tls-key", key] def wait_for_cores(c, ncores=1): @@ -28,33 +33,43 @@ def wait_for_cores(c, ncores=1): def test_basic(loop): - with popen(['dask-scheduler', '--no-bokeh'] + tls_args) as s: - with popen(['dask-worker', '--no-bokeh', 'tls://127.0.0.1:8786'] + tls_args) as w: - with Client('tls://127.0.0.1:8786', loop=loop, - security=tls_security()) as c: + with popen(["dask-scheduler", "--no-bokeh"] + tls_args) as s: + with popen( + ["dask-worker", "--no-bokeh", "tls://127.0.0.1:8786"] + tls_args + ) as w: + with Client( + "tls://127.0.0.1:8786", loop=loop, security=tls_security() + ) as c: wait_for_cores(c) def test_nanny(loop): - with popen(['dask-scheduler', '--no-bokeh'] + tls_args) as s: - with popen(['dask-worker', '--no-bokeh', '--nanny', 'tls://127.0.0.1:8786'] + tls_args) as w: - with Client('tls://127.0.0.1:8786', loop=loop, - security=tls_security()) as c: + with popen(["dask-scheduler", "--no-bokeh"] + tls_args) as s: + with popen( + ["dask-worker", "--no-bokeh", "--nanny", "tls://127.0.0.1:8786"] + tls_args + ) as w: + with Client( + "tls://127.0.0.1:8786", loop=loop, security=tls_security() + ) as c: wait_for_cores(c) def test_separate_key_cert(loop): - with popen(['dask-scheduler', '--no-bokeh'] + tls_args_2) as s: - with popen(['dask-worker', '--no-bokeh', 'tls://127.0.0.1:8786'] + tls_args_2) as w: - with Client('tls://127.0.0.1:8786', loop=loop, - security=tls_security()) as c: + with popen(["dask-scheduler", "--no-bokeh"] + tls_args_2) as s: + with popen( + ["dask-worker", "--no-bokeh", "tls://127.0.0.1:8786"] + tls_args_2 + ) as w: + with Client( + "tls://127.0.0.1:8786", loop=loop, security=tls_security() + ) as c: wait_for_cores(c) def test_use_config_file(loop): with new_config_file(tls_only_config()): - with popen(['dask-scheduler', '--no-bokeh', '--host', 'tls://']) as s: - with popen(['dask-worker', '--no-bokeh', 'tls://127.0.0.1:8786']) as w: - with Client('tls://127.0.0.1:8786', loop=loop, - security=tls_security()) as c: + with popen(["dask-scheduler", "--no-bokeh", "--host", "tls://"]) as s: + with popen(["dask-worker", "--no-bokeh", "tls://127.0.0.1:8786"]) as w: + with Client( + "tls://127.0.0.1:8786", loop=loop, security=tls_security() + ) as c: wait_for_cores(c) diff --git a/distributed/cli/utils.py b/distributed/cli/utils.py index 86250fd21ae..4ce1d845821 100644 --- a/distributed/cli/utils.py +++ b/distributed/cli/utils.py @@ -3,8 +3,12 @@ from tornado import gen from tornado.ioloop import IOLoop -from distributed.comm import (parse_address, unparse_address, - parse_host_port, unparse_host_port) +from distributed.comm import ( + parse_address, + unparse_address, + parse_host_port, + unparse_host_port, +) py3_err_msg = """ @@ -30,13 +34,16 @@ def check_python_3(): """Ensures that the environment is good for unicode on Python 3.""" # https://github.com/pallets/click/issues/448#issuecomment-246029304 import click.core + click.core._verify_python3_env = lambda: None try: from click import _unicodefun + _unicodefun._verify_python3_env() except (TypeError, RuntimeError) as e: import click + click.echo(py3_err_msg, err=True) @@ -78,16 +85,20 @@ def uri_from_host_port(host_arg, port_arg, default_port): # Much of distributed depends on a well-known IP being assigned to # each entity (Worker, Scheduler, etc.), so avoid "universal" addresses # like '' which would listen on all registered IPs and interfaces. - scheme, loc = parse_address(host_arg or '') + scheme, loc = parse_address(host_arg or "") - host, port = parse_host_port(loc, port_arg if port_arg is not None else default_port) + host, port = parse_host_port( + loc, port_arg if port_arg is not None else default_port + ) if port is None and port_arg is None: port_arg = default_port if port and port_arg and port != port_arg: - raise ValueError("port number given twice in options: " - "host %r and port %r" % (host_arg, port_arg)) + raise ValueError( + "port number given twice in options: " + "host %r and port %r" % (host_arg, port_arg) + ) if port is None and port_arg is not None: port = port_arg # Note `port = 0` means "choose a random port" diff --git a/distributed/client.py b/distributed/client.py index 47bccc7fc82..96d20a7ece2 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -29,6 +29,7 @@ from dask.optimization import SubgraphCallable from dask.compatibility import apply, unicode from dask.utils import ensure_dict + try: from cytoolz import first, groupby, merge, valmap, keymap except ImportError: @@ -44,11 +45,22 @@ from tornado.queues import Queue from .batched import BatchedSend -from .utils_comm import (WrappedKey, unpack_remotedata, pack_data, - scatter_to_workers, gather_from_workers) +from .utils_comm import ( + WrappedKey, + unpack_remotedata, + pack_data, + scatter_to_workers, + gather_from_workers, +) from .cfexecutor import ClientExecutor -from .compatibility import (Queue as pyQueue, Empty, isqueue, html_escape, - StopAsyncIteration, Iterator) +from .compatibility import ( + Queue as pyQueue, + Empty, + isqueue, + html_escape, + StopAsyncIteration, + Iterator, +) from .core import connect, rpc, clean_exception, CommClosedError, PooledRPCCall from .metrics import time from .node import Node @@ -60,10 +72,26 @@ from .sizeof import sizeof from .threadpoolexecutor import rejoin from .worker import dumps_task, get_client, get_worker, secede -from .utils import (All, sync, funcname, ignoring, queue_to_iterator, - tokey, log_errors, str_graph, key_split, format_bytes, asciitable, - thread_state, no_default, PeriodicCallback, LoopRunner, - parse_timedelta, shutting_down, Any) +from .utils import ( + All, + sync, + funcname, + ignoring, + queue_to_iterator, + tokey, + log_errors, + str_graph, + key_split, + format_bytes, + asciitable, + thread_state, + no_default, + PeriodicCallback, + LoopRunner, + parse_timedelta, + shutting_down, + Any, +) from .versions import get_versions @@ -73,16 +101,14 @@ _global_client_index = [0] -DEFAULT_EXTENSIONS = [ - PubSubClientExtension, -] +DEFAULT_EXTENSIONS = [PubSubClientExtension] def _get_global_client(): L = sorted(list(_global_clients), reverse=True) for k in L: c = _global_clients[k] - if c.status != 'closed': + if c.status != "closed": return c else: del _global_clients[k] @@ -140,6 +166,7 @@ class Future(WrappedKey): -------- Client: Creates futures """ + _cb_executor = None _cb_executor_pid = None @@ -157,9 +184,13 @@ def __init__(self, key, client=None, inform=True, state=None): self._state = self.client.futures[tkey] = FutureState() if inform: - self.client._send_to_scheduler({'op': 'client-desires-keys', - 'keys': [tokey(key)], - 'client': self.client.id}) + self.client._send_to_scheduler( + { + "op": "client-desires-keys", + "keys": [tokey(key)], + "client": self.client.id, + } + ) if state is not None: try: @@ -191,11 +222,10 @@ def result(self, timeout=None): return self.client.sync(self._result, callback_timeout=timeout) # shorten error traceback - result = self.client.sync(self._result, callback_timeout=timeout, - raiseit=False) - if self.status == 'error': + result = self.client.sync(self._result, callback_timeout=timeout, raiseit=False) + if self.status == "error": six.reraise(*result) - elif self.status == 'cancelled': + elif self.status == "cancelled": raise result else: return result @@ -203,14 +233,13 @@ def result(self, timeout=None): @gen.coroutine def _result(self, raiseit=True): yield self._state.wait() - if self.status == 'error': - exc = clean_exception(self._state.exception, - self._state.traceback) + if self.status == "error": + exc = clean_exception(self._state.exception, self._state.traceback) if raiseit: six.reraise(*exc) else: raise gen.Return(exc) - elif self.status == 'cancelled': + elif self.status == "cancelled": exception = CancelledError(self.key) if raiseit: raise exception @@ -223,7 +252,7 @@ def _result(self, raiseit=True): @gen.coroutine def _exception(self): yield self._state.wait() - if self.status == 'error': + if self.status == "error": raise gen.Return(self._state.exception) else: raise gen.Return(None) @@ -238,8 +267,7 @@ def exception(self, timeout=None, **kwargs): -------- Future.traceback """ - return self.client.sync(self._exception, callback_timeout=timeout, - **kwargs) + return self.client.sync(self._exception, callback_timeout=timeout, **kwargs) def add_done_callback(self, fn): """ Call callback on future when callback has finished @@ -253,7 +281,9 @@ def add_done_callback(self, fn): cls = Future if cls._cb_executor is None or cls._cb_executor_pid != os.getpid(): try: - cls._cb_executor = ThreadPoolExecutor(1, thread_name_prefix="Dask-Callback-Thread") + cls._cb_executor = ThreadPoolExecutor( + 1, thread_name_prefix="Dask-Callback-Thread" + ) except TypeError: cls._cb_executor = ThreadPoolExecutor(1) cls._cb_executor_pid = os.getpid() @@ -264,8 +294,9 @@ def execute_callback(fut): except BaseException: logger.exception("Error in callback %s of %s:", fn, fut) - self.client.loop.add_callback(done_callback, self, - partial(cls._cb_executor.submit, execute_callback)) + self.client.loop.add_callback( + done_callback, self, partial(cls._cb_executor.submit, execute_callback) + ) def cancel(self, **kwargs): """ Cancel request to run this future @@ -287,12 +318,12 @@ def retry(self, **kwargs): def cancelled(self): """ Returns True if the future has been cancelled """ - return self._state.status == 'cancelled' + return self._state.status == "cancelled" @gen.coroutine def _traceback(self): yield self._state.wait() - if self.status == 'error': + if self.status == "error": raise gen.Return(self._state.traceback) else: raise gen.Return(None) @@ -318,8 +349,7 @@ def traceback(self, timeout=None, **kwargs): -------- Future.exception """ - return self.client.sync(self._traceback, callback_timeout=timeout, - **kwargs) + return self.client.sync(self._traceback, callback_timeout=timeout, **kwargs) @property def type(self): @@ -342,8 +372,14 @@ def __setstate__(self, state): key, address = state c = get_client(address) Future.__init__(self, key, c) - c._send_to_scheduler({'op': 'update-graph', 'tasks': {}, - 'keys': [tokey(self.key)], 'client': c.id}) + c._send_to_scheduler( + { + "op": "update-graph", + "tasks": {}, + "keys": [tokey(self.key)], + "client": c.id, + } + ) def __del__(self): try: @@ -357,17 +393,23 @@ def __repr__(self): typ = self.type.__name__ except AttributeError: typ = str(self.type) - return '' % (self.status, - typ, self.key) + return "" % ( + self.status, + typ, + self.key, + ) else: - return '' % (self.status, self.key) + return "" % (self.status, self.key) def _repr_html_(self): - text = 'Future: %s ' % html_escape(key_split(self.key)) - text += ('status: ' - '%(status)s, ') % { - 'status': self.status, - 'color': 'red' if self.status == 'error' else 'black'} + text = "Future: %s " % html_escape(key_split(self.key)) + text += ( + 'status: ' + '%(status)s, ' + ) % { + "status": self.status, + "color": "red" if self.status == "error" else "black", + } if self.type: try: typ = self.type.__name__ @@ -386,11 +428,12 @@ class FutureState(object): This is shared between all Futures with the same key and client. """ - __slots__ = ('_event', 'status', 'type', 'exception', 'traceback') + + __slots__ = ("_event", "status", "type", "exception", "traceback") def __init__(self): self._event = None - self.status = 'pending' + self.status = "pending" self.type = None def _get_event(self): @@ -403,27 +446,27 @@ def _get_event(self): return event def cancel(self): - self.status = 'cancelled' + self.status = "cancelled" self._get_event().set() def finish(self, type=None): - self.status = 'finished' + self.status = "finished" self._get_event().set() if type is not None: self.type = type def lose(self): - self.status = 'lost' + self.status = "lost" self._get_event().clear() def retry(self): - self.status = 'pending' + self.status = "pending" self._get_event().clear() def set_error(self, exception, traceback): _, exception, traceback = clean_exception(exception, traceback) - self.status = 'error' + self.status = "error" self.exception = exception self.traceback = traceback self._get_event().set() @@ -432,7 +475,7 @@ def done(self): return self._event is not None and self._event.is_set() def reset(self): - self.status = 'pending' + self.status = "pending" if self._event is not None: self._event.clear() @@ -441,13 +484,13 @@ def wait(self, timeout=None): yield self._get_event().wait(timeout) def __repr__(self): - return '<%s: %s>' % (self.__class__.__name__, self.status) + return "<%s: %s>" % (self.__class__.__name__, self.status) @gen.coroutine def done_callback(future, callback): """ Coroutine that waits on future, then calls callback """ - while future.status == 'pending': + while future.status == "pending": yield future._state.wait() callback(future) @@ -520,27 +563,42 @@ class resembles executors in ``concurrent.futures`` but also allows -------- distributed.scheduler.Scheduler: Internal scheduler """ - def __init__(self, address=None, loop=None, timeout=no_default, - set_as_default=True, scheduler_file=None, - security=None, asynchronous=False, - name=None, heartbeat_interval=None, - serializers=None, deserializers=None, - extensions=DEFAULT_EXTENSIONS, direct_to_workers=False, - **kwargs): + + def __init__( + self, + address=None, + loop=None, + timeout=no_default, + set_as_default=True, + scheduler_file=None, + security=None, + asynchronous=False, + name=None, + heartbeat_interval=None, + serializers=None, + deserializers=None, + extensions=DEFAULT_EXTENSIONS, + direct_to_workers=False, + **kwargs + ): if timeout == no_default: - timeout = dask.config.get('distributed.comm.timeouts.connect') + timeout = dask.config.get("distributed.comm.timeouts.connect") if timeout is not None: - timeout = parse_timedelta(timeout, 's') + timeout = parse_timedelta(timeout, "s") self._timeout = timeout self.futures = dict() self.refcount = defaultdict(lambda: 0) self.coroutines = [] if name is None: - name = dask.config.get('client-name', None) - self.id = type(self).__name__ + ('-' + name + '-' if name else '-') + str(uuid.uuid1(clock_seq=os.getpid())) + name = dask.config.get("client-name", None) + self.id = ( + type(self).__name__ + + ("-" + name + "-" if name else "-") + + str(uuid.uuid1(clock_seq=os.getpid())) + ) self.generation = 0 - self.status = 'newly-created' + self.status = "newly-created" self._pending_msg_buffer = [] self.extensions = {} self.scheduler_file = scheduler_file @@ -569,16 +627,15 @@ def __init__(self, address=None, loop=None, timeout=no_default, self.scheduler_comm = None assert isinstance(self.security, Security) - if name == 'worker': - self.connection_args = self.security.get_connection_args('worker') + if name == "worker": + self.connection_args = self.security.get_connection_args("worker") else: - self.connection_args = self.security.get_connection_args('client') + self.connection_args = self.security.get_connection_args("client") if address is None: - address = dask.config.get('scheduler-address', None) + address = dask.config.get("scheduler-address", None) if address: - logger.info("Config value `scheduler-address` found: %s", - address) + logger.info("Config value `scheduler-address` found: %s", address) if isinstance(address, (rpc, PooledRPCCall)): self.scheduler = address @@ -595,44 +652,45 @@ def __init__(self, address=None, loop=None, timeout=no_default, self.loop = self._loop_runner.loop if heartbeat_interval is None: - heartbeat_interval = dask.config.get('distributed.client.heartbeat') - heartbeat_interval = parse_timedelta(heartbeat_interval, default='ms') + heartbeat_interval = dask.config.get("distributed.client.heartbeat") + heartbeat_interval = parse_timedelta(heartbeat_interval, default="ms") self._periodic_callbacks = dict() - self._periodic_callbacks['scheduler-info'] = PeriodicCallback( - self._update_scheduler_info, 2000, io_loop=self.loop + self._periodic_callbacks["scheduler-info"] = PeriodicCallback( + self._update_scheduler_info, 2000, io_loop=self.loop ) - self._periodic_callbacks['heartbeat'] = PeriodicCallback( - self._heartbeat, - heartbeat_interval * 1000, - io_loop=self.loop + self._periodic_callbacks["heartbeat"] = PeriodicCallback( + self._heartbeat, heartbeat_interval * 1000, io_loop=self.loop ) self._start_arg = address if set_as_default: - self._set_config = dask.config.set(scheduler='dask.distributed', - shuffle='tasks') + self._set_config = dask.config.set( + scheduler="dask.distributed", shuffle="tasks" + ) self._stream_handlers = { - 'key-in-memory': self._handle_key_in_memory, - 'lost-data': self._handle_lost_data, - 'cancelled-key': self._handle_cancelled_key, - 'task-retried': self._handle_retried_key, - 'task-erred': self._handle_task_erred, - 'restart': self._handle_restart, - 'error': self._handle_error + "key-in-memory": self._handle_key_in_memory, + "lost-data": self._handle_lost_data, + "cancelled-key": self._handle_cancelled_key, + "task-retried": self._handle_retried_key, + "task-erred": self._handle_task_erred, + "restart": self._handle_restart, + "error": self._handle_error, } self._state_handlers = { - 'memory': self._handle_key_in_memory, - 'lost': self._handle_lost_data, - 'erred': self._handle_task_erred + "memory": self._handle_key_in_memory, + "lost": self._handle_lost_data, + "erred": self._handle_task_erred, } - super(Client, self).__init__(connection_args=self.connection_args, - io_loop=self.loop, - serializers=serializers, - deserializers=deserializers) + super(Client, self).__init__( + connection_args=self.connection_args, + io_loop=self.loop, + serializers=serializers, + deserializers=deserializers, + ) for ext in extensions: ext(self) @@ -640,6 +698,7 @@ def __init__(self, address=None, loop=None, timeout=no_default, self.start(timeout=timeout) from distributed.recreate_exceptions import ReplayExceptionClient + ReplayExceptionClient(self) @classmethod @@ -665,13 +724,16 @@ def asynchronous(self): return self._asynchronous and self.loop is IOLoop.current() def sync(self, func, *args, **kwargs): - asynchronous = kwargs.pop('asynchronous', None) - if asynchronous or self.asynchronous or getattr(thread_state, 'asynchronous', False): - callback_timeout = kwargs.pop('callback_timeout', None) + asynchronous = kwargs.pop("asynchronous", None) + if ( + asynchronous + or self.asynchronous + or getattr(thread_state, "asynchronous", False) + ): + callback_timeout = kwargs.pop("callback_timeout", None) future = func(*args, **kwargs) if callback_timeout is not None: - future = gen.with_timeout(timedelta(seconds=callback_timeout), - future) + future = gen.with_timeout(timedelta(seconds=callback_timeout), future) return future else: return sync(self.loop, func, *args, **kwargs) @@ -679,26 +741,38 @@ def sync(self, func, *args, **kwargs): def __repr__(self): # Note: avoid doing I/O here... info = self._scheduler_identity - addr = info.get('address') + addr = info.get("address") if addr: - workers = info.get('workers', {}) + workers = info.get("workers", {}) nworkers = len(workers) - ncores = sum(w['ncores'] for w in workers.values()) - return '<%s: scheduler=%r processes=%d cores=%d>' % ( - self.__class__.__name__, addr, nworkers, ncores) + ncores = sum(w["ncores"] for w in workers.values()) + return "<%s: scheduler=%r processes=%d cores=%d>" % ( + self.__class__.__name__, + addr, + nworkers, + ncores, + ) elif self.scheduler is not None: - return '<%s: scheduler=%r>' % ( - self.__class__.__name__, self.scheduler.address) + return "<%s: scheduler=%r>" % ( + self.__class__.__name__, + self.scheduler.address, + ) else: - return '<%s: not connected>' % (self.__class__.__name__,) + return "<%s: not connected>" % (self.__class__.__name__,) def _repr_html_(self): - if self.cluster and hasattr(self.cluster, 'scheduler') and self.cluster.scheduler: + if ( + self.cluster + and hasattr(self.cluster, "scheduler") + and self.cluster.scheduler + ): info = self.cluster.scheduler.identity() scheduler = self.cluster.scheduler - elif (self._loop_runner.is_started() and - self.scheduler and - not (self.asynchronous and self.loop is IOLoop.current())): + elif ( + self._loop_runner.is_started() + and self.scheduler + and not (self.asynchronous and self.loop is IOLoop.current()) + ): info = sync(self.loop, self.scheduler.identity) scheduler = self.scheduler else: @@ -706,56 +780,63 @@ def _repr_html_(self): scheduler = self.scheduler if scheduler is not None: - text = ("

Client

\n" - "
    \n" - "
  • Scheduler: %s\n") % scheduler.address + text = ( + "

    Client

    \n" "
      \n" "
    • Scheduler: %s\n" + ) % scheduler.address else: - text = ("

      Client

      \n" - "
        \n" - "
      • Scheduler: not connected\n") - if info and 'bokeh' in info['services']: - protocol, rest = scheduler.address.split('://') - port = info['services']['bokeh'] - if protocol == 'inproc': - host = 'localhost' + text = ( + "

        Client

        \n" "
          \n" "
        • Scheduler: not connected\n" + ) + if info and "bokeh" in info["services"]: + protocol, rest = scheduler.address.split("://") + port = info["services"]["bokeh"] + if protocol == "inproc": + host = "localhost" else: - host = rest.split(':')[0] - template = dask.config.get('distributed.dashboard.link') + host = rest.split(":")[0] + template = dask.config.get("distributed.dashboard.link") address = template.format(host=host, port=port, **os.environ) - text += "
        • Dashboard: %(web)s\n" % {'web': address} + text += ( + "
        • Dashboard: %(web)s\n" + % {"web": address} + ) text += "
        \n" if info: - workers = len(info['workers']) - cores = sum(w['ncores'] for w in info['workers'].values()) - memory = sum(w['memory_limit'] for w in info['workers'].values()) + workers = len(info["workers"]) + cores = sum(w["ncores"] for w in info["workers"].values()) + memory = sum(w["memory_limit"] for w in info["workers"].values()) memory = format_bytes(memory) - text2 = ("

        Cluster

        \n" - "
          \n" - "
        • Workers: %d
        • \n" - "
        • Cores: %d
        • \n" - "
        • Memory: %s
        • \n" - "
        \n") % (workers, cores, memory) - - return ('\n' - '\n' - '\n' - '\n' - '\n
        \n%s\n%s
        ') % (text, text2) + text2 = ( + "

        Cluster

        \n" + "
          \n" + "
        • Workers: %d
        • \n" + "
        • Cores: %d
        • \n" + "
        • Memory: %s
        • \n" + "
        \n" + ) % (workers, cores, memory) + + return ( + '\n' + "\n" + '\n' + '\n' + "\n
        \n%s\n%s
        " + ) % (text, text2) else: return text def start(self, **kwargs): """ Start scheduler running in separate thread """ - if self.status != 'newly-created': + if self.status != "newly-created": return self._loop_runner.start() _set_global_client(self) - self.status = 'connecting' + self.status = "connecting" if self.asynchronous: self._started = self._start(**kwargs) @@ -763,37 +844,41 @@ def start(self, **kwargs): sync(self.loop, self._start, **kwargs) def __await__(self): - if hasattr(self, '_started'): + if hasattr(self, "_started"): return self._started.__await__() else: + @gen.coroutine def _(): raise gen.Return(self) + return _().__await__() def _send_to_scheduler_safe(self, msg): - if self.status in ('running', 'closing'): + if self.status in ("running", "closing"): try: self.scheduler_comm.send(msg) except (CommClosedError, AttributeError): - if self.status == 'running': + if self.status == "running": raise - elif self.status in ('connecting', 'newly-created'): + elif self.status in ("connecting", "newly-created"): self._pending_msg_buffer.append(msg) def _send_to_scheduler(self, msg): - if self.status in ('running', 'closing', 'connecting', 'newly-created'): + if self.status in ("running", "closing", "connecting", "newly-created"): self.loop.add_callback(self._send_to_scheduler_safe, msg) else: - raise Exception("Tried sending message after closing. Status: %s\n" - "Message: %s" % (self.status, msg)) + raise Exception( + "Tried sending message after closing. Status: %s\n" + "Message: %s" % (self.status, msg) + ) @gen.coroutine def _start(self, timeout=no_default, **kwargs): if timeout == no_default: timeout = self._timeout if timeout is not None: - timeout = parse_timedelta(timeout, 's') + timeout = parse_timedelta(timeout, "s") address = self._start_arg if self.cluster is not None: @@ -803,8 +888,10 @@ def _start(self, timeout=no_default, **kwargs): except AttributeError: # Some clusters don't have this method pass except Exception: - logger.info("Tried to start cluster and received an error. " - "Proceeding.", exc_info=True) + logger.info( + "Tried to start cluster and received an error. " "Proceeding.", + exc_info=True, + ) address = self.cluster.scheduler_address elif self.scheduler_file is not None: while not os.path.exists(self.scheduler_file): @@ -813,7 +900,7 @@ def _start(self, timeout=no_default, **kwargs): try: with open(self.scheduler_file) as f: cfg = json.load(f) - address = cfg['address'] + address = cfg["address"] break except (ValueError, KeyError): # JSON file not yet flushed yield gen.sleep(0.01) @@ -821,31 +908,39 @@ def _start(self, timeout=no_default, **kwargs): from .deploy import LocalCluster try: - self.cluster = LocalCluster(loop=self.loop, asynchronous=True, - **self._startup_kwargs) + self.cluster = LocalCluster( + loop=self.loop, asynchronous=True, **self._startup_kwargs + ) yield self.cluster except (OSError, socket.error) as e: if e.errno != errno.EADDRINUSE: raise # The default port was taken, use a random one - self.cluster = LocalCluster(scheduler_port=0, loop=self.loop, - asynchronous=True, - **self._startup_kwargs) + self.cluster = LocalCluster( + scheduler_port=0, + loop=self.loop, + asynchronous=True, + **self._startup_kwargs + ) yield self.cluster # Wait for all workers to be ready # XXX should be a LocalCluster method instead - while (not self.cluster.workers or - len(self.cluster.scheduler.workers) < len(self.cluster.workers)): + while not self.cluster.workers or len(self.cluster.scheduler.workers) < len( + self.cluster.workers + ): yield gen.sleep(0.01) address = self.cluster.scheduler_address if self.scheduler is None: - self.scheduler = rpc(address, timeout=timeout, - connection_args=self.connection_args, - serializers=self._serializers, - deserializers=self._deserializers) + self.scheduler = rpc( + address, + timeout=timeout, + connection_args=self.connection_args, + serializers=self._serializers, + deserializers=self._deserializers, + ) self.scheduler_comm = None yield self._ensure_connected(timeout=timeout) @@ -862,14 +957,14 @@ def _start(self, timeout=no_default, **kwargs): def _reconnect(self, timeout=0.1): with log_errors(): assert self.scheduler_comm.comm.closed() - self.status = 'connecting' + self.status = "connecting" self.scheduler_comm = None for st in self.futures.values(): st.cancel() self.futures.clear() - while self.status == 'connecting': + while self.status == "connecting": try: yield self._ensure_connected() break @@ -878,39 +973,46 @@ def _reconnect(self, timeout=0.1): @gen.coroutine def _ensure_connected(self, timeout=None): - if (self.scheduler_comm and not self.scheduler_comm.closed() or - self._connecting_to_scheduler or self.scheduler is None): + if ( + self.scheduler_comm + and not self.scheduler_comm.closed() + or self._connecting_to_scheduler + or self.scheduler is None + ): return self._connecting_to_scheduler = True try: - comm = yield connect(self.scheduler.address, timeout=timeout, - connection_args=self.connection_args) + comm = yield connect( + self.scheduler.address, + timeout=timeout, + connection_args=self.connection_args, + ) if timeout is not None: - yield gen.with_timeout(timedelta(seconds=timeout), - self._update_scheduler_info()) + yield gen.with_timeout( + timedelta(seconds=timeout), self._update_scheduler_info() + ) else: yield self._update_scheduler_info() - yield comm.write({'op': 'register-client', - 'client': self.id, - 'reply': False}) + yield comm.write( + {"op": "register-client", "client": self.id, "reply": False} + ) finally: self._connecting_to_scheduler = False if timeout is not None: - msg = yield gen.with_timeout(timedelta(seconds=timeout), - comm.read()) + msg = yield gen.with_timeout(timedelta(seconds=timeout), comm.read()) else: msg = yield comm.read() assert len(msg) == 1 - assert msg[0]['op'] == 'stream-start' + assert msg[0]["op"] == "stream-start" - bcomm = BatchedSend(interval='10ms', loop=self.loop) + bcomm = BatchedSend(interval="10ms", loop=self.loop) bcomm.start(comm) self.scheduler_comm = bcomm _set_global_client(self) - self.status = 'running' + self.status = "running" for msg in self._pending_msg_buffer: self._send_to_scheduler(msg) @@ -920,7 +1022,7 @@ def _ensure_connected(self, timeout=None): @gen.coroutine def _update_scheduler_info(self): - if self.status not in ('running', 'connecting'): + if self.status not in ("running", "connecting"): return try: self._scheduler_identity = yield self.scheduler.identity() @@ -929,7 +1031,7 @@ def _update_scheduler_info(self): def _heartbeat(self): if self.scheduler_comm: - self.scheduler_comm.send({'op': 'heartbeat-client'}) + self.scheduler_comm.send({"op": "heartbeat-client"}) def __enter__(self): if not self._loop_runner.is_started(): @@ -968,10 +1070,10 @@ def _release_key(self, key): st = self.futures.pop(key, None) if st is not None: st.cancel() - if self.status != 'closed': - self._send_to_scheduler({'op': 'client-releases-keys', - 'keys': [key], - 'client': self.id}) + if self.status != "closed": + self._send_to_scheduler( + {"op": "client-releases-keys", "keys": [key], "client": self.id} + ) @gen.coroutine def _handle_report(self): @@ -984,10 +1086,10 @@ def _handle_report(self): try: msgs = yield self.scheduler_comm.comm.read() except CommClosedError: - if self.status == 'running': + if self.status == "running": logger.info("Client report stream closed to scheduler") logger.info("Reconnecting...") - self.status = 'connecting' + self.status = "connecting" yield self._reconnect() continue else: @@ -999,12 +1101,12 @@ def _handle_report(self): for msg in msgs: logger.debug("Client receives message %s", msg) - if 'status' in msg and 'error' in msg['status']: + if "status" in msg and "error" in msg["status"]: six.reraise(*clean_exception(**msg)) - op = msg.pop('op') + op = msg.pop("op") - if op == 'close' or op == 'stream-closed': + if op == "close" or op == "stream-closed": breakout = True break @@ -1067,7 +1169,7 @@ def _handle_error(self, exception=None): @gen.coroutine def _close(self, fast=False): """ Send close signal and wait until scheduler completes """ - self.status = 'closing' + self.status = "closing" with log_errors(): _del_global_client(self) @@ -1078,30 +1180,40 @@ def _close(self, fast=False): # clear the dask.config set keys with self._set_config: pass - if self.get == dask.config.get('get', None): - del dask.config.config['get'] - if self.status == 'closed': + if self.get == dask.config.get("get", None): + del dask.config.config["get"] + if self.status == "closed": raise gen.Return() - if self.scheduler_comm and self.scheduler_comm.comm and not self.scheduler_comm.comm.closed(): - self._send_to_scheduler({'op': 'close-client'}) - self._send_to_scheduler({'op': 'close-stream'}) + if ( + self.scheduler_comm + and self.scheduler_comm.comm + and not self.scheduler_comm.comm.closed() + ): + self._send_to_scheduler({"op": "close-client"}) + self._send_to_scheduler({"op": "close-stream"}) # Give the scheduler 'stream-closed' message 100ms to come through # This makes the shutdown slightly smoother and quieter with ignoring(AttributeError, gen.TimeoutError): - yield gen.with_timeout(timedelta(milliseconds=100), - self._handle_scheduler_coroutine, - quiet_exceptions=(CancelledError,)) - - if self.scheduler_comm and self.scheduler_comm.comm and not self.scheduler_comm.comm.closed(): + yield gen.with_timeout( + timedelta(milliseconds=100), + self._handle_scheduler_coroutine, + quiet_exceptions=(CancelledError,), + ) + + if ( + self.scheduler_comm + and self.scheduler_comm.comm + and not self.scheduler_comm.comm.closed() + ): yield self.scheduler_comm.close() for key in list(self.futures): self._release_key(key=key) if self._start_arg is None: with ignoring(AttributeError): yield self.cluster._close() - self.status = 'closed' + self.status = "closed" if _get_global_client() is self: _set_global_client(None) coroutines = set(self.coroutines) @@ -1115,13 +1227,12 @@ def _close(self, fast=False): del self.coroutines[:] if not fast: with ignoring(TimeoutError): - yield gen.with_timeout(timedelta(seconds=2), - list(coroutines)) + yield gen.with_timeout(timedelta(seconds=2), list(coroutines)) with ignoring(AttributeError): self.scheduler.close_rpc() self.scheduler = None - self.status = 'closed' + self.status = "closed" _shutdown = _close @@ -1140,9 +1251,9 @@ def close(self, timeout=no_default): if timeout == no_default: timeout = self._timeout * 2 # XXX handling of self.status here is not thread-safe - if self.status == 'closed': + if self.status == "closed": return - self.status = 'closing' + self.status = "closing" if self.asynchronous: future = self._close() @@ -1156,7 +1267,7 @@ def close(self, timeout=no_default): sync(self.loop, self._close, fast=True) - assert self.status == 'closed' + assert self.status == "closed" if self._should_close_loop and not shutting_down(): self._loop_runner.stop() @@ -1229,24 +1340,24 @@ def submit(self, func, *args, **kwargs): if not callable(func): raise TypeError("First input to submit must be a callable function") - key = kwargs.pop('key', None) - workers = kwargs.pop('workers', None) - resources = kwargs.pop('resources', None) - retries = kwargs.pop('retries', None) - priority = kwargs.pop('priority', 0) - fifo_timeout = kwargs.pop('fifo_timeout', '100ms') - allow_other_workers = kwargs.pop('allow_other_workers', False) - actor = kwargs.pop('actor', kwargs.pop('actors', False)) - pure = kwargs.pop('pure', not actor) + key = kwargs.pop("key", None) + workers = kwargs.pop("workers", None) + resources = kwargs.pop("resources", None) + retries = kwargs.pop("retries", None) + priority = kwargs.pop("priority", 0) + fifo_timeout = kwargs.pop("fifo_timeout", "100ms") + allow_other_workers = kwargs.pop("allow_other_workers", False) + actor = kwargs.pop("actor", kwargs.pop("actors", False)) + pure = kwargs.pop("pure", not actor) if allow_other_workers not in (True, False, None): raise TypeError("allow_other_workers= must be True or False") if key is None: if pure: - key = funcname(func) + '-' + tokenize(func, kwargs, *args) + key = funcname(func) + "-" + tokenize(func, kwargs, *args) else: - key = funcname(func) + '-' + str(uuid.uuid4()) + key = funcname(func) + "-" + str(uuid.uuid4()) skey = tokey(key) @@ -1271,13 +1382,18 @@ def submit(self, func, *args, **kwargs): else: dsk = {skey: (func,) + tuple(args)} - futures = self._graph_to_futures(dsk, [skey], restrictions, - loose_restrictions, priority={skey: 0}, - user_priority=priority, - resources={skey: resources} if resources else None, - retries=retries, - fifo_timeout=fifo_timeout, - actors=actor) + futures = self._graph_to_futures( + dsk, + [skey], + restrictions, + loose_restrictions, + priority={skey: 0}, + user_priority=priority, + resources={skey: resources} if resources else None, + retries=retries, + fifo_timeout=fifo_timeout, + actors=actor, + ) logger.debug("Submit %s(...), %s", funcname(func), key) @@ -1345,14 +1461,17 @@ def map(self, func, *iterables, **kwargs): if not callable(func): raise TypeError("First input to map must be a callable function") - if (all(map(isqueue, iterables)) or - all(isinstance(i, Iterator) for i in iterables)): - maxsize = kwargs.pop('maxsize', 0) + if all(map(isqueue, iterables)) or all( + isinstance(i, Iterator) for i in iterables + ): + maxsize = kwargs.pop("maxsize", 0) q_out = pyQueue(maxsize=maxsize) - t = threading.Thread(target=self._threaded_map, - name="Threaded map()", - args=(q_out, func, iterables), - kwargs=kwargs) + t = threading.Thread( + target=self._threaded_map, + name="Threaded map()", + args=(q_out, func, iterables), + kwargs=kwargs, + ) t.daemon = True t.start() if isqueue(iterables[0]): @@ -1360,16 +1479,16 @@ def map(self, func, *iterables, **kwargs): else: return queue_to_iterator(q_out) - key = kwargs.pop('key', None) + key = kwargs.pop("key", None) key = key or funcname(func) - workers = kwargs.pop('workers', None) - retries = kwargs.pop('retries', None) - resources = kwargs.pop('resources', None) - user_priority = kwargs.pop('priority', 0) - allow_other_workers = kwargs.pop('allow_other_workers', False) - fifo_timeout = kwargs.pop('fifo_timeout', '100ms') - actor = kwargs.pop('actor', kwargs.pop('actors', False)) - pure = kwargs.pop('pure', not actor) + workers = kwargs.pop("workers", None) + retries = kwargs.pop("retries", None) + resources = kwargs.pop("resources", None) + user_priority = kwargs.pop("priority", 0) + allow_other_workers = kwargs.pop("allow_other_workers", False) + fifo_timeout = kwargs.pop("fifo_timeout", "100ms") + actor = kwargs.pop("actor", kwargs.pop("actors", False)) + pure = kwargs.pop("pure", not actor) if allow_other_workers and workers is None: raise ValueError("Only use allow_other_workers= if using workers=") @@ -1379,16 +1498,23 @@ def map(self, func, *iterables, **kwargs): keys = key else: if pure: - keys = [key + '-' + tokenize(func, kwargs, *args) - for args in zip(*iterables)] + keys = [ + key + "-" + tokenize(func, kwargs, *args) + for args in zip(*iterables) + ] else: uid = str(uuid.uuid4()) - keys = [key + '-' + uid + '-' + str(i) - for i in range(min(map(len, iterables)))] if iterables else [] + keys = ( + [ + key + "-" + uid + "-" + str(i) + for i in range(min(map(len, iterables))) + ] + if iterables + else [] + ) if not kwargs: - dsk = {key: (func,) + args - for key, args in zip(keys, zip(*iterables))} + dsk = {key: (func,) + args for key, args in zip(keys, zip(*iterables))} else: kwargs2 = {} dsk = {} @@ -1399,16 +1525,22 @@ def map(self, func, *iterables, **kwargs): dsk.update(vv.dask) else: kwargs2[k] = v - dsk.update({key: (apply, func, (tuple, list(args)), kwargs2) - for key, args in zip(keys, zip(*iterables))}) + dsk.update( + { + key: (apply, func, (tuple, list(args)), kwargs2) + for key, args in zip(keys, zip(*iterables)) + } + ) if isinstance(workers, six.string_types + (Number,)): workers = [workers] if isinstance(workers, (list, set)): if workers and isinstance(first(workers), (list, set)): if len(workers) != len(keys): - raise ValueError("You only provided %d worker restrictions" - " for a sequence of length %d" % (len(workers), len(keys))) + raise ValueError( + "You only provided %d worker restrictions" + " for a sequence of length %d" % (len(workers), len(keys)) + ) restrictions = dict(zip(keys, workers)) else: restrictions = {k: workers for k in keys} @@ -1430,20 +1562,24 @@ def map(self, func, *iterables, **kwargs): else: resources = None - futures = self._graph_to_futures(dsk, keys, restrictions, - loose_restrictions, - priority=priority, - resources=resources, - retries=retries, - user_priority=user_priority, - fifo_timeout=fifo_timeout, - actors=actor) + futures = self._graph_to_futures( + dsk, + keys, + restrictions, + loose_restrictions, + priority=priority, + resources=resources, + retries=retries, + user_priority=user_priority, + fifo_timeout=fifo_timeout, + actors=actor, + ) logger.debug("map(%s, ...)", funcname(func)) return [futures[tokey(k)] for k in keys] @gen.coroutine - def _gather(self, futures, errors='raise', direct=None, local_worker=None): + def _gather(self, futures, errors="raise", direct=None, local_worker=None): unpacked, future_set = unpack_remotedata(futures, byte_keys=True) keys = [tokey(future.key) for future in future_set] bad_data = dict() @@ -1465,38 +1601,35 @@ def wait(k): """ Want to stop the All(...) early if we find an error """ st = self.futures[k] yield st.wait() - if st.status != 'finished' and errors == 'raise' : + if st.status != "finished" and errors == "raise": raise AllExit() while True: logger.debug("Waiting on futures to clear before gather") with ignoring(AllExit): - yield All([wait(key) for key in keys if key in self.futures], - quiet_exceptions=AllExit) + yield All( + [wait(key) for key in keys if key in self.futures], + quiet_exceptions=AllExit, + ) - failed = ('error', 'cancelled') + failed = ("error", "cancelled") exceptions = set() bad_keys = set() for key in keys: - if (key not in self.futures or - self.futures[key].status in failed): + if key not in self.futures or self.futures[key].status in failed: exceptions.add(key) - if errors == 'raise': + if errors == "raise": try: st = self.futures[key] exception = st.exception traceback = st.traceback except (AttributeError, KeyError): - six.reraise(CancelledError, - CancelledError(key), - None) + six.reraise(CancelledError, CancelledError(key), None) else: - six.reraise(type(exception), - exception, - traceback) - if errors == 'skip': + six.reraise(type(exception), exception, traceback) + if errors == "skip": bad_keys.add(key) bad_data[key] = None else: @@ -1505,16 +1638,16 @@ def wait(k): keys = [k for k in keys if k not in bad_keys and k not in data] if local_worker: # look inside local worker - data.update({k: local_worker.data[k] - for k in keys - if k in local_worker.data}) + data.update( + {k: local_worker.data[k] for k in keys if k in local_worker.data} + ) keys = [k for k in keys if k not in data] # We now do an actual remote communication with workers or scheduler if self._gather_future: # attach onto another pending gather request self._gather_keys |= set(keys) response = yield self._gather_future - else: # no one waiting, go ahead + else: # no one waiting, go ahead self._gather_keys = set(keys) future = self._gather_remote(direct, local_worker) if self._gather_keys is None: @@ -1523,13 +1656,16 @@ def wait(k): self._gather_future = future response = yield future - if response['status'] == 'error': - log = logger.warning if errors == 'raise' else logger.debug - log("Couldn't gather %s keys, rescheduling %s", len(response['keys']), response['keys']) - for key in response['keys']: - self._send_to_scheduler({'op': 'report-key', - 'key': key}) - for key in response['keys']: + if response["status"] == "error": + log = logger.warning if errors == "raise" else logger.debug + log( + "Couldn't gather %s keys, rescheduling %s", + len(response["keys"]), + response["keys"], + ) + for key in response["keys"]: + self._send_to_scheduler({"op": "report-key", "key": key}) + for key in response["keys"]: try: self.futures[key].reset() except KeyError: # TODO: verify that this is safe @@ -1537,10 +1673,10 @@ def wait(k): else: break - if bad_data and errors == 'skip' and isinstance(unpacked, list): + if bad_data and errors == "skip" and isinstance(unpacked, list): unpacked = [f for f in unpacked if f not in bad_data] - data.update(response['data']) + data.update(response["data"]) result = pack_data(unpacked, merge(data, bad_data)) raise gen.Return(result) @@ -1561,13 +1697,14 @@ def _gather_remote(self, direct, local_worker): if direct or local_worker: # gather directly from workers who_has = yield self.scheduler.who_has(keys=keys) data2, missing_keys, missing_workers = yield gather_from_workers( - who_has, rpc=self.rpc, close=False) - response = {'status': 'OK', 'data': data2} + who_has, rpc=self.rpc, close=False + ) + response = {"status": "OK", "data": data2} if missing_keys: keys2 = [key for key in keys if key not in data2] response = yield self.scheduler.gather(keys=keys2) - if response['status'] == 'OK': - response['data'].update(data2) + if response["status"] == "OK": + response["data"].update(data2) else: # ask scheduler to gather data for us response = yield self.scheduler.gather(keys=keys) @@ -1589,8 +1726,9 @@ def _threaded_gather(self, qin, qout, **kwargs): for item in results: qout.put(item) - def gather(self, futures, errors='raise', maxsize=0, direct=None, - asynchronous=None): + def gather( + self, futures, errors="raise", maxsize=0, direct=None, asynchronous=None + ): """ Gather futures from distributed memory Accepts a future, nested container of futures, iterator, or queue. @@ -1637,34 +1775,49 @@ def gather(self, futures, errors='raise', maxsize=0, direct=None, """ if isqueue(futures): qout = pyQueue(maxsize=maxsize) - t = threading.Thread(target=self._threaded_gather, - name="Threaded gather()", - args=(futures, qout), - kwargs={'errors': errors, 'direct': direct}) + t = threading.Thread( + target=self._threaded_gather, + name="Threaded gather()", + args=(futures, qout), + kwargs={"errors": errors, "direct": direct}, + ) t.daemon = True t.start() return qout elif isinstance(futures, Iterator): - return (self.gather(f, errors=errors, direct=direct) - for f in futures) + return (self.gather(f, errors=errors, direct=direct) for f in futures) else: - if hasattr(thread_state, 'execution_state'): # within worker task - local_worker = thread_state.execution_state['worker'] + if hasattr(thread_state, "execution_state"): # within worker task + local_worker = thread_state.execution_state["worker"] else: local_worker = None - return self.sync(self._gather, futures, errors=errors, - direct=direct, local_worker=local_worker, - asynchronous=asynchronous) + return self.sync( + self._gather, + futures, + errors=errors, + direct=direct, + local_worker=local_worker, + asynchronous=asynchronous, + ) @gen.coroutine - def _scatter(self, data, workers=None, broadcast=False, direct=None, - local_worker=None, timeout=no_default, hash=True): + def _scatter( + self, + data, + workers=None, + broadcast=False, + direct=None, + local_worker=None, + timeout=no_default, + hash=True, + ): if timeout == no_default: timeout = self._timeout if isinstance(workers, six.string_types + (Number,)): workers = [workers] - if isinstance(data, dict) and not all(isinstance(k, (bytes, unicode)) - for k in data): + if isinstance(data, dict) and not all( + isinstance(k, (bytes, unicode)) for k in data + ): d = yield self._scatter(keymap(tokey, data), workers, broadcast) raise gen.Return({k: d[tokey(k)] for k in data}) @@ -1682,9 +1835,9 @@ def _scatter(self, data, workers=None, broadcast=False, direct=None, data = [data] if isinstance(data, (list, tuple)): if hash: - names = [type(x).__name__ + '-' + tokenize(x) for x in data] + names = [type(x).__name__ + "-" + tokenize(x) for x in data] else: - names = [type(x).__name__ + '-' + uuid.uuid4().hex for x in data] + names = [type(x).__name__ + "-" + uuid.uuid4().hex for x in data] data = dict(zip(names, data)) assert isinstance(data, dict) @@ -1708,7 +1861,8 @@ def _scatter(self, data, workers=None, broadcast=False, direct=None, yield self.scheduler.update_data( who_has={key: [local_worker.address] for key in data}, nbytes=valmap(sizeof, data), - client=self.id) + client=self.id, + ) else: data2 = valmap(to_serialize, data) @@ -1724,18 +1878,21 @@ def _scatter(self, data, workers=None, broadcast=False, direct=None, if not ncores: raise ValueError("No valid workers") - _, who_has, nbytes = yield scatter_to_workers(ncores, data2, - report=False, - rpc=self.rpc) + _, who_has, nbytes = yield scatter_to_workers( + ncores, data2, report=False, rpc=self.rpc + ) - yield self.scheduler.update_data(who_has=who_has, - nbytes=nbytes, - client=self.id) + yield self.scheduler.update_data( + who_has=who_has, nbytes=nbytes, client=self.id + ) else: - yield self.scheduler.scatter(data=data2, workers=workers, - client=self.id, - broadcast=broadcast, - timeout=timeout) + yield self.scheduler.scatter( + data=data2, + workers=workers, + client=self.id, + broadcast=broadcast, + timeout=timeout, + ) out = {k: Future(k, self, inform=False) for k in data} for key, typ in types.items(): @@ -1774,8 +1931,17 @@ def _threaded_scatter(self, q_or_i, qout, **kwargs): for future in futures: qout.put(future) - def scatter(self, data, workers=None, broadcast=False, direct=None, - hash=True, maxsize=0, timeout=no_default, asynchronous=None): + def scatter( + self, + data, + workers=None, + broadcast=False, + direct=None, + hash=True, + maxsize=0, + timeout=no_default, + asynchronous=None, + ): """ Scatter data into distributed memory This moves data from the local client process into the workers of the @@ -1853,11 +2019,12 @@ def scatter(self, data, workers=None, broadcast=False, direct=None, logger.debug("Starting thread for streaming data") qout = pyQueue(maxsize=maxsize) - t = threading.Thread(target=self._threaded_scatter, - name="Threaded scatter()", - args=(data, qout), - kwargs={'workers': workers, - 'broadcast': broadcast}) + t = threading.Thread( + target=self._threaded_scatter, + name="Threaded scatter()", + args=(data, qout), + kwargs={"workers": workers, "broadcast": broadcast}, + ) t.daemon = True t.start() @@ -1866,14 +2033,21 @@ def scatter(self, data, workers=None, broadcast=False, direct=None, else: return queue_to_iterator(qout) else: - if hasattr(thread_state, 'execution_state'): # within worker task - local_worker = thread_state.execution_state['worker'] + if hasattr(thread_state, "execution_state"): # within worker task + local_worker = thread_state.execution_state["worker"] else: local_worker = None - return self.sync(self._scatter, data, workers=workers, - broadcast=broadcast, direct=direct, - local_worker=local_worker, timeout=timeout, - asynchronous=asynchronous, hash=hash) + return self.sync( + self._scatter, + data, + workers=workers, + broadcast=broadcast, + direct=direct, + local_worker=local_worker, + timeout=timeout, + asynchronous=asynchronous, + hash=hash, + ) @gen.coroutine def _cancel(self, futures, force=False): @@ -1898,8 +2072,7 @@ def cancel(self, futures, asynchronous=None, force=False): force: boolean (False) Cancel this future even if other clients desire it """ - return self.sync(self._cancel, futures, asynchronous=asynchronous, - force=force) + return self.sync(self._cancel, futures, asynchronous=asynchronous, force=force) @gen.coroutine def _retry(self, futures): @@ -1926,17 +2099,19 @@ def _publish_dataset(self, *args, **kwargs): def add_coro(name, data): keys = [tokey(f.key) for f in futures_of(data)] - coroutines.append(self.scheduler.publish_put(keys=keys, - name=name, - data=to_serialize(data), - client=self.id)) + coroutines.append( + self.scheduler.publish_put( + keys=keys, name=name, data=to_serialize(data), client=self.id + ) + ) - name = kwargs.pop('name', None) + name = kwargs.pop("name", None) if name: if len(args) == 0: raise ValueError( "If name is provided, expecting call signature like" - " publish_dataset(df, name='ds')") + " publish_dataset(df, name='ds')" + ) # in case this is a singleton, collapse it elif len(args) == 1: args = args[0] @@ -2031,7 +2206,7 @@ def _get_dataset(self, name): raise KeyError("Dataset '%s' not found" % name) with temp_default_client(self): - data = out['data'] + data = out["data"] raise gen.Return(data) def get_dataset(self, name, **kwargs): @@ -2047,15 +2222,14 @@ def get_dataset(self, name, **kwargs): @gen.coroutine def _run_on_scheduler(self, function, *args, **kwargs): - wait = kwargs.pop('wait', True) - response = yield self.scheduler.run_function(function=dumps(function), - args=dumps(args), - kwargs=dumps(kwargs), - wait=wait) - if response['status'] == 'error': + wait = kwargs.pop("wait", True) + response = yield self.scheduler.run_function( + function=dumps(function), args=dumps(args), kwargs=dumps(kwargs), wait=wait + ) + if response["status"] == "error": six.reraise(*clean_exception(**response)) else: - raise gen.Return(response['result']) + raise gen.Return(response["result"]) def run_on_scheduler(self, function, *args, **kwargs): """ Run a function on the scheduler process @@ -2087,25 +2261,29 @@ def run_on_scheduler(self, function, *args, **kwargs): Client.run: Run a function on all workers Client.start_ipython_scheduler: Start an IPython session on scheduler """ - return self.sync(self._run_on_scheduler, function, *args, - **kwargs) + return self.sync(self._run_on_scheduler, function, *args, **kwargs) @gen.coroutine def _run(self, function, *args, **kwargs): - nanny = kwargs.pop('nanny', False) - workers = kwargs.pop('workers', None) - wait = kwargs.pop('wait', True) - responses = yield self.scheduler.broadcast(msg=dict(op='run', - function=dumps(function), - args=dumps(args), - wait=wait, - kwargs=dumps(kwargs)), - workers=workers, nanny=nanny) + nanny = kwargs.pop("nanny", False) + workers = kwargs.pop("workers", None) + wait = kwargs.pop("wait", True) + responses = yield self.scheduler.broadcast( + msg=dict( + op="run", + function=dumps(function), + args=dumps(args), + wait=wait, + kwargs=dumps(kwargs), + ), + workers=workers, + nanny=nanny, + ) results = {} for key, resp in responses.items(): - if resp['status'] == 'OK': - results[key] = resp['result'] - elif resp['status'] == 'error': + if resp["status"] == "OK": + results[key] = resp["result"] + elif resp["status"] == "error": six.reraise(*clean_exception(**resp)) if wait: raise gen.Return(results) @@ -2188,23 +2366,36 @@ def run_coroutine(self, function, *args, **kwargs): Workers on which to run the function. Defaults to all known workers. """ - warnings.warn("This method has been deprecated. " - "Instead use Client.run which detects async functions " - "automatically") + warnings.warn( + "This method has been deprecated. " + "Instead use Client.run which detects async functions " + "automatically" + ) return self.run(function, *args, **kwargs) - def _graph_to_futures(self, dsk, keys, restrictions=None, - loose_restrictions=None, priority=None, - user_priority=0, resources=None, retries=None, - fifo_timeout=0, actors=None): + def _graph_to_futures( + self, + dsk, + keys, + restrictions=None, + loose_restrictions=None, + priority=None, + user_priority=0, + resources=None, + retries=None, + fifo_timeout=0, + actors=None, + ): with self._refcount_lock: if resources: - resources = self._expand_resources(resources, - all_keys=itertools.chain(dsk, keys)) + resources = self._expand_resources( + resources, all_keys=itertools.chain(dsk, keys) + ) if retries: - retries = self._expand_retries(retries, - all_keys=itertools.chain(dsk, keys)) + retries = self._expand_retries( + retries, all_keys=itertools.chain(dsk, keys) + ) if actors is not None and actors is not True and actors is not False: actors = list(self._expand_key(actors)) @@ -2213,8 +2404,9 @@ def _graph_to_futures(self, dsk, keys, restrictions=None, flatkeys = list(map(tokey, keys)) futures = {key: Future(key, self, inform=False) for key in keyset} - values = {k for k, v in dsk.items() if isinstance(v, Future) - and k not in keyset} + values = { + k for k, v in dsk.items() if isinstance(v, Future) and k not in keyset + } if values: dsk = dask.optimization.inline(dsk, keys=values) @@ -2225,8 +2417,9 @@ def _graph_to_futures(self, dsk, keys, restrictions=None, dsk3 = {k: v for k, v in dsk2.items() if k is not v} for future in extra_futures: if future.client is not self: - msg = ("Inputs contain futures that were created by " - "another client.") + msg = ( + "Inputs contain futures that were created by " "another client." + ) raise ValueError(msg) if restrictions: @@ -2236,7 +2429,9 @@ def _graph_to_futures(self, dsk, keys, restrictions=None, if loose_restrictions is not None: loose_restrictions = list(map(tokey, loose_restrictions)) - future_dependencies = {tokey(k): {tokey(f.key) for f in v[1]} for k, v in d.items()} + future_dependencies = { + tokey(k): {tokey(f.key) for f in v[1]} for k, v in d.items() + } for s in future_dependencies.values(): for v in s: @@ -2249,8 +2444,10 @@ def _graph_to_futures(self, dsk, keys, restrictions=None, priority = dask.order.order(dsk, dependencies=dependencies) priority = keymap(tokey, priority) - dependencies = {tokey(k): [tokey(dep) for dep in deps] - for k, deps in dependencies.items()} + dependencies = { + tokey(k): [tokey(dep) for dep in deps] + for k, deps in dependencies.items() + } for k, deps in future_dependencies.items(): if deps: dependencies[k] = list(set(dependencies.get(k, ())) | deps) @@ -2258,25 +2455,41 @@ def _graph_to_futures(self, dsk, keys, restrictions=None, if isinstance(retries, Number) and retries > 0: retries = {k: retries for k in dsk3} - self._send_to_scheduler({'op': 'update-graph', - 'tasks': valmap(dumps_task, dsk3), - 'dependencies': dependencies, - 'keys': list(flatkeys), - 'restrictions': restrictions or {}, - 'loose_restrictions': loose_restrictions, - 'priority': priority, - 'user_priority': user_priority, - 'resources': resources, - 'submitting_task': getattr(thread_state, 'key', None), - 'retries': retries, - 'fifo_timeout': fifo_timeout, - 'actors': actors}) + self._send_to_scheduler( + { + "op": "update-graph", + "tasks": valmap(dumps_task, dsk3), + "dependencies": dependencies, + "keys": list(flatkeys), + "restrictions": restrictions or {}, + "loose_restrictions": loose_restrictions, + "priority": priority, + "user_priority": user_priority, + "resources": resources, + "submitting_task": getattr(thread_state, "key", None), + "retries": retries, + "fifo_timeout": fifo_timeout, + "actors": actors, + } + ) return futures - def get(self, dsk, keys, restrictions=None, loose_restrictions=None, - resources=None, sync=True, asynchronous=None, direct=None, - retries=None, priority=0, fifo_timeout='60s', actors=None, - **kwargs): + def get( + self, + dsk, + keys, + restrictions=None, + loose_restrictions=None, + resources=None, + sync=True, + asynchronous=None, + direct=None, + retries=None, + priority=0, + fifo_timeout="60s", + actors=None, + **kwargs + ): """ Compute dask graph Parameters @@ -2322,19 +2535,18 @@ def get(self, dsk, keys, restrictions=None, loose_restrictions=None, ) packed = pack_data(keys, futures) if sync: - if getattr(thread_state, 'key', False): + if getattr(thread_state, "key", False): try: secede() should_rejoin = True except Exception: should_rejoin = False try: - results = self.gather(packed, asynchronous=asynchronous, - direct=direct) + results = self.gather(packed, asynchronous=asynchronous, direct=direct) finally: for f in futures.values(): f.release() - if getattr(thread_state, 'key', False) and should_rejoin: + if getattr(thread_state, "key", False) and should_rejoin: rejoin() return results return packed @@ -2394,10 +2606,20 @@ def normalize_collection(self, collection): else: return redict_collection(collection, dsk) - def compute(self, collections, sync=False, optimize_graph=True, - workers=None, allow_other_workers=False, resources=None, - retries=0, priority=0, fifo_timeout='60s', actors=None, - **kwargs): + def compute( + self, + collections, + sync=False, + optimize_graph=True, + workers=None, + allow_other_workers=False, + resources=None, + retries=0, + priority=0, + fifo_timeout="60s", + actors=None, + **kwargs + ): """ Compute dask collections on cluster Parameters @@ -2459,16 +2681,19 @@ def compute(self, collections, sync=False, optimize_graph=True, collections = [collections] singleton = True - traverse = kwargs.pop('traverse', True) + traverse = kwargs.pop("traverse", True) if traverse: - collections = tuple(dask.delayed(a) - if isinstance(a, (list, set, tuple, dict, Iterator)) - else a for a in collections) + collections = tuple( + dask.delayed(a) + if isinstance(a, (list, set, tuple, dict, Iterator)) + else a + for a in collections + ) variables = [a for a in collections if dask.is_dask_collection(a)] dsk = self.collections_to_dsk(variables, optimize_graph, **kwargs) - names = ['finalize-%s' % tokenize(v) for v in variables] + names = ["finalize-%s" % tokenize(v) for v in variables] dsk2 = {} for i, (name, v) in enumerate(zip(names, variables)): func, extra_args = v.__dask_postcompute__() @@ -2478,20 +2703,24 @@ def compute(self, collections, sync=False, optimize_graph=True, else: dsk2[name] = (func, keys) + extra_args - restrictions, loose_restrictions = self.get_restrictions(collections, - workers, allow_other_workers) + restrictions, loose_restrictions = self.get_restrictions( + collections, workers, allow_other_workers + ) if not isinstance(priority, Number): - priority = {k: p for c, p in priority.items() - for k in self._expand_key(c)} + priority = {k: p for c, p in priority.items() for k in self._expand_key(c)} - futures_dict = self._graph_to_futures(merge(dsk2, dsk), names, - restrictions, loose_restrictions, - resources=resources, - retries=retries, - user_priority=priority, - fifo_timeout=fifo_timeout, - actors=actors) + futures_dict = self._graph_to_futures( + merge(dsk2, dsk), + names, + restrictions, + loose_restrictions, + resources=resources, + retries=retries, + user_priority=priority, + fifo_timeout=fifo_timeout, + actors=actors, + ) i = 0 futures = [] @@ -2512,9 +2741,19 @@ def compute(self, collections, sync=False, optimize_graph=True, else: return result - def persist(self, collections, optimize_graph=True, workers=None, - allow_other_workers=None, resources=None, retries=None, - priority=0, fifo_timeout='60s', actors=None, **kwargs): + def persist( + self, + collections, + optimize_graph=True, + workers=None, + allow_other_workers=None, + resources=None, + retries=None, + priority=0, + fifo_timeout="60s", + actors=None, + **kwargs + ): """ Persist dask collections on cluster Starts computation of the collection on the cluster in the background. @@ -2571,24 +2810,30 @@ def persist(self, collections, optimize_graph=True, workers=None, names = {k for c in collections for k in flatten(c.__dask_keys__())} - restrictions, loose_restrictions = self.get_restrictions(collections, - workers, allow_other_workers) + restrictions, loose_restrictions = self.get_restrictions( + collections, workers, allow_other_workers + ) if not isinstance(priority, Number): - priority = {k: p for c, p in priority.items() - for k in self._expand_key(c)} + priority = {k: p for c, p in priority.items() for k in self._expand_key(c)} - futures = self._graph_to_futures(dsk, names, restrictions, - loose_restrictions, - resources=resources, - retries=retries, - user_priority=priority, - fifo_timeout=fifo_timeout, - actors=actors) + futures = self._graph_to_futures( + dsk, + names, + restrictions, + loose_restrictions, + resources=resources, + retries=retries, + user_priority=priority, + fifo_timeout=fifo_timeout, + actors=actors, + ) postpersists = [c.__dask_postpersist__() for c in collections] - result = [func({k: futures[k] for k in flatten(c.__dask_keys__())}, *args) - for (func, args), c in zip(postpersists, collections)] + result = [ + func({k: futures[k] for k in flatten(c.__dask_keys__())}, *args) + for (func, args), c in zip(postpersists, collections) + ] if singleton: return first(result) @@ -2604,6 +2849,7 @@ def unzip(dask_worker=None): from distributed.utils import log_errors import zipfile import shutil + with log_errors(): a = os.path.join(dask_worker.worker_dir, name) b = os.path.join(dask_worker.local_dir, name) @@ -2613,7 +2859,7 @@ def unzip(dask_worker=None): with zipfile.ZipFile(b) as f: f.extractall(path=c) - for fn in glob(os.path.join(c, name[:-4], 'bin', '*')): + for fn in glob(os.path.join(c, name[:-4], "bin", "*")): st = os.stat(fn) os.chmod(fn, st.st_mode | 64) # chmod u+x fn @@ -2630,7 +2876,7 @@ def upload_environment(self, name, zipfile): def _restart(self, timeout=no_default): if timeout == no_default: timeout = self._timeout * 2 - self._send_to_scheduler({'op': 'restart', 'timeout': timeout}) + self._send_to_scheduler({"op": "restart", "timeout": timeout}) self._restart_event = Event() try: yield self._restart_event.wait(self.loop.time() + timeout) @@ -2653,29 +2899,28 @@ def restart(self, **kwargs): @gen.coroutine def _upload_file(self, filename, raise_on_error=True): - with open(filename, 'rb') as f: + with open(filename, "rb") as f: data = f.read() _, fn = os.path.split(filename) - d = yield self.scheduler.broadcast(msg={'op': 'upload_file', - 'filename': fn, - 'data': to_serialize(data)}) + d = yield self.scheduler.broadcast( + msg={"op": "upload_file", "filename": fn, "data": to_serialize(data)} + ) - if any(v['status'] == 'error' for v in d.values()): - exceptions = [v['exception'] for v in d.values() - if v['status'] == 'error'] + if any(v["status"] == "error" for v in d.values()): + exceptions = [v["exception"] for v in d.values() if v["status"] == "error"] if raise_on_error: raise exceptions[0] else: raise gen.Return(exceptions[0]) - assert all(len(data) == v['nbytes'] for v in d.values()) + assert all(len(data) == v["nbytes"] for v in d.values()) @gen.coroutine def _upload_large_file(self, local_filename, remote_filename=None): if remote_filename is None: remote_filename = os.path.split(local_filename)[1] - with open(local_filename, 'rb') as f: + with open(local_filename, "rb") as f: data = f.read() [future] = yield self._scatter([data]) @@ -2687,7 +2932,7 @@ def dump_to_file(dask_worker=None): fn = os.path.join(dask_worker.local_dir, remote_filename) else: fn = remote_filename - with open(fn, 'wb') as f: + with open(fn, "wb") as f: f.write(dask_worker.data[key]) return len(dask_worker.data[key]) @@ -2714,8 +2959,9 @@ def upload_file(self, filename, **kwargs): >>> from mylibrary import myfunc # doctest: +SKIP >>> L = c.map(myfunc, seq) # doctest: +SKIP """ - result = self.sync(self._upload_file, filename, - raise_on_error=self.asynchronous, **kwargs) + result = self.sync( + self._upload_file, filename, raise_on_error=self.asynchronous, **kwargs + ) if isinstance(result, Exception): raise result else: @@ -2726,7 +2972,7 @@ def _rebalance(self, futures=None, workers=None): yield _wait(futures) keys = list({tokey(f.key) for f in self.futures_of(futures)}) result = yield self.scheduler.rebalance(keys=keys, workers=workers) - assert result['status'] == 'OK' + assert result["status"] == "OK" def rebalance(self, futures=None, workers=None, **kwargs): """ Rebalance data within network @@ -2753,11 +2999,11 @@ def _replicate(self, futures, n=None, workers=None, branching_factor=2): futures = self.futures_of(futures) yield _wait(futures) keys = {tokey(f.key) for f in futures} - yield self.scheduler.replicate(keys=list(keys), n=n, workers=workers, - branching_factor=branching_factor) + yield self.scheduler.replicate( + keys=list(keys), n=n, workers=workers, branching_factor=branching_factor + ) - def replicate(self, futures, n=None, workers=None, branching_factor=2, - **kwargs): + def replicate(self, futures, n=None, workers=None, branching_factor=2, **kwargs): """ Set replication of futures within network Copy data onto many workers. This helps to broadcast frequently @@ -2793,8 +3039,14 @@ def replicate(self, futures, n=None, workers=None, branching_factor=2, -------- Client.rebalance """ - return self.sync(self._replicate, futures, n=n, workers=workers, - branching_factor=branching_factor, **kwargs) + return self.sync( + self._replicate, + futures, + n=n, + workers=workers, + branching_factor=branching_factor, + **kwargs + ) def ncores(self, workers=None, **kwargs): """ The number of threads/cores available on each worker node @@ -2818,8 +3070,9 @@ def ncores(self, workers=None, **kwargs): Client.who_has Client.has_what """ - if (isinstance(workers, tuple) - and all(isinstance(i, (str, tuple)) for i in workers)): + if isinstance(workers, tuple) and all( + isinstance(i, (str, tuple)) for i in workers + ): workers = list(workers) if workers is not None and not isinstance(workers, (tuple, list, set)): workers = [workers] @@ -2884,8 +3137,9 @@ def has_what(self, workers=None, **kwargs): Client.ncores Client.processing """ - if (isinstance(workers, tuple) - and all(isinstance(i, (str, tuple)) for i in workers)): + if isinstance(workers, tuple) and all( + isinstance(i, (str, tuple)) for i in workers + ): workers = list(workers) if workers is not None and not isinstance(workers, (tuple, list, set)): workers = [workers] @@ -2913,8 +3167,9 @@ def processing(self, workers=None): Client.has_what Client.ncores """ - if (isinstance(workers, tuple) - and all(isinstance(i, (str, tuple)) for i in workers)): + if isinstance(workers, tuple) and all( + isinstance(i, (str, tuple)) for i in workers + ): workers = list(workers) if workers is not None and not isinstance(workers, (tuple, list, set)): workers = [workers] @@ -2948,8 +3203,7 @@ def nbytes(self, keys=None, summary=True, **kwargs): -------- Client.who_has """ - return self.sync(self.scheduler.nbytes, keys=keys, - summary=summary, **kwargs) + return self.sync(self.scheduler.nbytes, keys=keys, summary=summary, **kwargs) def call_stack(self, futures=None, keys=None): """ The actively running call stack of all relevant keys @@ -2979,8 +3233,16 @@ def call_stack(self, futures=None, keys=None): keys += list(map(tokey, {f.key for f in futures})) return self.sync(self.scheduler.call_stack, keys=keys or None) - def profile(self, key=None, start=None, stop=None, workers=None, - merge_workers=True, plot=False, filename=None): + def profile( + self, + key=None, + start=None, + stop=None, + workers=None, + merge_workers=True, + plot=False, + filename=None, + ): """ Collect statistical profiling information about recent work Parameters @@ -3005,32 +3267,54 @@ def profile(self, key=None, start=None, stop=None, workers=None, if isinstance(workers, six.string_types + (Number,)): workers = [workers] - return self.sync(self._profile, key=key, workers=workers, - merge_workers=merge_workers, start=start, stop=stop, - plot=plot, filename=filename) + return self.sync( + self._profile, + key=key, + workers=workers, + merge_workers=merge_workers, + start=start, + stop=stop, + plot=plot, + filename=filename, + ) @gen.coroutine - def _profile(self, key=None, start=None, stop=None, workers=None, - merge_workers=True, plot=False, filename=None): + def _profile( + self, + key=None, + start=None, + stop=None, + workers=None, + merge_workers=True, + plot=False, + filename=None, + ): if isinstance(workers, six.string_types + (Number,)): workers = [workers] - state = yield self.scheduler.profile(key=key, workers=workers, - merge_workers=merge_workers, start=start, stop=stop) + state = yield self.scheduler.profile( + key=key, + workers=workers, + merge_workers=merge_workers, + start=start, + stop=stop, + ) if filename: plot = True if plot: from . import profile + data = profile.plot_data(state) - figure, source = profile.plot_figure(data, sizing_mode='stretch_both') + figure, source = profile.plot_figure(data, sizing_mode="stretch_both") - if plot == 'save' and not filename: - filename = 'dask-profile.html' + if plot == "save" and not filename: + filename = "dask-profile.html" from bokeh.plotting import save - save(figure, title='Dask Profile', filename=filename) + + save(figure, title="Dask Profile", filename=filename) raise gen.Return((state, figure)) else: @@ -3075,11 +3359,11 @@ def write_scheduler_file(self, scheduler_file): >>> client2 = Client(scheduler_file='scheduler.json') # doctest: +SKIP """ if self.scheduler_file: - raise ValueError('Scheduler file already set') + raise ValueError("Scheduler file already set") else: self.scheduler_file = scheduler_file - with open(self.scheduler_file, 'w') as f: + with open(self.scheduler_file, "w") as f: json.dump(self.scheduler_info(), f, indent=2) def get_metadata(self, keys, default=no_default): @@ -3102,8 +3386,7 @@ def get_metadata(self, keys, default=no_default): """ if not isinstance(keys, (list, tuple)): keys = (keys,) - return self.sync(self.scheduler.get_metadata, keys=keys, - default=default) + return self.sync(self.scheduler.get_metadata, keys=keys, default=default) def get_scheduler_logs(self, n=None): """ Get logs from scheduler @@ -3155,8 +3438,12 @@ def retire_workers(self, workers=None, close_workers=True, **kwargs): -------- dask.distributed.Scheduler.retire_workers """ - return self.sync(self.scheduler.retire_workers, workers=workers, - close_workers=close_workers, **kwargs) + return self.sync( + self.scheduler.retire_workers, + workers=workers, + close_workers=close_workers, + **kwargs + ) def set_metadata(self, key, value): """ Set arbitrary metadata in the scheduler @@ -3224,44 +3511,46 @@ def get_versions(self, check=False, packages=[]): """ client = get_versions(packages=packages) try: - scheduler = sync(self.loop, self.scheduler.versions, - packages=packages) + scheduler = sync(self.loop, self.scheduler.versions, packages=packages) except KeyError: scheduler = None except TypeError: # packages keyword not supported scheduler = sync(self.loop, self.scheduler.versions) # this raises - workers = sync(self.loop, self.scheduler.broadcast, - msg={'op': 'versions', 'packages': packages}) - result = {'scheduler': scheduler, 'workers': workers, 'client': client} + workers = sync( + self.loop, + self.scheduler.broadcast, + msg={"op": "versions", "packages": packages}, + ) + result = {"scheduler": scheduler, "workers": workers, "client": client} if check: # we care about the required & optional packages matching def to_packages(d): - L = list(d['packages'].values()) + L = list(d["packages"].values()) return dict(sum(L, type(L[0])())) - client_versions = to_packages(result['client']) - versions = [('scheduler', to_packages(result['scheduler']))] - versions.extend((w, to_packages(d)) - for w, d in sorted(workers.items())) + + client_versions = to_packages(result["client"]) + versions = [("scheduler", to_packages(result["scheduler"]))] + versions.extend((w, to_packages(d)) for w, d in sorted(workers.items())) mismatched = defaultdict(list) for name, vers in versions: for pkg, cv in client_versions.items(): - v = vers.get(pkg, 'MISSING') + v = vers.get(pkg, "MISSING") if cv != v: mismatched[pkg].append((name, v)) if mismatched: errs = [] for pkg, versions in sorted(mismatched.items()): - rows = [('client', client_versions[pkg])] + rows = [("client", client_versions[pkg])] rows.extend(versions) - errs.append("%s\n%s" % (pkg, asciitable(['', 'version'], rows))) + errs.append("%s\n%s" % (pkg, asciitable(["", "version"], rows))) - raise ValueError("Mismatched versions found\n" - "\n" - "%s" % ('\n\n'.join(errs))) + raise ValueError( + "Mismatched versions found\n" "\n" "%s" % ("\n\n".join(errs)) + ) return result @@ -3277,12 +3566,13 @@ def _start_ipython_workers(self, workers): workers = yield self.scheduler.ncores() responses = yield self.scheduler.broadcast( - msg=dict(op='start_ipython'), workers=workers, + msg=dict(op="start_ipython"), workers=workers ) raise gen.Return((workers, responses)) - def start_ipython_workers(self, workers=None, magic_names=False, - qtconsole=False, qtconsole_args=None): + def start_ipython_workers( + self, workers=None, magic_names=False, qtconsole=False, qtconsole_args=None + ): """ Start IPython kernels on workers Parameters @@ -3337,31 +3627,34 @@ def start_ipython_workers(self, workers=None, magic_names=False, (workers, info_dict) = sync(self.loop, self._start_ipython_workers, workers) if magic_names and isinstance(magic_names, six.string_types): - if '*' in magic_names: - magic_names = [magic_names.replace('*', str(i)) - for i in range(len(workers))] + if "*" in magic_names: + magic_names = [ + magic_names.replace("*", str(i)) for i in range(len(workers)) + ] else: magic_names = [magic_names] - if 'IPython' in sys.modules: + if "IPython" in sys.modules: from ._ipython_utils import register_remote_magic + register_remote_magic() if magic_names: from ._ipython_utils import register_worker_magic + for worker, magic_name in zip(workers, magic_names): connection_info = info_dict[worker] register_worker_magic(connection_info, magic_name) if qtconsole: from ._ipython_utils import connect_qtconsole + for worker, connection_info in info_dict.items(): - name = 'dask-' + worker.replace(':', '-').replace('/', '-') - connect_qtconsole(connection_info, name=name, - extra_args=qtconsole_args, - ) + name = "dask-" + worker.replace(":", "-").replace("/", "-") + connect_qtconsole(connection_info, name=name, extra_args=qtconsole_args) return info_dict - def start_ipython_scheduler(self, magic_name='scheduler_if_ipython', - qtconsole=False, qtconsole_args=None): + def start_ipython_scheduler( + self, magic_name="scheduler_if_ipython", qtconsole=False, qtconsole_args=None + ): """ Start IPython kernel on the scheduler Parameters @@ -3397,23 +3690,25 @@ def start_ipython_scheduler(self, magic_name='scheduler_if_ipython', Client.start_ipython_workers: Start IPython on the workers """ info = sync(self.loop, self.scheduler.start_ipython) - if magic_name == 'scheduler_if_ipython': + if magic_name == "scheduler_if_ipython": # default to %scheduler if in IPython, no magic otherwise in_ipython = False - if 'IPython' in sys.modules: + if "IPython" in sys.modules: from IPython import get_ipython + in_ipython = bool(get_ipython()) if in_ipython: - magic_name = 'scheduler' + magic_name = "scheduler" else: magic_name = None if magic_name: from ._ipython_utils import register_worker_magic + register_worker_magic(info, magic_name) if qtconsole: from ._ipython_utils import connect_qtconsole - connect_qtconsole(info, name='dask-scheduler', - extra_args=qtconsole_args,) + + connect_qtconsole(info, name="dask-scheduler", extra_args=qtconsole_args) return info @classmethod @@ -3438,15 +3733,18 @@ def _expand_retries(cls, retries, all_keys): to a {task key: Integral} dictionary. """ if retries and isinstance(retries, dict): - result = {name: value - for key, value in retries.items() - for name in cls._expand_key(key)} + result = { + name: value + for key, value in retries.items() + for name in cls._expand_key(key) + } elif isinstance(retries, Integral): # Each task unit may potentially fail, allow retrying all of them result = {name: retries for name in all_keys} else: - raise TypeError("`retries` should be an integer or dict, got %r" - % (type(retries,))) + raise TypeError( + "`retries` should be an integer or dict, got %r" % (type(retries)) + ) return keymap(tokey, result) def _expand_resources(cls, resources, all_keys): @@ -3459,8 +3757,7 @@ def _expand_resources(cls, resources, all_keys): # such as {'x': {'GPU': 1}, 'y': {'SSD': 4}} indicating # per-key requirements if not isinstance(resources, dict): - raise TypeError("`resources` should be a dict, got %r" - % (type(resources,))) + raise TypeError("`resources` should be a dict, got %r" % (type(resources))) per_key_reqs = {} global_reqs = {} @@ -3474,8 +3771,10 @@ def _expand_resources(cls, resources, all_keys): global_reqs.update((kk, {k: v}) for kk in all_keys) if global_reqs and per_key_reqs: - raise ValueError("cannot have both per-key and all-key requirements " - "in resources dict %r" % (resources,)) + raise ValueError( + "cannot have both per-key and all-key requirements " + "in resources dict %r" % (resources,) + ) return global_reqs or per_key_reqs @classmethod @@ -3491,8 +3790,9 @@ def get_restrictions(cls, collections, workers, allow_other_workers): if dask.is_dask_collection(colls): keys = flatten(colls.__dask_keys__()) else: - keys = list({k for c in flatten(colls) - for k in flatten(c.__dask_keys__())}) + keys = list( + {k for c in flatten(colls) for k in flatten(c.__dask_keys__())} + ) restrictions.update({k: ws for k in keys}) else: restrictions = {} @@ -3500,8 +3800,9 @@ def get_restrictions(cls, collections, workers, allow_other_workers): if allow_other_workers is True: loose_restrictions = list(restrictions) elif allow_other_workers: - loose_restrictions = list({k for c in flatten(allow_other_workers) - for k in c.__dask_keys__()}) + loose_restrictions = list( + {k for c in flatten(allow_other_workers) for k in c.__dask_keys__()} + ) else: loose_restrictions = [] @@ -3511,8 +3812,9 @@ def get_restrictions(cls, collections, workers, allow_other_workers): def collections_to_dsk(collections, *args, **kwargs): return collections_to_dsk(collections, *args, **kwargs) - def get_task_stream(self, start=None, stop=None, count=None, plot=False, - filename='task-stream.html'): + def get_task_stream( + self, start=None, stop=None, count=None, plot=False, filename="task-stream.html" + ): """ Get task stream data from scheduler This collects the data present in the diagnostic "Task Stream" plot on @@ -3572,23 +3874,32 @@ def get_task_stream(self, start=None, stop=None, count=None, plot=False, -------- get_task_stream: a context manager version of this method """ - return self.sync(self._get_task_stream, start=start, stop=stop, - count=count, plot=plot, filename=filename) + return self.sync( + self._get_task_stream, + start=start, + stop=stop, + count=count, + plot=plot, + filename=filename, + ) @gen.coroutine - def _get_task_stream(self, start=None, stop=None, count=None, plot=False, - filename='task-stream.html'): - msgs = yield self.scheduler.get_task_stream(start=start, - stop=stop, count=count) + def _get_task_stream( + self, start=None, stop=None, count=None, plot=False, filename="task-stream.html" + ): + msgs = yield self.scheduler.get_task_stream(start=start, stop=stop, count=count) if plot: from .diagnostics.task_stream import rectangles + rects = rectangles(msgs) from .bokeh.components import task_stream_figure - source, figure = task_stream_figure(sizing_mode='stretch_both') + + source, figure = task_stream_figure(sizing_mode="stretch_both") source.data.update(rects) - if plot == 'save': + if plot == "save": from bokeh.plotting import save - save(figure, title='Dask Task Stream', filename=filename) + + save(figure, title="Dask Task Stream", filename=filename) raise gen.Return((msgs, figure)) else: raise gen.Return(msgs) @@ -3598,9 +3909,9 @@ def _register_worker_callbacks(self, setup=None): responses = yield self.scheduler.register_worker_callbacks(setup=dumps(setup)) results = {} for key, resp in responses.items(): - if resp['status'] == 'OK': - results[key] = resp['result'] - elif resp['status'] == 'error': + if resp["status"] == "OK": + results[key] = resp["result"] + elif resp["status"] == "error": six.reraise(*clean_exception(**resp)) raise gen.Return(results) @@ -3637,35 +3948,39 @@ def CompatibleExecutor(*args, **kwargs): raise Exception("This has been moved to the Client.get_executor() method") -ALL_COMPLETED = 'ALL_COMPLETED' -FIRST_COMPLETED = 'FIRST_COMPLETED' +ALL_COMPLETED = "ALL_COMPLETED" +FIRST_COMPLETED = "FIRST_COMPLETED" @gen.coroutine def _wait(fs, timeout=None, return_when=ALL_COMPLETED): if timeout is not None and not isinstance(timeout, Number): - raise TypeError("timeout= keyword received a non-numeric value.\n" - "Beware that wait expects a list of values\n" - " Bad: wait(x, y, z)\n" - " Good: wait([x, y, z])") + raise TypeError( + "timeout= keyword received a non-numeric value.\n" + "Beware that wait expects a list of values\n" + " Bad: wait(x, y, z)\n" + " Good: wait([x, y, z])" + ) fs = futures_of(fs) if return_when == ALL_COMPLETED: wait_for = All elif return_when == FIRST_COMPLETED: wait_for = Any else: - raise NotImplementedError("Only return_when='ALL_COMPLETED' and 'FIRST_COMPLETED' are " - "supported") + raise NotImplementedError( + "Only return_when='ALL_COMPLETED' and 'FIRST_COMPLETED' are " "supported" + ) future = wait_for({f._state.wait() for f in fs}) if timeout is not None: future = gen.with_timeout(timedelta(seconds=timeout), future) yield future - done, not_done = ({fu for fu in fs if fu.status != 'pending'}, - {fu for fu in fs if fu.status == 'pending'}) - cancelled = [f.key for f in done - if f.status == 'cancelled'] + done, not_done = ( + {fu for fu in fs if fu.status != "pending"}, + {fu for fu in fs if fu.status == "pending"}, + ) + cancelled = [f.key for f in done if f.status == "cancelled"] if cancelled: raise CancelledError(cancelled) @@ -3854,7 +4169,7 @@ def _get_and_raise(self): res = self.queue.get() if self.with_results: future, result = res - if self.raise_errors and future.status == 'error': + if self.raise_errors and future.status == "error": six.reraise(*result) return res @@ -3945,19 +4260,22 @@ def default_client(c=None): if c: return c else: - raise ValueError("No clients found\n" - "Start an client and point it to the scheduler address\n" - " from distributed import Client\n" - " client = Client('ip-addr-of-scheduler:8786')\n") + raise ValueError( + "No clients found\n" + "Start an client and point it to the scheduler address\n" + " from distributed import Client\n" + " client = Client('ip-addr-of-scheduler:8786')\n" + ) def ensure_default_get(client): - dask.config.set(scheduler='dask.distributed') + dask.config.set(scheduler="dask.distributed") _set_global_client(client) def redict_collection(c, dsk): from dask.delayed import Delayed + if isinstance(c, Delayed): return Delayed(c.key, dsk) else: @@ -4016,9 +4334,13 @@ def fire_and_forget(obj): """ futures = futures_of(obj) for future in futures: - future.client._send_to_scheduler({'op': 'client-desires-keys', - 'keys': [tokey(future.key)], - 'client': 'fire-and-forget'}) + future.client._send_to_scheduler( + { + "op": "client-desires-keys", + "keys": [tokey(future.key)], + "client": "fire-and-forget", + } + ) class get_task_stream(object): @@ -4069,7 +4391,8 @@ class get_task_stream(object): -------- Client.get_task_stream: Function version of this context manager """ - def __init__(self, client=None, plot=False, filename='task-stream.html'): + + def __init__(self, client=None, plot=False, filename="task-stream.html"): self.data = [] self._plot = plot self._filename = filename @@ -4082,8 +4405,9 @@ def __enter__(self): return self def __exit__(self, typ, value, traceback): - L = self.client.get_task_stream(start=self.start, plot=self._plot, - filename=self._filename) + L = self.client.get_task_stream( + start=self.start, plot=self._plot, filename=self._filename + ) if self._plot: L, self.figure = L self.data.extend(L) @@ -4094,8 +4418,9 @@ def __aenter__(self): @gen.coroutine def __aexit__(self, typ, value, traceback): - L = yield self.client.get_task_stream(start=self.start, plot=self._plot, - filename=self._filename) + L = yield self.client.get_task_stream( + start=self.start, plot=self._plot, filename=self._filename + ) if self._plot: L, self.figure = L self.data.extend(L) diff --git a/distributed/comm/__init__.py b/distributed/comm/__init__.py index dfda0459a54..0f7c701847d 100644 --- a/distributed/comm/__init__.py +++ b/distributed/comm/__init__.py @@ -1,11 +1,16 @@ from __future__ import print_function, division, absolute_import -from .addressing import (parse_address, unparse_address, - normalize_address, parse_host_port, - unparse_host_port, resolve_address, - get_address_host_port, get_address_host, - get_local_address_for, - ) +from .addressing import ( + parse_address, + unparse_address, + normalize_address, + parse_host_port, + unparse_host_port, + resolve_address, + get_address_host_port, + get_address_host, + get_local_address_for, +) from .core import connect, listen, Comm, CommClosedError diff --git a/distributed/comm/addressing.py b/distributed/comm/addressing.py index 8ff401475b9..20ddb2c863f 100644 --- a/distributed/comm/addressing.py +++ b/distributed/comm/addressing.py @@ -7,7 +7,7 @@ from . import registry -DEFAULT_SCHEME = dask.config.get('distributed.comm.default-scheme') +DEFAULT_SCHEME = dask.config.get("distributed.comm.default-scheme") def parse_address(addr, strict=False): @@ -21,11 +21,13 @@ def parse_address(addr, strict=False): """ if not isinstance(addr, six.string_types): raise TypeError("expected str, got %r" % addr.__class__.__name__) - scheme, sep, loc = addr.rpartition('://') + scheme, sep, loc = addr.rpartition("://") if strict and not sep: - msg = ("Invalid url scheme. " - "Must include protocol like tcp://localhost:8000. " - "Got %s" % addr) + msg = ( + "Invalid url scheme. " + "Must include protocol like tcp://localhost:8000. " + "Got %s" % addr + ) raise ValueError(msg) if not sep: scheme = DEFAULT_SCHEME @@ -39,7 +41,7 @@ def unparse_address(scheme, loc): >>> unparse_address('tcp', '127.0.0.1') 'tcp://127.0.0.1' """ - return '%s://%s' % (scheme, loc) + return "%s://%s" % (scheme, loc) def normalize_address(addr): @@ -69,24 +71,24 @@ def _default(): raise ValueError("missing port number in address %r" % (address,)) return default_port - if address.startswith('['): + if address.startswith("["): # IPv6 notation: '[addr]:port' or '[addr]'. # The address may contain multiple colons. - host, sep, tail = address[1:].partition(']') + host, sep, tail = address[1:].partition("]") if not sep: _fail() if not tail: port = _default() else: - if not tail.startswith(':'): + if not tail.startswith(":"): _fail() port = tail[1:] else: # Generic notation: 'addr:port' or 'addr'. - host, sep, port = address.partition(':') + host, sep, port = address.partition(":") if not sep: port = _default() - elif ':' in host: + elif ":" in host: _fail() return host, int(port) @@ -96,10 +98,10 @@ def unparse_host_port(host, port=None): """ Undo parse_host_port(). """ - if ':' in host and not host.startswith('['): - host = '[%s]' % host + if ":" in host and not host.startswith("["): + host = "[%s]" % host if port: - return '%s:%s' % (host, port) + return "%s:%s" % (host, port) else: return host @@ -119,8 +121,9 @@ def get_address_host_port(addr, strict=False): try: return backend.get_address_host_port(loc) except NotImplementedError: - raise ValueError("don't know how to extract host and port " - "for address %r" % (addr,)) + raise ValueError( + "don't know how to extract host and port " "for address %r" % (addr,) + ) def get_address_host(addr): diff --git a/distributed/comm/core.py b/distributed/comm/core.py index a7aaf7217db..b66be0b6dc4 100644 --- a/distributed/comm/core.py +++ b/distributed/comm/core.py @@ -116,12 +116,14 @@ def __repr__(self): if self.closed(): return "" % (clsname,) else: - return ("<%s local=%s remote=%s>" - % (clsname, self.local_address, self.peer_address)) + return "<%s local=%s remote=%s>" % ( + clsname, + self.local_address, + self.peer_address, + ) class Listener(with_metaclass(ABCMeta)): - @abstractmethod def start(self): """ @@ -158,7 +160,6 @@ def __exit__(self, *exc): class Connector(with_metaclass(ABCMeta)): - @abstractmethod def connect(self, address, deserialize=True): """ @@ -177,8 +178,8 @@ def connect(addr, timeout=None, deserialize=True, connection_args=None): retried until the *timeout* is expired. """ if timeout is None: - timeout = dask.config.get('distributed.comm.timeouts.connect') - timeout = parse_timedelta(timeout, default='seconds') + timeout = dask.config.get("distributed.comm.timeouts.connect") + timeout = parse_timedelta(timeout, default="seconds") scheme, loc = parse_address(addr) backend = registry.get_backend(scheme) @@ -190,18 +191,24 @@ def connect(addr, timeout=None, deserialize=True, connection_args=None): def _raise(error): error = error or "connect() didn't finish in time" - msg = ("Timed out trying to connect to %r after %s s: %s" - % (addr, timeout, error)) + msg = "Timed out trying to connect to %r after %s s: %s" % ( + addr, + timeout, + error, + ) raise IOError(msg) # This starts a thread while True: try: - future = connector.connect(loc, deserialize=deserialize, - **(connection_args or {})) - comm = yield gen.with_timeout(timedelta(seconds=deadline - time()), - future, - quiet_exceptions=EnvironmentError) + future = connector.connect( + loc, deserialize=deserialize, **(connection_args or {}) + ) + comm = yield gen.with_timeout( + timedelta(seconds=deadline - time()), + future, + quiet_exceptions=EnvironmentError, + ) except FatalCommClosedError: raise except EnvironmentError as e: @@ -231,13 +238,14 @@ def listen(addr, handle_comm, deserialize=True, connection_args=None): try: scheme, loc = parse_address(addr, strict=True) except ValueError: - if connection_args and connection_args.get('ssl_context'): - addr = 'tls://' + addr + if connection_args and connection_args.get("ssl_context"): + addr = "tls://" + addr else: - addr = 'tcp://' + addr + addr = "tcp://" + addr scheme, loc = parse_address(addr, strict=True) backend = registry.get_backend(scheme) - return backend.get_listener(loc, handle_comm, deserialize, - **(connection_args or {})) + return backend.get_listener( + loc, handle_comm, deserialize, **(connection_args or {}) + ) diff --git a/distributed/comm/inproc.py b/distributed/comm/inproc.py index f5d9adf7de2..8721a3df8ac 100644 --- a/distributed/comm/inproc.py +++ b/distributed/comm/inproc.py @@ -21,9 +21,9 @@ logger = logging.getLogger(__name__) -ConnectionRequest = namedtuple('ConnectionRequest', - ('c2s_q', 's2c_q', 'c_loop', 'c_addr', - 'conn_event')) +ConnectionRequest = namedtuple( + "ConnectionRequest", ("c2s_q", "s2c_q", "c_loop", "c_addr", "conn_event") +) class Manager(object): @@ -62,10 +62,12 @@ def validate_address(self, addr): """ Validate the address' IP and pid. """ - ip, pid, suffix = addr.split('/') + ip, pid, suffix = addr.split("/") if ip != self.ip or int(pid) != os.getpid(): - raise ValueError("inproc address %r does not match host (%r) or pid (%r)" - % (addr, self.ip, os.getpid())) + raise ValueError( + "inproc address %r does not match host (%r) or pid (%r)" + % (addr, self.ip, os.getpid()) + ) global_manager = Manager() @@ -75,7 +77,7 @@ def new_address(): """ Generate a new address. """ - return 'inproc://' + global_manager.new_address() + return "inproc://" + global_manager.new_address() class QueueEmpty(Exception): @@ -144,10 +146,12 @@ class InProc(Comm): Reminder: a Comm must always be used from a single thread. Its peer Comm can be running in any thread. """ + _initialized = False - def __init__(self, local_addr, peer_addr, read_q, write_q, write_loop, - deserialize=True): + def __init__( + self, local_addr, peer_addr, read_q, write_q, write_loop, deserialize=True + ): self._local_addr = local_addr self._peer_addr = peer_addr self.deserialize = deserialize @@ -161,8 +165,7 @@ def __init__(self, local_addr, peer_addr, read_q, write_q, write_loop, self._initialized = True def _get_finalizer(self): - def finalize(write_q=self._write_q, write_loop=self._write_loop, - r=repr(self)): + def finalize(write_q=self._write_q, write_loop=self._write_loop, r=repr(self)): logger.warning("Closing dangling queue in %s" % (r,)) write_loop.add_callback(write_q.put_nowait, _EOF) @@ -177,7 +180,7 @@ def peer_address(self): return self._peer_addr @gen.coroutine - def read(self, deserializers='ignored'): + def read(self, deserializers="ignored"): if self._closed: raise CommClosedError @@ -233,7 +236,7 @@ def closed(self): class InProcListener(Listener): - prefix = 'inproc' + prefix = "inproc" def __init__(self, address, comm_handler, deserialize=True): self.manager = global_manager @@ -248,12 +251,14 @@ def _listen(self): conn_req = yield self.listen_q.get() if conn_req is None: break - comm = InProc(local_addr='inproc://' + self.address, - peer_addr='inproc://' + conn_req.c_addr, - read_q=conn_req.c2s_q, - write_q=conn_req.s2c_q, - write_loop=conn_req.c_loop, - deserialize=self.deserialize) + comm = InProc( + local_addr="inproc://" + self.address, + peer_addr="inproc://" + conn_req.c_addr, + read_q=conn_req.c2s_q, + write_q=conn_req.s2c_q, + write_loop=conn_req.c_loop, + deserialize=self.deserialize, + ) # Notify connector conn_req.c_loop.add_callback(conn_req.conn_event.set) self.comm_handler(comm) @@ -272,15 +277,14 @@ def stop(self): @property def listen_address(self): - return 'inproc://' + self.address + return "inproc://" + self.address @property def contact_address(self): - return 'inproc://' + self.address + return "inproc://" + self.address class InProcConnector(Connector): - def __init__(self, manager): self.manager = manager @@ -290,24 +294,27 @@ def connect(self, address, deserialize=True, **connection_args): if listener is None: raise IOError("no endpoint for inproc address %r" % (address,)) - conn_req = ConnectionRequest(c2s_q=Queue(), - s2c_q=Queue(), - c_loop=IOLoop.current(), - c_addr=self.manager.new_address(), - conn_event=locks.Event(), - ) + conn_req = ConnectionRequest( + c2s_q=Queue(), + s2c_q=Queue(), + c_loop=IOLoop.current(), + c_addr=self.manager.new_address(), + conn_event=locks.Event(), + ) listener.connect_threadsafe(conn_req) # Wait for connection acknowledgement # (do not pretend we're connected if the other comm never gets # created, for example if the listener was stopped in the meantime) yield conn_req.conn_event.wait() - comm = InProc(local_addr='inproc://' + conn_req.c_addr, - peer_addr='inproc://' + address, - read_q=conn_req.s2c_q, - write_q=conn_req.c2s_q, - write_loop=listener.loop, - deserialize=deserialize) + comm = InProc( + local_addr="inproc://" + conn_req.c_addr, + peer_addr="inproc://" + address, + read_q=conn_req.s2c_q, + write_q=conn_req.c2s_q, + write_loop=listener.loop, + deserialize=deserialize, + ) raise gen.Return(comm) @@ -336,4 +343,4 @@ def get_local_address_for(self, loc): return self.manager.new_address() -backends['inproc'] = InProcBackend() +backends["inproc"] = InProcBackend() diff --git a/distributed/comm/registry.py b/distributed/comm/registry.py index a47b0f7435d..a646b4d71b9 100644 --- a/distributed/comm/registry.py +++ b/distributed/comm/registry.py @@ -67,6 +67,7 @@ def get_backend(scheme): """ backend = backends.get(scheme) if backend is None: - raise ValueError("unknown address scheme %r (known schemes: %s)" - % (scheme, sorted(backends))) + raise ValueError( + "unknown address scheme %r (known schemes: %s)" % (scheme, sorted(backends)) + ) return backend diff --git a/distributed/comm/tcp.py b/distributed/comm/tcp.py index f10af0bc167..6d90a7bc9c7 100644 --- a/distributed/comm/tcp.py +++ b/distributed/comm/tcp.py @@ -20,14 +20,20 @@ from ..compatibility import finalize, PY3 from ..threadpoolexecutor import ThreadPoolExecutor -from ..utils import (ensure_bytes, ensure_ip, get_ip, get_ipv6, nbytes, - parse_timedelta, shutting_down) +from ..utils import ( + ensure_bytes, + ensure_ip, + get_ip, + get_ipv6, + nbytes, + parse_timedelta, + shutting_down, +) from .registry import Backend, backends from .addressing import parse_host_port, unparse_host_port from .core import Comm, Connector, Listener, CommClosedError, FatalCommClosedError -from .utils import (to_frames, from_frames, - get_tcp_server_address, ensure_concrete_host,) +from .utils import to_frames, from_frames, get_tcp_server_address, ensure_concrete_host logger = logging.getLogger(__name__) @@ -36,6 +42,7 @@ def get_total_physical_memory(): try: import psutil + return psutil.virtual_memory().total / 2 except ImportError: return 2e9 @@ -51,8 +58,8 @@ def set_tcp_timeout(stream): if stream.closed(): return - timeout = dask.config.get('distributed.comm.timeouts.tcp') - timeout = int(parse_timedelta(timeout, default='seconds')) + timeout = dask.config.get("distributed.comm.timeouts.tcp") + timeout = int(parse_timedelta(timeout, default="seconds")) sock = stream.socket @@ -68,8 +75,7 @@ def set_tcp_timeout(stream): try: if sys.platform.startswith("win"): - logger.debug("Setting TCP keepalive: idle=%d, interval=%d", - idle, interval) + logger.debug("Setting TCP keepalive: idle=%d, interval=%d", idle, interval) sock.ioctl(socket.SIO_KEEPALIVE_VALS, (1, idle * 1000, interval * 1000)) else: sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) @@ -86,15 +92,18 @@ def set_tcp_timeout(stream): TCP_KEEPIDLE = None if TCP_KEEPIDLE is not None: - logger.debug("Setting TCP keepalive: nprobes=%d, idle=%d, interval=%d", - nprobes, idle, interval) + logger.debug( + "Setting TCP keepalive: nprobes=%d, idle=%d, interval=%d", + nprobes, + idle, + interval, + ) sock.setsockopt(socket.SOL_TCP, TCP_KEEPCNT, nprobes) sock.setsockopt(socket.SOL_TCP, TCP_KEEPIDLE, idle) sock.setsockopt(socket.SOL_TCP, TCP_KEEPINTVL, interval) if sys.platform.startswith("linux"): - logger.debug("Setting TCP user timeout: %d ms", - timeout * 1000) + logger.debug("Setting TCP user timeout: %d ms", timeout * 1000) TCP_USER_TIMEOUT = 18 # since Linux 2.6.37 sock.setsockopt(socket.SOL_TCP, TCP_USER_TIMEOUT, timeout * 1000) except EnvironmentError as e: @@ -123,8 +132,10 @@ def convert_stream_closed_error(obj, exc): # The stream was closed because of an underlying OS error exc = exc.real_error if ssl and isinstance(exc, ssl.SSLError): - if 'UNKNOWN_CA' in exc.reason: - raise FatalCommClosedError("in %s: %s: %s" % (obj, exc.__class__.__name__, exc)) + if "UNKNOWN_CA" in exc.reason: + raise FatalCommClosedError( + "in %s: %s: %s" % (obj, exc.__class__.__name__, exc) + ) raise CommClosedError("in %s: %s: %s" % (obj, exc.__class__.__name__, exc)) else: raise CommClosedError("in %s: %s" % (obj, exc)) @@ -134,6 +145,7 @@ class TCP(Comm): """ An established communication based on an underlying Tornado IOStream. """ + _iostream_allows_memoryview = tornado.version_info >= (4, 5) # IOStream.read_into() currently proposed in # https://github.com/tornadoweb/tornado/pull/2193 @@ -179,9 +191,9 @@ def read(self, deserializers=None): try: n_frames = yield stream.read_bytes(8) - n_frames = struct.unpack('Q', n_frames)[0] + n_frames = struct.unpack("Q", n_frames)[0] lengths = yield stream.read_bytes(8 * n_frames) - lengths = struct.unpack('Q' * n_frames, lengths) + lengths = struct.unpack("Q" * n_frames, lengths) frames = [] for length in lengths: @@ -193,7 +205,7 @@ def read(self, deserializers=None): else: frame = yield stream.read_bytes(length) else: - frame = b'' + frame = b"" frames.append(frame) except StreamClosedError as e: self.stream = None @@ -201,9 +213,9 @@ def read(self, deserializers=None): convert_stream_closed_error(self, e) else: try: - msg = yield from_frames(frames, - deserialize=self.deserialize, - deserializers=deserializers) + msg = yield from_frames( + frames, deserialize=self.deserialize, deserializers=deserializers + ) except EOFError: # Frames possibly garbled or truncated by communication error self.abort() @@ -211,27 +223,29 @@ def read(self, deserializers=None): raise gen.Return(msg) @gen.coroutine - def write(self, msg, serializers=None, on_error='message'): + def write(self, msg, serializers=None, on_error="message"): stream = self.stream bytes_since_last_yield = 0 if stream is None: raise CommClosedError - frames = yield to_frames(msg, - serializers=serializers, - on_error=on_error, - context={'sender': self._local_addr, - 'recipient': self._peer_addr}) + frames = yield to_frames( + msg, + serializers=serializers, + on_error=on_error, + context={"sender": self._local_addr, "recipient": self._peer_addr}, + ) try: lengths = [nbytes(frame) for frame in frames] - length_bytes = ([struct.pack('Q', len(frames))] + - [struct.pack('Q', x) for x in lengths]) - if PY3 and sum(lengths) < 2**17: # 128kiB - b = b''.join(length_bytes + frames) # small enough, send in one go + length_bytes = [struct.pack("Q", len(frames))] + [ + struct.pack("Q", x) for x in lengths + ] + if PY3 and sum(lengths) < 2 ** 17: # 128kiB + b = b"".join(length_bytes + frames) # small enough, send in one go stream.write(b) else: - stream.write(b''.join(length_bytes)) # avoid large memcpy, send in many + stream.write(b"".join(length_bytes)) # avoid large memcpy, send in many for frame in frames: # Can't wait for the write() Future as it may be lost @@ -262,7 +276,7 @@ def close(self): try: # Flush the stream's write buffer by waiting for a last write. if stream.writing(): - yield stream.write(b'') + yield stream.write(b"") stream.socket.shutdown(socket.SHUT_RDWR) except EnvironmentError: pass @@ -293,37 +307,42 @@ def _read_extra(self): TCP._read_extra(self) sock = self.stream.socket if sock is not None: - self._extra.update(peercert=sock.getpeercert(), - cipher=sock.cipher()) - cipher, proto, bits = self._extra['cipher'] - logger.debug("TLS connection with %r: protocol=%s, cipher=%s, bits=%d", - self._peer_addr, proto, cipher, bits) + self._extra.update(peercert=sock.getpeercert(), cipher=sock.cipher()) + cipher, proto, bits = self._extra["cipher"] + logger.debug( + "TLS connection with %r: protocol=%s, cipher=%s, bits=%d", + self._peer_addr, + proto, + cipher, + bits, + ) def _expect_tls_context(connection_args): - ctx = connection_args.get('ssl_context') + ctx = connection_args.get("ssl_context") if not isinstance(ctx, ssl.SSLContext): - raise TypeError("TLS expects a `ssl_context` argument of type " - "ssl.SSLContext (perhaps check your TLS configuration?)" - " Instead got %s" % str(ctx)) + raise TypeError( + "TLS expects a `ssl_context` argument of type " + "ssl.SSLContext (perhaps check your TLS configuration?)" + " Instead got %s" % str(ctx) + ) return ctx class RequireEncryptionMixin(object): - def _check_encryption(self, address, connection_args): - if not self.encrypted and connection_args.get('require_encryption'): + if not self.encrypted and connection_args.get("require_encryption"): # XXX Should we have a dedicated SecurityError class? - raise RuntimeError("encryption required by Dask configuration, " - "refusing communication from/to %r" - % (self.prefix + address,)) + raise RuntimeError( + "encryption required by Dask configuration, " + "refusing communication from/to %r" % (self.prefix + address,) + ) class BaseTCPConnector(Connector, RequireEncryptionMixin): if PY3: # see github PR #2403 discussion for more info _executor = ThreadPoolExecutor(2, thread_name_prefix="TCP-Executor") - _resolver = netutil.ExecutorResolver(close_executor=False, - executor=_executor) + _resolver = netutil.ExecutorResolver(close_executor=False, executor=_executor) else: _resolver = None client = TCPClient(resolver=_resolver) @@ -335,9 +354,9 @@ def connect(self, address, deserialize=True, **connection_args): kwargs = self._get_connect_args(**connection_args) try: - stream = yield BaseTCPConnector.client.connect(ip, port, - max_buffer_size=MAX_BUFFER_SIZE, - **kwargs) + stream = yield BaseTCPConnector.client.connect( + ip, port, max_buffer_size=MAX_BUFFER_SIZE, **kwargs + ) # Under certain circumstances tornado will have a closed connnection with an error and not raise # a StreamClosedError. @@ -351,14 +370,13 @@ def connect(self, address, deserialize=True, **connection_args): convert_stream_closed_error(self, e) local_address = self.prefix + get_stream_address(stream) - raise gen.Return(self.comm_class(stream, - local_address, - self.prefix + address, - deserialize)) + raise gen.Return( + self.comm_class(stream, local_address, self.prefix + address, deserialize) + ) class TCPConnector(BaseTCPConnector): - prefix = 'tcp://' + prefix = "tcp://" comm_class = TCP encrypted = False @@ -367,19 +385,19 @@ def _get_connect_args(self, **connection_args): class TLSConnector(BaseTCPConnector): - prefix = 'tls://' + prefix = "tls://" comm_class = TLS encrypted = True def _get_connect_args(self, **connection_args): ctx = _expect_tls_context(connection_args) - return {'ssl_options': ctx} + return {"ssl_options": ctx} class BaseTCPListener(Listener, RequireEncryptionMixin): - - def __init__(self, address, comm_handler, deserialize=True, - default_port=0, **connection_args): + def __init__( + self, address, comm_handler, deserialize=True, default_port=0, **connection_args + ): self._check_encryption(address, connection_args) self.ip, self.port = parse_host_port(address, default_port) self.comm_handler = comm_handler @@ -389,18 +407,18 @@ def __init__(self, address, comm_handler, deserialize=True, self.bound_address = None def start(self): - self.tcp_server = TCPServer(max_buffer_size=MAX_BUFFER_SIZE, - **self.server_args) + self.tcp_server = TCPServer(max_buffer_size=MAX_BUFFER_SIZE, **self.server_args) self.tcp_server.handle_stream = self._handle_stream - backlog = int(dask.config.get('distributed.comm.socket-backlog')) + backlog = int(dask.config.get("distributed.comm.socket-backlog")) for i in range(5): try: # When shuffling data between workers, there can # really be O(cluster size) connection requests # on a single worker socket, make sure the backlog # is large enough not to lose any. - sockets = netutil.bind_sockets(self.port, address=self.ip, - backlog=backlog) + sockets = netutil.bind_sockets( + self.port, address=self.ip, backlog=backlog + ) except EnvironmentError as e: # EADDRINUSE can happen sporadically when trying to bind # to an ephemeral port @@ -429,8 +447,7 @@ def _handle_stream(self, stream, address): if stream is None: # Preparation failed return - logger.debug("Incoming connection from %r to %r", - address, self.contact_address) + logger.debug("Incoming connection from %r to %r", address, self.contact_address) local_address = self.prefix + get_stream_address(stream) comm = self.comm_class(stream, local_address, address, self.deserialize) yield self.comm_handler(comm) @@ -464,7 +481,7 @@ def contact_address(self): class TCPListener(BaseTCPListener): - prefix = 'tcp://' + prefix = "tcp://" comm_class = TCP encrypted = False @@ -477,13 +494,13 @@ def _prepare_stream(self, stream, address): class TLSListener(BaseTCPListener): - prefix = 'tls://' + prefix = "tls://" comm_class = TLS encrypted = True def _get_server_args(self, **connection_args): ctx = _expect_tls_context(connection_args) - return {'ssl_options': ctx} + return {"ssl_options": ctx} @gen.coroutine def _prepare_stream(self, stream, address): @@ -491,9 +508,12 @@ def _prepare_stream(self, stream, address): yield stream.wait_for_handshake() except EnvironmentError as e: # The handshake went wrong, log and ignore - logger.warning("Listener on %r: TLS handshake failed with remote %r: %s", - self.listen_address, address, - getattr(e, "real_error", None) or e) + logger.warning( + "Listener on %r: TLS handshake failed with remote %r: %s", + self.listen_address, + address, + getattr(e, "real_error", None) or e, + ) else: raise gen.Return(stream) @@ -523,7 +543,7 @@ def resolve_address(self, loc): def get_local_address_for(self, loc): host, port = parse_host_port(loc) host = ensure_ip(host) - if ':' in host: + if ":" in host: local_host = get_ipv6(host) else: local_host = get_ip(host) @@ -540,5 +560,5 @@ class TLSBackend(BaseTCPBackend): _listener_class = TLSListener -backends['tcp'] = TCPBackend() -backends['tls'] = TLSBackend() +backends["tcp"] = TCPBackend() +backends["tls"] = TLSBackend() diff --git a/distributed/comm/tests/test_comms.py b/distributed/comm/tests/test_comms.py index e4aee5805db..0e8782718a0 100644 --- a/distributed/comm/tests/test_comms.py +++ b/distributed/comm/tests/test_comms.py @@ -14,82 +14,96 @@ from distributed.compatibility import PY3 from distributed.metrics import time from distributed.utils import get_ip, get_ipv6 -from distributed.utils_test import (gen_test, requires_ipv6, has_ipv6, - get_cert, get_server_ssl_context, - get_client_ssl_context) +from distributed.utils_test import ( + gen_test, + requires_ipv6, + has_ipv6, + get_cert, + get_server_ssl_context, + get_client_ssl_context, +) from distributed.utils_test import loop # noqa: F401 -from distributed.protocol import (to_serialize, Serialized, serialize, - deserialize) - -from distributed.comm import (tcp, inproc, connect, listen, CommClosedError, - parse_address, parse_host_port, - unparse_host_port, resolve_address, - get_address_host, get_local_address_for) +from distributed.protocol import to_serialize, Serialized, serialize, deserialize + +from distributed.comm import ( + tcp, + inproc, + connect, + listen, + CommClosedError, + parse_address, + parse_host_port, + unparse_host_port, + resolve_address, + get_address_host, + get_local_address_for, +) EXTERNAL_IP4 = get_ip() if has_ipv6(): with warnings.catch_warnings(record=True): - warnings.simplefilter('always') + warnings.simplefilter("always") EXTERNAL_IP6 = get_ipv6() -ca_file = get_cert('tls-ca-cert.pem') +ca_file = get_cert("tls-ca-cert.pem") # The Subject field of our test certs cert_subject = ( - (('countryName', 'XY'),), - (('localityName', 'Dask-distributed'),), - (('organizationName', 'Dask'),), - (('commonName', 'localhost'),) + (("countryName", "XY"),), + (("localityName", "Dask-distributed"),), + (("organizationName", "Dask"),), + (("commonName", "localhost"),), ) def check_tls_extra(info): assert isinstance(info, dict) - assert info['peercert']['subject'] == cert_subject - assert 'cipher' in info - cipher_name, proto_name, secret_bits = info['cipher'] + assert info["peercert"]["subject"] == cert_subject + assert "cipher" in info + cipher_name, proto_name, secret_bits = info["cipher"] # Most likely - assert 'AES' in cipher_name - assert 'TLS' in proto_name + assert "AES" in cipher_name + assert "TLS" in proto_name assert secret_bits >= 128 -tls_kwargs = dict(listen_args={'ssl_context': get_server_ssl_context()}, - connect_args={'ssl_context': get_client_ssl_context()}) +tls_kwargs = dict( + listen_args={"ssl_context": get_server_ssl_context()}, + connect_args={"ssl_context": get_client_ssl_context()}, +) @gen.coroutine -def get_comm_pair(listen_addr, listen_args=None, connect_args=None, - **kwargs): +def get_comm_pair(listen_addr, listen_args=None, connect_args=None, **kwargs): q = queues.Queue() def handle_comm(comm): q.put(comm) - listener = listen(listen_addr, handle_comm, - connection_args=listen_args, **kwargs) + listener = listen(listen_addr, handle_comm, connection_args=listen_args, **kwargs) listener.start() - comm = yield connect(listener.contact_address, - connection_args=connect_args, **kwargs) + comm = yield connect( + listener.contact_address, connection_args=connect_args, **kwargs + ) serv_comm = yield q.get() raise gen.Return((comm, serv_comm)) def get_tcp_comm_pair(**kwargs): - return get_comm_pair('tcp://', **kwargs) + return get_comm_pair("tcp://", **kwargs) def get_tls_comm_pair(**kwargs): kwargs.update(tls_kwargs) - return get_comm_pair('tls://', **kwargs) + return get_comm_pair("tls://", **kwargs) def get_inproc_comm_pair(**kwargs): - return get_comm_pair('inproc://', **kwargs) + return get_comm_pair("inproc://", **kwargs) @gen.coroutine @@ -99,7 +113,7 @@ def debug_loop(): """ while True: loop = ioloop.IOLoop.current() - print('.', loop, loop._handlers) + print(".", loop, loop._handlers) yield gen.sleep(0.50) @@ -107,86 +121,86 @@ def debug_loop(): # Test utility functions # + def test_parse_host_port(): f = parse_host_port - assert f('localhost:123') == ('localhost', 123) - assert f('127.0.0.1:456') == ('127.0.0.1', 456) - assert f('localhost:123', 80) == ('localhost', 123) - assert f('localhost', 80) == ('localhost', 80) + assert f("localhost:123") == ("localhost", 123) + assert f("127.0.0.1:456") == ("127.0.0.1", 456) + assert f("localhost:123", 80) == ("localhost", 123) + assert f("localhost", 80) == ("localhost", 80) with pytest.raises(ValueError): - f('localhost') + f("localhost") - assert f('[::1]:123') == ('::1', 123) - assert f('[fe80::1]:123', 80) == ('fe80::1', 123) - assert f('[::1]', 80) == ('::1', 80) + assert f("[::1]:123") == ("::1", 123) + assert f("[fe80::1]:123", 80) == ("fe80::1", 123) + assert f("[::1]", 80) == ("::1", 80) with pytest.raises(ValueError): - f('[::1]') + f("[::1]") with pytest.raises(ValueError): - f('::1:123') + f("::1:123") with pytest.raises(ValueError): - f('::1') + f("::1") def test_unparse_host_port(): f = unparse_host_port - assert f('localhost', 123) == 'localhost:123' - assert f('127.0.0.1', 123) == '127.0.0.1:123' - assert f('::1', 123) == '[::1]:123' - assert f('[::1]', 123) == '[::1]:123' + assert f("localhost", 123) == "localhost:123" + assert f("127.0.0.1", 123) == "127.0.0.1:123" + assert f("::1", 123) == "[::1]:123" + assert f("[::1]", 123) == "[::1]:123" - assert f('127.0.0.1') == '127.0.0.1' - assert f('127.0.0.1', 0) == '127.0.0.1' - assert f('127.0.0.1', None) == '127.0.0.1' - assert f('127.0.0.1', '*') == '127.0.0.1:*' + assert f("127.0.0.1") == "127.0.0.1" + assert f("127.0.0.1", 0) == "127.0.0.1" + assert f("127.0.0.1", None) == "127.0.0.1" + assert f("127.0.0.1", "*") == "127.0.0.1:*" - assert f('::1') == '[::1]' - assert f('[::1]') == '[::1]' - assert f('::1', '*') == '[::1]:*' + assert f("::1") == "[::1]" + assert f("[::1]") == "[::1]" + assert f("::1", "*") == "[::1]:*" def test_get_address_host(): f = get_address_host - assert f('tcp://127.0.0.1:123') == '127.0.0.1' - assert f('inproc://%s/%d/123' % (get_ip(), os.getpid())) == get_ip() + assert f("tcp://127.0.0.1:123") == "127.0.0.1" + assert f("inproc://%s/%d/123" % (get_ip(), os.getpid())) == get_ip() def test_resolve_address(): f = resolve_address - assert f('tcp://127.0.0.1:123') == 'tcp://127.0.0.1:123' - assert f('127.0.0.2:789') == 'tcp://127.0.0.2:789' - assert f('tcp://0.0.0.0:456') == 'tcp://0.0.0.0:456' - assert f('tcp://0.0.0.0:456') == 'tcp://0.0.0.0:456' + assert f("tcp://127.0.0.1:123") == "tcp://127.0.0.1:123" + assert f("127.0.0.2:789") == "tcp://127.0.0.2:789" + assert f("tcp://0.0.0.0:456") == "tcp://0.0.0.0:456" + assert f("tcp://0.0.0.0:456") == "tcp://0.0.0.0:456" if has_ipv6(): - assert f('tcp://[::1]:123') == 'tcp://[::1]:123' - assert f('tls://[::1]:123') == 'tls://[::1]:123' + assert f("tcp://[::1]:123") == "tcp://[::1]:123" + assert f("tls://[::1]:123") == "tls://[::1]:123" # OS X returns '::0.0.0.2' as canonical representation - assert f('[::2]:789') in ('tcp://[::2]:789', - 'tcp://[::0.0.0.2]:789') - assert f('tcp://[::]:123') == 'tcp://[::]:123' + assert f("[::2]:789") in ("tcp://[::2]:789", "tcp://[::0.0.0.2]:789") + assert f("tcp://[::]:123") == "tcp://[::]:123" - assert f('localhost:123') == 'tcp://127.0.0.1:123' - assert f('tcp://localhost:456') == 'tcp://127.0.0.1:456' - assert f('tls://localhost:456') == 'tls://127.0.0.1:456' + assert f("localhost:123") == "tcp://127.0.0.1:123" + assert f("tcp://localhost:456") == "tcp://127.0.0.1:456" + assert f("tls://localhost:456") == "tls://127.0.0.1:456" def test_get_local_address_for(): f = get_local_address_for - assert f('tcp://127.0.0.1:80') == 'tcp://127.0.0.1' - assert f('tcp://8.8.8.8:4444') == 'tcp://' + get_ip() + assert f("tcp://127.0.0.1:80") == "tcp://127.0.0.1" + assert f("tcp://8.8.8.8:4444") == "tcp://" + get_ip() if has_ipv6(): - assert f('tcp://[::1]:123') == 'tcp://[::1]' + assert f("tcp://[::1]:123") == "tcp://[::1]" - inproc_arg = 'inproc://%s/%d/444' % (get_ip(), os.getpid()) + inproc_arg = "inproc://%s/%d/444" % (get_ip(), os.getpid()) inproc_res = f(inproc_arg) - assert inproc_res.startswith('inproc://') + assert inproc_res.startswith("inproc://") assert inproc_res != inproc_arg @@ -194,24 +208,26 @@ def test_get_local_address_for(): # Test concrete transport APIs # + @gen_test() def test_tcp_specific(): """ Test concrete TCP API. """ + @gen.coroutine def handle_comm(comm): - assert comm.peer_address.startswith('tcp://' + host) + assert comm.peer_address.startswith("tcp://" + host) assert comm.extra_info == {} msg = yield comm.read() - msg['op'] = 'pong' + msg["op"] = "pong" yield comm.write(msg) yield comm.close() - listener = tcp.TCPListener('localhost', handle_comm) + listener = tcp.TCPListener("localhost", handle_comm) listener.start() host, port = listener.get_host_port() - assert host in ('localhost', '127.0.0.1', '::1') + assert host in ("localhost", "127.0.0.1", "::1") assert port > 0 connector = tcp.TCPConnector() @@ -219,15 +235,15 @@ def handle_comm(comm): @gen.coroutine def client_communicate(key, delay=0): - addr = '%s:%d' % (host, port) + addr = "%s:%d" % (host, port) comm = yield connector.connect(addr) - assert comm.peer_address == 'tcp://' + addr + assert comm.peer_address == "tcp://" + addr assert comm.extra_info == {} - yield comm.write({'op': 'ping', 'data': key}) + yield comm.write({"op": "ping", "data": key}) if delay: yield gen.sleep(delay) msg = yield comm.read() - assert msg == {'op': 'pong', 'data': key} + assert msg == {"op": "pong", "data": key} l.append(key) yield comm.close() @@ -245,23 +261,23 @@ def test_tls_specific(): """ Test concrete TLS API. """ + @gen.coroutine def handle_comm(comm): - assert comm.peer_address.startswith('tls://' + host) + assert comm.peer_address.startswith("tls://" + host) check_tls_extra(comm.extra_info) msg = yield comm.read() - msg['op'] = 'pong' + msg["op"] = "pong" yield comm.write(msg) yield comm.close() server_ctx = get_server_ssl_context() client_ctx = get_client_ssl_context() - listener = tcp.TLSListener('localhost', handle_comm, - ssl_context=server_ctx) + listener = tcp.TLSListener("localhost", handle_comm, ssl_context=server_ctx) listener.start() host, port = listener.get_host_port() - assert host in ('localhost', '127.0.0.1', '::1') + assert host in ("localhost", "127.0.0.1", "::1") assert port > 0 connector = tcp.TLSConnector() @@ -269,15 +285,15 @@ def handle_comm(comm): @gen.coroutine def client_communicate(key, delay=0): - addr = '%s:%d' % (host, port) + addr = "%s:%d" % (host, port) comm = yield connector.connect(addr, ssl_context=client_ctx) - assert comm.peer_address == 'tls://' + addr + assert comm.peer_address == "tls://" + addr check_tls_extra(comm.extra_info) - yield comm.write({'op': 'ping', 'data': key}) + yield comm.write({"op": "ping", "data": key}) if delay: yield gen.sleep(delay) msg = yield comm.read() - assert msg == {'op': 'pong', 'data': key} + assert msg == {"op": "pong", "data": key} l.append(key) yield comm.close() @@ -309,6 +325,7 @@ def sleep_for_60ms(): if thread_count > max_thread_count: max_thread_count = thread_count raise gen.Return(max_thread_count) + original_thread_count = threading.active_count() # tcp.TCPConnector() @@ -323,8 +340,11 @@ def sleep_for_60ms(): # tcp.TLSConnector() sleep_future = sleep_for_60ms() with pytest.raises(IOError): - yield connect("tls://localhost:28400", 0.052, - connection_args={'ssl_context': get_client_ssl_context()}) + yield connect( + "tls://localhost:28400", + 0.052, + connection_args={"ssl_context": get_client_ssl_context()}, + ) max_thread_count = yield sleep_future if PY3: assert max_thread_count <= 2 + original_thread_count @@ -336,7 +356,7 @@ def check_inproc_specific(run_client): Test concrete InProc API. """ listener_addr = inproc.global_manager.new_address() - addr_head = listener_addr.rpartition('/')[0] + addr_head = listener_addr.rpartition("/")[0] client_addresses = set() @@ -344,17 +364,21 @@ def check_inproc_specific(run_client): @gen.coroutine def handle_comm(comm): - assert comm.peer_address.startswith('inproc://' + addr_head) + assert comm.peer_address.startswith("inproc://" + addr_head) client_addresses.add(comm.peer_address) for i in range(N_MSGS): msg = yield comm.read() - msg['op'] = 'pong' + msg["op"] = "pong" yield comm.write(msg) yield comm.close() listener = inproc.InProcListener(listener_addr, handle_comm) listener.start() - assert listener.listen_address == listener.contact_address == 'inproc://' + listener_addr + assert ( + listener.listen_address + == listener.contact_address + == "inproc://" + listener_addr + ) connector = inproc.InProcConnector(inproc.global_manager) l = [] @@ -362,13 +386,13 @@ def handle_comm(comm): @gen.coroutine def client_communicate(key, delay=0): comm = yield connector.connect(listener_addr) - assert comm.peer_address == 'inproc://' + listener_addr + assert comm.peer_address == "inproc://" + listener_addr for i in range(N_MSGS): - yield comm.write({'op': 'ping', 'data': key}) + yield comm.write({"op": "ping", "data": key}) if delay: yield gen.sleep(delay) msg = yield comm.read() - assert msg == {'op': 'pong', 'data': key} + assert msg == {"op": "pong", "data": key} l.append(key) with pytest.raises(CommClosedError): yield comm.read() @@ -399,8 +423,7 @@ def run_coro_in_thread(func, *args, **kwargs): def run(): thread_loop = ioloop.IOLoop() # need fresh IO loop for run_sync() try: - res = thread_loop.run_sync(partial(func, *args, **kwargs), - timeout=10) + res = thread_loop.run_sync(partial(func, *args, **kwargs), timeout=10) except Exception: main_loop.add_callback(fut.set_exc_info, sys.exc_info()) else: @@ -427,30 +450,37 @@ def test_inproc_specific_different_threads(): # Test communications through the abstract API # + @gen.coroutine -def check_client_server(addr, check_listen_addr=None, check_contact_addr=None, - listen_args=None, connect_args=None): +def check_client_server( + addr, + check_listen_addr=None, + check_contact_addr=None, + listen_args=None, + connect_args=None, +): """ Abstract client / server test. """ + @gen.coroutine def handle_comm(comm): scheme, loc = parse_address(comm.peer_address) assert scheme == bound_scheme msg = yield comm.read() - assert msg['op'] == 'ping' - msg['op'] = 'pong' + assert msg["op"] == "ping" + msg["op"] = "pong" yield comm.write(msg) msg = yield comm.read() - assert msg['op'] == 'foobar' + assert msg["op"] == "foobar" yield comm.close() # Arbitrary connection args should be ignored - listen_args = listen_args or {'xxx': 'bar'} - connect_args = connect_args or {'xxx': 'foo'} + listen_args = listen_args or {"xxx": "bar"} + connect_args = connect_args or {"xxx": "foo"} listener = listen(addr, handle_comm, connection_args=listen_args) listener.start() @@ -458,7 +488,7 @@ def handle_comm(comm): # Check listener properties bound_addr = listener.listen_address bound_scheme, bound_loc = parse_address(bound_addr) - assert bound_scheme in ('inproc', 'tcp', 'tls') + assert bound_scheme in ("inproc", "tcp", "tls") assert bound_scheme == parse_address(addr)[0] if check_listen_addr is not None: @@ -478,16 +508,15 @@ def handle_comm(comm): @gen.coroutine def client_communicate(key, delay=0): - comm = yield connect(listener.contact_address, - connection_args=connect_args) + comm = yield connect(listener.contact_address, connection_args=connect_args) assert comm.peer_address == listener.contact_address - yield comm.write({'op': 'ping', 'data': key}) - yield comm.write({'op': 'foobar'}) + yield comm.write({"op": "ping", "data": key}) + yield comm.write({"op": "foobar"}) if delay: yield gen.sleep(delay) msg = yield comm.read() - assert msg == {'op': 'pong', 'data': key} + assert msg == {"op": "pong", "data": key} l.append(key) yield comm.close() @@ -521,7 +550,7 @@ def inproc_check(): expected_pid = os.getpid() def checker(loc): - ip, pid, suffix = loc.split('/') + ip, pid, suffix = loc.split("/") assert ip == expected_ip assert int(pid) == expected_pid @@ -531,70 +560,75 @@ def checker(loc): @gen_test() def test_default_client_server_ipv4(): # Default scheme is (currently) TCP - yield check_client_server('127.0.0.1', tcp_eq('127.0.0.1')) - yield check_client_server('127.0.0.1:3201', tcp_eq('127.0.0.1', 3201)) - yield check_client_server('0.0.0.0', - tcp_eq('0.0.0.0'), tcp_eq(EXTERNAL_IP4)) - yield check_client_server('0.0.0.0:3202', - tcp_eq('0.0.0.0', 3202), tcp_eq(EXTERNAL_IP4, 3202)) + yield check_client_server("127.0.0.1", tcp_eq("127.0.0.1")) + yield check_client_server("127.0.0.1:3201", tcp_eq("127.0.0.1", 3201)) + yield check_client_server("0.0.0.0", tcp_eq("0.0.0.0"), tcp_eq(EXTERNAL_IP4)) + yield check_client_server( + "0.0.0.0:3202", tcp_eq("0.0.0.0", 3202), tcp_eq(EXTERNAL_IP4, 3202) + ) # IPv4 is preferred for the bound address - yield check_client_server('', - tcp_eq('0.0.0.0'), tcp_eq(EXTERNAL_IP4)) - yield check_client_server(':3203', - tcp_eq('0.0.0.0', 3203), tcp_eq(EXTERNAL_IP4, 3203)) + yield check_client_server("", tcp_eq("0.0.0.0"), tcp_eq(EXTERNAL_IP4)) + yield check_client_server( + ":3203", tcp_eq("0.0.0.0", 3203), tcp_eq(EXTERNAL_IP4, 3203) + ) @requires_ipv6 @gen_test() def test_default_client_server_ipv6(): - yield check_client_server('[::1]', tcp_eq('::1')) - yield check_client_server('[::1]:3211', tcp_eq('::1', 3211)) - yield check_client_server('[::]', tcp_eq('::'), tcp_eq(EXTERNAL_IP6)) - yield check_client_server('[::]:3212', tcp_eq('::', 3212), tcp_eq(EXTERNAL_IP6, 3212)) + yield check_client_server("[::1]", tcp_eq("::1")) + yield check_client_server("[::1]:3211", tcp_eq("::1", 3211)) + yield check_client_server("[::]", tcp_eq("::"), tcp_eq(EXTERNAL_IP6)) + yield check_client_server( + "[::]:3212", tcp_eq("::", 3212), tcp_eq(EXTERNAL_IP6, 3212) + ) @gen_test() def test_tcp_client_server_ipv4(): - yield check_client_server('tcp://127.0.0.1', tcp_eq('127.0.0.1')) - yield check_client_server('tcp://127.0.0.1:3221', tcp_eq('127.0.0.1', 3221)) - yield check_client_server('tcp://0.0.0.0', - tcp_eq('0.0.0.0'), tcp_eq(EXTERNAL_IP4)) - yield check_client_server('tcp://0.0.0.0:3222', - tcp_eq('0.0.0.0', 3222), tcp_eq(EXTERNAL_IP4, 3222)) - yield check_client_server('tcp://', - tcp_eq('0.0.0.0'), tcp_eq(EXTERNAL_IP4)) - yield check_client_server('tcp://:3223', - tcp_eq('0.0.0.0', 3223), tcp_eq(EXTERNAL_IP4, 3223)) + yield check_client_server("tcp://127.0.0.1", tcp_eq("127.0.0.1")) + yield check_client_server("tcp://127.0.0.1:3221", tcp_eq("127.0.0.1", 3221)) + yield check_client_server("tcp://0.0.0.0", tcp_eq("0.0.0.0"), tcp_eq(EXTERNAL_IP4)) + yield check_client_server( + "tcp://0.0.0.0:3222", tcp_eq("0.0.0.0", 3222), tcp_eq(EXTERNAL_IP4, 3222) + ) + yield check_client_server("tcp://", tcp_eq("0.0.0.0"), tcp_eq(EXTERNAL_IP4)) + yield check_client_server( + "tcp://:3223", tcp_eq("0.0.0.0", 3223), tcp_eq(EXTERNAL_IP4, 3223) + ) @requires_ipv6 @gen_test() def test_tcp_client_server_ipv6(): - yield check_client_server('tcp://[::1]', tcp_eq('::1')) - yield check_client_server('tcp://[::1]:3231', tcp_eq('::1', 3231)) - yield check_client_server('tcp://[::]', - tcp_eq('::'), tcp_eq(EXTERNAL_IP6)) - yield check_client_server('tcp://[::]:3232', - tcp_eq('::', 3232), tcp_eq(EXTERNAL_IP6, 3232)) + yield check_client_server("tcp://[::1]", tcp_eq("::1")) + yield check_client_server("tcp://[::1]:3231", tcp_eq("::1", 3231)) + yield check_client_server("tcp://[::]", tcp_eq("::"), tcp_eq(EXTERNAL_IP6)) + yield check_client_server( + "tcp://[::]:3232", tcp_eq("::", 3232), tcp_eq(EXTERNAL_IP6, 3232) + ) @gen_test() def test_tls_client_server_ipv4(): - yield check_client_server('tls://127.0.0.1', tls_eq('127.0.0.1'), **tls_kwargs) - yield check_client_server('tls://127.0.0.1:3221', tls_eq('127.0.0.1', 3221), **tls_kwargs) - yield check_client_server('tls://', tls_eq('0.0.0.0'), - tls_eq(EXTERNAL_IP4), **tls_kwargs) + yield check_client_server("tls://127.0.0.1", tls_eq("127.0.0.1"), **tls_kwargs) + yield check_client_server( + "tls://127.0.0.1:3221", tls_eq("127.0.0.1", 3221), **tls_kwargs + ) + yield check_client_server( + "tls://", tls_eq("0.0.0.0"), tls_eq(EXTERNAL_IP4), **tls_kwargs + ) @requires_ipv6 @gen_test() def test_tls_client_server_ipv6(): - yield check_client_server('tls://[::1]', tls_eq('::1'), **tls_kwargs) + yield check_client_server("tls://[::1]", tls_eq("::1"), **tls_kwargs) @gen_test() def test_inproc_client_server(): - yield check_client_server('inproc://', inproc_check()) + yield check_client_server("inproc://", inproc_check()) yield check_client_server(inproc.new_address(), inproc_check()) @@ -602,56 +636,66 @@ def test_inproc_client_server(): # TLS certificate handling # + @gen_test() def test_tls_reject_certificate(): cli_ctx = get_client_ssl_context() serv_ctx = get_server_ssl_context() # These certs are not signed by our test CA - bad_cert_key = ('tls-self-signed-cert.pem', 'tls-self-signed-key.pem') + bad_cert_key = ("tls-self-signed-cert.pem", "tls-self-signed-key.pem") bad_cli_ctx = get_client_ssl_context(*bad_cert_key) bad_serv_ctx = get_server_ssl_context(*bad_cert_key) @gen.coroutine def handle_comm(comm): scheme, loc = parse_address(comm.peer_address) - assert scheme == 'tls' + assert scheme == "tls" yield comm.close() # Listener refuses a connector not signed by the CA - listener = listen('tls://', handle_comm, - connection_args={'ssl_context': serv_ctx}) + listener = listen("tls://", handle_comm, connection_args={"ssl_context": serv_ctx}) listener.start() with pytest.raises(EnvironmentError) as excinfo: - comm = yield connect(listener.contact_address, timeout=0.5, - connection_args={'ssl_context': bad_cli_ctx}) - yield comm.write({'x': 'foo'}) # TODO: why is this necessary in Tornado 6 ? + comm = yield connect( + listener.contact_address, + timeout=0.5, + connection_args={"ssl_context": bad_cli_ctx}, + ) + yield comm.write({"x": "foo"}) # TODO: why is this necessary in Tornado 6 ? # The wrong error is reported on Python 2, see https://github.com/tornadoweb/tornado/pull/2028 - if sys.version_info >= (3,) and os.name != 'nt': + if sys.version_info >= (3,) and os.name != "nt": try: # See https://serverfault.com/questions/793260/what-does-tlsv1-alert-unknown-ca-mean assert "unknown ca" in str(excinfo.value) except AssertionError: - if os.name == 'nt': - assert "An existing connection was forcibly closed" in str(excinfo.value) + if os.name == "nt": + assert "An existing connection was forcibly closed" in str( + excinfo.value + ) else: raise # Sanity check - comm = yield connect(listener.contact_address, timeout=0.5, - connection_args={'ssl_context': cli_ctx}) + comm = yield connect( + listener.contact_address, timeout=0.5, connection_args={"ssl_context": cli_ctx} + ) yield comm.close() # Connector refuses a listener not signed by the CA - listener = listen('tls://', handle_comm, - connection_args={'ssl_context': bad_serv_ctx}) + listener = listen( + "tls://", handle_comm, connection_args={"ssl_context": bad_serv_ctx} + ) listener.start() with pytest.raises(EnvironmentError) as excinfo: - yield connect(listener.contact_address, timeout=0.5, - connection_args={'ssl_context': cli_ctx}) + yield connect( + listener.contact_address, + timeout=0.5, + connection_args={"ssl_context": cli_ctx}, + ) # The wrong error is reported on Python 2, see https://github.com/tornadoweb/tornado/pull/2028 if sys.version_info >= (3,): assert "certificate verify failed" in str(excinfo.value) @@ -661,9 +705,9 @@ def handle_comm(comm): # Test communication closing # + @gen.coroutine -def check_comm_closed_implicit(addr, delay=None, listen_args=None, - connect_args=None): +def check_comm_closed_implicit(addr, delay=None, listen_args=None, connect_args=None): @gen.coroutine def handle_comm(comm): yield comm.close() @@ -683,12 +727,12 @@ def handle_comm(comm): @gen_test() def test_tcp_comm_closed_implicit(): - yield check_comm_closed_implicit('tcp://127.0.0.1') + yield check_comm_closed_implicit("tcp://127.0.0.1") @gen_test() def test_tls_comm_closed_implicit(): - yield check_comm_closed_implicit('tls://127.0.0.1', **tls_kwargs) + yield check_comm_closed_implicit("tls://127.0.0.1", **tls_kwargs) @gen_test() @@ -722,12 +766,12 @@ def check_comm_closed_explicit(addr, listen_args=None, connect_args=None): @gen_test() def test_tcp_comm_closed_explicit(): - yield check_comm_closed_explicit('tcp://127.0.0.1') + yield check_comm_closed_explicit("tcp://127.0.0.1") @gen_test() def test_tls_comm_closed_explicit(): - yield check_comm_closed_explicit('tls://127.0.0.1', **tls_kwargs) + yield check_comm_closed_explicit("tls://127.0.0.1", **tls_kwargs) @gen_test() @@ -750,7 +794,7 @@ def handle_comm(comm): else: comm.close() - listener = listen('inproc://', handle_comm) + listener = listen("inproc://", handle_comm) listener.start() contact_addr = listener.contact_address @@ -792,6 +836,7 @@ def handle_comm(comm): # Various stress tests # + @gen.coroutine def check_connect_timeout(addr): t1 = time() @@ -803,7 +848,7 @@ def check_connect_timeout(addr): @gen_test() def test_tcp_connect_timeout(): - yield check_connect_timeout('tcp://127.0.0.1:44444') + yield check_connect_timeout("tcp://127.0.0.1:44444") @gen_test() @@ -833,20 +878,21 @@ def handle_comm(comm): @gen_test() def test_tcp_many_listeners(): - check_many_listeners('tcp://127.0.0.1') - check_many_listeners('tcp://0.0.0.0') - check_many_listeners('tcp://') + check_many_listeners("tcp://127.0.0.1") + check_many_listeners("tcp://0.0.0.0") + check_many_listeners("tcp://") @gen_test() def test_inproc_many_listeners(): - check_many_listeners('inproc://') + check_many_listeners("inproc://") # # Test deserialization # + @gen.coroutine def check_listener_deserialize(addr, deserialize, in_value, check_out): q = queues.Queue() @@ -893,21 +939,22 @@ def check_deserialize(addr): """ # Test with Serialize and Serialized objects - msg = {'op': 'update', - 'x': b'abc', - 'to_ser': [to_serialize(123)], - 'ser': Serialized(*serialize(456)), - } + msg = { + "op": "update", + "x": b"abc", + "to_ser": [to_serialize(123)], + "ser": Serialized(*serialize(456)), + } msg_orig = msg.copy() def check_out_false(out_value): # Check output with deserialize=False out_value = out_value.copy() # in case transport passed the object as-is - to_ser = out_value.pop('to_ser') - ser = out_value.pop('ser') + to_ser = out_value.pop("to_ser") + ser = out_value.pop("ser") expected_msg = msg_orig.copy() - del expected_msg['ser'] - del expected_msg['to_ser'] + del expected_msg["ser"] + del expected_msg["to_ser"] assert out_value == expected_msg assert isinstance(ser, Serialized) @@ -925,8 +972,8 @@ def check_out_false(out_value): def check_out_true(out_value): # Check output with deserialize=True expected_msg = msg.copy() - expected_msg['ser'] = 456 - expected_msg['to_ser'] = [123] + expected_msg["ser"] = 456 + expected_msg["to_ser"] = [123] assert out_value == expected_msg yield check_listener_deserialize(addr, False, msg, check_out_false) @@ -940,22 +987,23 @@ def check_out_true(out_value): _uncompressible = os.urandom(1024 ** 2) * 4 # end size: 8 MB - msg = {'op': 'update', - 'x': _uncompressible, - 'to_ser': [to_serialize(_uncompressible)], - 'ser': Serialized(*serialize(_uncompressible)), - } + msg = { + "op": "update", + "x": _uncompressible, + "to_ser": [to_serialize(_uncompressible)], + "ser": Serialized(*serialize(_uncompressible)), + } msg_orig = msg.copy() def check_out(deserialize_flag, out_value): # Check output with deserialize=False assert sorted(out_value) == sorted(msg_orig) out_value = out_value.copy() # in case transport passed the object as-is - to_ser = out_value.pop('to_ser') - ser = out_value.pop('ser') + to_ser = out_value.pop("to_ser") + ser = out_value.pop("ser") expected_msg = msg_orig.copy() - del expected_msg['ser'] - del expected_msg['to_ser'] + del expected_msg["ser"] + del expected_msg["to_ser"] assert out_value == expected_msg if deserialize_flag: @@ -980,15 +1028,15 @@ def check_out(deserialize_flag, out_value): yield check_connector_deserialize(addr, True, msg, partial(check_out, True)) -@pytest.mark.xfail(reason='intermittent failure on windows') +@pytest.mark.xfail(reason="intermittent failure on windows") @gen_test() def test_tcp_deserialize(): - yield check_deserialize('tcp://') + yield check_deserialize("tcp://") @gen_test() def test_inproc_deserialize(): - yield check_deserialize('inproc://') + yield check_deserialize("inproc://") @gen.coroutine @@ -1000,11 +1048,12 @@ def check_deserialize_roundtrip(addr): # as a separate payload _uncompressible = os.urandom(1024 ** 2) * 4 # end size: 4 MB - msg = {'op': 'update', - 'x': _uncompressible, - 'to_ser': [to_serialize(_uncompressible)], - 'ser': Serialized(*serialize(_uncompressible)), - } + msg = { + "op": "update", + "x": _uncompressible, + "to_ser": [to_serialize(_uncompressible)], + "ser": Serialized(*serialize(_uncompressible)), + } for should_deserialize in (True, False): a, b = yield get_comm_pair(addr, deserialize=should_deserialize) @@ -1014,24 +1063,24 @@ def check_deserialize_roundtrip(addr): got = yield a.read() assert sorted(got) == sorted(msg) - for k in ('op', 'x'): + for k in ("op", "x"): assert got[k] == msg[k] if should_deserialize: - assert isinstance(got['to_ser'][0], (bytes, bytearray)) - assert isinstance(got['ser'], (bytes, bytearray)) + assert isinstance(got["to_ser"][0], (bytes, bytearray)) + assert isinstance(got["ser"], (bytes, bytearray)) else: - assert isinstance(got['to_ser'][0], (to_serialize, Serialized)) - assert isinstance(got['ser'], Serialized) + assert isinstance(got["to_ser"][0], (to_serialize, Serialized)) + assert isinstance(got["ser"], Serialized) @gen_test() def test_inproc_deserialize_roundtrip(): - yield check_deserialize_roundtrip('inproc://') + yield check_deserialize_roundtrip("inproc://") @gen_test() def test_tcp_deserialize_roundtrip(): - yield check_deserialize_roundtrip('tcp://') + yield check_deserialize_roundtrip("tcp://") def _raise_eoferror(): @@ -1048,9 +1097,10 @@ def check_deserialize_eoferror(addr): """ EOFError when deserializing should close the comm. """ + @gen.coroutine def handle_comm(comm): - yield comm.write({'data': to_serialize(_EOFRaising())}) + yield comm.write({"data": to_serialize(_EOFRaising())}) with pytest.raises(CommClosedError): yield comm.read() @@ -1062,21 +1112,22 @@ def handle_comm(comm): @gen_test() def test_tcp_deserialize_eoferror(): - yield check_deserialize_eoferror('tcp://') + yield check_deserialize_eoferror("tcp://") # # Test various properties # + @gen.coroutine def check_repr(a, b): - assert 'closed' not in repr(a) - assert 'closed' not in repr(b) + assert "closed" not in repr(a) + assert "closed" not in repr(b) yield a.close() - assert 'closed' in repr(a) + assert "closed" in repr(a) yield b.close() - assert 'closed' in repr(b) + assert "closed" in repr(b) @gen_test() diff --git a/distributed/comm/utils.py b/distributed/comm/utils.py index 6c9a99b8a8d..bb6621e2021 100644 --- a/distributed/comm/utils.py +++ b/distributed/comm/utils.py @@ -17,10 +17,12 @@ # Offload (de)serializing large frames to improve event loop responsiveness. # We use at most 4 threads to allow for parallel processing of large messages. -FRAME_OFFLOAD_THRESHOLD = 10 * 1024 ** 2 # 10 MB +FRAME_OFFLOAD_THRESHOLD = 10 * 1024 ** 2 # 10 MB try: - _offload_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix='Dask-Offload') + _offload_executor = ThreadPoolExecutor( + max_workers=1, thread_name_prefix="Dask-Offload" + ) except TypeError: _offload_executor = ThreadPoolExecutor(max_workers=1) finalize(_offload_executor, _offload_executor.shutdown) @@ -31,16 +33,18 @@ def offload(fn, *args, **kwargs): @gen.coroutine -def to_frames(msg, serializers=None, on_error='message', context=None): +def to_frames(msg, serializers=None, on_error="message", context=None): """ Serialize a message into a list of Distributed protocol frames. """ + def _to_frames(): try: - return list(protocol.dumps(msg, - serializers=serializers, - on_error=on_error, - context=context)) + return list( + protocol.dumps( + msg, serializers=serializers, on_error=on_error, context=context + ) + ) except Exception as e: logger.info("Unserializable Message: %s", msg) logger.exception(e) @@ -63,17 +67,16 @@ def from_frames(frames, deserialize=True, deserializers=None): def _from_frames(): try: - return protocol.loads(frames, - deserialize=deserialize, - deserializers=deserializers) + return protocol.loads( + frames, deserialize=deserialize, deserializers=deserializers + ) except EOFError: if size > 1000: datastr = "[too large to display]" else: datastr = frames # Aid diagnosing - logger.error("truncated data stream (%d bytes): %s", size, - datastr) + logger.error("truncated data stream (%d bytes): %s", size, datastr) raise if deserialize and size > FRAME_OFFLOAD_THRESHOLD: @@ -114,9 +117,9 @@ def ensure_concrete_host(host): Ensure the given host string (or IP) denotes a concrete host, not a wildcard listening address. """ - if host in ('0.0.0.0', ''): + if host in ("0.0.0.0", ""): return get_ip() - elif host == '::': + elif host == "::": return get_ipv6() else: return host diff --git a/distributed/compatibility.py b/distributed/compatibility.py index fbf8f86df5d..f3a85973802 100644 --- a/distributed/compatibility.py +++ b/distributed/compatibility.py @@ -20,6 +20,7 @@ PY3 = False ConnectionRefusedError = OSError FileExistsError = OSError + class StopAsyncIteration(Exception): pass @@ -33,7 +34,7 @@ def gzip_decompress(b): def gzip_compress(b): bio = BytesIO() - f = gzip.GzipFile(fileobj=bio, mode='w') + f = gzip.GzipFile(fileobj=bio, mode="w") f.write(b) f.close() bio.seek(0) @@ -41,23 +42,25 @@ def gzip_compress(b): return result def isqueue(o): - return (hasattr(o, 'queue') and - hasattr(o, '__module__') and - o.__module__ == 'Queue') + return ( + hasattr(o, "queue") and hasattr(o, "__module__") and o.__module__ == "Queue" + ) def invalidate_caches(): pass def cache_from_source(path): import os + name, ext = os.path.splitext(path) - return name + '.pyc' + return name + ".pyc" logging_names = logging._levelNames def iscoroutinefunction(func): return False + if sys.version_info[0] == 3: from asyncio import iscoroutinefunction from collections.abc import Iterator, Mapping, Set, MutableMapping @@ -75,6 +78,7 @@ def iscoroutinefunction(func): unicode = str from gzip import decompress as gzip_decompress from gzip import compress as gzip_compress + ConnectionRefusedError = ConnectionRefusedError FileExistsError = FileExistsError StopAsyncIteration = StopAsyncIteration @@ -87,8 +91,9 @@ def isqueue(o): import platform -PYPY = platform.python_implementation().lower() == 'pypy' -WINDOWS = sys.platform.startswith('win') + +PYPY = platform.python_implementation().lower() == "pypy" +WINDOWS = sys.platform.startswith("win") try: @@ -141,6 +146,7 @@ def __init__(self, obj, func, *args, **kwargs): # We may register the exit function more than once because # of a thread race, but that is harmless import atexit + atexit.register(self._exitfunc) finalize._registered_with_atexit = True info = self._Info() @@ -197,10 +203,14 @@ def __repr__(self): info = self._registry.get(self) obj = info and info.weakref() if obj is None: - return '<%s object at %#x; dead>' % (type(self).__name__, id(self)) + return "<%s object at %#x; dead>" % (type(self).__name__, id(self)) else: - return '<%s object at %#x; for %r at %#x>' % \ - (type(self).__name__, id(self), type(obj).__name__, id(obj)) + return "<%s object at %#x; for %r at %#x>" % ( + type(self).__name__, + id(self), + type(obj).__name__, + id(obj), + ) @classmethod def _select_for_exit(cls): @@ -218,6 +228,7 @@ def _exitfunc(cls): try: if cls._registry: import gc + if gc.isenabled(): reenable_gc = True gc.disable() diff --git a/distributed/config.py b/distributed/config.py index d2b27397393..4b7b589d58f 100644 --- a/distributed/config.py +++ b/distributed/config.py @@ -13,7 +13,7 @@ config = dask.config.config -fn = os.path.join(os.path.dirname(__file__), 'distributed.yaml') +fn = os.path.join(os.path.dirname(__file__), "distributed.yaml") dask.config.ensure_file(source=fn) with open(fn) as f: @@ -22,39 +22,34 @@ dask.config.update_defaults(defaults) aliases = { - 'allowed-failures': 'distributed.scheduler.allowed-failures', - 'bandwidth': 'distributed.scheduler.bandwidth', - 'default-data-size': 'distributed.scheduler.default-data-size', - 'transition-log-length': 'distributed.scheduler.transition-log-length', - 'work-stealing': 'distributed.scheduler.work-stealing', - 'worker-ttl': 'distributed.scheduler.worker-ttl', - - 'multiprocessing-method': 'distributed.worker.multiprocessing-method', - 'use-file-locking': 'distributed.worker.use-file-locking', - 'profile-interval': 'distributed.worker.profile.interval', - 'profile-cycle-interval': 'distributed.worker.profile.cycle', - 'worker-memory-target': 'distributed.worker.memory.target', - 'worker-memory-spill': 'distributed.worker.memory.spill', - 'worker-memory-pause': 'distributed.worker.memory.pause', - 'worker-memory-terminate': 'distributed.worker.memory.terminate', - - 'heartbeat-interval': 'distributed.client.heartbeat', - - 'compression': 'distributed.comm.compression', - 'connect-timeout': 'distributed.comm.timeouts.connect', - 'tcp-timeout': 'distributed.comm.timeouts.tcp', - 'default-scheme': 'distributed.comm.default-scheme', - 'socket-backlog': 'distributed.comm.socket-backlog', - 'recent-messages-log-length': 'distributed.comm.recent-messages-log-length', - - 'diagnostics-link': 'distributed.dashboard.link', - 'bokeh-export-tool': 'distributed.dashboard.export-tool', - - 'tick-time': 'distributed.admin.tick.interval', - 'tick-maximum-delay': 'distributed.admin.tick.limit', - 'log-length': 'distributed.admin.log-length', - 'log-format': 'distributed.admin.log-format', - 'pdb-on-err': 'distributed.admin.pdb-on-err', + "allowed-failures": "distributed.scheduler.allowed-failures", + "bandwidth": "distributed.scheduler.bandwidth", + "default-data-size": "distributed.scheduler.default-data-size", + "transition-log-length": "distributed.scheduler.transition-log-length", + "work-stealing": "distributed.scheduler.work-stealing", + "worker-ttl": "distributed.scheduler.worker-ttl", + "multiprocessing-method": "distributed.worker.multiprocessing-method", + "use-file-locking": "distributed.worker.use-file-locking", + "profile-interval": "distributed.worker.profile.interval", + "profile-cycle-interval": "distributed.worker.profile.cycle", + "worker-memory-target": "distributed.worker.memory.target", + "worker-memory-spill": "distributed.worker.memory.spill", + "worker-memory-pause": "distributed.worker.memory.pause", + "worker-memory-terminate": "distributed.worker.memory.terminate", + "heartbeat-interval": "distributed.client.heartbeat", + "compression": "distributed.comm.compression", + "connect-timeout": "distributed.comm.timeouts.connect", + "tcp-timeout": "distributed.comm.timeouts.tcp", + "default-scheme": "distributed.comm.default-scheme", + "socket-backlog": "distributed.comm.socket-backlog", + "recent-messages-log-length": "distributed.comm.recent-messages-log-length", + "diagnostics-link": "distributed.dashboard.link", + "bokeh-export-tool": "distributed.dashboard.export-tool", + "tick-time": "distributed.admin.tick.interval", + "tick-maximum-delay": "distributed.admin.tick.limit", + "log-length": "distributed.admin.log-length", + "log-format": "distributed.admin.log-format", + "pdb-on-err": "distributed.admin.pdb-on-err", } dask.config.rename(aliases) @@ -81,17 +76,20 @@ def _initialize_logging_old_style(config): } """ loggers = { # default values - 'distributed': 'info', - 'distributed.client': 'warning', - 'bokeh': 'critical', - 'tornado': 'critical', - 'tornado.application': 'error', + "distributed": "info", + "distributed.client": "warning", + "bokeh": "critical", + "tornado": "critical", + "tornado.application": "error", } - loggers.update(config.get('logging', {})) + loggers.update(config.get("logging", {})) handler = logging.StreamHandler(sys.stderr) - handler.setFormatter(logging.Formatter(dask.config.get('distributed.admin.log-format', - config=config))) + handler.setFormatter( + logging.Formatter( + dask.config.get("distributed.admin.log-format", config=config) + ) + ) for name, level in loggers.items(): if isinstance(level, str): level = logging_names[level.upper()] @@ -107,7 +105,7 @@ def _initialize_logging_new_style(config): Initialize logging using logging's "Configuration dictionary schema". (ref.: https://docs.python.org/2/library/logging.config.html#logging-config-dictschema) """ - logging.config.dictConfig(config.get('logging')) + logging.config.dictConfig(config.get("logging")) def _initialize_logging_file_config(config): @@ -115,19 +113,23 @@ def _initialize_logging_file_config(config): Initialize logging using logging's "Configuration file format". (ref.: https://docs.python.org/2/library/logging.config.html#configuration-file-format) """ - logging.config.fileConfig(config.get('logging-file-config'), disable_existing_loggers=False) + logging.config.fileConfig( + config.get("logging-file-config"), disable_existing_loggers=False + ) def initialize_logging(config): - if 'logging-file-config' in config: - if 'logging' in config: - raise RuntimeError("Config options 'logging-file-config' and 'logging' are mutually exclusive.") + if "logging-file-config" in config: + if "logging" in config: + raise RuntimeError( + "Config options 'logging-file-config' and 'logging' are mutually exclusive." + ) _initialize_logging_file_config(config) else: - log_config = config.get('logging', {}) - if 'version' in log_config: + log_config = config.get("logging", {}) + if "version" in log_config: # logging module mandates version to be an int - log_config['version'] = int(log_config['version']) + log_config["version"] = int(log_config["version"]) _initialize_logging_new_style(config) else: _initialize_logging_old_style(config) diff --git a/distributed/core.py b/distributed/core.py index a9883801549..e074fa68148 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -17,14 +17,26 @@ from tornado.locks import Event from .compatibility import get_thread_identity -from .comm import (connect, listen, CommClosedError, - normalize_address, - unparse_host_port, get_address_host_port) +from .comm import ( + connect, + listen, + CommClosedError, + normalize_address, + unparse_host_port, + get_address_host_port, +) from .metrics import time from . import profile from .system_monitor import SystemMonitor -from .utils import (get_traceback, truncate_exception, ignoring, shutting_down, - PeriodicCallback, parse_timedelta, has_keyword) +from .utils import ( + get_traceback, + truncate_exception, + ignoring, + shutting_down, + PeriodicCallback, + parse_timedelta, + has_keyword, +) from . import protocol @@ -38,6 +50,7 @@ class RPCClosed(IOError): def get_total_physical_memory(): try: import psutil + return psutil.virtual_memory().total / 2 except ImportError: return 2e9 @@ -46,14 +59,17 @@ def get_total_physical_memory(): def raise_later(exc): def _raise(*args, **kwargs): raise exc + return _raise MAX_BUFFER_SIZE = get_total_physical_memory() -tick_maximum_delay = parse_timedelta(dask.config.get('distributed.admin.tick.limit'), default='ms') +tick_maximum_delay = parse_timedelta( + dask.config.get("distributed.admin.tick.limit"), default="ms" +) -LOG_PDB = dask.config.get('distributed.admin.pdb-on-err') +LOG_PDB = dask.config.get("distributed.admin.pdb-on-err") class Server(object): @@ -92,23 +108,33 @@ class Server(object): * ``{'op': 'add', 'x': 10, 'y': 20}`` """ - default_ip = '' + + default_ip = "" default_port = 0 - def __init__(self, handlers, blocked_handlers=None, stream_handlers=None, connection_limit=512, - deserialize=True, io_loop=None): + def __init__( + self, + handlers, + blocked_handlers=None, + stream_handlers=None, + connection_limit=512, + deserialize=True, + io_loop=None, + ): self.handlers = { - 'identity': self.identity, - 'connection_stream': self.handle_stream, + "identity": self.identity, + "connection_stream": self.handle_stream, } self.handlers.update(handlers) if blocked_handlers is None: - blocked_handlers = dask.config.get('distributed.%s.blocked-handlers' % type(self).__name__.lower(), []) + blocked_handlers = dask.config.get( + "distributed.%s.blocked-handlers" % type(self).__name__.lower(), [] + ) self.blocked_handlers = blocked_handlers self.stream_handlers = {} self.stream_handlers.update(stream_handlers or {}) - self.id = type(self).__name__ + '-' + str(uuid.uuid4()) + self.id = type(self).__name__ + "-" + str(uuid.uuid4()) self._address = None self._listen_address = None self._port = None @@ -125,31 +151,36 @@ def __init__(self, handlers, blocked_handlers=None, stream_handlers=None, connec self.io_loop = io_loop or IOLoop.current() self.loop = self.io_loop - if not hasattr(self.io_loop, 'profile'): + if not hasattr(self.io_loop, "profile"): ref = weakref.ref(self.io_loop) - if hasattr(self.io_loop, 'closing'): + if hasattr(self.io_loop, "closing"): + def stop(): loop = ref() return loop is None or loop.closing + else: + def stop(): loop = ref() return loop is None or loop._closing self.io_loop.profile = profile.watch( - omit=('profile.py', 'selectors.py'), - interval=dask.config.get('distributed.worker.profile.interval'), - cycle=dask.config.get('distributed.worker.profile.cycle'), - stop=stop, + omit=("profile.py", "selectors.py"), + interval=dask.config.get("distributed.worker.profile.interval"), + cycle=dask.config.get("distributed.worker.profile.cycle"), + stop=stop, ) # Statistics counters for various events with ignoring(ImportError): from .counter import Digest + self.digests = defaultdict(partial(Digest, loop=self.io_loop)) from .counter import Counter + self.counters = defaultdict(partial(Counter, loop=self.io_loop)) self.events = defaultdict(lambda: deque(maxlen=10000)) self.event_counts = defaultdict(lambda: 0) @@ -157,15 +188,18 @@ def stop(): self.periodic_callbacks = dict() pc = PeriodicCallback(self.monitor.update, 500, io_loop=self.io_loop) - self.periodic_callbacks['monitor'] = pc + self.periodic_callbacks["monitor"] = pc self._last_tick = time() pc = PeriodicCallback( - self._measure_tick, - parse_timedelta(dask.config.get('distributed.admin.tick.interval'), default='ms') * 1000, - io_loop=self.io_loop + self._measure_tick, + parse_timedelta( + dask.config.get("distributed.admin.tick.interval"), default="ms" + ) + * 1000, + io_loop=self.io_loop, ) - self.periodic_callbacks['tick'] = pc + self.periodic_callbacks["tick"] = pc self.thread_id = 0 @@ -189,6 +223,7 @@ def start_pcs(): for pc in self.periodic_callbacks.values(): if not pc.is_running(): pc.start() + self.io_loop.add_callback(start_pcs) def stop(self): @@ -209,16 +244,19 @@ def _measure_tick(self): diff = now - self._last_tick self._last_tick = now if diff > tick_maximum_delay: - logger.info("Event loop was unresponsive in %s for %.2fs. " - "This is often caused by long-running GIL-holding " - "functions or moving large chunks of data. " - "This can cause timeouts and instability.", - type(self).__name__, diff) + logger.info( + "Event loop was unresponsive in %s for %.2fs. " + "This is often caused by long-running GIL-holding " + "functions or moving large chunks of data. " + "This can cause timeouts and instability.", + type(self).__name__, + diff, + ) if self.digests is not None: - self.digests['tick-duration'].add(diff) + self.digests["tick-duration"].add(diff) def log_event(self, name, msg): - msg['time'] = time() + msg["time"] = time() if isinstance(name, list): for n in name: self.events[n].append(msg) @@ -263,7 +301,7 @@ def port(self): return self._port def identity(self, comm=None): - return {'type': type(self).__name__, 'id': self.id} + return {"type": type(self).__name__, "id": self.id} def listen(self, port_or_addr=None, listen_args=None): if port_or_addr is None: @@ -275,9 +313,12 @@ def listen(self, port_or_addr=None, listen_args=None): else: addr = port_or_addr assert isinstance(addr, string_types) - self.listener = listen(addr, self.handle_comm, - deserialize=self.deserialize, - connection_args=listen_args) + self.listener = listen( + addr, + self.handle_comm, + deserialize=self.deserialize, + connection_args=listen_args, + ) self.listener.start() @gen.coroutine @@ -307,51 +348,61 @@ def handle_comm(self, comm, shutting_down=shutting_down): logger.debug("Message from %r: %s", address, msg) except EnvironmentError as e: if not shutting_down(): - logger.debug("Lost connection to %r while reading message: %s." - " Last operation: %s", - address, e, op) + logger.debug( + "Lost connection to %r while reading message: %s." + " Last operation: %s", + address, + e, + op, + ) break except Exception as e: logger.exception(e) - yield comm.write(error_message(e, status='uncaught-error')) + yield comm.write(error_message(e, status="uncaught-error")) continue if not isinstance(msg, dict): - raise TypeError("Bad message type. Expected dict, got\n " - + str(msg)) + raise TypeError( + "Bad message type. Expected dict, got\n " + str(msg) + ) try: - op = msg.pop('op') + op = msg.pop("op") except KeyError: raise ValueError( - "Received unexpected message without 'op' key: " % - str(msg) + "Received unexpected message without 'op' key: " % str(msg) ) if self.counters is not None: - self.counters['op'].add(op) + self.counters["op"].add(op) self._comms[comm] = op - serializers = msg.pop('serializers', None) - close_desired = msg.pop('close', False) - reply = msg.pop('reply', True) - if op == 'close': + serializers = msg.pop("serializers", None) + close_desired = msg.pop("close", False) + reply = msg.pop("reply", True) + if op == "close": if reply: - yield comm.write('OK') + yield comm.write("OK") break result = None try: if op in self.blocked_handlers: - _msg = ("The '{op}' handler has been explicitly disallowed " - "in {obj}, possibly due to security concerns.") + _msg = ( + "The '{op}' handler has been explicitly disallowed " + "in {obj}, possibly due to security concerns." + ) exc = ValueError(_msg.format(op=op, obj=type(self).__name__)) handler = raise_later(exc) else: handler = self.handlers[op] except KeyError: - logger.warning("No handler %s found in %s", op, - type(self).__name__, exc_info=True) + logger.warning( + "No handler %s found in %s", + op, + type(self).__name__, + exc_info=True, + ) else: - if serializers is not None and has_keyword(handler, 'serializers'): - msg['serializers'] = serializers # add back in + if serializers is not None and has_keyword(handler, "serializers"): + msg["serializers"] = serializers # add back in logger.debug("Calling into handler %s", handler.__name__) try: @@ -360,19 +411,23 @@ def handle_comm(self, comm, shutting_down=shutting_down): self._ongoing_coroutines.add(result) result = yield result except (CommClosedError, CancelledError) as e: - if self.status == 'running': + if self.status == "running": logger.info("Lost connection to %r: %s", address, e) break except Exception as e: logger.exception(e) - result = error_message(e, status='uncaught-error') + result = error_message(e, status="uncaught-error") - if reply and result != 'dont-reply': + if reply and result != "dont-reply": try: yield comm.write(result, serializers=serializers) except (EnvironmentError, TypeError) as e: - logger.debug("Lost connection to %r while sending result for op %r: %s", - address, op, e) + logger.debug( + "Lost connection to %r while sending result for op %r: %s", + address, + op, + e, + ) break msg = result = None if close_desired: @@ -386,8 +441,9 @@ def handle_comm(self, comm, shutting_down=shutting_down): try: comm.abort() except Exception as e: - logger.error("Failed while closing connection to %r: %s", - address, e) + logger.error( + "Failed while closing connection to %r: %s", address, e + ) @gen.coroutine def handle_stream(self, comm, extra=None, every_cycle=[]): @@ -404,11 +460,11 @@ def handle_stream(self, comm, extra=None, every_cycle=[]): if not comm.closed(): for msg in msgs: - if msg == 'OK': # from close + if msg == "OK": # from close break - op = msg.pop('op') + op = msg.pop("op") if op: - if op == 'close-stream': + if op == "close-stream": closed = True break handler = self.stream_handlers[op] @@ -424,6 +480,7 @@ def handle_stream(self, comm, extra=None, every_cycle=[]): logger.exception(e) if LOG_PDB: import pdb + pdb.set_trace() raise finally: @@ -445,12 +502,11 @@ def close(self): def pingpong(comm): - return b'pong' + return b"pong" @gen.coroutine -def send_recv(comm, reply=True, serializers=None, deserializers=None, - **kwargs): +def send_recv(comm, reply=True, serializers=None, deserializers=None, **kwargs): """ Send and recv with a Comm. Keyword arguments turn into the message @@ -458,16 +514,16 @@ def send_recv(comm, reply=True, serializers=None, deserializers=None, response = yield send_recv(comm, op='ping', reply=True) """ msg = kwargs - msg['reply'] = reply - please_close = kwargs.get('close') + msg["reply"] = reply + please_close = kwargs.get("close") force_close = False if deserializers is None: deserializers = serializers if deserializers is not None: - msg['serializers'] = deserializers + msg["serializers"] = deserializers try: - yield comm.write(msg, serializers=serializers, on_error='raise') + yield comm.write(msg, serializers=serializers, on_error="raise") if reply: response = yield comm.read(deserializers=deserializers) else: @@ -482,11 +538,11 @@ def send_recv(comm, reply=True, serializers=None, deserializers=None, elif force_close: comm.abort() - if isinstance(response, dict) and response.get('status') == 'uncaught-error': + if isinstance(response, dict) and response.get("status") == "uncaught-error": if comm.deserialize: six.reraise(*clean_exception(**response)) else: - raise Exception(response['text']) + raise Exception(response["text"]) raise gen.Return(response) @@ -514,16 +570,25 @@ class rpc(object): >>> remote.close_comms() # doctest: +SKIP """ + active = weakref.WeakSet() comms = () address = None - def __init__(self, arg=None, comm=None, deserialize=True, timeout=None, - connection_args=None, serializers=None, deserializers=None): + def __init__( + self, + arg=None, + comm=None, + deserialize=True, + timeout=None, + connection_args=None, + serializers=None, + deserializers=None, + ): self.comms = {} self.address = coerce_to_address(arg) self.timeout = timeout - self.status = 'running' + self.status = "running" self.deserialize = deserialize self.serializers = serializers self.deserializers = deserializers if deserializers is not None else serializers @@ -549,7 +614,7 @@ def live_comm(self): As is done in __getattr__ below. """ - if self.status == 'closed': + if self.status == "closed": raise RPCClosed("RPC Closed") to_clear = set() open = False @@ -561,19 +626,21 @@ def live_comm(self): for s in to_clear: del self.comms[s] if not open or comm.closed(): - comm = yield connect(self.address, self.timeout, - deserialize=self.deserialize, - connection_args=self.connection_args) - self.comms[comm] = False # mark as taken + comm = yield connect( + self.address, + self.timeout, + deserialize=self.deserialize, + connection_args=self.connection_args, + ) + self.comms[comm] = False # mark as taken raise gen.Return(comm) def close_comms(self): - @gen.coroutine def _close_comm(comm): # Make sure we tell the peer to close try: - yield comm.write({'op': 'close', 'reply': False}) + yield comm.write({"op": "close", "reply": False}) yield comm.close() except EnvironmentError: comm.abort() @@ -586,25 +653,27 @@ def _close_comm(comm): def __getattr__(self, key): @gen.coroutine def send_recv_from_rpc(**kwargs): - if self.serializers is not None and kwargs.get('serializers') is None: - kwargs['serializers'] = self.serializers - if self.deserializers is not None and kwargs.get('deserializers') is None: - kwargs['deserializers'] = self.deserializers + if self.serializers is not None and kwargs.get("serializers") is None: + kwargs["serializers"] = self.serializers + if self.deserializers is not None and kwargs.get("deserializers") is None: + kwargs["deserializers"] = self.deserializers try: comm = yield self.live_comm() result = yield send_recv(comm=comm, op=key, **kwargs) except (RPCClosed, CommClosedError) as e: - raise e.__class__("%s: while trying to call remote method %r" - % (e, key,)) + raise e.__class__( + "%s: while trying to call remote method %r" % (e, key) + ) self.comms[comm] = True # mark as open raise gen.Return(result) + return send_recv_from_rpc def close_rpc(self): - if self.status != 'closed': + if self.status != "closed": rpc.active.discard(self) - self.status = 'closed' + self.status = "closed" self.close_comms() def __enter__(self): @@ -614,13 +683,14 @@ def __exit__(self, *args): self.close_rpc() def __del__(self): - if self.status != 'closed': + if self.status != "closed": rpc.active.discard(self) - self.status = 'closed' + self.status = "closed" still_open = [comm for comm in self.comms if not comm.closed()] if still_open: - logger.warning("rpc object %s deleted with %d open comms", - self, len(still_open)) + logger.warning( + "rpc object %s deleted with %d open comms", self, len(still_open) + ) for comm in still_open: comm.abort() @@ -648,10 +718,10 @@ def address(self): def __getattr__(self, key): @gen.coroutine def send_recv_from_rpc(**kwargs): - if self.serializers is not None and kwargs.get('serializers') is None: - kwargs['serializers'] = self.serializers - if self.deserializers is not None and kwargs.get('deserializers') is None: - kwargs['deserializers'] = self.deserializers + if self.serializers is not None and kwargs.get("serializers") is None: + kwargs["serializers"] = self.serializers + if self.deserializers is not None and kwargs.get("deserializers") is None: + kwargs["deserializers"] = self.deserializers comm = yield self.pool.connect(self.addr) try: result = yield send_recv(comm=comm, op=key, **kwargs) @@ -659,6 +729,7 @@ def send_recv_from_rpc(**kwargs): self.pool.reuse(self.addr, comm) raise gen.Return(result) + return send_recv_from_rpc def close_rpc(self): @@ -709,12 +780,15 @@ class ConnectionPool(object): Whether or not to deserialize data by default or pass it through """ - def __init__(self, limit=512, - deserialize=True, - serializers=None, - deserializers=None, - connection_args=None): - self.limit = limit # Max number of open comms + def __init__( + self, + limit=512, + deserialize=True, + serializers=None, + deserializers=None, + connection_args=None, + ): + self.limit = limit # Max number of open comms # Invariant: len(available) == open - active self.available = defaultdict(set) # Invariant: len(occupied) == active @@ -734,15 +808,14 @@ def open(self): return self.active + sum(map(len, self.available.values())) def __repr__(self): - return "" % (self.open, - self.active) + return "" % (self.open, self.active) def __call__(self, addr=None, ip=None, port=None): """ Cached rpc objects """ addr = addr_from_args(addr=addr, ip=ip, port=port) - return PooledRPCCall(addr, self, - serializers=self.serializers, - deserializers=self.deserializers) + return PooledRPCCall( + addr, self, serializers=self.serializers, deserializers=self.deserializers + ) @gen.coroutine def connect(self, addr, timeout=None): @@ -763,9 +836,12 @@ def connect(self, addr, timeout=None): yield self.event.wait() try: - comm = yield connect(addr, timeout=timeout, - deserialize=self.deserialize, - connection_args=self.connection_args) + comm = yield connect( + addr, + timeout=timeout, + deserialize=self.deserialize, + connection_args=self.connection_args, + ) except Exception: raise occupied.add(comm) @@ -794,8 +870,9 @@ def collect(self): """ Collect open but unused communications, to allow opening other ones. """ - logger.info("Collecting unused comms. open: %d, active: %d", - self.open, self.active) + logger.info( + "Collecting unused comms. open: %d, active: %d", self.open, self.active + ) for addr, comms in self.available.items(): for comm in comms: comm.close() @@ -838,7 +915,7 @@ def coerce_to_address(o): return normalize_address(o) -def error_message(e, status='error'): +def error_message(e, status="error"): """ Produce message to send back given an exception has occurred This does the following: @@ -865,17 +942,14 @@ def error_message(e, status='error'): try: tb2 = protocol.pickle.dumps(tb) except Exception: - tb = tb2 = ''.join(traceback.format_tb(tb)) + tb = tb2 = "".join(traceback.format_tb(tb)) if len(tb2) > 10000: tb_result = None else: tb_result = protocol.to_serialize(tb) - return {'status': status, - 'exception': e4, - 'traceback': tb_result, - 'text': str(e2)} + return {"status": status, "exception": e4, "traceback": tb_result, "text": str(e2)} def clean_exception(exception, traceback, **kwargs): diff --git a/distributed/counter.py b/distributed/counter.py index 8d76def7189..d5a3181b112 100644 --- a/distributed/counter.py +++ b/distributed/counter.py @@ -12,6 +12,7 @@ except ImportError: pass else: + class Digest(object): def __init__(self, loop=None, intervals=(5, 60, 3600)): self.intervals = intervals diff --git a/distributed/deploy/__init__.py b/distributed/deploy/__init__.py index c283cde8394..35abf0a6439 100644 --- a/distributed/deploy/__init__.py +++ b/distributed/deploy/__init__.py @@ -5,5 +5,6 @@ from .cluster import Cluster from .local import LocalCluster from .adaptive import Adaptive + with ignoring(ImportError): from .ssh import SSHCluster diff --git a/distributed/deploy/adaptive.py b/distributed/deploy/adaptive.py index 62d308c6e22..890e30c027f 100644 --- a/distributed/deploy/adaptive.py +++ b/distributed/deploy/adaptive.py @@ -86,18 +86,30 @@ class Adaptive(object): the cluster's ``scale_up`` method. ''' - def __init__(self, scheduler, cluster=None, interval='1s', startup_cost='1s', - scale_factor=2, minimum=0, maximum=None, wait_count=3, - target_duration='5s', worker_key=lambda x: x, **kwargs): - interval = parse_timedelta(interval, default='ms') + def __init__( + self, + scheduler, + cluster=None, + interval="1s", + startup_cost="1s", + scale_factor=2, + minimum=0, + maximum=None, + wait_count=3, + target_duration="5s", + worker_key=lambda x: x, + **kwargs + ): + interval = parse_timedelta(interval, default="ms") self.worker_key = worker_key self.scheduler = scheduler self.cluster = cluster - self.startup_cost = parse_timedelta(startup_cost, default='s') + self.startup_cost = parse_timedelta(startup_cost, default="s") self.scale_factor = scale_factor if self.cluster: - self._adapt_callback = PeriodicCallback(self._adapt, interval * 1000, - io_loop=scheduler.loop) + self._adapt_callback = PeriodicCallback( + self._adapt, interval * 1000, io_loop=scheduler.loop + ) self.scheduler.loop.add_callback(self._adapt_callback.start) self._adapting = False self._workers_to_close_kwargs = kwargs @@ -108,7 +120,7 @@ def __init__(self, scheduler, cluster=None, interval='1s', startup_cost='1s', self.wait_count = wait_count self.target_duration = parse_timedelta(target_duration) - self.scheduler.handlers['adaptive_recommendations'] = self.recommendations + self.scheduler.handlers["adaptive_recommendations"] = self.recommendations def stop(self): if self.cluster: @@ -129,8 +141,11 @@ def needs_cpu(self): total_cores = sum([ws.ncores for ws in self.scheduler.workers.values()]) if total_occupancy / (total_cores + 1e-9) > self.startup_cost * 2: - logger.info("CPU limit exceeded [%d occupancy / %d cores]", - total_occupancy, total_cores) + logger.info( + "CPU limit exceeded [%d occupancy / %d cores]", + total_occupancy, + total_cores, + ) return True else: return False @@ -144,8 +159,9 @@ def needs_memory(self): Returns ``True`` if the required bytes in distributed memory is some factor larger than the actual distributed memory available. """ - limit_bytes = {addr: ws.memory_limit - for addr, ws in self.scheduler.workers.items()} + limit_bytes = { + addr: ws.memory_limit for addr, ws in self.scheduler.workers.items() + } worker_bytes = [ws.nbytes for ws in self.scheduler.workers.values()] limit = sum(limit_bytes.values()) @@ -221,25 +237,24 @@ def workers_to_close(self, **kwargs): kw.update(kwargs) if self.maximum is not None and len(self.scheduler.workers) > self.maximum: - kw['n'] = len(self.scheduler.workers) - self.maximum + kw["n"] = len(self.scheduler.workers) - self.maximum L = self.scheduler.workers_to_close(**kw) if len(self.scheduler.workers) - len(L) < self.minimum: - L = L[:len(self.scheduler.workers) - self.minimum] + L = L[: len(self.scheduler.workers) - self.minimum] return L @gen.coroutine def _retire_workers(self, workers=None): if workers is None: - workers = self.workers_to_close(key=self.worker_key, - minimum=self.minimum) + workers = self.workers_to_close(key=self.worker_key, minimum=self.minimum) if not workers: raise gen.Return(workers) with log_errors(): - yield self.scheduler.retire_workers(workers=workers, - remove=True, - close_workers=True) + yield self.scheduler.retire_workers( + workers=workers, remove=True, close_workers=True + ) logger.info("Retiring workers %s", workers) f = self.cluster.scale_down(workers) @@ -263,33 +278,32 @@ def get_scale_up_kwargs(self): -------- LocalCluster.scale_up """ - target = math.ceil(self.scheduler.total_occupancy / - self.target_duration) - instances = max(1, - len(self.scheduler.workers) * self.scale_factor, - target, - self.minimum) + target = math.ceil(self.scheduler.total_occupancy / self.target_duration) + instances = max( + 1, len(self.scheduler.workers) * self.scale_factor, target, self.minimum + ) if self.maximum: instances = min(self.maximum, instances) instances = int(instances) logger.info("Scaling up to %d workers", instances) - return {'n': instances} + return {"n": instances} def recommendations(self, comm=None): should_scale_up = self.should_scale_up() - workers = set(self.workers_to_close(key=self.worker_key, - minimum=self.minimum)) + workers = set(self.workers_to_close(key=self.worker_key, minimum=self.minimum)) if should_scale_up and workers: logger.info("Attempting to scale up and scale down simultaneously.") self.close_counts.clear() - return {'status': 'error', - 'msg': 'Trying to scale up and down simultaneously'} + return { + "status": "error", + "msg": "Trying to scale up and down simultaneously", + } elif should_scale_up: self.close_counts.clear() - return toolz.merge({'status': 'up'}, self.get_scale_up_kwargs()) + return toolz.merge({"status": "up"}, self.get_scale_up_kwargs()) elif workers: d = {} @@ -307,7 +321,7 @@ def recommendations(self, comm=None): self.close_counts = d if to_close: - return {'status': 'down', 'workers': to_close} + return {"status": "down", "workers": to_close} else: self.close_counts.clear() return None @@ -322,16 +336,16 @@ def _adapt(self): recommendations = self.recommendations() if not recommendations: return - status = recommendations.pop('status') - if status == 'up': + status = recommendations.pop("status") + if status == "up": f = self.cluster.scale_up(**recommendations) - self.log.append((time(), 'up', recommendations)) + self.log.append((time(), "up", recommendations)) if gen.is_future(f): yield f - elif status == 'down': - self.log.append((time(), 'down', recommendations['workers'])) - workers = yield self._retire_workers(workers=recommendations['workers']) + elif status == "down": + self.log.append((time(), "down", recommendations["workers"])) + workers = yield self._retire_workers(workers=recommendations["workers"]) finally: self._adapting = False diff --git a/distributed/deploy/cluster.py b/distributed/deploy/cluster.py index 3bc2b2d9124..f170d4ea5ad 100644 --- a/distributed/deploy/cluster.py +++ b/distributed/deploy/cluster.py @@ -51,6 +51,7 @@ def scale_down(self, workers: List[str]): -------- LocalCluster: a simple implementation with local workers """ + def adapt(self, **kwargs): """ Turn on adaptivity @@ -62,7 +63,7 @@ def adapt(self, **kwargs): """ with ignoring(AttributeError): self._adaptive.stop() - if not hasattr(self, '_adaptive_options'): + if not hasattr(self, "_adaptive_options"): self._adaptive_options = {} self._adaptive_options.update(kwargs) self._adaptive = Adaptive(self.scheduler, self, **self._adaptive_options) @@ -74,9 +75,9 @@ def scheduler_address(self): @property def dashboard_link(self): - template = dask.config.get('distributed.dashboard.link') - host = self.scheduler.address.split('://')[1].split(':')[0] - port = self.scheduler.services['bokeh'].port + template = dask.config.get("distributed.dashboard.link") + host = self.scheduler.address.split("://")[1].split(":")[0] + port = self.scheduler.services["bokeh"].port return template.format(host=host, port=port, **os.environ) def scale(self, n): @@ -101,9 +102,12 @@ def scale(self, n): self.scheduler.loop.add_callback(self.scale_up, n) else: to_close = self.scheduler.workers_to_close( - n=len(self.scheduler.workers) - n) + n=len(self.scheduler.workers) - n + ) logger.debug("Closing workers: %s", to_close) - self.scheduler.loop.add_callback(self.scheduler.retire_workers, workers=to_close) + self.scheduler.loop.add_callback( + self.scheduler.retire_workers, workers=to_close + ) self.scheduler.loop.add_callback(self.scale_down, to_close) def _widget_status(self): @@ -132,7 +136,11 @@ def _widget_status(self): Memory %s -""" % (workers, cores, memory) +""" % ( + workers, + cores, + memory, + ) return text def _widget(self): @@ -144,38 +152,39 @@ def _widget(self): from ipywidgets import Layout, VBox, HBox, IntText, Button, HTML, Accordion - layout = Layout(width='150px') + layout = Layout(width="150px") - if 'bokeh' in self.scheduler.services: + if "bokeh" in self.scheduler.services: link = self.dashboard_link - link = '

        Dashboard: %s

        \n' % (link, link) + link = '

        Dashboard: %s

        \n' % ( + link, + link, + ) else: - link = '' + link = "" - title = '

        %s

        ' % type(self).__name__ + title = "

        %s

        " % type(self).__name__ title = HTML(title) dashboard = HTML(link) - status = HTML(self._widget_status(), layout=Layout(min_width='150px')) + status = HTML(self._widget_status(), layout=Layout(min_width="150px")) - request = IntText(0, description='Workers', layout=layout) - scale = Button(description='Scale', layout=layout) + request = IntText(0, description="Workers", layout=layout) + scale = Button(description="Scale", layout=layout) - minimum = IntText(0, description='Minimum', layout=layout) - maximum = IntText(0, description='Maximum', layout=layout) - adapt = Button(description='Adapt', layout=layout) + minimum = IntText(0, description="Minimum", layout=layout) + maximum = IntText(0, description="Maximum", layout=layout) + adapt = Button(description="Adapt", layout=layout) - accordion = Accordion([HBox([request, scale]), - HBox([minimum, maximum, adapt])], - layout=Layout(min_width='500px')) + accordion = Accordion( + [HBox([request, scale]), HBox([minimum, maximum, adapt])], + layout=Layout(min_width="500px"), + ) accordion.selected_index = None - accordion.set_title(0, 'Manual Scaling') - accordion.set_title(1, 'Adaptive Scaling') + accordion.set_title(0, "Manual Scaling") + accordion.set_title(1, "Adaptive Scaling") - box = VBox([title, - HBox([status, - accordion]), - dashboard]) + box = VBox([title, HBox([status, accordion]), dashboard]) self._cached_widget = box @@ -199,7 +208,7 @@ def update(): status.value = self._widget_status() pc = PeriodicCallback(update, 500, io_loop=self.scheduler.loop) - self.scheduler.periodic_callbacks['cluster-repr'] = pc + self.scheduler.periodic_callbacks["cluster-repr"] = pc pc.start() return box diff --git a/distributed/deploy/local.py b/distributed/deploy/local.py index ececc9fe12c..22a796bbb78 100644 --- a/distributed/deploy/local.py +++ b/distributed/deploy/local.py @@ -14,8 +14,16 @@ from .cluster import Cluster from ..compatibility import get_thread_identity from ..core import CommClosedError -from ..utils import (sync, ignoring, All, silence_logging, LoopRunner, - log_errors, thread_state, parse_timedelta) +from ..utils import ( + sync, + ignoring, + All, + silence_logging, + LoopRunner, + log_errors, + thread_state, + parse_timedelta, +) from ..nanny import Nanny from ..scheduler import Scheduler from ..worker import Worker, parse_memory_limit, _ncores @@ -87,18 +95,35 @@ class LocalCluster(Cluster): >>> LocalCluster(service_kwargs={'bokeh': {'prefix': '/foo'}}) # doctest: +SKIP """ - def __init__(self, n_workers=None, threads_per_worker=None, processes=True, - loop=None, start=None, ip=None, scheduler_port=0, - silence_logs=logging.WARN, dashboard_address=':8787', - diagnostics_port=None, - services=None, worker_services=None, service_kwargs=None, - asynchronous=False, security=None, protocol=None, - blocked_handlers=None, **worker_kwargs): + + def __init__( + self, + n_workers=None, + threads_per_worker=None, + processes=True, + loop=None, + start=None, + ip=None, + scheduler_port=0, + silence_logs=logging.WARN, + dashboard_address=":8787", + diagnostics_port=None, + services=None, + worker_services=None, + service_kwargs=None, + asynchronous=False, + security=None, + protocol=None, + blocked_handlers=None, + **worker_kwargs + ): if start is not None: - msg = ("The start= parameter is deprecated. " - "LocalCluster always starts. " - "For asynchronous operation use the following: \n\n" - " cluster = yield LocalCluster(asynchronous=True)") + msg = ( + "The start= parameter is deprecated. " + "LocalCluster always starts. " + "For asynchronous operation use the following: \n\n" + " cluster = yield LocalCluster(asynchronous=True)" + ) raise ValueError(msg) if diagnostics_port is not None: @@ -112,16 +137,16 @@ def __init__(self, n_workers=None, threads_per_worker=None, processes=True, self.processes = processes if protocol is None: - if ip and '://' in ip: - protocol = ip.split('://')[0] + if ip and "://" in ip: + protocol = ip.split("://")[0] elif security: - protocol = 'tls://' + protocol = "tls://" elif not self.processes and not scheduler_port: - protocol = 'inproc://' + protocol = "inproc://" else: - protocol = 'tcp://' - if not protocol.endswith('://'): - protocol = protocol + '://' + protocol = "tcp://" + if not protocol.endswith("://"): + protocol = protocol + "://" self.protocol = protocol self.silence_logs = silence_logs @@ -142,13 +167,12 @@ def __init__(self, n_workers=None, threads_per_worker=None, processes=True, if n_workers and threads_per_worker is None: # Overcommit threads per worker, rather than undercommit threads_per_worker = max(1, int(math.ceil(_ncores / n_workers))) - if n_workers and 'memory_limit' not in worker_kwargs: - worker_kwargs['memory_limit'] = parse_memory_limit('auto', 1, n_workers) + if n_workers and "memory_limit" not in worker_kwargs: + worker_kwargs["memory_limit"] = parse_memory_limit("auto", 1, n_workers) - worker_kwargs.update({ - 'ncores': threads_per_worker, - 'services': worker_services, - }) + worker_kwargs.update( + {"ncores": threads_per_worker, "services": worker_services} + ) self._loop_runner = LoopRunner(loop=loop, asynchronous=asynchronous) self.loop = self._loop_runner.loop @@ -160,29 +184,35 @@ def __init__(self, n_workers=None, threads_per_worker=None, processes=True, except ImportError: logger.debug("To start diagnostics web server please install Bokeh") else: - services[('bokeh', dashboard_address)] = (BokehScheduler, (service_kwargs or {}).get('bokeh', {})) - worker_services[('bokeh', 0)] = BokehWorker + services[("bokeh", dashboard_address)] = ( + BokehScheduler, + (service_kwargs or {}).get("bokeh", {}), + ) + worker_services[("bokeh", 0)] = BokehWorker - self.scheduler = Scheduler(loop=self.loop, - services=services, - security=security, - blocked_handlers=blocked_handlers) + self.scheduler = Scheduler( + loop=self.loop, + services=services, + security=security, + blocked_handlers=blocked_handlers, + ) self.scheduler_port = scheduler_port self.workers = [] self.worker_kwargs = worker_kwargs if security: - self.worker_kwargs['security'] = security + self.worker_kwargs["security"] = security self.start(ip=ip, n_workers=n_workers) clusters_to_close.add(self) def __repr__(self): - return ('LocalCluster(%r, workers=%d, ncores=%d)' % - (self.scheduler_address, len(self.workers), - sum(w.ncores for w in self.workers)) - ) + return "LocalCluster(%r, workers=%d, ncores=%d)" % ( + self.scheduler_address, + len(self.workers), + sum(w.ncores for w in self.workers), + ) def __await__(self): return self._started.__await__() @@ -190,18 +220,18 @@ def __await__(self): @property def asynchronous(self): return ( - self._asynchronous or - getattr(thread_state, 'asynchronous', False) or - hasattr(self.loop, '_thread_identity') and self.loop._thread_identity == get_thread_identity() + self._asynchronous + or getattr(thread_state, "asynchronous", False) + or hasattr(self.loop, "_thread_identity") + and self.loop._thread_identity == get_thread_identity() ) def sync(self, func, *args, **kwargs): - if kwargs.pop('asynchronous', None) or self.asynchronous: - callback_timeout = kwargs.pop('callback_timeout', None) + if kwargs.pop("asynchronous", None) or self.asynchronous: + callback_timeout = kwargs.pop("callback_timeout", None) future = func(*args, **kwargs) if callback_timeout is not None: - future = gen.with_timeout(timedelta(seconds=callback_timeout), - future) + future = gen.with_timeout(timedelta(seconds=callback_timeout), future) return future else: return sync(self.loop, func, *args, **kwargs) @@ -218,52 +248,56 @@ def _start(self, ip=None, n_workers=0): """ Start all cluster services. """ - if self.status == 'running': + if self.status == "running": return - if self.protocol == 'inproc://': + if self.protocol == "inproc://": address = self.protocol else: if ip is None: - ip = '127.0.0.1' + ip = "127.0.0.1" - if '://' in ip: + if "://" in ip: address = ip else: address = self.protocol + ip if self.scheduler_port: - address += ':' + str(self.scheduler_port) + address += ":" + str(self.scheduler_port) self.scheduler.start(address) yield [self._start_worker(**self.worker_kwargs) for i in range(n_workers)] - self.status = 'running' + self.status = "running" raise gen.Return(self) @gen.coroutine def _start_worker(self, death_timeout=60, **kwargs): - if self.status and self.status.startswith('clos'): + if self.status and self.status.startswith("clos"): warnings.warn("Tried to start a worker while status=='%s'" % self.status) return if self.processes: W = Nanny - kwargs['quiet'] = True + kwargs["quiet"] = True else: W = Worker - w = yield W(self.scheduler.address, loop=self.loop, - death_timeout=death_timeout, - silence_logs=self.silence_logs, **kwargs) + w = yield W( + self.scheduler.address, + loop=self.loop, + death_timeout=death_timeout, + silence_logs=self.silence_logs, + **kwargs + ) self.workers.append(w) - while w.status != 'closed' and w.worker_address not in self.scheduler.workers: + while w.status != "closed" and w.worker_address not in self.scheduler.workers: yield gen.sleep(0.01) - if w.status == 'closed' and self.scheduler.status == 'running': + if w.status == "closed" and self.scheduler.status == "running": self.workers.remove(w) raise gen.TimeoutError("Worker failed to start") @@ -308,11 +342,11 @@ def stop_worker(self, w): self.sync(self._stop_worker, w) @gen.coroutine - def _close(self, timeout='2s'): + def _close(self, timeout="2s"): # Can be 'closing' as we're called by close() below - if self.status == 'closed': + if self.status == "closed": return - self.status = 'closing' + self.status = "closing" self.scheduler.clear_task_state() @@ -328,11 +362,11 @@ def _close(self, timeout='2s'): yield self.scheduler.close(fast=True) del self.workers[:] finally: - self.status = 'closed' + self.status = "closed" def close(self, timeout=20): """ Close the cluster """ - if self.status == 'closed': + if self.status == "closed": return try: @@ -340,9 +374,11 @@ def close(self, timeout=20): except RuntimeError: # IOLoop is closed result = None - if hasattr(self, '_old_logging_level'): + if hasattr(self, "_old_logging_level"): if self.asynchronous: - result.add_done_callback(lambda _: silence_logging(self._old_logging_level)) + result.add_done_callback( + lambda _: silence_logging(self._old_logging_level) + ) else: silence_logging(self._old_logging_level) @@ -362,11 +398,13 @@ def scale_up(self, n, **kwargs): """ with log_errors(): kwargs2 = toolz.merge(self.worker_kwargs, kwargs) - yield [self._start_worker(**kwargs2) - for i in range(n - len(self.scheduler.workers))] + yield [ + self._start_worker(**kwargs2) + for i in range(n - len(self.scheduler.workers)) + ] # clean up any closed worker - self.workers = [w for w in self.workers if w.status != 'closed'] + self.workers = [w for w in self.workers if w.status != "closed"] @gen.coroutine def scale_down(self, workers): @@ -380,7 +418,7 @@ def scale_down(self, workers): """ with log_errors(): # clean up any closed worker - self.workers = [w for w in self.workers if w.status != 'closed'] + self.workers = [w for w in self.workers if w.status != "closed"] workers = set(workers) # we might be given addresses @@ -413,7 +451,7 @@ def scheduler_address(self): try: return self.scheduler.address except ValueError: - return '' + return "" def nprocesses_nthreads(n): diff --git a/distributed/deploy/ssh.py b/distributed/deploy/ssh.py index 20888471ff8..ba8ed01d1c7 100644 --- a/distributed/deploy/ssh.py +++ b/distributed/deploy/ssh.py @@ -25,20 +25,21 @@ # These are handy for creating colorful terminal output to enhance readability # of the output generated by dask-ssh. class bcolors: - HEADER = '\033[95m' - OKBLUE = '\033[94m' - OKGREEN = '\033[92m' - WARNING = '\033[93m' - FAIL = '\033[91m' - ENDC = '\033[0m' - BOLD = '\033[1m' - UNDERLINE = '\033[4m' + HEADER = "\033[95m" + OKBLUE = "\033[94m" + OKGREEN = "\033[92m" + WARNING = "\033[93m" + FAIL = "\033[91m" + ENDC = "\033[0m" + BOLD = "\033[1m" + UNDERLINE = "\033[4m" def async_ssh(cmd_dict): import paramiko from paramiko.buffered_pipe import PipeTimeout - from paramiko.ssh_exception import (SSHException, PasswordRequiredException) + from paramiko.ssh_exception import SSHException, PasswordRequiredException + ssh = paramiko.SSHClient() ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) @@ -46,28 +47,40 @@ def async_ssh(cmd_dict): while True: # Be robust to transient SSH failures. try: # Set paramiko logging to WARN or higher to squelch INFO messages. - logging.getLogger('paramiko').setLevel(logging.WARN) - - ssh.connect(hostname=cmd_dict['address'], - username=cmd_dict['ssh_username'], - port=cmd_dict['ssh_port'], - key_filename=cmd_dict['ssh_private_key'], - compress=True, - timeout=20, - banner_timeout=20) # Helps prevent timeouts when many concurrent ssh connections are opened. + logging.getLogger("paramiko").setLevel(logging.WARN) + + ssh.connect( + hostname=cmd_dict["address"], + username=cmd_dict["ssh_username"], + port=cmd_dict["ssh_port"], + key_filename=cmd_dict["ssh_private_key"], + compress=True, + timeout=20, + banner_timeout=20, + ) # Helps prevent timeouts when many concurrent ssh connections are opened. # Connection successful, break out of while loop break - except (SSHException, - PasswordRequiredException) as e: - - print('[ dask-ssh ] : ' + bcolors.FAIL + - 'SSH connection error when connecting to {addr}:{port}' - 'to run \'{cmd}\''.format(addr=cmd_dict['address'], - port=cmd_dict['ssh_port'], - cmd=cmd_dict['cmd']) + bcolors.ENDC) - - print(bcolors.FAIL + ' SSH reported this exception: ' + str(e) + bcolors.ENDC) + except (SSHException, PasswordRequiredException) as e: + + print( + "[ dask-ssh ] : " + + bcolors.FAIL + + "SSH connection error when connecting to {addr}:{port}" + "to run '{cmd}'".format( + addr=cmd_dict["address"], + port=cmd_dict["ssh_port"], + cmd=cmd_dict["cmd"], + ) + + bcolors.ENDC + ) + + print( + bcolors.FAIL + + " SSH reported this exception: " + + str(e) + + bcolors.ENDC + ) # Print an exception traceback traceback.print_exc() @@ -77,18 +90,23 @@ def async_ssh(cmd_dict): # attempts to retry. retries += 1 if retries >= 3: - print('[ dask-ssh ] : ' - + bcolors.FAIL - + 'SSH connection failed after 3 retries. Exiting.' - + bcolors.ENDC) + print( + "[ dask-ssh ] : " + + bcolors.FAIL + + "SSH connection failed after 3 retries. Exiting." + + bcolors.ENDC + ) # Connection failed after multiple attempts. Terminate this thread. os._exit(1) # Wait a moment before retrying - print(' ' + bcolors.FAIL + - 'Retrying... (attempt {n}/{total})'.format(n=retries, total=3) + - bcolors.ENDC) + print( + " " + + bcolors.FAIL + + "Retrying... (attempt {n}/{total})".format(n=retries, total=3) + + bcolors.ENDC + ) time.sleep(1) @@ -99,9 +117,10 @@ def async_ssh(cmd_dict): # before the command is run. This should help to ensure that important # aspects of the environment like PATH and PYTHONPATH are configured. - print('[ {label} ] : {cmd}'.format(label=cmd_dict['label'], - cmd=cmd_dict['cmd'])) - stdin, stdout, stderr = ssh.exec_command('$SHELL -i -c \'' + cmd_dict['cmd'] + '\'', get_pty=True) + print("[ {label} ] : {cmd}".format(label=cmd_dict["label"], cmd=cmd_dict["cmd"])) + stdin, stdout, stderr = ssh.exec_command( + "$SHELL -i -c '" + cmd_dict["cmd"] + "'", get_pty=True + ) # Set up channel timeout (which we rely on below to make readline() non-blocking) channel = stdout.channel @@ -113,11 +132,14 @@ def read_from_stdout(): """ try: line = stdout.readline() - while len(line) > 0: # Loops until a timeout exception occurs + while len(line) > 0: # Loops until a timeout exception occurs line = line.rstrip() - logger.debug('stdout from ssh channel: %s', line) - cmd_dict['output_queue'].put('[ {label} ] : {output}'.format(label=cmd_dict['label'], - output=line)) + logger.debug("stdout from ssh channel: %s", line) + cmd_dict["output_queue"].put( + "[ {label} ] : {output}".format( + label=cmd_dict["label"], output=line + ) + ) line = stdout.readline() except (PipeTimeout, socket.timeout): pass @@ -130,9 +152,13 @@ def read_from_stderr(): line = stderr.readline() while len(line) > 0: line = line.rstrip() - logger.debug('stderr from ssh channel: %s', line) - cmd_dict['output_queue'].put('[ {label} ] : '.format(label=cmd_dict['label']) + - bcolors.FAIL + '{output}'.format(output=line) + bcolors.ENDC) + logger.debug("stderr from ssh channel: %s", line) + cmd_dict["output_queue"].put( + "[ {label} ] : ".format(label=cmd_dict["label"]) + + bcolors.FAIL + + "{output}".format(output=line) + + bcolors.ENDC + ) line = stderr.readline() except (PipeTimeout, socket.timeout): pass @@ -149,15 +175,18 @@ def communicate(): # terminate. if channel.exit_status_ready(): exit_status = channel.recv_exit_status() - cmd_dict['output_queue'].put('[ {label} ] : '.format(label=cmd_dict['label']) + - bcolors.FAIL + - "remote process exited with exit status " + - str(exit_status) + bcolors.ENDC) + cmd_dict["output_queue"].put( + "[ {label} ] : ".format(label=cmd_dict["label"]) + + bcolors.FAIL + + "remote process exited with exit status " + + str(exit_status) + + bcolors.ENDC + ) return True # Wait for a message on the input_queue. Any message received signals this # thread to shut itself down. - while cmd_dict['input_queue'].empty(): + while cmd_dict["input_queue"].empty(): # Kill some time so that this thread does not hog the CPU. time.sleep(1.0) if communicate(): @@ -166,7 +195,7 @@ def communicate(): # Ctrl-C the executing command and wait a bit for command to end cleanly start = time.time() while time.time() < start + 5.0: - channel.send(b'\x03') # Ctrl-C + channel.send(b"\x03") # Ctrl-C if communicate(): break time.sleep(1.0) @@ -176,63 +205,87 @@ def communicate(): ssh.close() -def start_scheduler(logdir, addr, port, ssh_username, ssh_port, ssh_private_key, remote_python=None): - cmd = '{python} -m distributed.cli.dask_scheduler --port {port}'.format( - python=remote_python or sys.executable, port=port, logdir=logdir) +def start_scheduler( + logdir, addr, port, ssh_username, ssh_port, ssh_private_key, remote_python=None +): + cmd = "{python} -m distributed.cli.dask_scheduler --port {port}".format( + python=remote_python or sys.executable, port=port, logdir=logdir + ) # Optionally re-direct stdout and stderr to a logfile if logdir is not None: - cmd = 'mkdir -p {logdir} && '.format(logdir=logdir) + cmd - cmd += '&> {logdir}/dask_scheduler_{addr}:{port}.log'.format(addr=addr, - port=port, logdir=logdir) + cmd = "mkdir -p {logdir} && ".format(logdir=logdir) + cmd + cmd += "&> {logdir}/dask_scheduler_{addr}:{port}.log".format( + addr=addr, port=port, logdir=logdir + ) # Format output labels we can prepend to each line of output, and create # a 'status' key to keep track of jobs that terminate prematurely. - label = (bcolors.BOLD + - 'scheduler {addr}:{port}'.format(addr=addr, port=port) + - bcolors.ENDC) + label = ( + bcolors.BOLD + + "scheduler {addr}:{port}".format(addr=addr, port=port) + + bcolors.ENDC + ) # Create a command dictionary, which contains everything we need to run and # interact with this command. input_queue = Queue() output_queue = Queue() - cmd_dict = {'cmd': cmd, 'label': label, 'address': addr, 'port': port, - 'input_queue': input_queue, 'output_queue': output_queue, - 'ssh_username': ssh_username, 'ssh_port': ssh_port, - 'ssh_private_key': ssh_private_key} + cmd_dict = { + "cmd": cmd, + "label": label, + "address": addr, + "port": port, + "input_queue": input_queue, + "output_queue": output_queue, + "ssh_username": ssh_username, + "ssh_port": ssh_port, + "ssh_private_key": ssh_private_key, + } # Start the thread thread = Thread(target=async_ssh, args=[cmd_dict]) thread.daemon = True thread.start() - return merge(cmd_dict, {'thread': thread}) - - -def start_worker(logdir, scheduler_addr, scheduler_port, worker_addr, nthreads, nprocs, - ssh_username, ssh_port, ssh_private_key, nohost, - memory_limit, - worker_port, - nanny_port, - remote_python=None, - remote_dask_worker='distributed.cli.dask_worker'): - - cmd = ('{python} -m {remote_dask_worker} ' - '{scheduler_addr}:{scheduler_port} ' - '--nthreads {nthreads}' - + (' --nprocs {nprocs}' if nprocs != 1 else '')) + return merge(cmd_dict, {"thread": thread}) + + +def start_worker( + logdir, + scheduler_addr, + scheduler_port, + worker_addr, + nthreads, + nprocs, + ssh_username, + ssh_port, + ssh_private_key, + nohost, + memory_limit, + worker_port, + nanny_port, + remote_python=None, + remote_dask_worker="distributed.cli.dask_worker", +): + + cmd = ( + "{python} -m {remote_dask_worker} " + "{scheduler_addr}:{scheduler_port} " + "--nthreads {nthreads}" + (" --nprocs {nprocs}" if nprocs != 1 else "") + ) if not nohost: - cmd += ' --host {worker_addr}' + cmd += " --host {worker_addr}" if memory_limit: - cmd += ' --memory-limit {memory_limit}' + cmd += " --memory-limit {memory_limit}" if worker_port: - cmd += ' --worker-port {worker_port}' + cmd += " --worker-port {worker_port}" if nanny_port: - cmd += ' --nanny-port {nanny_port}' + cmd += " --nanny-port {nanny_port}" cmd = cmd.format( python=remote_python or sys.executable, @@ -244,40 +297,60 @@ def start_worker(logdir, scheduler_addr, scheduler_port, worker_addr, nthreads, nprocs=nprocs, memory_limit=memory_limit, worker_port=worker_port, - nanny_port=nanny_port) + nanny_port=nanny_port, + ) # Optionally redirect stdout and stderr to a logfile if logdir is not None: - cmd = 'mkdir -p {logdir} && '.format(logdir=logdir) + cmd - cmd += '&> {logdir}/dask_scheduler_{addr}.log'.format( - addr=worker_addr, logdir=logdir) + cmd = "mkdir -p {logdir} && ".format(logdir=logdir) + cmd + cmd += "&> {logdir}/dask_scheduler_{addr}.log".format( + addr=worker_addr, logdir=logdir + ) - label = 'worker {addr}'.format(addr=worker_addr) + label = "worker {addr}".format(addr=worker_addr) # Create a command dictionary, which contains everything we need to run and # interact with this command. input_queue = Queue() output_queue = Queue() - cmd_dict = {'cmd': cmd, 'label': label, 'address': worker_addr, - 'input_queue': input_queue, 'output_queue': output_queue, - 'ssh_username': ssh_username, 'ssh_port': ssh_port, - 'ssh_private_key': ssh_private_key} + cmd_dict = { + "cmd": cmd, + "label": label, + "address": worker_addr, + "input_queue": input_queue, + "output_queue": output_queue, + "ssh_username": ssh_username, + "ssh_port": ssh_port, + "ssh_private_key": ssh_private_key, + } # Start the thread thread = Thread(target=async_ssh, args=[cmd_dict]) thread.daemon = True thread.start() - return merge(cmd_dict, {'thread': thread}) + return merge(cmd_dict, {"thread": thread}) class SSHCluster(object): - - def __init__(self, scheduler_addr, scheduler_port, worker_addrs, nthreads=0, nprocs=1, - ssh_username=None, ssh_port=22, ssh_private_key=None, - nohost=False, logdir=None, remote_python=None, - memory_limit=None, worker_port=None, nanny_port=None, - remote_dask_worker='distributed.cli.dask_worker'): + def __init__( + self, + scheduler_addr, + scheduler_port, + worker_addrs, + nthreads=0, + nprocs=1, + ssh_username=None, + ssh_port=22, + ssh_private_key=None, + nohost=False, + logdir=None, + remote_python=None, + memory_limit=None, + worker_port=None, + nanny_port=None, + remote_dask_worker="distributed.cli.dask_worker", + ): self.scheduler_addr = scheduler_addr self.scheduler_port = scheduler_port @@ -299,20 +372,34 @@ def __init__(self, scheduler_addr, scheduler_port, worker_addrs, nthreads=0, npr # Generate a universal timestamp to use for log files import datetime + if logdir is not None: - logdir = os.path.join(logdir, "dask-ssh_" + datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S")) - print(bcolors.WARNING + 'Output will be redirected to logfiles ' - 'stored locally on individual worker nodes under "{logdir}".'.format(logdir=logdir) - + bcolors.ENDC) + logdir = os.path.join( + logdir, + "dask-ssh_" + datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S"), + ) + print( + bcolors.WARNING + "Output will be redirected to logfiles " + 'stored locally on individual worker nodes under "{logdir}".'.format( + logdir=logdir + ) + + bcolors.ENDC + ) self.logdir = logdir # Keep track of all running threads self.threads = [] # Start the scheduler node - self.scheduler = start_scheduler(logdir, scheduler_addr, - scheduler_port, ssh_username, ssh_port, - ssh_private_key, remote_python) + self.scheduler = start_scheduler( + logdir, + scheduler_addr, + scheduler_port, + ssh_username, + ssh_port, + ssh_private_key, + remote_python, + ) # Start worker nodes self.workers = [] @@ -325,7 +412,7 @@ def _start(self): @property def scheduler_address(self): - return '%s:%d' % (self.scheduler_addr, self.scheduler_port) + return "%s:%d" % (self.scheduler_addr, self.scheduler_port) def monitor_remote_processes(self): @@ -335,8 +422,8 @@ def monitor_remote_processes(self): try: while True: for process in all_processes: - while not process['output_queue'].empty(): - print(process['output_queue'].get()) + while not process["output_queue"].empty(): + print(process["output_queue"].get()) # Kill some time and free up CPU before starting the next sweep # through the processes. @@ -345,26 +432,35 @@ def monitor_remote_processes(self): # end while true except KeyboardInterrupt: - pass # Return execution to the calling process + pass # Return execution to the calling process def add_worker(self, address): - self.workers.append(start_worker(self.logdir, self.scheduler_addr, - self.scheduler_port, address, - self.nthreads, self.nprocs, - self.ssh_username, self.ssh_port, - self.ssh_private_key, self.nohost, - self.memory_limit, - self.worker_port, - self.nanny_port, - self.remote_python, - self.remote_dask_worker)) + self.workers.append( + start_worker( + self.logdir, + self.scheduler_addr, + self.scheduler_port, + address, + self.nthreads, + self.nprocs, + self.ssh_username, + self.ssh_port, + self.ssh_private_key, + self.nohost, + self.memory_limit, + self.worker_port, + self.nanny_port, + self.remote_python, + self.remote_dask_worker, + ) + ) def shutdown(self): all_processes = [self.scheduler] + self.workers for process in all_processes: - process['input_queue'].put('shutdown') - process['thread'].join() + process["input_queue"].put("shutdown") + process["thread"].join() def __enter__(self): return self diff --git a/distributed/deploy/tests/test_adaptive.py b/distributed/deploy/tests/test_adaptive.py index dca8653a752..1d8a48bf7fc 100644 --- a/distributed/deploy/tests/test_adaptive.py +++ b/distributed/deploy/tests/test_adaptive.py @@ -14,21 +14,21 @@ def test_get_scale_up_kwargs(loop): - with LocalCluster(0, scheduler_port=0, silence_logs=False, - dashboard_address=None, loop=loop) as cluster: + with LocalCluster( + 0, scheduler_port=0, silence_logs=False, dashboard_address=None, loop=loop + ) as cluster: - alc = Adaptive(cluster.scheduler, cluster, interval=100, - scale_factor=3) - assert alc.get_scale_up_kwargs() == {'n': 1} + alc = Adaptive(cluster.scheduler, cluster, interval=100, scale_factor=3) + assert alc.get_scale_up_kwargs() == {"n": 1} with Client(cluster, loop=loop) as c: future = c.submit(lambda x: x + 1, 1) assert future.result() == 2 assert c.ncores() - assert alc.get_scale_up_kwargs() == {'n': 3} + assert alc.get_scale_up_kwargs() == {"n": 3} -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 4) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 4) def test_simultaneous_scale_up_and_down(c, s, *workers): class TestAdaptive(Adaptive): def get_scale_up_kwargs(self): @@ -46,11 +46,11 @@ def scale_down(self, workers): cluster = TestCluster() - s.task_duration['a'] = 4 - s.task_duration['b'] = 4 - s.task_duration['c'] = 1 + s.task_duration["a"] = 4 + s.task_duration["b"] = 4 + s.task_duration["c"] = 1 - future = c.map(slowinc, [1, 1, 1], key=['a-4', 'b-4', 'c-1']) + future = c.map(slowinc, [1, 1, 1], key=["a-4", "b-4", "c-1"]) while len(s.rprocessing) < 3: yield gen.sleep(0.001) @@ -61,8 +61,9 @@ def scale_down(self, workers): def test_adaptive_local_cluster(loop): - with LocalCluster(0, scheduler_port=0, silence_logs=False, - dashboard_address=None, loop=loop) as cluster: + with LocalCluster( + 0, scheduler_port=0, silence_logs=False, dashboard_address=None, loop=loop + ) as cluster: alc = Adaptive(cluster.scheduler, cluster, interval=100) with Client(cluster, loop=loop) as c: assert not c.ncores() @@ -86,9 +87,14 @@ def test_adaptive_local_cluster(loop): @nodebug @gen_test(timeout=30) def test_adaptive_local_cluster_multi_workers(): - cluster = yield LocalCluster(0, scheduler_port=0, silence_logs=False, - processes=False, dashboard_address=None, - asynchronous=True) + cluster = yield LocalCluster( + 0, + scheduler_port=0, + silence_logs=False, + processes=False, + dashboard_address=None, + asynchronous=True, + ) try: cluster.scheduler.allowed_failures = 1000 alc = cluster.adapt(interval=100) @@ -124,7 +130,7 @@ def test_adaptive_local_cluster_multi_workers(): yield cluster.close() -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 10, active_rpc_timeout=10) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 10, active_rpc_timeout=10) def test_adaptive_scale_down_override(c, s, *workers): class TestAdaptive(Adaptive): def __init__(self, *args, **kwargs): @@ -135,7 +141,7 @@ def workers_to_close(self, **kwargs): num_workers = len(self.scheduler.workers) to_close = self.scheduler.workers_to_close(**kwargs) if num_workers - len(to_close) < self.min_size: - to_close = to_close[:num_workers - self.min_size] + to_close = to_close[: num_workers - self.min_size] return to_close @@ -151,7 +157,7 @@ def scale_down(self, workers): # Assert that adaptive cycle does not reduce cluster below minimum size # as determined via override. cluster = TestCluster() - ta = TestAdaptive(s, cluster, min_size=2, interval=.1, scale_factor=2) + ta = TestAdaptive(s, cluster, min_size=2, interval=0.1, scale_factor=2) yield gen.sleep(0.3) assert len(s.workers) == 2 @@ -160,13 +166,25 @@ def scale_down(self, workers): @gen_test(timeout=30) def test_min_max(): loop = IOLoop.current() - cluster = yield LocalCluster(0, scheduler_port=0, silence_logs=False, - processes=False, dashboard_address=None, - loop=loop, asynchronous=True) + cluster = yield LocalCluster( + 0, + scheduler_port=0, + silence_logs=False, + processes=False, + dashboard_address=None, + loop=loop, + asynchronous=True, + ) yield cluster._start() try: - adapt = Adaptive(cluster.scheduler, cluster, minimum=1, maximum=2, - interval='20 ms', wait_count=10) + adapt = Adaptive( + cluster.scheduler, + cluster, + minimum=1, + maximum=2, + interval="20 ms", + wait_count=10, + ) c = yield Client(cluster, asynchronous=True, loop=loop) start = time() @@ -176,7 +194,7 @@ def test_min_max(): yield gen.sleep(0.2) assert len(cluster.scheduler.workers) == 1 - assert frequencies(pluck(1, adapt.log)) == {'up': 1} + assert frequencies(pluck(1, adapt.log)) == {"up": 1} futures = c.map(slowinc, range(100), delay=0.1) @@ -189,7 +207,7 @@ def test_min_max(): yield gen.sleep(0.5) assert len(cluster.scheduler.workers) == 2 assert len(cluster.workers) == 2 - assert frequencies(pluck(1, adapt.log)) == {'up': 2} + assert frequencies(pluck(1, adapt.log)) == {"up": 2} del futures @@ -197,7 +215,7 @@ def test_min_max(): while len(cluster.scheduler.workers) != 1: yield gen.sleep(0.01) assert time() < start + 2 - assert frequencies(pluck(1, adapt.log)) == {'up': 2, 'down': 1} + assert frequencies(pluck(1, adapt.log)) == {"up": 2, "down": 1} finally: yield c.close() yield cluster.close() @@ -210,18 +228,23 @@ def test_avoid_churn(): Instead we want to wait a few beats before removing a worker in case the user is taking a brief pause between work """ - cluster = yield LocalCluster(0, asynchronous=True, processes=False, - scheduler_port=0, silence_logs=False, - dashboard_address=None) + cluster = yield LocalCluster( + 0, + asynchronous=True, + processes=False, + scheduler_port=0, + silence_logs=False, + dashboard_address=None, + ) client = yield Client(cluster, asynchronous=True) try: - adapt = Adaptive(cluster.scheduler, cluster, interval='20 ms', wait_count=5) + adapt = Adaptive(cluster.scheduler, cluster, interval="20 ms", wait_count=5) for i in range(10): yield client.submit(slowinc, i, delay=0.040) yield gen.sleep(0.040) - assert frequencies(pluck(1, adapt.log)) == {'up': 1} + assert frequencies(pluck(1, adapt.log)) == {"up": 1} finally: yield client.close() yield cluster.close() @@ -234,12 +257,16 @@ def test_adapt_quickly(): Instead we want to wait a few beats before removing a worker in case the user is taking a brief pause between work """ - cluster = yield LocalCluster(0, asynchronous=True, processes=False, - scheduler_port=0, silence_logs=False, - dashboard_address=None) + cluster = yield LocalCluster( + 0, + asynchronous=True, + processes=False, + scheduler_port=0, + silence_logs=False, + dashboard_address=None, + ) client = yield Client(cluster, asynchronous=True) - adapt = Adaptive(cluster.scheduler, cluster, interval=20, wait_count=5, - maximum=10) + adapt = Adaptive(cluster.scheduler, cluster, interval=20, wait_count=5, maximum=10) try: future = client.submit(slowinc, 1, delay=0.100) yield wait(future) @@ -247,12 +274,12 @@ def test_adapt_quickly(): # Scale up when there is plenty of available work futures = client.map(slowinc, range(1000), delay=0.100) - while frequencies(pluck(1, adapt.log)) == {'up': 1}: + while frequencies(pluck(1, adapt.log)) == {"up": 1}: yield gen.sleep(0.01) assert len(adapt.log) == 2 - assert 'up' in adapt.log[-1] + assert "up" in adapt.log[-1] d = [x for x in adapt.log[-1] if isinstance(x, dict)][0] - assert 2 < d['n'] <= adapt.maximum + assert 2 < d["n"] <= adapt.maximum while len(cluster.scheduler.workers) < adapt.maximum: yield gen.sleep(0.01) @@ -277,11 +304,16 @@ def test_adapt_quickly(): @gen_test(timeout=None) def test_adapt_down(): """ Ensure that redefining adapt with a lower maximum removes workers """ - cluster = yield LocalCluster(0, asynchronous=True, processes=False, - scheduler_port=0, silence_logs=False, - dashboard_address=None) + cluster = yield LocalCluster( + 0, + asynchronous=True, + processes=False, + scheduler_port=0, + silence_logs=False, + dashboard_address=None, + ) client = yield Client(cluster, asynchronous=True) - cluster.adapt(interval='20ms', maximum=5) + cluster.adapt(interval="20ms", maximum=5) try: futures = client.map(slowinc, range(1000), delay=0.1) @@ -303,15 +335,22 @@ def test_adapt_down(): @gen_test(timeout=30) def test_no_more_workers_than_tasks(): loop = IOLoop.current() - cluster = yield LocalCluster(0, scheduler_port=0, silence_logs=False, - processes=False, dashboard_address=None, - loop=loop, asynchronous=True) + cluster = yield LocalCluster( + 0, + scheduler_port=0, + silence_logs=False, + processes=False, + dashboard_address=None, + loop=loop, + asynchronous=True, + ) yield cluster._start() try: - adapt = Adaptive(cluster.scheduler, cluster, minimum=0, maximum=4, - interval='10 ms') + adapt = Adaptive( + cluster.scheduler, cluster, minimum=0, maximum=4, interval="10 ms" + ) client = yield Client(cluster, asynchronous=True, loop=loop) - cluster.scheduler.task_duration['slowinc'] = 1000 + cluster.scheduler.task_duration["slowinc"] = 1000 yield client.submit(slowinc, 1, delay=0.100) @@ -323,8 +362,9 @@ def test_no_more_workers_than_tasks(): def test_basic_no_loop(): try: - with LocalCluster(0, scheduler_port=0, silence_logs=False, - dashboard_address=None) as cluster: + with LocalCluster( + 0, scheduler_port=0, silence_logs=False, dashboard_address=None + ) as cluster: with Client(cluster) as client: cluster.adapt() future = client.submit(lambda x: x + 1, 1) @@ -337,13 +377,18 @@ def test_basic_no_loop(): @gen_test(timeout=None) def test_target_duration(): """ Ensure that redefining adapt with a lower maximum removes workers """ - cluster = yield LocalCluster(0, asynchronous=True, processes=False, - scheduler_port=0, silence_logs=False, - dashboard_address=None) + cluster = yield LocalCluster( + 0, + asynchronous=True, + processes=False, + scheduler_port=0, + silence_logs=False, + dashboard_address=None, + ) client = yield Client(cluster, asynchronous=True) - adaptive = cluster.adapt(interval='20ms', minimum=2, target_duration='5s') + adaptive = cluster.adapt(interval="20ms", minimum=2, target_duration="5s") - cluster.scheduler.task_duration['slowinc'] = 1 + cluster.scheduler.task_duration["slowinc"] = 1 try: while len(cluster.scheduler.workers) < 2: @@ -354,8 +399,8 @@ def test_target_duration(): while len(adaptive.log) < 2: yield gen.sleep(0.01) - assert adaptive.log[0][1:] == ('up', {'n': 2}) - assert adaptive.log[1][1:] == ('up', {'n': 20}) + assert adaptive.log[0][1:] == ("up", {"n": 2}) + assert adaptive.log[1][1:] == ("up", {"n": 20}) finally: yield client.close() @@ -365,22 +410,30 @@ def test_target_duration(): @gen_test(timeout=None) def test_worker_keys(): """ Ensure that redefining adapt with a lower maximum removes workers """ - cluster = yield LocalCluster(0, asynchronous=True, processes=False, - scheduler_port=0, silence_logs=False, - dashboard_address=None) + cluster = yield LocalCluster( + 0, + asynchronous=True, + processes=False, + scheduler_port=0, + silence_logs=False, + dashboard_address=None, + ) try: - yield [cluster.start_worker(name='a-1'), - cluster.start_worker(name='a-2'), - cluster.start_worker(name='b-1'), - cluster.start_worker(name='b-2')] + yield [ + cluster.start_worker(name="a-1"), + cluster.start_worker(name="a-2"), + cluster.start_worker(name="b-1"), + cluster.start_worker(name="b-2"), + ] while len(cluster.scheduler.workers) != 4: yield gen.sleep(0.01) def key(ws): - return ws.name.split('-')[0] - cluster._adaptive_options = {'worker_key': key} + return ws.name.split("-")[0] + + cluster._adaptive_options = {"worker_key": key} adaptive = cluster.adapt(minimum=1) yield adaptive._adapt() @@ -389,7 +442,7 @@ def key(ws): yield gen.sleep(0.01) names = {ws.name for ws in cluster.scheduler.workers.values()} - assert names == {'a-1', 'a-2'} or names == {'b-1', 'b-2'} + assert names == {"a-1", "a-2"} or names == {"b-1", "b-2"} finally: yield cluster.close() @@ -403,4 +456,4 @@ def test_without_cluster(c, s): yield gen.sleep(0.01) response = yield c.scheduler.adaptive_recommendations() - assert response['status'] == 'up' + assert response["status"] == "up" diff --git a/distributed/deploy/tests/test_local.py b/distributed/deploy/tests/test_local.py index 95b4bb42fd4..ee2d48c2df3 100644 --- a/distributed/deploy/tests/test_local.py +++ b/distributed/deploy/tests/test_local.py @@ -16,12 +16,16 @@ from distributed import Client, Worker, Nanny from distributed.deploy.local import LocalCluster, nprocesses_nthreads from distributed.metrics import time -from distributed.utils_test import (inc, gen_test, slowinc, - assert_cannot_connect, - assert_can_connect_locally_4, - assert_can_connect_from_everywhere_4, - assert_can_connect_from_everywhere_4_6, - captured_logger) +from distributed.utils_test import ( + inc, + gen_test, + slowinc, + assert_cannot_connect, + assert_can_connect_locally_4, + assert_can_connect_from_everywhere_4, + assert_can_connect_from_everywhere_4_6, + captured_logger, +) from distributed.utils_test import loop # noqa: F401 from distributed.utils import sync from distributed.worker import TOTAL_MEMORY @@ -30,8 +34,14 @@ def test_simple(loop): - with LocalCluster(4, scheduler_port=0, processes=False, silence_logs=False, - dashboard_address=None, loop=loop) as c: + with LocalCluster( + 4, + scheduler_port=0, + processes=False, + silence_logs=False, + dashboard_address=None, + loop=loop, + ) as c: with Client(c) as e: x = e.submit(inc, 1) x.result() @@ -42,21 +52,23 @@ def test_simple(loop): def test_local_cluster_supports_blocked_handlers(loop): - with LocalCluster(blocked_handlers=['run_function'], loop=loop) as c: + with LocalCluster(blocked_handlers=["run_function"], loop=loop) as c: with Client(c) as client: with pytest.raises(ValueError) as exc: client.run_on_scheduler(lambda x: x, 42) - assert "'run_function' handler has been explicitly disallowed in Scheduler" in str(exc.value) + assert "'run_function' handler has been explicitly disallowed in Scheduler" in str( + exc.value + ) -@pytest.mark.skipif('sys.version_info[0] == 2', reason='fork issues') +@pytest.mark.skipif("sys.version_info[0] == 2", reason="fork issues") def test_close_twice(): with LocalCluster() as cluster: with Client(cluster.scheduler_address) as client: f = client.map(inc, range(100)) client.gather(f) - with captured_logger('tornado.application') as log: + with captured_logger("tornado.application") as log: cluster.close() cluster.close() sleep(0.5) @@ -64,10 +76,16 @@ def test_close_twice(): assert not log -@pytest.mark.skipif('sys.version_info[0] == 2', reason='multi-loop') +@pytest.mark.skipif("sys.version_info[0] == 2", reason="multi-loop") def test_procs(): - with LocalCluster(2, scheduler_port=0, processes=False, threads_per_worker=3, - dashboard_address=None, silence_logs=False) as c: + with LocalCluster( + 2, + scheduler_port=0, + processes=False, + threads_per_worker=3, + dashboard_address=None, + silence_logs=False, + ) as c: assert len(c.workers) == 2 assert all(isinstance(w, Worker) for w in c.workers) with Client(c.scheduler.address) as e: @@ -75,8 +93,14 @@ def test_procs(): assert all(isinstance(w, Worker) for w in c.workers) repr(c) - with LocalCluster(2, scheduler_port=0, processes=True, threads_per_worker=3, - dashboard_address=None, silence_logs=False) as c: + with LocalCluster( + 2, + scheduler_port=0, + processes=True, + threads_per_worker=3, + dashboard_address=None, + silence_logs=False, + ) as c: assert len(c.workers) == 2 assert all(isinstance(w, Nanny) for w in c.workers) with Client(c.scheduler.address) as e: @@ -92,10 +116,11 @@ def test_move_unserializable_data(): Test that unserializable data is still fine to transfer over inproc transports. """ - with LocalCluster(processes=False, silence_logs=False, - dashboard_address=None) as cluster: - assert cluster.scheduler_address.startswith('inproc://') - assert cluster.workers[0].address.startswith('inproc://') + with LocalCluster( + processes=False, silence_logs=False, dashboard_address=None + ) as cluster: + assert cluster.scheduler_address.startswith("inproc://") + assert cluster.workers[0].address.startswith("inproc://") with Client(cluster) as client: lock = Lock() x = client.scatter(lock) @@ -107,41 +132,49 @@ def test_transports(): """ Test the transport chosen by LocalCluster depending on arguments. """ - with LocalCluster(1, processes=False, silence_logs=False, - dashboard_address=None) as c: - assert c.scheduler_address.startswith('inproc://') - assert c.workers[0].address.startswith('inproc://') + with LocalCluster( + 1, processes=False, silence_logs=False, dashboard_address=None + ) as c: + assert c.scheduler_address.startswith("inproc://") + assert c.workers[0].address.startswith("inproc://") with Client(c.scheduler.address) as e: assert e.submit(inc, 4).result() == 5 # Have nannies => need TCP - with LocalCluster(1, processes=True, silence_logs=False, - dashboard_address=None) as c: - assert c.scheduler_address.startswith('tcp://') - assert c.workers[0].address.startswith('tcp://') + with LocalCluster( + 1, processes=True, silence_logs=False, dashboard_address=None + ) as c: + assert c.scheduler_address.startswith("tcp://") + assert c.workers[0].address.startswith("tcp://") with Client(c.scheduler.address) as e: assert e.submit(inc, 4).result() == 5 # Scheduler port specified => need TCP - with LocalCluster(1, processes=False, scheduler_port=8786, silence_logs=False, - dashboard_address=None) as c: - - assert c.scheduler_address == 'tcp://127.0.0.1:8786' - assert c.workers[0].address.startswith('tcp://') + with LocalCluster( + 1, + processes=False, + scheduler_port=8786, + silence_logs=False, + dashboard_address=None, + ) as c: + + assert c.scheduler_address == "tcp://127.0.0.1:8786" + assert c.workers[0].address.startswith("tcp://") with Client(c.scheduler.address) as e: assert e.submit(inc, 4).result() == 5 -@pytest.mark.skipif('sys.version_info[0] == 2', reason='') +@pytest.mark.skipif("sys.version_info[0] == 2", reason="") class LocalTest(ClusterTest, unittest.TestCase): Cluster = partial(LocalCluster, silence_logs=False, dashboard_address=None) - kwargs = {'dashboard_address': None} + kwargs = {"dashboard_address": None} -@pytest.mark.skipif('sys.version_info[0] == 2', reason='') +@pytest.mark.skipif("sys.version_info[0] == 2", reason="") def test_Client_with_local(loop): - with LocalCluster(1, scheduler_port=0, silence_logs=False, - dashboard_address=None, loop=loop) as c: + with LocalCluster( + 1, scheduler_port=0, silence_logs=False, dashboard_address=None, loop=loop + ) as c: with Client(c) as e: assert len(e.ncores()) == len(c.workers) assert c.scheduler_address in repr(c) @@ -150,22 +183,26 @@ def test_Client_with_local(loop): def test_Client_solo(loop): with Client(loop=loop, silence_logs=False) as c: pass - assert c.cluster.status == 'closed' + assert c.cluster.status == "closed" @gen_test() def test_duplicate_clients(): - pytest.importorskip('bokeh') + pytest.importorskip("bokeh") c1 = yield Client(processes=False, silence_logs=False, dashboard_address=9876) with pytest.warns(Exception) as info: c2 = yield Client(processes=False, silence_logs=False, dashboard_address=9876) - assert 'bokeh' in c1.cluster.scheduler.services - assert 'bokeh' in c2.cluster.scheduler.services + assert "bokeh" in c1.cluster.scheduler.services + assert "bokeh" in c2.cluster.scheduler.services - assert any(all(word in str(msg.message).lower() - for word in ['9876', 'running', 'already in use']) - for msg in info.list) + assert any( + all( + word in str(msg.message).lower() + for word in ["9876", "running", "already in use"] + ) + for msg in info.list + ) yield c1.close() @@ -173,7 +210,7 @@ def test_Client_kwargs(loop): with Client(loop=loop, processes=False, n_workers=2, silence_logs=False) as c: assert len(c.cluster.workers) == 2 assert all(isinstance(w, Worker) for w in c.cluster.workers) - assert c.cluster.status == 'closed' + assert c.cluster.status == "closed" def test_Client_twice(loop): @@ -182,23 +219,26 @@ def test_Client_twice(loop): assert c.cluster.scheduler.port != f.cluster.scheduler.port -@pytest.mark.skipif('sys.version_info[0] == 2', reason='fork issues') +@pytest.mark.skipif("sys.version_info[0] == 2", reason="fork issues") def test_defaults(): from distributed.worker import _ncores - with LocalCluster(scheduler_port=0, silence_logs=False, - dashboard_address=None) as c: + with LocalCluster( + scheduler_port=0, silence_logs=False, dashboard_address=None + ) as c: assert sum(w.ncores for w in c.workers) == _ncores assert all(isinstance(w, Nanny) for w in c.workers) - with LocalCluster(processes=False, scheduler_port=0, silence_logs=False, - dashboard_address=None) as c: + with LocalCluster( + processes=False, scheduler_port=0, silence_logs=False, dashboard_address=None + ) as c: assert sum(w.ncores for w in c.workers) == _ncores assert all(isinstance(w, Worker) for w in c.workers) assert len(c.workers) == 1 - with LocalCluster(n_workers=2, scheduler_port=0, silence_logs=False, - dashboard_address=None) as c: + with LocalCluster( + n_workers=2, scheduler_port=0, silence_logs=False, dashboard_address=None + ) as c: if _ncores % 2 == 0: expected_total_threads = max(2, _ncores) else: @@ -206,60 +246,91 @@ def test_defaults(): expected_total_threads = max(2, _ncores + 1) assert sum(w.ncores for w in c.workers) == expected_total_threads - with LocalCluster(threads_per_worker=_ncores * 2, scheduler_port=0, - silence_logs=False, dashboard_address=None) as c: + with LocalCluster( + threads_per_worker=_ncores * 2, + scheduler_port=0, + silence_logs=False, + dashboard_address=None, + ) as c: assert len(c.workers) == 1 - with LocalCluster(n_workers=_ncores * 2, scheduler_port=0, - silence_logs=False, dashboard_address=None) as c: + with LocalCluster( + n_workers=_ncores * 2, + scheduler_port=0, + silence_logs=False, + dashboard_address=None, + ) as c: assert all(w.ncores == 1 for w in c.workers) - with LocalCluster(threads_per_worker=2, n_workers=3, scheduler_port=0, - silence_logs=False, dashboard_address=None) as c: + with LocalCluster( + threads_per_worker=2, + n_workers=3, + scheduler_port=0, + silence_logs=False, + dashboard_address=None, + ) as c: assert len(c.workers) == 3 assert all(w.ncores == 2 for w in c.workers) def test_worker_params(): - with LocalCluster(n_workers=2, scheduler_port=0, silence_logs=False, - dashboard_address=None, memory_limit=500) as c: + with LocalCluster( + n_workers=2, + scheduler_port=0, + silence_logs=False, + dashboard_address=None, + memory_limit=500, + ) as c: assert [w.memory_limit for w in c.workers] == [500] * 2 def test_memory_limit_none(): - with LocalCluster(n_workers=2, scheduler_port=0, silence_logs=False, - processes=False, dashboard_address=None, memory_limit=None) as c: + with LocalCluster( + n_workers=2, + scheduler_port=0, + silence_logs=False, + processes=False, + dashboard_address=None, + memory_limit=None, + ) as c: w = c.workers[0] assert type(w.data) is dict assert w.memory_limit is None def test_cleanup(): - c = LocalCluster(2, scheduler_port=0, silence_logs=False, - dashboard_address=None) + c = LocalCluster(2, scheduler_port=0, silence_logs=False, dashboard_address=None) port = c.scheduler.port c.close() - c2 = LocalCluster(2, scheduler_port=port, silence_logs=False, - dashboard_address=None) + c2 = LocalCluster( + 2, scheduler_port=port, silence_logs=False, dashboard_address=None + ) c.close() def test_repeated(): - with LocalCluster(scheduler_port=8448, silence_logs=False, - dashboard_address=None) as c: + with LocalCluster( + scheduler_port=8448, silence_logs=False, dashboard_address=None + ) as c: pass - with LocalCluster(scheduler_port=8448, silence_logs=False, - dashboard_address=None) as c: + with LocalCluster( + scheduler_port=8448, silence_logs=False, dashboard_address=None + ) as c: pass -@pytest.mark.parametrize('processes', [True, False]) +@pytest.mark.parametrize("processes", [True, False]) def test_bokeh(loop, processes): - pytest.importorskip('bokeh') - requests = pytest.importorskip('requests') - with LocalCluster(scheduler_port=0, silence_logs=False, loop=loop, - processes=processes, dashboard_address=0) as c: - bokeh_port = c.scheduler.services['bokeh'].port - url = 'http://127.0.0.1:%d/status/' % bokeh_port + pytest.importorskip("bokeh") + requests = pytest.importorskip("requests") + with LocalCluster( + scheduler_port=0, + silence_logs=False, + loop=loop, + processes=processes, + dashboard_address=0, + ) as c: + bokeh_port = c.scheduler.services["bokeh"].port + url = "http://127.0.0.1:%d/status/" % bokeh_port start = time() while True: response = requests.get(url) @@ -268,14 +339,14 @@ def test_bokeh(loop, processes): assert time() < start + 20 sleep(0.01) # 'localhost' also works - response = requests.get('http://localhost:%d/status/' % bokeh_port) + response = requests.get("http://localhost:%d/status/" % bokeh_port) assert response.ok with pytest.raises(requests.RequestException): requests.get(url, timeout=0.2) -@pytest.mark.skipif(sys.version_info < (3, 6), reason='Unknown') +@pytest.mark.skipif(sys.version_info < (3, 6), reason="Unknown") def test_blocks_until_full(loop): with Client(loop=loop) as c: assert len(c.ncores()) > 0 @@ -284,9 +355,15 @@ def test_blocks_until_full(loop): @gen_test() def test_scale_up_and_down(): loop = IOLoop.current() - cluster = yield LocalCluster(0, scheduler_port=0, processes=False, - silence_logs=False, dashboard_address=None, - loop=loop, asynchronous=True) + cluster = yield LocalCluster( + 0, + scheduler_port=0, + processes=False, + silence_logs=False, + dashboard_address=None, + loop=loop, + asynchronous=True, + ) c = yield Client(cluster, asynchronous=True) assert not cluster.workers @@ -314,8 +391,9 @@ def test_silent_startup(): sleep(1.5) """ - out = subprocess.check_output([sys.executable, "-Wi", "-c", code], - stderr=subprocess.STDOUT) + out = subprocess.check_output( + [sys.executable, "-Wi", "-c", code], stderr=subprocess.STDOUT + ) out = out.decode() try: assert not out @@ -326,51 +404,74 @@ def test_silent_startup(): def test_only_local_access(loop): - with LocalCluster(scheduler_port=0, silence_logs=False, - dashboard_address=None, loop=loop) as c: + with LocalCluster( + scheduler_port=0, silence_logs=False, dashboard_address=None, loop=loop + ) as c: sync(loop, assert_can_connect_locally_4, c.scheduler.port) def test_remote_access(loop): - with LocalCluster(scheduler_port=0, silence_logs=False, - dashboard_address=None, ip='', loop=loop) as c: + with LocalCluster( + scheduler_port=0, silence_logs=False, dashboard_address=None, ip="", loop=loop + ) as c: sync(loop, assert_can_connect_from_everywhere_4_6, c.scheduler.port) -@pytest.mark.parametrize('n_workers', [None, 3]) +@pytest.mark.parametrize("n_workers", [None, 3]) def test_memory(loop, n_workers): - with LocalCluster(n_workers=n_workers, scheduler_port=0, processes=False, - silence_logs=False, dashboard_address=None, loop=loop) as cluster: + with LocalCluster( + n_workers=n_workers, + scheduler_port=0, + processes=False, + silence_logs=False, + dashboard_address=None, + loop=loop, + ) as cluster: assert sum(w.memory_limit for w in cluster.workers) <= TOTAL_MEMORY -@pytest.mark.parametrize('n_workers', [None, 3]) +@pytest.mark.parametrize("n_workers", [None, 3]) def test_memory_nanny(loop, n_workers): - with LocalCluster(n_workers=n_workers, scheduler_port=0, processes=True, - silence_logs=False, dashboard_address=None, loop=loop) as cluster: + with LocalCluster( + n_workers=n_workers, + scheduler_port=0, + processes=True, + silence_logs=False, + dashboard_address=None, + loop=loop, + ) as cluster: with Client(cluster.scheduler_address, loop=loop) as c: info = c.scheduler_info() - assert (sum(w['memory_limit'] for w in info['workers'].values()) - <= TOTAL_MEMORY) + assert ( + sum(w["memory_limit"] for w in info["workers"].values()) <= TOTAL_MEMORY + ) def test_death_timeout_raises(loop): with pytest.raises(gen.TimeoutError): - with LocalCluster(scheduler_port=0, silence_logs=False, - death_timeout=1e-10, dashboard_address=None, - loop=loop) as cluster: + with LocalCluster( + scheduler_port=0, + silence_logs=False, + death_timeout=1e-10, + dashboard_address=None, + loop=loop, + ) as cluster: pass -@pytest.mark.skipif(sys.version_info < (3, 6), reason='Unknown') +@pytest.mark.skipif(sys.version_info < (3, 6), reason="Unknown") def test_bokeh_kwargs(loop): - pytest.importorskip('bokeh') - with LocalCluster(scheduler_port=0, silence_logs=False, loop=loop, - dashboard_address=0, - service_kwargs={'bokeh': {'prefix': '/foo'}}) as c: + pytest.importorskip("bokeh") + with LocalCluster( + scheduler_port=0, + silence_logs=False, + loop=loop, + dashboard_address=0, + service_kwargs={"bokeh": {"prefix": "/foo"}}, + ) as c: - bs = c.scheduler.services['bokeh'] - assert bs.prefix == '/foo' + bs = c.scheduler.services["bokeh"] + assert bs.prefix == "/foo" def test_io_loop_periodic_callbacks(loop): @@ -393,9 +494,14 @@ def test_logging(): def test_ipywidgets(loop): - ipywidgets = pytest.importorskip('ipywidgets') - with LocalCluster(scheduler_port=0, silence_logs=False, loop=loop, - dashboard_address=False, processes=False) as cluster: + ipywidgets = pytest.importorskip("ipywidgets") + with LocalCluster( + scheduler_port=0, + silence_logs=False, + loop=loop, + dashboard_address=False, + processes=False, + ) as cluster: cluster._ipython_display_() box = cluster._cached_widget assert isinstance(box, ipywidgets.Widget) @@ -403,8 +509,14 @@ def test_ipywidgets(loop): def test_scale(loop): """ Directly calling scale both up and down works as expected """ - with LocalCluster(scheduler_port=0, silence_logs=False, loop=loop, - dashboard_address=False, processes=False, n_workers=0) as cluster: + with LocalCluster( + scheduler_port=0, + silence_logs=False, + loop=loop, + dashboard_address=False, + processes=False, + n_workers=0, + ) as cluster: assert not cluster.scheduler.workers cluster.scale(3) @@ -424,14 +536,20 @@ def test_scale(loop): def test_adapt(loop): - with LocalCluster(scheduler_port=0, silence_logs=False, loop=loop, - dashboard_address=False, processes=False, n_workers=0) as cluster: - cluster.adapt(minimum=0, maximum=2, interval='10ms') + with LocalCluster( + scheduler_port=0, + silence_logs=False, + loop=loop, + dashboard_address=False, + processes=False, + n_workers=0, + ) as cluster: + cluster.adapt(minimum=0, maximum=2, interval="10ms") assert cluster._adaptive.minimum == 0 assert cluster._adaptive.maximum == 2 ref = weakref.ref(cluster._adaptive) - cluster.adapt(minimum=1, maximum=2, interval='10ms') + cluster.adapt(minimum=1, maximum=2, interval="10ms") assert cluster._adaptive.minimum == 1 gc.collect() @@ -450,10 +568,16 @@ def test_adapt(loop): def test_adapt_then_manual(loop): """ We can revert from adaptive, back to manual """ - with LocalCluster(scheduler_port=0, silence_logs=False, loop=loop, - dashboard_address=False, processes=False, n_workers=8) as cluster: + with LocalCluster( + scheduler_port=0, + silence_logs=False, + loop=loop, + dashboard_address=False, + processes=False, + n_workers=8, + ) as cluster: sleep(0.1) - cluster.adapt(minimum=0, maximum=4, interval='10ms') + cluster.adapt(minimum=0, maximum=4, interval="10ms") start = time() while cluster.scheduler.workers or cluster.workers: @@ -480,19 +604,33 @@ def test_adapt_then_manual(loop): def test_local_tls(loop): from distributed.utils_test import tls_only_security + security = tls_only_security() - with LocalCluster(scheduler_port=8786, silence_logs=False, security=security, - dashboard_address=False, ip='tls://0.0.0.0', loop=loop) as c: - sync(loop, assert_can_connect_from_everywhere_4, c.scheduler.port, - connection_args=security.get_connection_args('client'), - protocol='tls', timeout=3) + with LocalCluster( + scheduler_port=8786, + silence_logs=False, + security=security, + dashboard_address=False, + ip="tls://0.0.0.0", + loop=loop, + ) as c: + sync( + loop, + assert_can_connect_from_everywhere_4, + c.scheduler.port, + connection_args=security.get_connection_args("client"), + protocol="tls", + timeout=3, + ) # If we connect to a TLS localculster without ssl information we should fail - sync(loop, assert_cannot_connect, - addr='tcp://127.0.0.1:%d' % c.scheduler.port, - connection_args=security.get_connection_args('client'), - exception_class=RuntimeError, - ) + sync( + loop, + assert_cannot_connect, + addr="tcp://127.0.0.1:%d" % c.scheduler.port, + connection_args=security.get_connection_args("client"), + exception_class=RuntimeError, + ) @gen_test() @@ -502,9 +640,15 @@ def scale_down(self, *args, **kwargs): pass loop = IOLoop.current() - cluster = yield MyCluster(0, scheduler_port=0, processes=False, - silence_logs=False, dashboard_address=None, - loop=loop, asynchronous=True) + cluster = yield MyCluster( + 0, + scheduler_port=0, + processes=False, + silence_logs=False, + dashboard_address=None, + loop=loop, + asynchronous=True, + ) c = yield Client(cluster, asynchronous=True) assert not cluster.workers @@ -529,15 +673,23 @@ def scale_down(self, *args, **kwargs): def test_local_tls_restart(loop): from distributed.utils_test import tls_only_security + security = tls_only_security() - with LocalCluster(n_workers=1, scheduler_port=8786, silence_logs=False, security=security, - dashboard_address=False, ip='tls://0.0.0.0', loop=loop) as c: + with LocalCluster( + n_workers=1, + scheduler_port=8786, + silence_logs=False, + security=security, + dashboard_address=False, + ip="tls://0.0.0.0", + loop=loop, + ) as c: with Client(c.scheduler.address, loop=loop, security=security) as client: print(c.workers, c.workers[0].address) - workers_before = set(client.scheduler_info()['workers']) + workers_before = set(client.scheduler_info()["workers"]) assert client.submit(inc, 1).result() == 2 client.restart() - workers_after = set(client.scheduler_info()['workers']) + workers_after = set(client.scheduler_info()["workers"]) assert client.submit(inc, 2).result() == 3 assert workers_before != workers_after @@ -556,8 +708,14 @@ def test_default_process_thread_breakdown(): def test_asynchronous_property(loop): - with LocalCluster(4, scheduler_port=0, processes=False, silence_logs=False, - dashboard_address=None, loop=loop) as cluster: + with LocalCluster( + 4, + scheduler_port=0, + processes=False, + silence_logs=False, + dashboard_address=None, + loop=loop, + ) as cluster: @gen.coroutine def _(): @@ -567,20 +725,21 @@ def _(): def test_protocol_inproc(loop): - with LocalCluster(protocol='inproc://', loop=loop, processes=False) as cluster: - assert cluster.scheduler.address.startswith('inproc://') + with LocalCluster(protocol="inproc://", loop=loop, processes=False) as cluster: + assert cluster.scheduler.address.startswith("inproc://") def test_protocol_tcp(loop): - with LocalCluster(protocol='tcp', loop=loop, processes=False) as cluster: - assert cluster.scheduler.address.startswith('tcp://') + with LocalCluster(protocol="tcp", loop=loop, processes=False) as cluster: + assert cluster.scheduler.address.startswith("tcp://") -@pytest.mark.skipif(not sys.platform.startswith('linux'), - reason="Need 127.0.0.2 to mean localhost") +@pytest.mark.skipif( + not sys.platform.startswith("linux"), reason="Need 127.0.0.2 to mean localhost" +) def test_protocol_ip(loop): - with LocalCluster(ip='tcp://127.0.0.2', loop=loop, processes=False) as cluster: - assert cluster.scheduler.address.startswith('tcp://127.0.0.2') + with LocalCluster(ip="tcp://127.0.0.2", loop=loop, processes=False) as cluster: + assert cluster.scheduler.address.startswith("tcp://127.0.0.2") if sys.version_info >= (3, 5): diff --git a/distributed/deploy/tests/test_ssh.py b/distributed/deploy/tests/test_ssh.py index 5c6c76ea3c8..a86a8ddd280 100644 --- a/distributed/deploy/tests/test_ssh.py +++ b/distributed/deploy/tests/test_ssh.py @@ -3,7 +3,8 @@ from time import sleep import pytest -pytest.importorskip('paramiko') + +pytest.importorskip("paramiko") from distributed import Client from distributed.deploy.ssh import SSHCluster @@ -13,16 +14,18 @@ @pytest.mark.avoid_travis def test_cluster(loop): - with SSHCluster(scheduler_addr='127.0.0.1', - scheduler_port=7437, - worker_addrs=['127.0.0.1', '127.0.0.1']) as c: + with SSHCluster( + scheduler_addr="127.0.0.1", + scheduler_port=7437, + worker_addrs=["127.0.0.1", "127.0.0.1"], + ) as c: with Client(c, loop=loop) as e: start = time() while len(e.ncores()) != 2: sleep(0.01) assert time() < start + 5 - c.add_worker('127.0.0.1') + c.add_worker("127.0.0.1") start = time() while len(e.ncores()) != 3: diff --git a/distributed/diagnostics/__init__.py b/distributed/diagnostics/__init__.py index ab9f6d7a9dd..9469c3855d1 100644 --- a/distributed/diagnostics/__init__.py +++ b/distributed/diagnostics/__init__.py @@ -2,6 +2,7 @@ from ..utils import ignoring from .graph_layout import GraphLayout + with ignoring(ImportError): from .progressbar import progress with ignoring(ImportError): diff --git a/distributed/diagnostics/eventstream.py b/distributed/diagnostics/eventstream.py index 1eabf0ea4dc..a4eb0830534 100644 --- a/distributed/diagnostics/eventstream.py +++ b/distributed/diagnostics/eventstream.py @@ -22,9 +22,9 @@ def __init__(self, scheduler=None): scheduler.add_plugin(self) def transition(self, key, start, finish, *args, **kwargs): - if start == 'processing': - kwargs['key'] = key - if finish == 'memory' or finish == 'erred': + if start == "processing": + kwargs["key"] = key + if finish == "memory" or finish == "erred": self.buffer.append(kwargs) @@ -66,9 +66,13 @@ def eventstream(address, interval): """ address = coerce_to_address(address) comm = yield connect(address) - yield comm.write({'op': 'feed', - 'setup': dumps_function(EventStream), - 'function': dumps_function(swap_buffer), - 'interval': interval, - 'teardown': dumps_function(teardown)}) + yield comm.write( + { + "op": "feed", + "setup": dumps_function(EventStream), + "function": dumps_function(swap_buffer), + "interval": interval, + "teardown": dumps_function(teardown), + } + ) raise gen.Return(comm) diff --git a/distributed/diagnostics/graph_layout.py b/distributed/diagnostics/graph_layout.py index 5c29fe28b83..62e115a9ad4 100644 --- a/distributed/diagnostics/graph_layout.py +++ b/distributed/diagnostics/graph_layout.py @@ -12,6 +12,7 @@ class GraphLayout(SchedulerPlugin): It is commonly used with distributed/bokeh/scheduler.py::GraphPlot, which is rendered at /graph on the diagnostic dashboard. """ + def __init__(self, scheduler): self.x = {} self.y = {} @@ -31,14 +32,16 @@ def __init__(self, scheduler): scheduler.add_plugin(self) if self.scheduler.tasks: - dependencies = {k: [ds.key for ds in ts.dependencies] - for k, ts in scheduler.tasks.items()} + dependencies = { + k: [ds.key for ds in ts.dependencies] + for k, ts in scheduler.tasks.items() + } priority = {k: ts.priority for k, ts in scheduler.tasks.items()} - self.update_graph(self.scheduler, dependencies=dependencies, - priority=priority) + self.update_graph( + self.scheduler, dependencies=dependencies, priority=priority + ) - def update_graph(self, scheduler, dependencies=None, priority=None, - **kwargs): + def update_graph(self, scheduler, dependencies=None, priority=None, **kwargs): stack = sorted(dependencies, key=lambda k: priority.get(k, 0), reverse=True) while stack: key = stack.pop() @@ -48,15 +51,18 @@ def update_graph(self, scheduler, dependencies=None, priority=None, if deps: if not all(dep in self.y for dep in deps): stack.append(key) - stack.extend(sorted(deps, key=lambda k: priority.get(k, 0), - reverse=True)) + stack.extend( + sorted(deps, key=lambda k: priority.get(k, 0), reverse=True) + ) continue else: - total_deps = sum(len(scheduler.tasks[dep].dependents) - for dep in deps) - y = sum(self.y[dep] * len(scheduler.tasks[dep].dependents) - / total_deps - for dep in deps) + total_deps = sum( + len(scheduler.tasks[dep].dependents) for dep in deps + ) + y = sum( + self.y[dep] * len(scheduler.tasks[dep].dependents) / total_deps + for dep in deps + ) x = max(self.x[dep] for dep in deps) + 1 else: x = 0 @@ -83,16 +89,20 @@ def update_graph(self, scheduler, dependencies=None, priority=None, self.new_edges.append(edge) def transition(self, key, start, finish, *args, **kwargs): - if finish != 'forgotten': + if finish != "forgotten": self.state_updates.append((self.index[key], finish)) else: - self.visible_updates.append((self.index[key], 'False')) + self.visible_updates.append((self.index[key], "False")) task = self.scheduler.tasks[key] for dep in task.dependents: edge = (key, dep.key) - self.visible_edge_updates.append((self.index_edge.pop((key, dep.key)), 'False')) + self.visible_edge_updates.append( + (self.index_edge.pop((key, dep.key)), "False") + ) for dep in task.dependencies: - self.visible_edge_updates.append((self.index_edge.pop((dep.key, key)), 'False')) + self.visible_edge_updates.append( + (self.index_edge.pop((dep.key, key)), "False") + ) try: del self.collision[(self.x[key], self.y[key])] diff --git a/distributed/diagnostics/plugin.py b/distributed/diagnostics/plugin.py index f5898e7f4e0..e1da4378fd4 100644 --- a/distributed/diagnostics/plugin.py +++ b/distributed/diagnostics/plugin.py @@ -39,8 +39,7 @@ class SchedulerPlugin(object): >>> scheduler.add_plugin(c) # doctest: +SKIP """ - def update_graph(self, scheduler, dsk=None, keys=None, - restrictions=None, **kwargs): + def update_graph(self, scheduler, dsk=None, keys=None, restrictions=None, **kwargs): """ Run when a new graph / tasks enter the scheduler """ def restart(self, scheduler, **kwargs): diff --git a/distributed/diagnostics/progress.py b/distributed/diagnostics/progress.py index 5a9fe7f083e..38638a248dd 100644 --- a/distributed/diagnostics/progress.py +++ b/distributed/diagnostics/progress.py @@ -62,7 +62,7 @@ class Progress(SchedulerPlugin): """ def __init__(self, keys, scheduler, minimum=0, dt=0.1, complete=False): - self.keys = {k.key if hasattr(k, 'key') else k for k in keys} + self.keys = {k.key if hasattr(k, "key") else k for k in keys} self.keys = {tokey(k) for k in self.keys} self.scheduler = scheduler self.complete = complete @@ -100,21 +100,21 @@ def setup(self): logger.debug("Set up Progress keys") for k in errors: - self.transition(k, None, 'erred', exception=True) + self.transition(k, None, "erred", exception=True) def transition(self, key, start, finish, *args, **kwargs): - if key in self.keys and start == 'processing' and finish == 'memory': + if key in self.keys and start == "processing" and finish == "memory": logger.debug("Progress sees key %s", key) self.keys.remove(key) if not self.keys: self.stop() - if key in self.all_keys and finish == 'erred': + if key in self.all_keys and finish == "erred": logger.debug("Progress sees task erred") - self.stop(exception=kwargs['exception'], key=key) + self.stop(exception=kwargs["exception"], key=key) - if key in self.keys and finish == 'forgotten': + if key in self.keys and finish == "forgotten": logger.debug("A task was cancelled (%s), stopping progress", key) self.stop(exception=True, key=key) @@ -125,11 +125,10 @@ def stop(self, exception=None, key=None): if self in self.scheduler.plugins: self.scheduler.plugins.remove(self) if exception: - self.status = 'error' - self.extra.update({'exception': self.scheduler.exceptions[key], - 'key': key}) + self.status = "error" + self.extra.update({"exception": self.scheduler.exceptions[key], "key": key}) else: - self.status = 'finished' + self.status = "finished" logger.debug("Remove Progress plugin") @@ -156,11 +155,13 @@ class MultiProgress(Progress): 'y': {'y-1', 'y-2'}} """ - def __init__(self, keys, scheduler=None, func=key_split, minimum=0, dt=0.1, - complete=False): + def __init__( + self, keys, scheduler=None, func=key_split, minimum=0, dt=0.1, complete=False + ): self.func = func - Progress.__init__(self, keys, scheduler, minimum=minimum, dt=dt, - complete=complete) + Progress.__init__( + self, keys, scheduler, minimum=minimum, dt=dt, complete=complete + ) @gen.coroutine def setup(self): @@ -193,11 +194,11 @@ def setup(self): self.keys[k] = set() for k in errors: - self.transition(k, None, 'erred', exception=True) + self.transition(k, None, "erred", exception=True) logger.debug("Set up Progress keys") def transition(self, key, start, finish, *args, **kwargs): - if start == 'processing' and finish == 'memory': + if start == "processing" and finish == "memory": s = self.keys.get(self.func(key), None) if s and key in s: s.remove(key) @@ -205,13 +206,13 @@ def transition(self, key, start, finish, *args, **kwargs): if not self.keys or not any(self.keys.values()): self.stop() - if finish == 'erred': + if finish == "erred": logger.debug("Progress sees task erred") k = self.func(key) - if (k in self.all_keys and key in self.all_keys[k]): - self.stop(exception=kwargs.get('exception'), key=key) + if k in self.all_keys and key in self.all_keys[k]: + self.stop(exception=kwargs.get("exception"), key=key) - if finish == 'forgotten': + if finish == "forgotten": k = self.func(key) if k in self.all_keys and key in self.all_keys[k]: logger.debug("A task was cancelled (%s), stopping progress", key) @@ -231,11 +232,11 @@ def format_time(t): m, s = divmod(t, 60) h, m = divmod(m, 60) if h: - return '{0:2.0f}hr {1:2.0f}min {2:4.1f}s'.format(h, m, s) + return "{0:2.0f}hr {1:2.0f}min {2:4.1f}s".format(h, m, s) elif m: - return '{0:2.0f}min {1:4.1f}s'.format(m, s) + return "{0:2.0f}min {1:4.1f}s".format(m, s) else: - return '{0:4.1f}s'.format(s) + return "{0:4.1f}s".format(s) class AllProgress(SchedulerPlugin): @@ -266,13 +267,13 @@ def transition(self, key, start, finish, *args, **kwargs): except KeyError: # TODO: remove me once we have a new or clean state pass - if start == 'memory': + if start == "memory": # XXX why not respect DEFAULT_DATA_SIZE? self.nbytes[prefix] -= ts.nbytes or 0 - if finish == 'memory': + if finish == "memory": self.nbytes[prefix] += ts.nbytes or 0 - if finish != 'forgotten': + if finish != "forgotten": self.state[finish][prefix].add(key) else: s = self.all[prefix] @@ -290,6 +291,7 @@ def restart(self, scheduler): class GroupProgress(SchedulerPlugin): """ Keep track of all keys, grouped by key_split """ + def __init__(self, scheduler): self.scheduler = scheduler self.keys = dict() @@ -305,7 +307,7 @@ def __init__(self, scheduler): self.create(key, k) self.keys[k].add(key) self.groups[k][ts.state] += 1 - if ts.state == 'memory' and ts.nbytes is not None: + if ts.state == "memory" and ts.nbytes is not None: self.nbytes[k] += ts.nbytes scheduler.add_plugin(self) @@ -313,14 +315,12 @@ def __init__(self, scheduler): def create(self, key, k): with log_errors(): ts = self.scheduler.tasks[key] - g = {'memory': 0, 'erred': 0, 'waiting': 0, - 'released': 0, 'processing': 0} + g = {"memory": 0, "erred": 0, "waiting": 0, "released": 0, "processing": 0} self.keys[k] = set() self.groups[k] = g self.nbytes[k] = 0 self.durations[k] = 0 - self.dependents[k] = {key_split_group(dts.key) - for dts in ts.dependents} + self.dependents[k] = {key_split_group(dts.key) for dts in ts.dependents} for dts in ts.dependencies: d = key_split_group(dts.key) self.dependents[d].add(k) @@ -340,7 +340,7 @@ def transition(self, key, start, finish, *args, **kwargs): else: g[start] -= 1 - if finish != 'forgotten': + if finish != "forgotten": g[finish] += 1 else: self.keys[k].remove(key) @@ -350,9 +350,9 @@ def transition(self, key, start, finish, *args, **kwargs): for dep in self.dependencies.pop(k): self.dependents[key_split_group(dep)].remove(k) - if start == 'memory' and ts.nbytes is not None: + if start == "memory" and ts.nbytes is not None: self.nbytes[k] -= ts.nbytes - if finish == 'memory' and ts.nbytes is not None: + if finish == "memory" and ts.nbytes is not None: self.nbytes[k] += ts.nbytes def restart(self, scheduler): diff --git a/distributed/diagnostics/progress_stream.py b/distributed/diagnostics/progress_stream.py index 60704a5670a..1630251658a 100644 --- a/distributed/diagnostics/progress_stream.py +++ b/distributed/diagnostics/progress_stream.py @@ -17,10 +17,13 @@ def counts(scheduler, allprogress): - return merge({'all': valmap(len, allprogress.all), - 'nbytes': allprogress.nbytes}, - {state: valmap(len, allprogress.state[state]) - for state in ['memory', 'erred', 'released', 'processing']}) + return merge( + {"all": valmap(len, allprogress.all), "nbytes": allprogress.nbytes}, + { + state: valmap(len, allprogress.state[state]) + for state in ["memory", "erred", "released", "processing"] + }, + ) @gen.coroutine @@ -44,11 +47,15 @@ def progress_stream(address, interval): """ address = coerce_to_address(address) comm = yield connect(address) - yield comm.write({'op': 'feed', - 'setup': dumps_function(AllProgress), - 'function': dumps_function(counts), - 'interval': interval, - 'teardown': dumps_function(Scheduler.remove_plugin)}) + yield comm.write( + { + "op": "feed", + "setup": dumps_function(AllProgress), + "function": dumps_function(counts), + "interval": interval, + "teardown": dumps_function(Scheduler.remove_plugin), + } + ) raise gen.Return(comm) @@ -64,14 +71,16 @@ def nbytes_bar(nbytes): total = sum(nbytes.values()) names = sorted(nbytes) - d = {'name': [], - 'text': [], - 'left': [], - 'right': [], - 'center': [], - 'color': [], - 'percent': [], - 'MB': []} + d = { + "name": [], + "text": [], + "left": [], + "right": [], + "center": [], + "color": [], + "percent": [], + "MB": [], + } if not total: return d @@ -81,17 +90,17 @@ def nbytes_bar(nbytes): left = right right = nbytes[name] / total + left center = (right + left) / 2 - d['MB'].append(nbytes[name] / 1000000) - d['percent'].append(round(nbytes[name] / total * 100, 2)) - d['left'].append(left) - d['right'].append(right) - d['center'].append(center) - d['color'].append(color_of(name)) - d['name'].append(name) + d["MB"].append(nbytes[name] / 1000000) + d["percent"].append(round(nbytes[name] / total * 100, 2)) + d["left"].append(left) + d["right"].append(right) + d["center"].append(center) + d["color"].append(color_of(name)) + d["name"].append(name) if right - left > 0.1: - d['text'].append(name) + d["text"].append(name) else: - d['text'].append('') + d["text"].append("") return d @@ -122,92 +131,98 @@ def progress_quads(msg, nrows=8, ncols=3): 'processing-loc': [4 / 5, 1 / 1, 1]}} """ width = 0.9 - names = sorted(msg['all'], key=msg['all'].get, reverse=True) - names = names[:nrows * ncols] + names = sorted(msg["all"], key=msg["all"].get, reverse=True) + names = names[: nrows * ncols] n = len(names) d = {k: [v.get(name, 0) for name in names] for k, v in msg.items()} - d['name'] = names - d['show-name'] = [name if len(name) <= 15 else name[:12] + '...' - for name in names] - d['left'] = [i // nrows for i in range(n)] - d['right'] = [i // nrows + width for i in range(n)] - d['top'] = [-(i % nrows) for i in range(n)] - d['bottom'] = [-(i % nrows) - 0.8 for i in range(n)] - d['color'] = [color_of(name) for name in names] - - d['released-loc'] = [] - d['memory-loc'] = [] - d['erred-loc'] = [] - d['processing-loc'] = [] - d['done'] = [] - for r, m, e, p, a, l in zip(d['released'], d['memory'], d['erred'], - d['processing'], d['all'], d['left']): + d["name"] = names + d["show-name"] = [name if len(name) <= 15 else name[:12] + "..." for name in names] + d["left"] = [i // nrows for i in range(n)] + d["right"] = [i // nrows + width for i in range(n)] + d["top"] = [-(i % nrows) for i in range(n)] + d["bottom"] = [-(i % nrows) - 0.8 for i in range(n)] + d["color"] = [color_of(name) for name in names] + + d["released-loc"] = [] + d["memory-loc"] = [] + d["erred-loc"] = [] + d["processing-loc"] = [] + d["done"] = [] + for r, m, e, p, a, l in zip( + d["released"], d["memory"], d["erred"], d["processing"], d["all"], d["left"] + ): rl = width * r / a + l ml = width * (r + m) / a + l el = width * (r + m + e) / a + l pl = width * (p + r + m + e) / a + l - done = '%d / %d' % (r + m + e, a) - d['released-loc'].append(rl) - d['memory-loc'].append(ml) - d['erred-loc'].append(el) - d['processing-loc'].append(pl) - d['done'].append(done) + done = "%d / %d" % (r + m + e, a) + d["released-loc"].append(rl) + d["memory-loc"].append(ml) + d["erred-loc"].append(el) + d["processing-loc"].append(pl) + d["done"].append(done) return d def color_of_message(msg): - if msg['status'] == 'OK': - split = key_split(msg['key']) + if msg["status"] == "OK": + split = key_split(msg["key"]) return color_of(split) else: - return 'black' + return "black" -colors = {'transfer': 'red', - 'disk-write': 'orange', - 'disk-read': 'orange', - 'deserialize': 'gray', - 'compute': color_of_message} +colors = { + "transfer": "red", + "disk-write": "orange", + "disk-read": "orange", + "deserialize": "gray", + "compute": color_of_message, +} -alphas = {'transfer': 0.4, - 'compute': 1, - 'deserialize': 0.4, - 'disk-write': 0.4, - 'disk-read': 0.4} +alphas = { + "transfer": 0.4, + "compute": 1, + "deserialize": 0.4, + "disk-write": 0.4, + "disk-read": 0.4, +} -prefix = {'transfer': 'transfer-', - 'disk-write': 'disk-write-', - 'disk-read': 'disk-read-', - 'deserialize': 'deserialize-', - 'compute': ''} +prefix = { + "transfer": "transfer-", + "disk-write": "disk-write-", + "disk-read": "disk-read-", + "deserialize": "deserialize-", + "compute": "", +} def task_stream_append(lists, msg, workers): - key = msg['key'] + key = msg["key"] name = key_split(key) - startstops = msg.get('startstops', []) + startstops = msg.get("startstops", []) for action, start, stop in startstops: color = colors[action] if type(color) is not str: color = color(msg) - lists['start'].append((start + stop) / 2 * 1000) - lists['duration'].append(1000 * (stop - start)) - lists['key'].append(key) - lists['name'].append(prefix[action] + name) - lists['color'].append(color) - lists['alpha'].append(alphas[action]) - lists['worker'].append(msg['worker']) + lists["start"].append((start + stop) / 2 * 1000) + lists["duration"].append(1000 * (stop - start)) + lists["key"].append(key) + lists["name"].append(prefix[action] + name) + lists["color"].append(color) + lists["alpha"].append(alphas[action]) + lists["worker"].append(msg["worker"]) - worker_thread = '%s-%d' % (msg['worker'], msg['thread']) - lists['worker_thread'].append(worker_thread) + worker_thread = "%s-%d" % (msg["worker"], msg["thread"]) + lists["worker_thread"].append(worker_thread) if worker_thread not in workers: workers[worker_thread] = len(workers) / 2 - lists['y'].append(workers[worker_thread]) + lists["y"].append(workers[worker_thread]) return len(startstops) diff --git a/distributed/diagnostics/progressbar.py b/distributed/diagnostics/progressbar.py index 38f784e9cf6..08ba8f7da63 100644 --- a/distributed/diagnostics/progressbar.py +++ b/distributed/diagnostics/progressbar.py @@ -15,8 +15,7 @@ from ..core import connect, coerce_to_address, CommClosedError from ..client import default_client, futures_of from ..protocol.pickle import dumps -from ..utils import (ignoring, key_split, is_kernel, LoopRunner, - parse_timedelta) +from ..utils import ignoring, key_split, is_kernel, LoopRunner, parse_timedelta logger = logging.getLogger(__name__) @@ -29,17 +28,17 @@ def get_scheduler(scheduler): class ProgressBar(object): - def __init__(self, keys, scheduler=None, interval='100ms', complete=True): + def __init__(self, keys, scheduler=None, interval="100ms", complete=True): self.scheduler = get_scheduler(scheduler) self.client = None for key in keys: - if hasattr(key, 'client'): + if hasattr(key, "client"): self.client = weakref.ref(key.client) break - self.keys = {k.key if hasattr(k, 'key') else k for k in keys} - self.interval = parse_timedelta(interval, default='s') + self.keys = {k.key if hasattr(k, "key") else k for k in keys} + self.interval = parse_timedelta(interval, default="s") self.complete = complete self._start_time = default_timer() @@ -59,34 +58,42 @@ def setup(scheduler): raise gen.Return(p) def function(scheduler, p): - result = {'all': len(p.all_keys), - 'remaining': len(p.keys), - 'status': p.status} - if p.status == 'error': + result = { + "all": len(p.all_keys), + "remaining": len(p.keys), + "status": p.status, + } + if p.status == "error": result.update(p.extra) return result - self.comm = yield connect(self.scheduler, - connection_args=self.client().connection_args - if self.client else None) + self.comm = yield connect( + self.scheduler, + connection_args=self.client().connection_args if self.client else None, + ) logger.debug("Progressbar Connected to scheduler") - yield self.comm.write({'op': 'feed', - 'setup': dumps(setup), - 'function': dumps(function), - 'interval': self.interval}, - serializers=self.client()._serializers if self.client else None) + yield self.comm.write( + { + "op": "feed", + "setup": dumps(setup), + "function": dumps(function), + "interval": self.interval, + }, + serializers=self.client()._serializers if self.client else None, + ) while True: try: - response = yield self.comm.read(deserializers=self.client()._deserializers - if self.client else None) + response = yield self.comm.read( + deserializers=self.client()._deserializers if self.client else None + ) except CommClosedError: break self._last_response = response - self.status = response['status'] + self.status = response["status"] self._draw_bar(**response) - if response['status'] in ('error', 'finished'): + if response["status"] in ("error", "finished"): yield self.comm.close() self._draw_stop(**response) break @@ -102,10 +109,17 @@ def __del__(self): class TextProgressBar(ProgressBar): - def __init__(self, keys, scheduler=None, interval='100ms', width=40, - loop=None, complete=True, start=True): - super(TextProgressBar, self).__init__(keys, scheduler, interval, - complete) + def __init__( + self, + keys, + scheduler=None, + interval="100ms", + width=40, + loop=None, + complete=True, + start=True, + ): + super(TextProgressBar, self).__init__(keys, scheduler, interval, complete) self.width = width self.loop = loop or IOLoop() @@ -115,17 +129,18 @@ def __init__(self, keys, scheduler=None, interval='100ms', width=40, def _draw_bar(self, remaining, all, **kwargs): frac = (1 - remaining / all) if all else 1.0 - bar = '#' * int(self.width * frac) + bar = "#" * int(self.width * frac) percent = int(100 * frac) elapsed = format_time(self.elapsed) - msg = '\r[{0:<{1}}] | {2}% Completed | {3}'.format(bar, self.width, - percent, elapsed) + msg = "\r[{0:<{1}}] | {2}% Completed | {3}".format( + bar, self.width, percent, elapsed + ) with ignoring(ValueError): sys.stdout.write(msg) sys.stdout.flush() def _draw_stop(self, **kwargs): - sys.stdout.write('\r') + sys.stdout.write("\r") sys.stdout.flush() @@ -138,15 +153,16 @@ class ProgressWidget(ProgressBar): TextProgressBar: Text version suitable for the console """ - def __init__(self, keys, scheduler=None, interval='100ms', - complete=False, loop=None): - super(ProgressWidget, self).__init__(keys, scheduler, interval, - complete) + def __init__( + self, keys, scheduler=None, interval="100ms", complete=False, loop=None + ): + super(ProgressWidget, self).__init__(keys, scheduler, interval, complete) from ipywidgets import FloatProgress, HBox, VBox, HTML - self.elapsed_time = HTML('') - self.bar = FloatProgress(min=0, max=1, description='') - self.bar_text = HTML('') + + self.elapsed_time = HTML("") + self.bar = FloatProgress(min=0, max=1, description="") + self.bar_text = HTML("") self.bar_widget = HBox([self.bar_text, self.bar]) self.widget = VBox([self.elapsed_time, self.bar_widget]) @@ -156,38 +172,52 @@ def _ipython_display_(self, **kwargs): return self.widget._ipython_display_(**kwargs) def _draw_stop(self, remaining, status, exception=None, **kwargs): - if status == 'error': - self.bar.bar_style = 'danger' + if status == "error": + self.bar.bar_style = "danger" self.elapsed_time.value = ( - '
        Exception ' - '' + repr(exception) + ':' + - format_time(self.elapsed) + ' ' + - '
        ' + '
        Exception ' + "" + + repr(exception) + + ":" + + format_time(self.elapsed) + + " " + + "
        " ) elif not remaining: - self.bar.bar_style = 'success' - self.elapsed_time.value = '
        Finished: ' + \ - format_time(self.elapsed) + '
        ' + self.bar.bar_style = "success" + self.elapsed_time.value = ( + '
        Finished: ' + + format_time(self.elapsed) + + "
        " + ) def _draw_bar(self, remaining, all, **kwargs): ndone = all - remaining - self.elapsed_time.value = '
        Computing: ' + \ - format_time(self.elapsed) + '
        ' + self.elapsed_time.value = ( + '
        Computing: ' + + format_time(self.elapsed) + + "
        " + ) self.bar.value = ndone / all if all else 1.0 - self.bar_text.value = '
        %d / %d
        ' % (ndone, all) + self.bar_text.value = ( + '
        %d / %d
        ' + % (ndone, all) + ) class MultiProgressBar(object): - def __init__(self, keys, scheduler=None, func=key_split, interval='100ms', complete=False): + def __init__( + self, keys, scheduler=None, func=key_split, interval="100ms", complete=False + ): self.scheduler = get_scheduler(scheduler) self.client = None for key in keys: - if hasattr(key, 'client'): + if hasattr(key, "client"): self.client = weakref.ref(key.client) break - self.keys = {k.key if hasattr(k, 'key') else k for k in keys} + self.keys = {k.key if hasattr(k, "key") else k for k in keys} self.func = func self.interval = interval self.complete = complete @@ -210,30 +240,38 @@ def setup(scheduler): raise gen.Return(p) def function(scheduler, p): - result = {'all': valmap(len, p.all_keys), - 'remaining': valmap(len, p.keys), - 'status': p.status} - if p.status == 'error': + result = { + "all": valmap(len, p.all_keys), + "remaining": valmap(len, p.keys), + "status": p.status, + } + if p.status == "error": result.update(p.extra) return result - self.comm = yield connect(self.scheduler, - connection_args=self.client().connection_args - if self.client else None) + self.comm = yield connect( + self.scheduler, + connection_args=self.client().connection_args if self.client else None, + ) logger.debug("Progressbar Connected to scheduler") - yield self.comm.write({'op': 'feed', - 'setup': dumps(setup), - 'function': dumps(function), - 'interval': self.interval}) + yield self.comm.write( + { + "op": "feed", + "setup": dumps(setup), + "function": dumps(function), + "interval": self.interval, + } + ) while True: - response = yield self.comm.read(deserializers=self.client()._deserializers if - self.client else None) + response = yield self.comm.read( + deserializers=self.client()._deserializers if self.client else None + ) self._last_response = response - self.status = response['status'] + self.status = response["status"] self._draw_bar(**response) - if response['status'] in ('error', 'finished'): + if response["status"] in ("error", "finished"): yield self.comm.close() self._draw_stop(**response) break @@ -260,26 +298,38 @@ class MultiProgressWidget(MultiProgressBar): ProgressWidget: Single progress bar widget """ - def __init__(self, keys, scheduler=None, minimum=0, interval=0.1, func=key_split, - complete=False): - super(MultiProgressWidget, self).__init__(keys, scheduler, func, interval, complete) + def __init__( + self, + keys, + scheduler=None, + minimum=0, + interval=0.1, + func=key_split, + complete=False, + ): + super(MultiProgressWidget, self).__init__( + keys, scheduler, func, interval, complete + ) from ipywidgets import VBox + self.widget = VBox([]) def make_widget(self, all): from ipywidgets import FloatProgress, HBox, VBox, HTML - self.elapsed_time = HTML('') - self.bars = {key: FloatProgress(min=0, max=1, description='') - for key in all} - self.bar_texts = {key: HTML('') for key in all} - self.bar_labels = {key: HTML('
        ' + - html_escape(key.decode() - if isinstance(key, bytes) - else key) + - '
        ') - for key in all} + + self.elapsed_time = HTML("") + self.bars = {key: FloatProgress(min=0, max=1, description="") for key in all} + self.bar_texts = {key: HTML("") for key in all} + self.bar_labels = { + key: HTML( + '
        ' + + html_escape(key.decode() if isinstance(key, bytes) else key) + + "
        " + ) + for key in all + } def keyfunc(kv): """ Order keys by most numerous, then by string name """ @@ -287,10 +337,12 @@ def keyfunc(kv): key_order = [k for k, v in sorted(all.items(), key=keyfunc, reverse=True)] - self.bar_widgets = VBox([HBox([self.bar_texts[key], - self.bars[key], - self.bar_labels[key]]) - for key in key_order]) + self.bar_widgets = VBox( + [ + HBox([self.bar_texts[key], self.bars[key], self.bar_labels[key]]) + for key in key_order + ] + ) self.widget.children = (self.elapsed_time, self.bar_widgets) def _ipython_display_(self, **kwargs): @@ -300,32 +352,43 @@ def _ipython_display_(self, **kwargs): def _draw_stop(self, remaining, status, exception=None, key=None, **kwargs): for k, v in remaining.items(): if not v: - self.bars[k].bar_style = 'success' + self.bars[k].bar_style = "success" else: - self.bars[k].bar_style = 'danger' + self.bars[k].bar_style = "danger" - if status == 'error': + if status == "error": # self.bars[self.func(key)].bar_style = 'danger' # TODO self.elapsed_time.value = ( - '
        Exception ' + - '' + repr(exception) + ':' + - format_time(self.elapsed) + ' ' + - '
        ' + '
        Exception ' + + "" + + repr(exception) + + ":" + + format_time(self.elapsed) + + " " + + "
        " ) else: - self.elapsed_time.value = '
        Finished: ' + \ - format_time(self.elapsed) + '
        ' + self.elapsed_time.value = ( + '
        Finished: ' + + format_time(self.elapsed) + + "
        " + ) def _draw_bar(self, remaining, all, status, **kwargs): if self.keys and not self.widget.children: self.make_widget(all) for k, ntasks in all.items(): ndone = ntasks - remaining[k] - self.elapsed_time.value = '
        Computing: ' + \ - format_time(self.elapsed) + '
        ' + self.elapsed_time.value = ( + '
        Computing: ' + + format_time(self.elapsed) + + "
        " + ) self.bars[k].value = ndone / ntasks if ntasks else 1.0 - self.bar_texts[k].value = '
        %d / %d
        ' % ( - ndone, ntasks) + self.bar_texts[k].value = ( + '
        %d / %d
        ' + % (ndone, ntasks) + ) def progress(*futures, **kwargs): @@ -359,9 +422,9 @@ def progress(*futures, **kwargs): >>> progress(futures) # doctest: +SKIP [########################################] | 100% Completed | 1.7s """ - notebook = kwargs.pop('notebook', None) - multi = kwargs.pop('multi', True) - complete = kwargs.pop('complete', True) + notebook = kwargs.pop("notebook", None) + multi = kwargs.pop("multi", True) + complete = kwargs.pop("complete", True) assert not kwargs futures = futures_of(futures) diff --git a/distributed/diagnostics/task_stream.py b/distributed/diagnostics/task_stream.py index fafcbebd5cb..89cacb67c97 100644 --- a/distributed/diagnostics/task_stream.py +++ b/distributed/diagnostics/task_stream.py @@ -20,11 +20,11 @@ def __init__(self, scheduler, maxlen=100000): self.index = 0 def transition(self, key, start, finish, *args, **kwargs): - if start == 'processing': + if start == "processing": if key not in self.scheduler.tasks: return - kwargs['key'] = key - if finish == 'memory' or finish == 'erred': + kwargs["key"] = key + if finish == "memory" or finish == "erred": self.buffer.append(kwargs) self.index += 1 @@ -34,7 +34,7 @@ def bisect(target, left, right): return left mid = (left + right) // 2 - value = max(stop for _, start, stop in self.buffer[mid]['startstops']) + value = max(stop for _, start, stop in self.buffer[mid]["startstops"]) if value < target: return bisect(target, mid + 1, right) @@ -98,16 +98,15 @@ def rectangles(msgs, workers=None, start_boundary=0): L_y = [] for msg in msgs: - key = msg['key'] + key = msg["key"] name = key_split(key) - startstops = msg.get('startstops', []) + startstops = msg.get("startstops", []) try: - worker_thread = '%s-%d' % (msg['worker'], msg['thread']) + worker_thread = "%s-%d" % (msg["worker"], msg["thread"]) except Exception: continue - logger.warning("Message contained bad information: %s", msg, - exc_info=True) - worker_thread = '' + logger.warning("Message contained bad information: %s", msg, exc_info=True) + worker_thread = "" if worker_thread not in workers: workers[worker_thread] = len(workers) / 2 @@ -126,46 +125,54 @@ def rectangles(msgs, workers=None, start_boundary=0): L_name.append(prefix[action] + name) L_color.append(color) L_alpha.append(alphas[action]) - L_worker.append(msg['worker']) + L_worker.append(msg["worker"]) L_worker_thread.append(worker_thread) L_y.append(workers[worker_thread]) - return {'start': L_start, - 'duration': L_duration, - 'duration_text': L_duration_text, - 'key': L_key, - 'name': L_name, - 'color': L_color, - 'alpha': L_alpha, - 'worker': L_worker, - 'worker_thread': L_worker_thread, - 'y': L_y} + return { + "start": L_start, + "duration": L_duration, + "duration_text": L_duration_text, + "key": L_key, + "name": L_name, + "color": L_color, + "alpha": L_alpha, + "worker": L_worker, + "worker_thread": L_worker_thread, + "y": L_y, + } def color_of_message(msg): - if msg['status'] == 'OK': - split = key_split(msg['key']) + if msg["status"] == "OK": + split = key_split(msg["key"]) return color_of(split) else: - return 'black' - - -colors = {'transfer': 'red', - 'disk-write': 'orange', - 'disk-read': 'orange', - 'deserialize': 'gray', - 'compute': color_of_message} - - -alphas = {'transfer': 0.4, - 'compute': 1, - 'deserialize': 0.4, - 'disk-write': 0.4, - 'disk-read': 0.4} - - -prefix = {'transfer': 'transfer-', - 'disk-write': 'disk-write-', - 'disk-read': 'disk-read-', - 'deserialize': 'deserialize-', - 'compute': ''} + return "black" + + +colors = { + "transfer": "red", + "disk-write": "orange", + "disk-read": "orange", + "deserialize": "gray", + "compute": color_of_message, +} + + +alphas = { + "transfer": 0.4, + "compute": 1, + "deserialize": 0.4, + "disk-write": 0.4, + "disk-read": 0.4, +} + + +prefix = { + "transfer": "transfer-", + "disk-write": "disk-write-", + "disk-read": "disk-read-", + "deserialize": "deserialize-", + "compute": "", +} diff --git a/distributed/diagnostics/tests/test_eventstream.py b/distributed/diagnostics/tests/test_eventstream.py index 7504a59846b..0995d80db26 100644 --- a/distributed/diagnostics/tests/test_eventstream.py +++ b/distributed/diagnostics/tests/test_eventstream.py @@ -11,9 +11,9 @@ from distributed.utils_test import div, gen_cluster -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 3) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 3) def test_eventstream(c, s, *workers): - pytest.importorskip('bokeh') + pytest.importorskip("bokeh") es = EventStream() s.add_plugin(es) @@ -28,17 +28,18 @@ def test_eventstream(c, s, *workers): from distributed.bokeh import messages from distributed.diagnostics.progress_stream import task_stream_append - lists = deepcopy(messages['task-events']['rectangles']) + + lists = deepcopy(messages["task-events"]["rectangles"]) workers = dict() for msg in es.buffer: task_stream_append(lists, msg, workers) - assert len([n for n in lists['name'] if n.startswith('transfer')]) == 2 - for name, color in zip(lists['name'], lists['color']): - if name == 'transfer': - assert color == 'red' + assert len([n for n in lists["name"] if n.startswith("transfer")]) == 2 + for name, color in zip(lists["name"], lists["color"]): + if name == "transfer": + assert color == "red" - assert any(c == 'black' for c in lists['color']) + assert any(c == "black" for c in lists["color"]) @gen_cluster(client=True) diff --git a/distributed/diagnostics/tests/test_graph_layout.py b/distributed/diagnostics/tests/test_graph_layout.py index 63ecb0c7008..fc8fba8d028 100644 --- a/distributed/diagnostics/tests/test_graph_layout.py +++ b/distributed/diagnostics/tests/test_graph_layout.py @@ -45,9 +45,9 @@ def test_states(c, s, a, b): yield total updates = {state for idx, state in gl.state_updates} - assert 'memory' in updates - assert 'processing' in updates - assert 'released' in updates + assert "memory" in updates + assert "processing" in updates + assert "released" in updates @gen_cluster(client=True) diff --git a/distributed/diagnostics/tests/test_plugin.py b/distributed/diagnostics/tests/test_plugin.py index afc1f4987d5..b1d5406e052 100644 --- a/distributed/diagnostics/tests/test_plugin.py +++ b/distributed/diagnostics/tests/test_plugin.py @@ -7,7 +7,6 @@ @gen_cluster(client=True) def test_simple(c, s, a, b): - class Counter(SchedulerPlugin): def start(self, scheduler): self.scheduler = scheduler @@ -15,7 +14,7 @@ def start(self, scheduler): self.count = 0 def transition(self, key, start, finish, *args, **kwargs): - if start == 'processing' and finish == 'memory': + if start == "processing" and finish == "memory": self.count += 1 counter = Counter() @@ -42,11 +41,11 @@ def test_add_remove_worker(s): class MyPlugin(SchedulerPlugin): def add_worker(self, worker, scheduler): assert scheduler is s - events.append(('add_worker', worker)) + events.append(("add_worker", worker)) def remove_worker(self, worker, scheduler): assert scheduler is s - events.append(('remove_worker', worker)) + events.append(("remove_worker", worker)) plugin = MyPlugin() s.add_plugin(plugin) @@ -59,11 +58,12 @@ def remove_worker(self, worker, scheduler): yield a._close() yield b._close() - assert events == [('add_worker', a.address), - ('add_worker', b.address), - ('remove_worker', a.address), - ('remove_worker', b.address), - ] + assert events == [ + ("add_worker", a.address), + ("add_worker", b.address), + ("remove_worker", a.address), + ("remove_worker", b.address), + ] events[:] = [] s.remove_plugin(plugin) diff --git a/distributed/diagnostics/tests/test_progress.py b/distributed/diagnostics/tests/test_progress.py index 2d88054a34b..d8435cc7ff0 100644 --- a/distributed/diagnostics/tests/test_progress.py +++ b/distributed/diagnostics/tests/test_progress.py @@ -8,8 +8,13 @@ from distributed.client import wait from distributed.metrics import time from distributed.utils_test import gen_cluster, inc, dec, div, nodebug -from distributed.diagnostics.progress import (Progress, SchedulerPlugin, - AllProgress, GroupProgress, MultiProgress) +from distributed.diagnostics.progress import ( + Progress, + SchedulerPlugin, + AllProgress, + GroupProgress, + MultiProgress, +) def f(*args): @@ -37,7 +42,7 @@ def test_many_Progress(c, s, a, b): yield z start = time() - while not all(b.status == 'finished' for b in bars): + while not all(b.status == "finished" for b in bars): yield gen.sleep(0.1) assert time() < start + 5 @@ -53,19 +58,20 @@ def test_multiprogress(c, s, a, b): p = MultiProgress([y2], scheduler=s, complete=True) yield p.setup() - assert p.all_keys == {'f': {f.key for f in [x1, x2, x3]}, - 'g': {f.key for f in [y1, y2]}} + assert p.all_keys == { + "f": {f.key for f in [x1, x2, x3]}, + "g": {f.key for f in [y1, y2]}, + } yield x3 - assert p.keys['f'] == set() + assert p.keys["f"] == set() yield y2 - assert p.keys == {'f': set(), - 'g': set()} + assert p.keys == {"f": set(), "g": set()} - assert p.status == 'finished' + assert p.status == "finished" @gen_cluster(client=True) @@ -85,9 +91,9 @@ def transition(self, key, start, finish, **kwargs): def check_bar_completed(capsys, width=40): out, err = capsys.readouterr() - bar, percent, time = [i.strip() for i in out.split('\r')[-1].split('|')] - assert bar == '[' + '#' * width + ']' - assert percent == '100% Completed' + bar, percent, time = [i.strip() for i in out.split("\r")[-1].split("|")] + assert bar == "[" + "#" * width + "]" + assert percent == "100% Completed" @gen_cluster(client=True, Worker=Nanny, timeout=None) @@ -97,20 +103,20 @@ def test_AllProgress(c, s, a, b): yield wait([x, y, z]) p = AllProgress(s) - assert p.all['inc'] == {x.key, y.key, z.key} - assert p.state['memory']['inc'] == {x.key, y.key, z.key} - assert p.state['released'] == {} - assert p.state['erred'] == {} - assert 'inc' in p.nbytes - assert isinstance(p.nbytes['inc'], int) - assert p.nbytes['inc'] > 0 + assert p.all["inc"] == {x.key, y.key, z.key} + assert p.state["memory"]["inc"] == {x.key, y.key, z.key} + assert p.state["released"] == {} + assert p.state["erred"] == {} + assert "inc" in p.nbytes + assert isinstance(p.nbytes["inc"], int) + assert p.nbytes["inc"] > 0 yield wait([xx, yy, zz]) - assert p.all['dec'] == {xx.key, yy.key, zz.key} - assert p.state['memory']['dec'] == {xx.key, yy.key, zz.key} - assert p.state['released'] == {} - assert p.state['erred'] == {} - assert p.nbytes['inc'] == p.nbytes['dec'] + assert p.all["dec"] == {xx.key, yy.key, zz.key} + assert p.state["memory"]["dec"] == {xx.key, yy.key, zz.key} + assert p.state["released"] == {} + assert p.state["erred"] == {} + assert p.nbytes["inc"] == p.nbytes["dec"] t = c.submit(sum, [x, y, z]) yield t @@ -118,32 +124,34 @@ def test_AllProgress(c, s, a, b): keys = {x.key, y.key, z.key} del x, y, z import gc + gc.collect() while any(k in s.who_has for k in keys): yield gen.sleep(0.01) - assert p.state['released']['inc'] == keys - assert p.all['inc'] == keys - assert p.all['dec'] == {xx.key, yy.key, zz.key} - if 'inc' in p.nbytes: - assert p.nbytes['inc'] == 0 + assert p.state["released"]["inc"] == keys + assert p.all["inc"] == keys + assert p.all["dec"] == {xx.key, yy.key, zz.key} + if "inc" in p.nbytes: + assert p.nbytes["inc"] == 0 xxx = c.submit(div, 1, 0) yield wait([xxx]) - assert p.state['erred'] == {'div': {xxx.key}} + assert p.state["erred"] == {"div": {xxx.key}} tkey = t.key del xx, yy, zz, t import gc + gc.collect() while tkey in s.tasks: yield gen.sleep(0.01) for coll in [p.all, p.nbytes] + list(p.state.values()): - assert 'inc' not in coll - assert 'dec' not in coll + assert "inc" not in coll + assert "dec" not in coll def f(x): return x @@ -151,12 +159,13 @@ def f(x): for i in range(4): future = c.submit(f, i) import gc + gc.collect() yield gen.sleep(1) yield wait([future]) - assert p.state['memory'] == {'f': {future.key}} + assert p.state["memory"] == {"f": {future.key}} yield c._restart() @@ -165,8 +174,8 @@ def f(x): x = c.submit(div, 1, 2) yield wait([x]) - assert set(p.all) == {'div'} - assert all(set(d) == {'div'} for d in p.state.values()) + assert set(p.all) == {"div"} + assert all(set(d) == {"div"} for d in p.state.values()) @gen_cluster(client=True, Worker=Nanny) @@ -174,20 +183,20 @@ def test_AllProgress_lost_key(c, s, a, b, timeout=None): p = AllProgress(s) futures = c.map(inc, range(5)) yield wait(futures) - assert len(p.state['memory']['inc']) == 5 + assert len(p.state["memory"]["inc"]) == 5 yield a._close() yield b._close() start = time() - while len(p.state['memory']['inc']) > 0: + while len(p.state["memory"]["inc"]) > 0: yield gen.sleep(0.1) assert time() < start + 5 @gen_cluster(client=True) def test_GroupProgress(c, s, a, b): - da = pytest.importorskip('dask.array') + da = pytest.importorskip("dask.array") fp = GroupProgress(s) x = da.ones(100, chunks=10) y = x + 1 diff --git a/distributed/diagnostics/tests/test_progress_stream.py b/distributed/diagnostics/tests/test_progress_stream.py index ce21bb34193..9cf89817f34 100644 --- a/distributed/diagnostics/tests/test_progress_stream.py +++ b/distributed/diagnostics/tests/test_progress_stream.py @@ -2,54 +2,64 @@ import pytest -pytest.importorskip('bokeh') + +pytest.importorskip("bokeh") from dask import delayed from distributed.client import wait -from distributed.diagnostics.progress_stream import (progress_quads, - nbytes_bar, progress_stream) +from distributed.diagnostics.progress_stream import ( + progress_quads, + nbytes_bar, + progress_stream, +) from distributed.utils_test import div, gen_cluster, inc def test_progress_quads(): - msg = {'all': {'inc': 5, 'dec': 1, 'add': 4}, - 'memory': {'inc': 2, 'dec': 0, 'add': 1}, - 'erred': {'inc': 0, 'dec': 1, 'add': 0}, - 'released': {'inc': 1, 'dec': 0, 'add': 1}, - 'processing': {'inc': 1, 'dec': 0, 'add': 2}} + msg = { + "all": {"inc": 5, "dec": 1, "add": 4}, + "memory": {"inc": 2, "dec": 0, "add": 1}, + "erred": {"inc": 0, "dec": 1, "add": 0}, + "released": {"inc": 1, "dec": 0, "add": 1}, + "processing": {"inc": 1, "dec": 0, "add": 2}, + } d = progress_quads(msg, nrows=2) - color = d.pop('color') + color = d.pop("color") assert len(set(color)) == 3 - expected = {'name': ['inc', 'add', 'dec'], - 'show-name': ['inc', 'add', 'dec'], - 'left': [0, 0, 1], - 'right': [0.9, 0.9, 1.9], - 'top': [0, -1, 0], - 'bottom': [-.8, -1.8, -.8], - 'all': [5, 4, 1], - 'released': [1, 1, 0], - 'memory': [2, 1, 0], - 'erred': [0, 0, 1], - 'processing': [1, 2, 0], - 'done': ['3 / 5', '2 / 4', '1 / 1'], - 'released-loc': [.9 * 1 / 5, .25 * 0.9, 1.0], - 'memory-loc': [.9 * 3 / 5, .5 * 0.9, 1.0], - 'erred-loc': [.9 * 3 / 5, .5 * 0.9, 1.9], - 'processing-loc': [.9 * 4 / 5, 1 * 0.9, 1 * 0.9 + 1]} + expected = { + "name": ["inc", "add", "dec"], + "show-name": ["inc", "add", "dec"], + "left": [0, 0, 1], + "right": [0.9, 0.9, 1.9], + "top": [0, -1, 0], + "bottom": [-0.8, -1.8, -0.8], + "all": [5, 4, 1], + "released": [1, 1, 0], + "memory": [2, 1, 0], + "erred": [0, 0, 1], + "processing": [1, 2, 0], + "done": ["3 / 5", "2 / 4", "1 / 1"], + "released-loc": [0.9 * 1 / 5, 0.25 * 0.9, 1.0], + "memory-loc": [0.9 * 3 / 5, 0.5 * 0.9, 1.0], + "erred-loc": [0.9 * 3 / 5, 0.5 * 0.9, 1.9], + "processing-loc": [0.9 * 4 / 5, 1 * 0.9, 1 * 0.9 + 1], + } assert d == expected def test_progress_quads_too_many(): - keys = ['x-%d' % i for i in range(1000)] - msg = {'all': {k: 1 for k in keys}, - 'memory': {k: 0 for k in keys}, - 'erred': {k: 0 for k in keys}, - 'released': {k: 0 for k in keys}, - 'processing': {k: 0 for k in keys}} + keys = ["x-%d" % i for i in range(1000)] + msg = { + "all": {k: 1 for k in keys}, + "memory": {k: 0 for k in keys}, + "erred": {k: 0 for k in keys}, + "released": {k: 0 for k in keys}, + "processing": {k: 0 for k in keys}, + } d = progress_quads(msg, nrows=6, ncols=3) - assert len(d['name']) == 6 * 3 + assert len(d["name"]) == 6 * 3 @gen_cluster(client=True) @@ -65,13 +75,15 @@ def test_progress_stream(c, s, a, b): comm = yield progress_stream(s.address, interval=0.010) msg = yield comm.read() - nbytes = msg.pop('nbytes') - assert msg == {'all': {'div': 10, 'inc': 5}, - 'erred': {'div': 1}, - 'memory': {'div': 9, 'inc': 1}, - 'released': {'inc': 4}, - 'processing': {}} - assert set(nbytes) == set(msg['all']) + nbytes = msg.pop("nbytes") + assert msg == { + "all": {"div": 10, "inc": 5}, + "erred": {"div": 1}, + "memory": {"div": 9, "inc": 1}, + "released": {"inc": 4}, + "processing": {}, + } + assert set(nbytes) == set(msg["all"]) assert all(v > 0 for v in nbytes.values()) assert progress_quads(msg) @@ -80,29 +92,33 @@ def test_progress_stream(c, s, a, b): def test_nbytes_bar(): - nbytes = {'inc': 1000, 'dec': 3000} - expected = {'name': ['dec', 'inc'], - 'left': [0, 0.75], - 'center': [0.375, 0.875], - 'right': [0.75, 1.0], - 'percent': [75, 25], - 'MB': [0.003, 0.001], - 'text': ['dec', 'inc']} + nbytes = {"inc": 1000, "dec": 3000} + expected = { + "name": ["dec", "inc"], + "left": [0, 0.75], + "center": [0.375, 0.875], + "right": [0.75, 1.0], + "percent": [75, 25], + "MB": [0.003, 0.001], + "text": ["dec", "inc"], + } result = nbytes_bar(nbytes) - color = result.pop('color') + color = result.pop("color") assert len(set(color)) == 2 assert result == expected def test_progress_quads_many_functions(): - funcnames = ['fn%d' % i for i in range(1000)] - msg = {'all': {fn: 1 for fn in funcnames}, - 'memory': {fn: 1 for fn in funcnames}, - 'erred': {fn: 0 for fn in funcnames}, - 'released': {fn: 0 for fn in funcnames}, - 'processing': {fn: 0 for fn in funcnames}} + funcnames = ["fn%d" % i for i in range(1000)] + msg = { + "all": {fn: 1 for fn in funcnames}, + "memory": {fn: 1 for fn in funcnames}, + "erred": {fn: 0 for fn in funcnames}, + "released": {fn: 0 for fn in funcnames}, + "processing": {fn: 0 for fn in funcnames}, + } d = progress_quads(msg, nrows=2) - color = d.pop('color') + color = d.pop("color") assert len(set(color)) < 100 diff --git a/distributed/diagnostics/tests/test_progressbar.py b/distributed/diagnostics/tests/test_progressbar.py index 7d75b52eeef..8738cb60e22 100644 --- a/distributed/diagnostics/tests/test_progressbar.py +++ b/distributed/diagnostics/tests/test_progressbar.py @@ -7,7 +7,7 @@ from distributed import Scheduler, Worker from distributed.diagnostics.progressbar import TextProgressBar, progress from distributed.metrics import time -from distributed.utils_test import (inc, div, gen_cluster) +from distributed.utils_test import inc, div, gen_cluster from distributed.utils_test import client, loop, cluster_fixture # noqa: F401 @@ -17,14 +17,12 @@ def test_text_progressbar(capsys, client): client.gather(futures) start = time() - while p.status != 'finished': + while p.status != "finished": sleep(0.01) assert time() - start < 5 check_bar_completed(capsys) - assert p._last_response == {'all': 10, - 'remaining': 0, - 'status': 'finished'} + assert p._last_response == {"all": 10, "remaining": 0, "status": "finished"} assert p.comm.closed() @@ -32,17 +30,19 @@ def test_text_progressbar(capsys, client): def test_TextProgressBar_error(c, s, a, b): x = c.submit(div, 1, 0) - progress = TextProgressBar([x.key], scheduler=(s.ip, s.port), - start=False, interval=0.01) + progress = TextProgressBar( + [x.key], scheduler=(s.ip, s.port), start=False, interval=0.01 + ) yield progress.listen() - assert progress.status == 'error' + assert progress.status == "error" assert progress.comm.closed() - progress = TextProgressBar([x.key], scheduler=(s.ip, s.port), - start=False, interval=0.01) + progress = TextProgressBar( + [x.key], scheduler=(s.ip, s.port), start=False, interval=0.01 + ) yield progress.listen() - assert progress.status == 'error' + assert progress.status == "error" assert progress.comm.closed() @@ -55,11 +55,12 @@ def f(): b = Worker(s.ip, s.port, loop=loop, ncores=1) yield [a._start(0), b._start(0)] - progress = TextProgressBar([], scheduler=(s.ip, s.port), start=False, - interval=0.01) + progress = TextProgressBar( + [], scheduler=(s.ip, s.port), start=False, interval=0.01 + ) yield progress.listen() - assert progress.status == 'finished' + assert progress.status == "finished" check_bar_completed(capsys) yield [a._close(), b._close()] @@ -72,9 +73,9 @@ def f(): def check_bar_completed(capsys, width=40): out, err = capsys.readouterr() # trailing newline so grab next to last line for final state of bar - bar, percent, time = [i.strip() for i in out.split('\r')[-2].split('|')] - assert bar == '[' + '#' * width + ']' - assert percent == '100% Completed' + bar, percent, time = [i.strip() for i in out.split("\r")[-2].split("|")] + assert bar == "[" + "#" * width + "]" + assert percent == "100% Completed" def test_progress_function(client, capsys): diff --git a/distributed/diagnostics/tests/test_task_stream.py b/distributed/diagnostics/tests/test_task_stream.py index 51bbc9e1021..366de8d79d5 100644 --- a/distributed/diagnostics/tests/test_task_stream.py +++ b/distributed/diagnostics/tests/test_task_stream.py @@ -14,7 +14,7 @@ from distributed.metrics import time -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 3) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 3) def test_TaskStreamPlugin(c, s, *workers): es = TaskStreamPlugin(s) assert not es.buffer @@ -29,20 +29,21 @@ def test_TaskStreamPlugin(c, s, *workers): rects = es.rectangles(0, 10, workers) assert workers - assert all(n == 'div' for n in rects['name']) - assert all(d > 0 for d in rects['duration']) - counts = frequencies(rects['color']) - assert counts['black'] == 1 + assert all(n == "div" for n in rects["name"]) + assert all(d > 0 for d in rects["duration"]) + counts = frequencies(rects["color"]) + assert counts["black"] == 1 assert set(counts.values()) == {9, 1} - assert len(set(rects['y'])) == 3 + assert len(set(rects["y"])) == 3 rects = es.rectangles(2, 5, workers) assert all(len(L) == 3 for L in rects.values()) - starts = sorted(rects['start']) - rects = es.rectangles(2, 5, workers=workers, - start_boundary=(starts[0] + starts[1]) / 2000) - assert set(rects['start']).issubset(set(starts[1:])) + starts = sorted(rects["start"]) + rects = es.rectangles( + 2, 5, workers=workers, start_boundary=(starts[0] + starts[1]) / 2000 + ) + assert set(rects["start"]).issubset(set(starts[1:])) @gen_cluster(client=True) @@ -68,10 +69,10 @@ def test_collect(c, s, a, b): L = tasks.collect(start=start + 0.2) assert 4 <= len(L) <= len(futures) - L = tasks.collect(start='20 s') + L = tasks.collect(start="20 s") assert len(L) == len(futures) - L = tasks.collect(start='500ms') + L = tasks.collect(start="500ms") assert 0 < len(L) <= len(futures) L = tasks.collect(count=3) @@ -107,7 +108,7 @@ def test_client_sync(client): @gen_cluster(client=True) def test_get_task_stream_plot(c, s, a, b): - bokeh = pytest.importorskip('bokeh') + bokeh = pytest.importorskip("bokeh") yield c.get_task_stream() futures = c.map(slowinc, range(10), delay=0.1) @@ -118,15 +119,15 @@ def test_get_task_stream_plot(c, s, a, b): def test_get_task_stream_save(client, tmpdir): - bokeh = pytest.importorskip('bokeh') + bokeh = pytest.importorskip("bokeh") tmpdir = str(tmpdir) - fn = os.path.join(tmpdir, 'foo.html') + fn = os.path.join(tmpdir, "foo.html") - with get_task_stream(plot='save', filename=fn) as ts: + with get_task_stream(plot="save", filename=fn) as ts: wait(client.map(inc, range(10))) with open(fn) as f: data = f.read() - assert 'inc' in data - assert 'bokeh' in data + assert "inc" in data + assert "bokeh" in data assert isinstance(ts.figure, bokeh.plotting.Figure) diff --git a/distributed/diagnostics/tests/test_widgets.py b/distributed/diagnostics/tests/test_widgets.py index 6ae29161a9c..033d49251cb 100644 --- a/distributed/diagnostics/tests/test_widgets.py +++ b/distributed/diagnostics/tests/test_widgets.py @@ -1,7 +1,8 @@ from __future__ import print_function, division, absolute_import import pytest -pytest.importorskip('ipywidgets') + +pytest.importorskip("ipywidgets") from ipykernel.comm import Comm from ipywidgets import Widget @@ -17,7 +18,7 @@ class DummyComm(Comm): - comm_id = 'a-b-c-d' + comm_id = "a-b-c-d" def open(self, *args, **kwargs): pass @@ -35,12 +36,13 @@ def close(self, *args, **kwargs): def setup(): - _widget_attrs['_comm_default'] = getattr(Widget, '_comm_default', undefined) + _widget_attrs["_comm_default"] = getattr(Widget, "_comm_default", undefined) Widget._comm_default = lambda self: DummyComm() - _widget_attrs['_ipython_display_'] = Widget._ipython_display_ + _widget_attrs["_ipython_display_"] = Widget._ipython_display_ def raise_not_implemented(*args, **kwargs): raise NotImplementedError() + Widget._ipython_display_ = raise_not_implemented @@ -78,11 +80,13 @@ def record_display(*args): from distributed.client import wait from distributed.worker import dumps_task -from distributed.utils_test import (inc, dec, throws, gen_cluster, - gen_tls_cluster) +from distributed.utils_test import inc, dec, throws, gen_cluster, gen_tls_cluster from distributed.utils_test import client, loop, cluster_fixture # noqa: F401 -from distributed.diagnostics.progressbar import (ProgressWidget, - MultiProgressWidget, progress) +from distributed.diagnostics.progressbar import ( + ProgressWidget, + MultiProgressWidget, + progress, +) @gen_cluster(client=True) @@ -96,7 +100,7 @@ def test_progressbar_widget(c, s, a, b): yield progress.listen() assert progress.bar.value == 1.0 - assert '3 / 3' in progress.bar_text.value + assert "3 / 3" in progress.bar_text.value progress = ProgressWidget([z.key], scheduler=s.address) yield progress.listen() @@ -116,49 +120,61 @@ def test_multi_progressbar_widget(c, s, a, b): p = MultiProgressWidget([e.key], scheduler=s.address, complete=True) yield p.listen() - assert p.bars['inc'].value == 1.0 - assert p.bars['dec'].value == 1.0 - assert p.bars['throws'].value == 0.0 - assert '3 / 3' in p.bar_texts['inc'].value - assert '2 / 2' in p.bar_texts['dec'].value - assert '0 / 1' in p.bar_texts['throws'].value + assert p.bars["inc"].value == 1.0 + assert p.bars["dec"].value == 1.0 + assert p.bars["throws"].value == 0.0 + assert "3 / 3" in p.bar_texts["inc"].value + assert "2 / 2" in p.bar_texts["dec"].value + assert "0 / 1" in p.bar_texts["throws"].value - assert p.bars['inc'].bar_style == 'success' - assert p.bars['dec'].bar_style == 'success' - assert p.bars['throws'].bar_style == 'danger' + assert p.bars["inc"].bar_style == "success" + assert p.bars["dec"].bar_style == "success" + assert p.bars["throws"].bar_style == "danger" - assert p.status == 'error' - assert 'Exception' in p.elapsed_time.value + assert p.status == "error" + assert "Exception" in p.elapsed_time.value try: throws(1) except Exception as e: assert repr(e) in p.elapsed_time.value - capacities = [int(re.search(r'\d+ / \d+', row.children[0].value) - .group().split(' / ')[1]) - for row in p.bar_widgets.children] + capacities = [ + int(re.search(r"\d+ / \d+", row.children[0].value).group().split(" / ")[1]) + for row in p.bar_widgets.children + ] assert sorted(capacities, reverse=True) == capacities @gen_cluster() def test_multi_progressbar_widget_after_close(s, a, b): - s.update_graph(tasks=valmap(dumps_task, {'x-1': (inc, 1), - 'x-2': (inc, 'x-1'), - 'x-3': (inc, 'x-2'), - 'y-1': (dec, 'x-3'), - 'y-2': (dec, 'y-1'), - 'e': (throws, 'y-2'), - 'other': (inc, 123)}), - keys=['e'], - dependencies={'x-2': {'x-1'}, 'x-3': {'x-2'}, - 'y-1': {'x-3'}, 'y-2': {'y-1'}, - 'e': {'y-2'}}) - - p = MultiProgressWidget(['x-1', 'x-2', 'x-3'], scheduler=s.address) + s.update_graph( + tasks=valmap( + dumps_task, + { + "x-1": (inc, 1), + "x-2": (inc, "x-1"), + "x-3": (inc, "x-2"), + "y-1": (dec, "x-3"), + "y-2": (dec, "y-1"), + "e": (throws, "y-2"), + "other": (inc, 123), + }, + ), + keys=["e"], + dependencies={ + "x-2": {"x-1"}, + "x-3": {"x-2"}, + "y-1": {"x-3"}, + "y-2": {"y-1"}, + "e": {"y-2"}, + }, + ) + + p = MultiProgressWidget(["x-1", "x-2", "x-3"], scheduler=s.address) yield p.listen() - assert 'x' in p.bars + assert "x" in p.bars def test_values(client): @@ -166,16 +182,16 @@ def test_values(client): wait(L) p = MultiProgressWidget(L) client.sync(p.listen) - assert set(p.bars) == {'inc'} - assert p.status == 'finished' + assert set(p.bars) == {"inc"} + assert p.status == "finished" assert p.comm.closed() - assert '5 / 5' in p.bar_texts['inc'].value - assert p.bars['inc'].value == 1.0 + assert "5 / 5" in p.bar_texts["inc"].value + assert p.bars["inc"].value == 1.0 x = client.submit(throws, 1) p = MultiProgressWidget([x]) client.sync(p.listen) - assert p.status == 'error' + assert p.status == "error" def test_progressbar_done(client): @@ -183,20 +199,20 @@ def test_progressbar_done(client): wait(L) p = ProgressWidget(L) client.sync(p.listen) - assert p.status == 'finished' + assert p.status == "finished" assert p.bar.value == 1.0 - assert p.bar.bar_style == 'success' - assert 'Finished' in p.elapsed_time.value + assert p.bar.bar_style == "success" + assert "Finished" in p.elapsed_time.value f = client.submit(throws, L) wait([f]) p = ProgressWidget([f]) client.sync(p.listen) - assert p.status == 'error' + assert p.status == "error" assert p.bar.value == 0.0 - assert p.bar.bar_style == 'danger' - assert 'Exception' in p.elapsed_time.value + assert p.bar.bar_style == "danger" + assert "Exception" in p.elapsed_time.value try: throws(1) @@ -206,36 +222,48 @@ def test_progressbar_done(client): def test_progressbar_cancel(client): import time + L = [client.submit(lambda: time.sleep(0.3), i) for i in range(5)] p = ProgressWidget(L) client.sync(p.listen) L[-1].cancel() wait(L[:-1]) - assert p.status == 'error' + assert p.status == "error" assert p.bar.value == 0 # no tasks finish before cancel is called @gen_cluster() def test_multibar_complete(s, a, b): - s.update_graph(tasks=valmap(dumps_task, {'x-1': (inc, 1), - 'x-2': (inc, 'x-1'), - 'x-3': (inc, 'x-2'), - 'y-1': (dec, 'x-3'), - 'y-2': (dec, 'y-1'), - 'e': (throws, 'y-2'), - 'other': (inc, 123)}), - keys=['e'], - dependencies={'x-2': {'x-1'}, 'x-3': {'x-2'}, - 'y-1': {'x-3'}, 'y-2': {'y-1'}, - 'e': {'y-2'}}) - - p = MultiProgressWidget(['e'], scheduler=s.address, complete=True) + s.update_graph( + tasks=valmap( + dumps_task, + { + "x-1": (inc, 1), + "x-2": (inc, "x-1"), + "x-3": (inc, "x-2"), + "y-1": (dec, "x-3"), + "y-2": (dec, "y-1"), + "e": (throws, "y-2"), + "other": (inc, 123), + }, + ), + keys=["e"], + dependencies={ + "x-2": {"x-1"}, + "x-3": {"x-2"}, + "y-1": {"x-3"}, + "y-2": {"y-1"}, + "e": {"y-2"}, + }, + ) + + p = MultiProgressWidget(["e"], scheduler=s.address, complete=True) yield p.listen() - assert p._last_response['all'] == {'x': 3, 'y': 2, 'e': 1} - assert all(b.value == 1.0 for k, b in p.bars.items() if k != 'e') - assert '3 / 3' in p.bar_texts['x'].value - assert '2 / 2' in p.bar_texts['y'].value + assert p._last_response["all"] == {"x": 3, "y": 2, "e": 1} + assert all(b.value == 1.0 for k, b in p.bars.items() if k != "e") + assert "3 / 3" in p.bar_texts["x"].value + assert "2 / 2" in p.bar_texts["y"].value def test_fast(client): @@ -244,10 +272,10 @@ def test_fast(client): L3 = client.map(add, L, L2) p = progress(L3, multi=True, complete=True, notebook=True) client.sync(p.listen) - assert set(p._last_response['all']) == {'inc', 'dec', 'add'} + assert set(p._last_response["all"]) == {"inc", "dec", "add"} -@gen_cluster(client=True, client_kwargs={'serializers': ['msgpack']}) +@gen_cluster(client=True, client_kwargs={"serializers": ["msgpack"]}) def test_serializers(c, s, a, b): x = c.submit(inc, 1) y = c.submit(inc, x) @@ -258,7 +286,7 @@ def test_serializers(c, s, a, b): yield progress.listen() assert progress.bar.value == 1.0 - assert '3 / 3' in progress.bar_text.value + assert "3 / 3" in progress.bar_text.value @gen_tls_cluster(client=True) @@ -272,4 +300,4 @@ def test_tls(c, s, a, b): yield progress.listen() assert progress.bar.value == 1.0 - assert '3 / 3' in progress.bar_text.value + assert "3 / 3" in progress.bar_text.value diff --git a/distributed/diskutils.py b/distributed/diskutils.py index ccc3096c038..395f7828505 100644 --- a/distributed/diskutils.py +++ b/distributed/diskutils.py @@ -16,11 +16,11 @@ logger = logging.getLogger(__name__) -DIR_LOCK_EXT = '.dirlock' +DIR_LOCK_EXT = ".dirlock" def is_locking_enabled(): - return dask.config.get('distributed.worker.use-file-locking') + return dask.config.get("distributed.worker.use-file-locking") def safe_unlink(path): @@ -58,24 +58,33 @@ def __init__(self, workspace, name=None, prefix=None): self._lock_file = locket.lock_file(self._lock_path) self._lock_file.acquire() except OSError as e: - logger.exception("Could not acquire workspace lock on " - "path: %s ." - "Continuing without lock. " - "This may result in workspaces not being " - "cleaned up", self._lock_path, - exc_info=True) + logger.exception( + "Could not acquire workspace lock on " + "path: %s ." + "Continuing without lock. " + "This may result in workspaces not being " + "cleaned up", + self._lock_path, + exc_info=True, + ) self._lock_file = None except Exception: shutil.rmtree(self.dir_path, ignore_errors=True) raise workspace._known_locks.add(self._lock_path) - self._finalizer = finalize(self, self._finalize, - workspace, self._lock_path, - self._lock_file, self.dir_path) + self._finalizer = finalize( + self, + self._finalize, + workspace, + self._lock_path, + self._lock_file, + self.dir_path, + ) else: - self._finalizer = finalize(self, self._finalize, - workspace, None, None, self.dir_path) + self._finalizer = finalize( + self, self._finalize, workspace, None, None, self.dir_path + ) def release(self): """ @@ -109,8 +118,8 @@ class WorkSpace(object): def __init__(self, base_dir): self.base_dir = os.path.abspath(base_dir) self._init_workspace() - self._global_lock_path = os.path.join(self.base_dir, 'global.lock') - self._purge_lock_path = os.path.join(self.base_dir, 'purge.lock') + self._global_lock_path = os.path.join(self.base_dir, "global.lock") + self._purge_lock_path = os.path.join(self.base_dir, "purge.lock") def _init_workspace(self): try: @@ -165,7 +174,7 @@ def _purge_leftovers(self): return purged def _list_unknown_locks(self): - for p in glob.glob(os.path.join(self.base_dir, '*' + DIR_LOCK_EXT)): + for p in glob.glob(os.path.join(self.base_dir, "*" + DIR_LOCK_EXT)): try: st = os.stat(p) except EnvironmentError: @@ -199,10 +208,9 @@ def _check_lock_or_purge(self, lock_path): return False try: # Lock file is stale, therefore purge corresponding directory - dir_path = lock_path[:-len(DIR_LOCK_EXT)] + dir_path = lock_path[: -len(DIR_LOCK_EXT)] if os.path.exists(dir_path): - logger.info("Found stale lock file and directory %r, purging", - dir_path) + logger.info("Found stale lock file and directory %r, purging", dir_path) self._purge_directory(dir_path) finally: lock.release() @@ -212,8 +220,7 @@ def _check_lock_or_purge(self, lock_path): def _on_remove_error(self, func, path, exc_info): typ, exc, tb = exc_info - logger.error("Failed to remove %r (failed in %r): %s", - path, func, str(exc)) + logger.error("Failed to remove %r (failed in %r): %s", path, func, str(exc)) def new_work_dir(self, **kwargs): """ @@ -231,6 +238,8 @@ def new_work_dir(self, **kwargs): try: self._purge_leftovers() except OSError: - logger.error("Failed to clean up lingering worker directories " - "in path: %s ", exc_info=True) + logger.error( + "Failed to clean up lingering worker directories " "in path: %s ", + exc_info=True, + ) return WorkDir(self, **kwargs) diff --git a/distributed/lock.py b/distributed/lock.py index 9f1c4390653..d12b1c41e15 100644 --- a/distributed/lock.py +++ b/distributed/lock.py @@ -29,10 +29,11 @@ def __init__(self, scheduler): self.events = defaultdict(deque) self.ids = dict() - self.scheduler.handlers.update({'lock_acquire': self.acquire, - 'lock_release': self.release}) + self.scheduler.handlers.update( + {"lock_acquire": self.acquire, "lock_release": self.release} + ) - self.scheduler.extensions['locks'] = self + self.scheduler.extensions["locks"] = self @gen.coroutine def acquire(self, stream=None, name=None, id=None, timeout=None): @@ -92,9 +93,10 @@ class Lock(object): >>> # do things with protected resource >>> lock.release() # doctest: +SKIP """ + def __init__(self, name=None, client=None): self.client = client or _get_global_client() or get_worker().client - self.name = name or 'lock-' + uuid.uuid4().hex + self.name = name or "lock-" + uuid.uuid4().hex self.id = uuid.uuid4().hex self._locked = False @@ -121,12 +123,15 @@ def acquire(self, blocking=True, timeout=None): """ if not blocking: if timeout is not None: - raise ValueError( - "can't specify a timeout for a non-blocking call") + raise ValueError("can't specify a timeout for a non-blocking call") timeout = 0 - result = self.client.sync(self.client.scheduler.lock_acquire, - name=self.name, id=self.id, timeout=timeout) + result = self.client.sync( + self.client.scheduler.lock_acquire, + name=self.name, + id=self.id, + timeout=timeout, + ) self._locked = True return result @@ -134,8 +139,9 @@ def release(self): """ Release the lock if already acquired """ if not self.locked(): raise ValueError("Lock is not yet acquired") - result = self.client.sync(self.client.scheduler.lock_release, - name=self.name, id=self.id) + result = self.client.sync( + self.client.scheduler.lock_release, name=self.name, id=self.id + ) self._locked = False return result diff --git a/distributed/locket.py b/distributed/locket.py index 84f4af2ca1f..1ed7b023085 100644 --- a/distributed/locket.py +++ b/distributed/locket.py @@ -19,16 +19,22 @@ import ctypes.wintypes import msvcrt except ImportError: - raise ImportError("Platform not supported (failed to import fcntl, ctypes, msvcrt)") + raise ImportError( + "Platform not supported (failed to import fcntl, ctypes, msvcrt)" + ) else: - _kernel32 = ctypes.WinDLL('kernel32', use_last_error=True) + _kernel32 = ctypes.WinDLL("kernel32", use_last_error=True) _WinAPI_LockFile = _kernel32.LockFile _WinAPI_LockFile.restype = ctypes.wintypes.BOOL - _WinAPI_LockFile.argtypes = [ctypes.wintypes.HANDLE] + [ctypes.wintypes.DWORD] * 4 + _WinAPI_LockFile.argtypes = [ctypes.wintypes.HANDLE] + [ + ctypes.wintypes.DWORD + ] * 4 _WinAPI_UnlockFile = _kernel32.UnlockFile _WinAPI_UnlockFile.restype = ctypes.wintypes.BOOL - _WinAPI_UnlockFile.argtypes = [ctypes.wintypes.HANDLE] + [ctypes.wintypes.DWORD] * 4 + _WinAPI_UnlockFile.argtypes = [ctypes.wintypes.HANDLE] + [ + ctypes.wintypes.DWORD + ] * 4 _lock_file_blocking_available = False @@ -46,8 +52,10 @@ def _lock_file_non_blocking(file_): def _unlock_file(file_): _WinAPI_UnlockFile(msvcrt.get_osfhandle(file_.fileno()), 0, 0, 1, 0) + else: _lock_file_blocking_available = True + def _lock_file_blocking(file_): fcntl.flock(file_.fileno(), fcntl.LOCK_EX) @@ -100,8 +108,7 @@ def _acquire_non_blocking(acquire, timeout, retry_period, path): success = acquire() if success: return - elif (timeout is not None and - time.time() - start_time > timeout): + elif timeout is not None and time.time() - start_time > timeout: raise LockError("Couldn't lock {0}".format(path)) else: time.sleep(retry_period) @@ -179,6 +186,7 @@ class _Locker(object): A lock wrapper to always apply the given *timeout* and *retry_period* to acquire() calls. """ + def __init__(self, lock, timeout=None, retry_period=None): self._lock = lock self._timeout = timeout diff --git a/distributed/metrics.py b/distributed/metrics.py index fb047faec79..6c0bdb4dc7e 100755 --- a/distributed/metrics.py +++ b/distributed/metrics.py @@ -83,7 +83,7 @@ def resync(self): # A high-resolution wall clock timer measuring the seconds since Unix epoch -if sys.platform.startswith('win'): +if sys.platform.startswith("win"): time = _WindowsTime().time else: # Under modern Unices, time.time() should be good enough @@ -97,7 +97,7 @@ def _native_thread_time(): def _linux_thread_time(): # Use hardcoded CLOCK_THREAD_CPUTIME_ID on Python 3 <= 3.6 - if sys.platform != 'linux': + if sys.platform != "linux": raise OSError return timemod.clock_gettime(3) @@ -134,8 +134,7 @@ def _detect_thread_time(): Return a per-thread CPU timer function if possible, otherwise a per-process CPU timer function, or at worse a wall-clock timer. """ - for func in [_native_thread_time, _linux_thread_time, - _native_process_time]: + for func in [_native_thread_time, _linux_thread_time, _native_process_time]: try: func() return func diff --git a/distributed/nanny.py b/distributed/nanny.py index c1965759685..356ebc3168d 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -21,8 +21,13 @@ from .process import AsyncProcess from .proctitle import enable_proctitle_on_children from .security import Security -from .utils import (get_ip, mp_context, silence_logging, json_load_robust, - PeriodicCallback) +from .utils import ( + get_ip, + mp_context, + silence_logging, + json_load_robust, + PeriodicCallback, +) from .worker import _ncores, run, parse_memory_limit, Worker logger = logging.getLogger(__name__) @@ -34,22 +39,43 @@ class Nanny(ServerNode): The nanny spins up Worker processes, watches then, and kills or restarts them as necessary. """ + process = None status = None - def __init__(self, scheduler_ip=None, scheduler_port=None, - scheduler_file=None, worker_port=0, ncores=None, loop=None, - local_dir='dask-worker-space', services=None, name=None, - memory_limit='auto', reconnect=True, validate=False, quiet=False, - resources=None, silence_logs=None, death_timeout=None, preload=(), - preload_argv=[], security=None, contact_address=None, - listen_address=None, worker_class=None, env=None, **worker_kwargs): + def __init__( + self, + scheduler_ip=None, + scheduler_port=None, + scheduler_file=None, + worker_port=0, + ncores=None, + loop=None, + local_dir="dask-worker-space", + services=None, + name=None, + memory_limit="auto", + reconnect=True, + validate=False, + quiet=False, + resources=None, + silence_logs=None, + death_timeout=None, + preload=(), + preload_argv=[], + security=None, + contact_address=None, + listen_address=None, + worker_class=None, + env=None, + **worker_kwargs + ): if scheduler_file: cfg = json_load_robust(scheduler_file) - self.scheduler_addr = cfg['address'] - elif scheduler_ip is None and dask.config.get('scheduler-address'): - self.scheduler_addr = dask.config.get('scheduler-address') + self.scheduler_addr = cfg["address"] + elif scheduler_ip is None and dask.config.get("scheduler-address"): + self.scheduler_addr = dask.config.get("scheduler-address") elif scheduler_port is None: self.scheduler_addr = coerce_to_address(scheduler_ip) else: @@ -67,12 +93,14 @@ def __init__(self, scheduler_ip=None, scheduler_port=None, self.worker_kwargs = worker_kwargs self.contact_address = contact_address - self.memory_terminate_fraction = dask.config.get('distributed.worker.memory.terminate') + self.memory_terminate_fraction = dask.config.get( + "distributed.worker.memory.terminate" + ) self.security = security or Security() assert isinstance(self.security, Security) - self.connection_args = self.security.get_connection_args('worker') - self.listen_args = self.security.get_listen_args('worker') + self.connection_args = self.security.get_connection_args("worker") + self.listen_args = self.security.get_listen_args("worker") self.local_dir = local_dir @@ -89,22 +117,25 @@ def __init__(self, scheduler_ip=None, scheduler_port=None, silence_logging(level=silence_logs) self.silence_logs = silence_logs - handlers = {'instantiate': self.instantiate, - 'kill': self.kill, - 'restart': self.restart, - # cannot call it 'close' on the rpc side for naming conflict - 'terminate': self._close, - 'run': self.run} - - super(Nanny, self).__init__(handlers, io_loop=self.loop, - connection_args=self.connection_args) + handlers = { + "instantiate": self.instantiate, + "kill": self.kill, + "restart": self.restart, + # cannot call it 'close' on the rpc side for naming conflict + "terminate": self._close, + "run": self.run, + } + + super(Nanny, self).__init__( + handlers, io_loop=self.loop, connection_args=self.connection_args + ) if self.memory_limit: pc = PeriodicCallback(self.memory_monitor, 100, io_loop=self.loop) - self.periodic_callbacks['memory'] = pc + self.periodic_callbacks["memory"] = pc self._listen_address = listen_address - self.status = 'init' + self.status = "init" def __repr__(self): return "" % (self.worker_address, self.ncores) @@ -117,11 +148,18 @@ def _unregister(self, timeout=10): if worker_address is None: return - allowed_errors = (gen.TimeoutError, CommClosedError, EnvironmentError, RPCClosed) + allowed_errors = ( + gen.TimeoutError, + CommClosedError, + EnvironmentError, + RPCClosed, + ) try: - yield gen.with_timeout(timedelta(seconds=timeout), - self.scheduler.unregister(address=self.worker_address), - quiet_exceptions=allowed_errors) + yield gen.with_timeout( + timedelta(seconds=timeout), + self.scheduler.unregister(address=self.worker_address), + quiet_exceptions=allowed_errors, + ) except allowed_errors: pass @@ -140,25 +178,24 @@ def _start(self, addr_or_port=0): # XXX Factor this out if not addr_or_port: # Default address is the required one to reach the scheduler - self.listen(get_local_address_for(self.scheduler.address), - listen_args=self.listen_args) + self.listen( + get_local_address_for(self.scheduler.address), + listen_args=self.listen_args, + ) self.ip = get_address_host(self.address) elif isinstance(addr_or_port, int): # addr_or_port is an integer => assume TCP - self.ip = get_ip( - get_address_host(self.scheduler.address) - ) - self.listen((self.ip, addr_or_port), - listen_args=self.listen_args) + self.ip = get_ip(get_address_host(self.scheduler.address)) + self.listen((self.ip, addr_or_port), listen_args=self.listen_args) else: self.listen(addr_or_port, listen_args=self.listen_args) self.ip = get_address_host(self.address) - logger.info(' Start Nanny at: %r', self.address) + logger.info(" Start Nanny at: %r", self.address) response = yield self.instantiate() - if response == 'running': + if response == "running": assert self.worker_address - self.status = 'running' + self.status = "running" else: yield self._close() @@ -181,7 +218,7 @@ def kill(self, comm=None, timeout=2): """ self.auto_restart = False if self.process is None: - raise gen.Return('OK') + raise gen.Return("OK") deadline = self.loop.time() + timeout yield self.process.kill(timeout=0.8 * (deadline - self.loop.time())) @@ -197,26 +234,29 @@ def instantiate(self, comm=None): start_arg = self._listen_address else: host = self.listener.bound_address[0] - start_arg = self.listener.prefix + unparse_host_port(host, - self._given_worker_port) + start_arg = self.listener.prefix + unparse_host_port( + host, self._given_worker_port + ) if self.process is None: - worker_kwargs = dict(scheduler_ip=self.scheduler_addr, - ncores=self.ncores, - local_dir=self.local_dir, - services=self.services, - service_ports={'nanny': self.port}, - name=self.name, - memory_limit=self.memory_limit, - reconnect=self.reconnect, - resources=self.resources, - validate=self.validate, - silence_logs=self.silence_logs, - death_timeout=self.death_timeout, - preload=self.preload, - preload_argv=self.preload_argv, - security=self.security, - contact_address=self.contact_address) + worker_kwargs = dict( + scheduler_ip=self.scheduler_addr, + ncores=self.ncores, + local_dir=self.local_dir, + services=self.services, + service_ports={"nanny": self.port}, + name=self.name, + memory_limit=self.memory_limit, + reconnect=self.reconnect, + resources=self.resources, + validate=self.validate, + silence_logs=self.silence_logs, + death_timeout=self.death_timeout, + preload=self.preload, + preload_argv=self.preload_argv, + security=self.security, + contact_address=self.contact_address, + ) worker_kwargs.update(self.worker_kwargs) self.process = WorkerProcess( worker_args=tuple(), @@ -232,12 +272,11 @@ def instantiate(self, comm=None): if self.death_timeout: try: result = yield gen.with_timeout( - timedelta(seconds=self.death_timeout), - self.process.start() + timedelta(seconds=self.death_timeout), self.process.start() ) except gen.TimeoutError: yield self._close(timeout=self.death_timeout) - raise gen.Return('timed out') + raise gen.Return("timed out") else: result = yield self.process.start() raise gen.Return(result) @@ -256,13 +295,13 @@ def _(): yield gen.with_timeout(timedelta(seconds=timeout), _()) except gen.TimeoutError: logger.error("Restart timed out, returning before finished") - raise gen.Return('timed out') + raise gen.Return("timed out") else: - raise gen.Return('OK') + raise gen.Return("OK") def memory_monitor(self): """ Track worker's memory. Restart if it goes above terminate fraction """ - if self.status != 'running': + if self.status != "running": return process = self.process.process if process is None: @@ -274,19 +313,21 @@ def memory_monitor(self): memory = proc.memory_info().rss frac = memory / self.memory_limit if self.memory_terminate_fraction and frac > self.memory_terminate_fraction: - logger.warning("Worker exceeded %d%% memory budget. Restarting", - 100 * self.memory_terminate_fraction) + logger.warning( + "Worker exceeded %d%% memory budget. Restarting", + 100 * self.memory_terminate_fraction, + ) process.terminate() def is_alive(self): - return self.process is not None and self.process.status == 'running' + return self.process is not None and self.process.status == "running" def run(self, *args, **kwargs): return run(self, *args, **kwargs) @gen.coroutine def _on_exit(self, exitcode): - if self.status not in ('closing', 'closed'): + if self.status not in ("closing", "closed"): try: yield self.scheduler.unregister(address=self.worker_address) except (EnvironmentError, CommClosedError): @@ -295,13 +336,14 @@ def _on_exit(self, exitcode): return try: - if self.status not in ('closing', 'closed'): + if self.status not in ("closing", "closed"): if self.auto_restart: logger.warning("Restarting worker") yield self.instantiate() except Exception: - logger.error("Failed to restart worker after its process exited", - exc_info=True) + logger.error( + "Failed to restart worker after its process exited", exc_info=True + ) @property def pid(self): @@ -312,9 +354,9 @@ def _close(self, comm=None, timeout=5, report=None): """ Close the worker process, stop all comms. """ - if self.status in ('closing', 'closed'): - raise gen.Return('OK') - self.status = 'closing' + if self.status in ("closing", "closed"): + raise gen.Return("OK") + self.status = "closing" logger.info("Closing Nanny at %r", self.address) self.stop() try: @@ -325,15 +367,22 @@ def _close(self, comm=None, timeout=5, report=None): self.process = None self.rpc.close() self.scheduler.close_rpc() - self.status = 'closed' - raise gen.Return('OK') + self.status = "closed" + raise gen.Return("OK") class WorkerProcess(object): - - def __init__(self, worker_args, worker_kwargs, worker_start_args, - silence_logs, on_exit, worker, env): - self.status = 'init' + def __init__( + self, + worker_args, + worker_kwargs, + worker_start_args, + silence_logs, + on_exit, + worker, + env, + ): + self.status = "init" self.silence_logs = silence_logs self.worker_args = worker_args self.worker_kwargs = worker_kwargs @@ -353,9 +402,9 @@ def start(self): Ensure the worker process is started. """ enable_proctitle_on_children() - if self.status == 'running': + if self.status == "running": raise gen.Return(self.status) - if self.status == 'starting': + if self.status == "starting": yield self.running.wait() raise gen.Return(self.status) @@ -365,29 +414,31 @@ def start(self): self.process = AsyncProcess( target=self._run, - kwargs=dict(worker_args=self.worker_args, - worker_kwargs=self.worker_kwargs, - worker_start_args=self.worker_start_args, - silence_logs=self.silence_logs, - init_result_q=self.init_result_q, - child_stop_q=self.child_stop_q, - uid=uid, - Worker=self.Worker, - env=self.env), + kwargs=dict( + worker_args=self.worker_args, + worker_kwargs=self.worker_kwargs, + worker_start_args=self.worker_start_args, + silence_logs=self.silence_logs, + init_result_q=self.init_result_q, + child_stop_q=self.child_stop_q, + uid=uid, + Worker=self.Worker, + env=self.env, + ), ) self.process.daemon = True self.process.set_exit_callback(self._on_exit) self.running = Event() self.stopped = Event() - self.status = 'starting' + self.status = "starting" yield self.process.start() msg = yield self._wait_until_connected(uid) if not msg: raise gen.Return(self.status) - self.worker_address = msg['address'] - self.worker_dir = msg['dir'] + self.worker_address = msg["address"] + self.worker_dir = msg["dir"] assert self.worker_address - self.status = 'running' + self.status = "running" self.running.set() init_q.close() @@ -405,27 +456,25 @@ def _death_message(self, pid, exitcode): if exitcode == 255: return "Worker process %d was killed by unknown signal" % (pid,) elif exitcode >= 0: - return "Worker process %d exited with status %d" % (pid, exitcode,) + return "Worker process %d exited with status %d" % (pid, exitcode) else: - return "Worker process %d was killed by signal %d" % (pid, -exitcode,) + return "Worker process %d was killed by signal %d" % (pid, -exitcode) def is_alive(self): return self.process is not None and self.process.is_alive() @property def pid(self): - return (self.process.pid - if self.process and self.process.is_alive() - else None) + return self.process.pid if self.process and self.process.is_alive() else None def mark_stopped(self): - if self.status != 'stopped': + if self.status != "stopped": r = self.process.exitcode assert r is not None if r != 0: msg = self._death_message(self.process.pid, r) logger.warning(msg) - self.status = 'stopped' + self.status = "stopped" self.stopped.set() # Release resources self.process.close() @@ -449,28 +498,31 @@ def kill(self, timeout=2, executor_wait=True): loop = IOLoop.current() deadline = loop.time() + timeout - if self.status == 'stopped': + if self.status == "stopped": return - if self.status == 'stopping': + if self.status == "stopping": yield self.stopped.wait() return - assert self.status in ('starting', 'running') - self.status = 'stopping' + assert self.status in ("starting", "running") + self.status = "stopping" process = self.process - self.child_stop_q.put({ - 'op': 'stop', - 'timeout': max(0, deadline - loop.time()) * 0.8, - 'executor_wait': executor_wait, - }) + self.child_stop_q.put( + { + "op": "stop", + "timeout": max(0, deadline - loop.time()) * 0.8, + "executor_wait": executor_wait, + } + ) self.child_stop_q.close() while process.is_alive() and loop.time() < deadline: yield gen.sleep(0.05) if process.is_alive(): - logger.warning("Worker process still alive after %d seconds, killing", - timeout) + logger.warning( + "Worker process still alive after %d seconds, killing", timeout + ) try: yield process.terminate() except Exception as e: @@ -480,7 +532,7 @@ def kill(self, timeout=2, executor_wait=True): def _wait_until_connected(self, uid): delay = 0.05 while True: - if self.status != 'starting': + if self.status != "starting": return try: msg = self.init_result_q.get_nowait() @@ -488,24 +540,35 @@ def _wait_until_connected(self, uid): yield gen.sleep(delay) continue - if msg['uid'] != uid: # ensure that we didn't cross queues + if msg["uid"] != uid: # ensure that we didn't cross queues continue - if 'exception' in msg: - logger.error("Failed while trying to start worker process: %s", - msg['exception']) + if "exception" in msg: + logger.error( + "Failed while trying to start worker process: %s", msg["exception"] + ) yield self.process.join() raise msg else: raise gen.Return(msg) @classmethod - def _run(cls, worker_args, worker_kwargs, worker_start_args, - silence_logs, init_result_q, child_stop_q, uid, env, Worker): # pragma: no cover + def _run( + cls, + worker_args, + worker_kwargs, + worker_start_args, + silence_logs, + init_result_q, + child_stop_q, + uid, + env, + Worker, + ): # pragma: no cover os.environ.update(env) try: from dask.multiprocessing import initialize_worker_process - except ImportError: # old Dask version + except ImportError: # old Dask version pass else: initialize_worker_process() @@ -521,10 +584,12 @@ def _run(cls, worker_args, worker_kwargs, worker_start_args, @gen.coroutine def do_stop(timeout=5, executor_wait=True): try: - yield worker._close(report=False, - nanny=False, - executor_wait=executor_wait, - timeout=timeout) + yield worker._close( + report=False, + nanny=False, + executor_wait=executor_wait, + timeout=timeout, + ) finally: loop.stop() @@ -540,7 +605,7 @@ def watch_stop_q(): pass else: child_stop_q.close() - assert msg.pop('op') == 'stop' + assert msg.pop("op") == "stop" loop.add_callback(do_stop, **msg) break @@ -557,13 +622,13 @@ def run(): yield worker._start(*worker_start_args) except Exception as e: logger.exception("Failed to start worker") - init_result_q.put({'uid': uid, 'exception': e}) + init_result_q.put({"uid": uid, "exception": e}) init_result_q.close() else: assert worker.address - init_result_q.put({'address': worker.address, - 'dir': worker.local_dir, - 'uid': uid}) + init_result_q.put( + {"address": worker.address, "dir": worker.local_dir, "uid": uid} + ) init_result_q.close() yield worker.wait_until_closed() logger.info("Worker closed") diff --git a/distributed/node.py b/distributed/node.py index 4123617620b..8a0b8c12195 100644 --- a/distributed/node.py +++ b/distributed/node.py @@ -11,41 +11,65 @@ class Node(object): Base class for nodes in a distributed cluster. """ - def __init__(self, connection_limit=512, deserialize=True, - connection_args=None, io_loop=None, - serializers=None, deserializers=None): + def __init__( + self, + connection_limit=512, + deserialize=True, + connection_args=None, + io_loop=None, + serializers=None, + deserializers=None, + ): self.io_loop = io_loop or IOLoop.current() - self.rpc = ConnectionPool(limit=connection_limit, - deserialize=deserialize, - serializers=serializers, - deserializers=deserializers, - connection_args=connection_args) + self.rpc = ConnectionPool( + limit=connection_limit, + deserialize=deserialize, + serializers=serializers, + deserializers=deserializers, + connection_args=connection_args, + ) class ServerNode(Node, Server): """ Base class for server nodes in a distributed cluster. """ + # TODO factor out security, listening, services, etc. here # XXX avoid inheriting from Server? there is some large potential for confusion # between base and derived attribute namespaces... - def __init__(self, handlers=None, blocked_handlers=None, stream_handlers=None, - connection_limit=512, deserialize=True, - connection_args=None, io_loop=None, serializers=None, - deserializers=None): - Node.__init__(self, deserialize=deserialize, - connection_limit=connection_limit, - connection_args=connection_args, - io_loop=io_loop, - serializers=serializers, - deserializers=deserializers) - Server.__init__(self, handlers=handlers, - blocked_handlers=blocked_handlers, - stream_handlers=stream_handlers, - connection_limit=connection_limit, - deserialize=deserialize, io_loop=self.io_loop) + def __init__( + self, + handlers=None, + blocked_handlers=None, + stream_handlers=None, + connection_limit=512, + deserialize=True, + connection_args=None, + io_loop=None, + serializers=None, + deserializers=None, + ): + Node.__init__( + self, + deserialize=deserialize, + connection_limit=connection_limit, + connection_args=connection_args, + io_loop=io_loop, + serializers=serializers, + deserializers=deserializers, + ) + Server.__init__( + self, + handlers=handlers, + blocked_handlers=blocked_handlers, + stream_handlers=stream_handlers, + connection_limit=connection_limit, + deserialize=deserialize, + io_loop=self.io_loop, + ) def versions(self, comm=None, packages=None): return get_versions(packages=packages) diff --git a/distributed/preloading.py b/distributed/preloading.py index 00fa4eaeae2..0f08f60f71c 100644 --- a/distributed/preloading.py +++ b/distributed/preloading.py @@ -25,25 +25,28 @@ def validate_preload_argv(ctx, param, value): for a in unexpected_args: raise click.NoSuchOption(a) raise click.UsageError( - "Got unexpected extra argument%s: (%s)" % - ("s" if len(value) > 1 else "", " ".join(value)) + "Got unexpected extra argument%s: (%s)" + % ("s" if len(value) > 1 else "", " ".join(value)) ) preload_modules = _import_modules(ctx.params.get("preload")) preload_commands = [ - m["dask_setup"] for m in preload_modules.values() + m["dask_setup"] + for m in preload_modules.values() if isinstance(m["dask_setup"], click.Command) ] if len(preload_commands) > 1: raise click.UsageError( - "Multiple --preload modules with click-configurable setup: %s" % - list(preload_modules.keys())) + "Multiple --preload modules with click-configurable setup: %s" + % list(preload_modules.keys()) + ) if value and not preload_commands: raise click.UsageError( - "Unknown argument specified: %r Was click-configurable --preload target provided?") + "Unknown argument specified: %r Was click-configurable --preload target provided?" + ) if not preload_commands: return value else: @@ -98,7 +101,7 @@ def _import_modules(names, file_dir=None): module = sys.modules[name] result_modules[name] = { - attrname : getattr(module, attrname, None) + attrname: getattr(module, attrname, None) for attrname in ("dask_setup", "dask_teardown") } @@ -128,7 +131,9 @@ def preload_modules(names, parameter=None, file_dir=None, argv=None): if dask_setup: if isinstance(dask_setup, click.Command): - context = dask_setup.make_context("dask_setup", list(argv), allow_extra_args=False) + context = dask_setup.make_context( + "dask_setup", list(argv), allow_extra_args=False + ) dask_setup.callback(parameter, *context.args, **context.params) else: dask_setup(parameter) diff --git a/distributed/process.py b/distributed/process.py index 38e3af62c3b..5dd9368fdc1 100644 --- a/distributed/process.py +++ b/distributed/process.py @@ -55,8 +55,7 @@ class AsyncProcess(object): def __init__(self, loop=None, target=None, name=None, args=(), kwargs={}): if not callable(target): - raise TypeError("`target` needs to be callable, not %r" - % (type(target),)) + raise TypeError("`target` needs to be callable, not %r" % (type(target),)) self._state = _ProcessState() self._loop = loop or IOLoop.current(instance=False) @@ -71,10 +70,11 @@ def __init__(self, loop=None, target=None, name=None, args=(), kwargs={}): # for the assignment here. parent_alive_pipe, self._keep_child_alive = mp_context.Pipe(duplex=False) - self._process = mp_context.Process(target=self._run, name=name, - args=(target, args, kwargs, - parent_alive_pipe, - self._keep_child_alive)) + self._process = mp_context.Process( + target=self._run, + name=name, + args=(target, args, kwargs, parent_alive_pipe, self._keep_child_alive), + ) _dangling.add(self._process) self._name = self._process.name self._watch_q = PyQueue() @@ -95,13 +95,20 @@ def _start_threads(self): self._watch_message_thread = threading.Thread( target=self._watch_message_queue, name="AsyncProcess %s watch message queue" % self.name, - args=(weakref.ref(self), self._process, self._loop, - self._state, self._watch_q, self._exit_future,)) + args=( + weakref.ref(self), + self._process, + self._loop, + self._state, + self._watch_q, + self._exit_future, + ), + ) self._watch_message_thread.daemon = True self._watch_message_thread.start() def stop_thread(q): - q.put_nowait({'op': 'stop'}) + q.put_nowait({"op": "stop"}) # We don't join the thread here as a finalizer can be called # asynchronously from anywhere @@ -120,6 +127,7 @@ def _immediate_exit_when_closed(cls, parent_alive_pipe): """ Immediately exit the process when parent_alive_pipe is closed. """ + def monitor_parent(): try: # The parent_alive_pipe should be held open as long as the @@ -186,7 +194,8 @@ def _start(): thread = threading.Thread( target=AsyncProcess._watch_process, name="AsyncProcess %s watch process join" % name, - args=(selfref, process, state, q)) + args=(selfref, process, state, q), + ) thread.daemon = True thread.start() @@ -197,12 +206,12 @@ def _start(): while True: msg = q.get() logger.debug("[%s] got message %r" % (r, msg)) - op = msg['op'] - if op == 'start': - _call_and_set_future(loop, msg['future'], _start) - elif op == 'terminate': - _call_and_set_future(loop, msg['future'], process.terminate) - elif op == 'stop': + op = msg["op"] + if op == "start": + _call_and_set_future(loop, msg["future"], _start) + elif op == "terminate": + _call_and_set_future(loop, msg["future"], process.terminate) + elif op == "stop": break else: assert 0, msg @@ -213,8 +222,7 @@ def _watch_process(cls, selfref, process, state, q): process.join() exitcode = process.exitcode assert exitcode is not None - logger.debug("[%s] process %r exited with code %r", - r, state.pid, exitcode) + logger.debug("[%s] process %r exited with code %r", r, state.pid, exitcode) state.is_alive = False state.exitcode = exitcode # Make sure the process is removed from the global list @@ -235,7 +243,7 @@ def start(self): """ self._check_closed() fut = Future() - self._watch_q.put_nowait({'op': 'start', 'future': fut}) + self._watch_q.put_nowait({"op": "start", "future": fut}) return fut def terminate(self): @@ -246,7 +254,7 @@ def terminate(self): """ self._check_closed() fut = Future() - self._watch_q.put_nowait({'op': 'terminate', 'future': fut}) + self._watch_q.put_nowait({"op": "terminate", "future": fut}) return fut @gen.coroutine @@ -257,7 +265,7 @@ def join(self, timeout=None): This method is a coroutine. """ self._check_closed() - assert self._state.pid is not None, 'can only join a started process' + assert self._state.pid is not None, "can only join a started process" if self._state.exitcode is not None: return if timeout is None: @@ -287,7 +295,9 @@ def set_exit_callback(self, func): """ # XXX should this be a property instead? assert callable(func), "exit callback should be callable" - assert self._state.pid is None, "cannot set exit callback when process already started" + assert ( + self._state.pid is None + ), "cannot set exit callback when process already started" self._exit_callback = func def is_alive(self): diff --git a/distributed/proctitle.py b/distributed/proctitle.py index bdaf8bed5d6..50c9859e17e 100644 --- a/distributed/proctitle.py +++ b/distributed/proctitle.py @@ -16,7 +16,7 @@ def enable_proctitle_on_children(): Enable setting the process title on this process' children and grandchildren. """ - os.environ['DASK_PARENT'] = str(os.getpid()) + os.environ["DASK_PARENT"] = str(os.getpid()) def enable_proctitle_on_current(): @@ -37,7 +37,7 @@ def setproctitle(title): enabled = _enabled if not enabled: try: - enabled = int(os.environ.get('DASK_PARENT', '')) != os.getpid() + enabled = int(os.environ.get("DASK_PARENT", "")) != os.getpid() except ValueError: pass if enabled: diff --git a/distributed/profile.py b/distributed/profile.py index 54e62c288c4..385c7449e75 100644 --- a/distributed/profile.py +++ b/distributed/profile.py @@ -46,30 +46,34 @@ def identifier(frame): Strings are cheaper to use as indexes into dicts than tuples or dicts """ if frame is None: - return 'None' + return "None" else: - return ';'.join((frame.f_code.co_name, - frame.f_code.co_filename, - str(frame.f_code.co_firstlineno))) + return ";".join( + ( + frame.f_code.co_name, + frame.f_code.co_filename, + str(frame.f_code.co_firstlineno), + ) + ) def repr_frame(frame): """ Render a frame as a line for inclusion into a text traceback """ co = frame.f_code - text = ' File "%s", line %s, in %s' % (co.co_filename, - frame.f_lineno, - co.co_name) + text = ' File "%s", line %s, in %s' % (co.co_filename, frame.f_lineno, co.co_name) line = linecache.getline(co.co_filename, frame.f_lineno, frame.f_globals).lstrip() - return text + '\n\t' + line + return text + "\n\t" + line def info_frame(frame): co = frame.f_code line = linecache.getline(co.co_filename, frame.f_lineno, frame.f_globals).lstrip() - return {'filename': co.co_filename, - 'name': co.co_name, - 'line_number': frame.f_lineno, - 'line': line} + return { + "filename": co.co_filename, + "name": co.co_name, + "line_number": frame.f_lineno, + "line": line, + } def process(frame, child, state, stop=None, omit=None): @@ -96,7 +100,9 @@ def process(frame, child, state, stop=None, omit=None): return False prev = frame.f_back - if prev is not None and (stop is None or not prev.f_code.co_filename.endswith(stop)): + if prev is not None and ( + stop is None or not prev.f_code.co_filename.endswith(stop) + ): state = process(prev, frame, state, stop=stop) if state is False: return False @@ -104,45 +110,53 @@ def process(frame, child, state, stop=None, omit=None): ident = identifier(frame) try: - d = state['children'][ident] + d = state["children"][ident] except KeyError: - d = {'count': 0, - 'description': info_frame(frame), - 'children': {}, - 'identifier': ident} - state['children'][ident] = d + d = { + "count": 0, + "description": info_frame(frame), + "children": {}, + "identifier": ident, + } + state["children"][ident] = d - state['count'] += 1 + state["count"] += 1 if child is not None: return d else: - d['count'] += 1 + d["count"] += 1 def merge(*args): """ Merge multiple frame states together """ if not args: return create() - s = {arg['identifier'] for arg in args} + s = {arg["identifier"] for arg in args} if len(s) != 1: raise ValueError("Expected identifiers, got %s" % str(s)) children = defaultdict(list) for arg in args: - for child in arg['children']: - children[child].append(arg['children'][child]) + for child in arg["children"]: + children[child].append(arg["children"][child]) children = {k: merge(*v) for k, v in children.items()} - count = sum(arg['count'] for arg in args) - return {'description': args[0]['description'], - 'children': dict(children), - 'count': count, - 'identifier': args[0]['identifier']} + count = sum(arg["count"] for arg in args) + return { + "description": args[0]["description"], + "children": dict(children), + "count": count, + "identifier": args[0]["identifier"], + } def create(): - return {'count': 0, 'children': {}, 'identifier': 'root', 'description': - {'filename': '', 'name': '', 'line_number': 0, 'line': ''}} + return { + "count": 0, + "children": {}, + "identifier": "root", + "description": {"filename": "", "name": "", "line_number": 0, "line": ""}, + } def call_stack(frame): @@ -180,7 +194,7 @@ def plot_data(state, profile_interval=0.010): names = [] def traverse(state, start, stop, height): - if not state['count']: + if not state["count"]: return starts.append(start) stops.append(stop) @@ -188,49 +202,50 @@ def traverse(state, start, stop, height): width = stop - start widths.append(width) states.append(state) - times.append(format_time(state['count'] * profile_interval)) + times.append(format_time(state["count"] * profile_interval)) - desc = state['description'] - filenames.append(desc['filename']) - lines.append(desc['line']) - line_numbers.append(desc['line_number']) - names.append(desc['name']) + desc = state["description"] + filenames.append(desc["filename"]) + lines.append(desc["line"]) + line_numbers.append(desc["line_number"]) + names.append(desc["name"]) - ident = state['identifier'] + ident = state["identifier"] try: - colors.append(color_of(desc['filename'])) + colors.append(color_of(desc["filename"])) except IndexError: - colors.append('gray') + colors.append("gray") - delta = (stop - start) / state['count'] + delta = (stop - start) / state["count"] x = start - for name, child in state['children'].items(): - width = child['count'] * delta + for name, child in state["children"].items(): + width = child["count"] * delta traverse(child, x, x + width, height + 1) x += width traverse(state, 0, 1, 0) percentages = ["{:.2f}%".format(100 * w) for w in widths] - return {'left': starts, - 'right': stops, - 'bottom': heights, - 'width': widths, - 'top': [x + 1 for x in heights], - 'color': colors, - 'states': states, - 'filename': filenames, - 'line': lines, - 'line_number': line_numbers, - 'name': names, - 'time': times, - 'percentage': percentages} - - -def _watch(thread_id, log, interval='20ms', cycle='2s', omit=None, - stop=lambda: False): + return { + "left": starts, + "right": stops, + "bottom": heights, + "width": widths, + "top": [x + 1 for x in heights], + "color": colors, + "states": states, + "filename": filenames, + "line": lines, + "line_number": line_numbers, + "name": names, + "time": times, + "percentage": percentages, + } + + +def _watch(thread_id, log, interval="20ms", cycle="2s", omit=None, stop=lambda: False): interval = parse_timedelta(interval) cycle = parse_timedelta(cycle) @@ -251,21 +266,31 @@ def _watch(thread_id, log, interval='20ms', cycle='2s', omit=None, sleep(interval) -def watch(thread_id=None, interval='20ms', cycle='2s', maxlen=1000, omit=None, - stop=lambda: False): +def watch( + thread_id=None, + interval="20ms", + cycle="2s", + maxlen=1000, + omit=None, + stop=lambda: False, +): if thread_id is None: thread_id = get_thread_identity() log = deque(maxlen=maxlen) - thread = threading.Thread(target=_watch, - name='Profile', - kwargs={'thread_id': thread_id, - 'interval': interval, - 'cycle': cycle, - 'log': log, - 'omit': omit, - 'stop': stop}) + thread = threading.Thread( + target=_watch, + name="Profile", + kwargs={ + "thread_id": thread_id, + "interval": interval, + "cycle": cycle, + "log": log, + "omit": omit, + "stop": stop, + }, + ) thread.daemon = True thread.start() @@ -307,14 +332,22 @@ def plot_figure(data, **kwargs): from bokeh.plotting import ColumnDataSource, figure from bokeh.models import HoverTool - if 'states' in data: - data = toolz.dissoc(data, 'states') + if "states" in data: + data = toolz.dissoc(data, "states") source = ColumnDataSource(data=data) - fig = figure(tools='tap', **kwargs) - r = fig.quad('left', 'right', 'top', 'bottom', color='color', - line_color='black', line_width=2, source=source) + fig = figure(tools="tap", **kwargs) + r = fig.quad( + "left", + "right", + "top", + "bottom", + color="color", + line_color="black", + line_width=2, + source=source, + ) r.selection_glyph = None r.nonselection_glyph = None @@ -346,7 +379,7 @@ def plot_figure(data, **kwargs): Percentage:  @width - """ + """, ) fig.add_tools(hover) diff --git a/distributed/protocol/__init__.py b/distributed/protocol/__init__.py index bd8f7331c8e..cf1a3df8994 100644 --- a/distributed/protocol/__init__.py +++ b/distributed/protocol/__init__.py @@ -3,12 +3,22 @@ from functools import partial from .compression import compressions, default_compression -from .core import (dumps, loads, maybe_compress, decompress, msgpack) +from .core import dumps, loads, maybe_compress, decompress, msgpack from .serialize import ( - serialize, deserialize, nested_deserialize, Serialize, Serialized, - to_serialize, register_serialization, dask_serialize, dask_deserialize, - serialize_bytes, deserialize_bytes, serialize_bytelist, - register_serialization_family, register_generic, + serialize, + deserialize, + nested_deserialize, + Serialize, + Serialized, + to_serialize, + register_serialization, + dask_serialize, + dask_deserialize, + serialize_bytes, + deserialize_bytes, + serialize_bytelist, + register_serialization_family, + register_generic, ) from ..utils import ignoring @@ -54,6 +64,7 @@ def _register_arrow(): @dask_deserialize.register_lazy("sklearn") def _register_sklearn(): import sklearn.base + register_generic(sklearn.base.BaseEstimator) diff --git a/distributed/protocol/arrow.py b/distributed/protocol/arrow.py index 012a91e6afc..cac146a575c 100644 --- a/distributed/protocol/arrow.py +++ b/distributed/protocol/arrow.py @@ -3,9 +3,12 @@ from .serialize import dask_serialize, dask_deserialize import pyarrow -if pyarrow.__version__ < '0.10': - raise ImportError("Need pyarrow >= 0.10 . " - "See https://arrow.apache.org/docs/python/install.html") + +if pyarrow.__version__ < "0.10": + raise ImportError( + "Need pyarrow >= 0.10 . " + "See https://arrow.apache.org/docs/python/install.html" + ) @dask_serialize.register(pyarrow.RecordBatch) diff --git a/distributed/protocol/compression.py b/distributed/protocol/compression.py index 2a6de6bfca1..f729748acc8 100644 --- a/distributed/protocol/compression.py +++ b/distributed/protocol/compression.py @@ -13,8 +13,9 @@ try: import blosc + n = blosc.set_nthreads(2) - if hasattr('blosc', 'releasegil'): + if hasattr("blosc", "releasegil"): blosc.set_releasegil(True) except ImportError: blosc = False @@ -22,8 +23,7 @@ from ..utils import ignoring, ensure_bytes -compressions = {None: {'compress': identity, - 'decompress': identity}} +compressions = {None: {"compress": identity, "decompress": identity}} compressions[False] = compressions[None] # alias @@ -36,8 +36,8 @@ with ignoring(ImportError): import zlib - compressions['zlib'] = {'compress': zlib.compress, - 'decompress': zlib.decompress} + + compressions["zlib"] = {"compress": zlib.compress, "decompress": zlib.decompress} with ignoring(ImportError): import snappy @@ -48,9 +48,11 @@ def _fixed_snappy_decompress(data): data = bytes(data) return snappy.decompress(data) - compressions['snappy'] = {'compress': snappy.compress, - 'decompress': _fixed_snappy_decompress} - default_compression = 'snappy' + compressions["snappy"] = { + "compress": snappy.compress, + "decompress": _fixed_snappy_decompress, + } + default_compression = "snappy" with ignoring(ImportError): import lz4 @@ -58,6 +60,7 @@ def _fixed_snappy_decompress(data): try: # try using the new lz4 API import lz4.block + lz4_compress = lz4.block.compress lz4_decompress = lz4.block.decompress except ImportError: @@ -86,25 +89,31 @@ def _fixed_lz4_decompress(data): else: raise - compressions['lz4'] = {'compress': _fixed_lz4_compress, - 'decompress': _fixed_lz4_decompress} - default_compression = 'lz4' + compressions["lz4"] = { + "compress": _fixed_lz4_compress, + "decompress": _fixed_lz4_decompress, + } + default_compression = "lz4" with ignoring(ImportError): import blosc - compressions['blosc'] = {'compress': partial(blosc.compress, clevel=5, - cname='lz4'), - 'decompress': blosc.decompress} + + compressions["blosc"] = { + "compress": partial(blosc.compress, clevel=5, cname="lz4"), + "decompress": blosc.decompress, + } -default = dask.config.get('distributed.comm.compression') -if default != 'auto': +default = dask.config.get("distributed.comm.compression") +if default != "auto": if default in compressions: default_compression = default else: - raise ValueError("Default compression '%s' not found.\n" - "Choices include auto, %s" % ( - default, ', '.join(sorted(map(str, compressions))))) + raise ValueError( + "Default compression '%s' not found.\n" + "Choices include auto, %s" + % (default, ", ".join(sorted(map(str, compressions)))) + ) def byte_sample(b, size, n): @@ -125,7 +134,7 @@ def byte_sample(b, size, n): ends.append(starts[-1] + size) parts = [b[start:end] for start, end in zip(starts, ends)] - return b''.join(map(ensure_bytes, parts)) + return b"".join(map(ensure_bytes, parts)) def maybe_compress(payload, min_size=1e4, sample_size=1e4, nsamples=5): @@ -139,21 +148,21 @@ def maybe_compress(payload, min_size=1e4, sample_size=1e4, nsamples=5): return the original 4. We return the compressed result """ - compression = dask.config.get('distributed.comm.compression') - if compression == 'auto': + compression = dask.config.get("distributed.comm.compression") + if compression == "auto": compression = default_compression if not compression: return None, payload if len(payload) < min_size: return None, payload - if len(payload) > 2**31: # Too large, compression libraries often fail + if len(payload) > 2 ** 31: # Too large, compression libraries often fail return None, payload min_size = int(min_size) sample_size = int(sample_size) - compress = compressions[compression]['compress'] + compress = compressions[compression]["compress"] # Compress a sample, return original if not very compressed sample = byte_sample(payload, sample_size, nsamples) @@ -167,9 +176,10 @@ def maybe_compress(payload, min_size=1e4, sample_size=1e4, nsamples=5): if default_compression and blosc and type(payload) is memoryview: # Blosc does itemsize-aware shuffling, resulting in better compression - compressed = blosc.compress(payload, typesize=payload.itemsize, - cname='lz4', clevel=5) - compression = 'blosc' + compressed = blosc.compress( + payload, typesize=payload.itemsize, cname="lz4", clevel=5 + ) + compression = "blosc" else: compressed = compress(ensure_bytes(payload)) @@ -181,5 +191,7 @@ def maybe_compress(payload, min_size=1e4, sample_size=1e4, nsamples=5): def decompress(header, frames): """ Decompress frames according to information in the header """ - return [compressions[c]['decompress'](frame) - for c, frame in zip(header['compression'], frames)] + return [ + compressions[c]["decompress"](frame) + for c, frame in zip(header["compression"], frames) + ] diff --git a/distributed/protocol/core.py b/distributed/protocol/core.py index 8d9e1b2b127..0b5f7eb0fea 100644 --- a/distributed/protocol/core.py +++ b/distributed/protocol/core.py @@ -11,8 +11,14 @@ from toolz import reduce from .compression import compressions, maybe_compress, decompress -from .serialize import (serialize, deserialize, Serialize, Serialized, - extract_serialize, msgpack_len_opts) +from .serialize import ( + serialize, + deserialize, + Serialize, + Serialized, + extract_serialize, + msgpack_len_opts, +) from .utils import frame_split_size, merge_frames from ..utils import nbytes @@ -20,18 +26,18 @@ try: - msgpack.loads(msgpack.dumps(''), raw=False, **msgpack_len_opts) - msgpack_opts = {'raw': False} + msgpack.loads(msgpack.dumps(""), raw=False, **msgpack_len_opts) + msgpack_opts = {"raw": False} msgpack_opts.update(msgpack_len_opts) except TypeError: # Backward compat with old msgpack (prior to 0.5.2) - msgpack_opts = {'encoding': 'utf-8'} + msgpack_opts = {"encoding": "utf-8"} logger = logging.getLogger(__name__) -def dumps(msg, serializers=None, on_error='message', context=None): +def dumps(msg, serializers=None, on_error="message", context=None): """ Transform Python message to bytestream suitable for communication """ try: data = {} @@ -43,56 +49,60 @@ def dumps(msg, serializers=None, on_error='message', context=None): if not data: # fast path without serialized data return small_header, small_payload - pre = {key: (value.header, value.frames) - for key, value in data.items() - if type(value) is Serialized} + pre = { + key: (value.header, value.frames) + for key, value in data.items() + if type(value) is Serialized + } - data = {key: serialize(value.data, - serializers=serializers, - on_error=on_error, - context=context) - for key, value in data.items() - if type(value) is Serialize} + data = { + key: serialize( + value.data, serializers=serializers, on_error=on_error, context=context + ) + for key, value in data.items() + if type(value) is Serialize + } - header = {'headers': {}, - 'keys': [], - 'bytestrings': list(bytestrings)} + header = {"headers": {}, "keys": [], "bytestrings": list(bytestrings)} out_frames = [] for key, (head, frames) in data.items(): - if 'lengths' not in head: - head['lengths'] = tuple(map(nbytes, frames)) - if 'compression' not in head: + if "lengths" not in head: + head["lengths"] = tuple(map(nbytes, frames)) + if "compression" not in head: frames = frame_split_size(frames) if frames: compression, frames = zip(*map(maybe_compress, frames)) else: compression = [] - head['compression'] = compression - head['count'] = len(frames) - header['headers'][key] = head - header['keys'].append(key) + head["compression"] = compression + head["count"] = len(frames) + header["headers"][key] = head + header["keys"].append(key) out_frames.extend(frames) for key, (head, frames) in pre.items(): - if 'lengths' not in head: - head['lengths'] = tuple(map(nbytes, frames)) - head['count'] = len(frames) - header['headers'][key] = head - header['keys'].append(key) + if "lengths" not in head: + head["lengths"] = tuple(map(nbytes, frames)) + head["count"] = len(frames) + header["headers"][key] = head + header["keys"].append(key) out_frames.extend(frames) for i, frame in enumerate(out_frames): if type(frame) is memoryview and frame.strides != (1,): try: - frame = frame.cast('b') + frame = frame.cast("b") except TypeError: frame = frame.tobytes() out_frames[i] = frame - return [small_header, small_payload, - msgpack.dumps(header, use_bin_type=True)] + out_frames + return [ + small_header, + small_payload, + msgpack.dumps(header, use_bin_type=True), + ] + out_frames except Exception: logger.critical("Failed to Serialize", exc_info=True) raise @@ -112,13 +122,13 @@ def loads(frames, deserialize=True, deserializers=None): header = frames.pop() header = msgpack.loads(header, use_list=False, **msgpack_opts) - keys = header['keys'] - headers = header['headers'] - bytestrings = set(header['bytestrings']) + keys = header["keys"] + headers = header["headers"] + bytestrings = set(header["bytestrings"]) for key in keys: head = headers[key] - count = head['count'] + count = head["count"] if count: fs = frames[-count::][::-1] del frames[-count:] @@ -126,7 +136,7 @@ def loads(frames, deserialize=True, deserializers=None): fs = [] if deserialize or key in bytestrings: - if 'compression' in head: + if "compression" in head: fs = decompress(head, fs) fs = merge_frames(head, fs) value = _deserialize(head, fs, deserializers=deserializers) @@ -166,12 +176,12 @@ def dumps_msgpack(msg): fmt, payload = maybe_compress(payload) if fmt: - header['compression'] = fmt + header["compression"] = fmt if header: header_bytes = msgpack.dumps(header, use_bin_type=True) else: - header_bytes = b'' + header_bytes = b"" return [header_bytes, payload] @@ -187,12 +197,14 @@ def loads_msgpack(header, payload): else: header = {} - if header.get('compression'): + if header.get("compression"): try: - decompress = compressions[header['compression']]['decompress'] + decompress = compressions[header["compression"]]["decompress"] payload = decompress(payload) except KeyError: - raise ValueError("Data is compressed as %s but we don't have this" - " installed" % str(header['compression'])) + raise ValueError( + "Data is compressed as %s but we don't have this" + " installed" % str(header["compression"]) + ) return msgpack.loads(payload, use_list=False, **msgpack_opts) diff --git a/distributed/protocol/h5py.py b/distributed/protocol/h5py.py index 9936920a759..cf08719e259 100644 --- a/distributed/protocol/h5py.py +++ b/distributed/protocol/h5py.py @@ -7,25 +7,26 @@ @dask_serialize.register(h5py.File) def serialize_h5py_file(f): - if f.mode != 'r': + if f.mode != "r": raise ValueError("Can only serialize read-only h5py files") - return {'filename': f.filename}, [] + return {"filename": f.filename}, [] @dask_deserialize.register(h5py.File) def deserialize_h5py_file(header, frames): import h5py - return h5py.File(header['filename'], mode='r') + + return h5py.File(header["filename"], mode="r") @dask_serialize.register((h5py.Group, h5py.Dataset)) def serialize_h5py_dataset(x): header, _ = serialize_h5py_file(x.file) - header['name'] = x.name + header["name"] = x.name return header, [] @dask_deserialize.register((h5py.Group, h5py.Dataset)) def deserialize_h5py_dataset(header, frames): file = deserialize_h5py_file(header, frames) - return file[header['name']] + return file[header["name"]] diff --git a/distributed/protocol/keras.py b/distributed/protocol/keras.py index a5437f60e18..4c6fc4b4d0a 100644 --- a/distributed/protocol/keras.py +++ b/distributed/protocol/keras.py @@ -8,15 +8,17 @@ @dask_serialize.register(keras.Model) def serialize_keras_model(model): import keras - if keras.__version__ < '1.2.0': - raise ImportError("Need Keras >= 1.2.0. " - "Try pip install keras --upgrade --no-deps") + + if keras.__version__ < "1.2.0": + raise ImportError( + "Need Keras >= 1.2.0. " "Try pip install keras --upgrade --no-deps" + ) header = model._updated_config() weights = model.get_weights() headers, frames = list(zip(*map(serialize, weights))) - header['headers'] = headers - header['nframes'] = [len(L) for L in frames] + header["headers"] = headers + header["nframes"] = [len(L) for L in frames] frames = [frame for L in frames for frame in L] return header, frames @@ -24,10 +26,11 @@ def serialize_keras_model(model): @dask_deserialize.register(keras.Model) def deserialize_keras_model(header, frames): from keras.models import model_from_config + n = 0 weights = [] - for head, length in zip(header['headers'], header['nframes']): - x = deserialize(head, frames[n: n + length]) + for head, length in zip(header["headers"], header["nframes"]): + x = deserialize(head, frames[n : n + length]) weights.append(x) n += length model = model_from_config(header) diff --git a/distributed/protocol/netcdf4.py b/distributed/protocol/netcdf4.py index 06711ad03cb..e04864d2b73 100644 --- a/distributed/protocol/netcdf4.py +++ b/distributed/protocol/netcdf4.py @@ -8,29 +8,29 @@ @dask_serialize.register(netCDF4.Dataset) def serialize_netcdf4_dataset(ds): # assume mode is read-only - return {'filename': ds.filepath()}, [] + return {"filename": ds.filepath()}, [] @dask_deserialize.register(netCDF4.Dataset) def deserialize_netcdf4_dataset(header, frames): - return netCDF4.Dataset(header['filename'], mode='r') + return netCDF4.Dataset(header["filename"], mode="r") @dask_serialize.register(netCDF4.Variable) def serialize_netcdf4_variable(x): header, _ = serialize(x.group()) - header['parent-type'] = header['type'] - header['parent-type-serialized'] = header['type-serialized'] - header['name'] = x.name + header["parent-type"] = header["type"] + header["parent-type-serialized"] = header["type-serialized"] + header["name"] = x.name return header, [] @dask_deserialize.register(netCDF4.Variable) def deserialize_netcdf4_variable(header, frames): - header['type'] = header['parent-type'] - header['type-serialized'] = header['parent-type-serialized'] + header["type"] = header["parent-type"] + header["type-serialized"] = header["parent-type-serialized"] parent = deserialize(header, frames) - return parent.variables[header['name']] + return parent.variables[header["name"]] @dask_serialize.register(netCDF4.Group) @@ -39,11 +39,11 @@ def serialize_netcdf4_group(g): while parent.parent: parent = parent.parent header, _ = serialize_netcdf4_dataset(parent) - header['path'] = g.path + header["path"] = g.path return header, [] @dask_deserialize.register(netCDF4.Group) def deserialize_netcdf4_group(header, frames): file = deserialize_netcdf4_dataset(header, frames) - return file[header['path']] + return file[header["path"]] diff --git a/distributed/protocol/numpy.py b/distributed/protocol/numpy.py index 3227a4bbaec..d8da4f204e4 100644 --- a/distributed/protocol/numpy.py +++ b/distributed/protocol/numpy.py @@ -24,13 +24,13 @@ def itemsize(dt): @dask_serialize.register(np.ndarray) def serialize_numpy_ndarray(x): if x.dtype.hasobject: - header = {'pickle': True} + header = {"pickle": True} frames = [pickle.dumps(x)] return header, frames # We cannot blindly pickle the dtype as some may fail pickling, # so we have a mixture of strategies. - if x.dtype.kind == 'V': + if x.dtype.kind == "V": # Preserving all the information works best when pickling try: # Only use stdlib pickle as cloudpickle is slow when failing @@ -53,31 +53,29 @@ def serialize_numpy_ndarray(x): elif x.flags.c_contiguous or x.flags.f_contiguous: # Avoid a copy and respect order when unserializing strides = x.strides - data = x.ravel(order='K') + data = x.ravel(order="K") else: x = np.ascontiguousarray(x) strides = x.strides data = x.ravel() if data.dtype.fields or data.dtype.itemsize > 8: - data = data.view('u%d' % gcd(x.dtype.itemsize, 8)) + data = data.view("u%d" % gcd(x.dtype.itemsize, 8)) try: data = data.data except ValueError: # "ValueError: cannot include dtype 'M' in a buffer" - data = data.view('u%d' % gcd(x.dtype.itemsize, 8)).data + data = data.view("u%d" % gcd(x.dtype.itemsize, 8)).data - header = {'dtype': dt, - 'shape': x.shape, - 'strides': strides} + header = {"dtype": dt, "shape": x.shape, "strides": strides} if x.nbytes > 1e5: frames = frame_split_size([data]) else: frames = [data] - header['lengths'] = [x.nbytes] + header["lengths"] = [x.nbytes] return header, frames @@ -88,17 +86,18 @@ def deserialize_numpy_ndarray(header, frames): if len(frames) > 1: frames = merge_frames(header, frames) - if header.get('pickle'): + if header.get("pickle"): return pickle.loads(frames[0]) - is_custom, dt = header['dtype'] + is_custom, dt = header["dtype"] if is_custom: dt = pickle.loads(dt) else: dt = np.dtype(dt) - x = np.ndarray(header['shape'], dtype=dt, buffer=frames[0], - strides=header['strides']) + x = np.ndarray( + header["shape"], dtype=dt, buffer=frames[0], strides=header["strides"] + ) return x @@ -116,13 +115,12 @@ def deserialize_numpy_ma_masked(header, frames): @dask_serialize.register(np.ma.core.MaskedArray) def serialize_numpy_maskedarray(x): data_header, frames = serialize_numpy_ndarray(x.data) - header = {'data-header': data_header, - 'nframes': len(frames)} + header = {"data-header": data_header, "nframes": len(frames)} # Serialize mask if present if x.mask is not np.ma.nomask: mask_header, mask_frames = serialize_numpy_ndarray(x.mask) - header['mask-header'] = mask_header + header["mask-header"] = mask_header frames += mask_frames # Only a few dtypes have python equivalents msgpack can serialize @@ -130,7 +128,7 @@ def serialize_numpy_maskedarray(x): serialized_fill_value = (False, x.fill_value.item()) else: serialized_fill_value = (True, pickle.dumps(x.fill_value)) - header['fill-value'] = serialized_fill_value + header["fill-value"] = serialized_fill_value return header, frames @@ -138,12 +136,12 @@ def serialize_numpy_maskedarray(x): @dask_deserialize.register(np.ma.core.MaskedArray) def deserialize_numpy_maskedarray(header, frames): data_header = header["data-header"] - data_frames = frames[:header["nframes"]] + data_frames = frames[: header["nframes"]] data = deserialize_numpy_ndarray(data_header, data_frames) - if 'mask-header' in header: + if "mask-header" in header: mask_header = header["mask-header"] - mask_frames = frames[header["nframes"]:] + mask_frames = frames[header["nframes"] :] mask = deserialize_numpy_ndarray(mask_header, mask_frames) else: mask = np.ma.nomask diff --git a/distributed/protocol/pickle.py b/distributed/protocol/pickle.py index 8419541687f..080bb9037db 100644 --- a/distributed/protocol/pickle.py +++ b/distributed/protocol/pickle.py @@ -14,14 +14,16 @@ def _always_use_pickle_for(x): - mod, _, _ = x.__class__.__module__.partition('.') - if mod == 'numpy': + mod, _, _ = x.__class__.__module__.partition(".") + if mod == "numpy": import numpy as np + return isinstance(x, np.ndarray) - elif mod == 'pandas': + elif mod == "pandas": import pandas as pd + return isinstance(x, pd.core.generic.NDFrame) - elif mod == 'builtins': + elif mod == "builtins": return isinstance(x, (str, bytes)) else: return False @@ -37,12 +39,12 @@ def dumps(x): try: result = pickle.dumps(x, protocol=pickle.HIGHEST_PROTOCOL) if len(result) < 1000: - if b'__main__' in result: + if b"__main__" in result: return cloudpickle.dumps(x, protocol=pickle.HIGHEST_PROTOCOL) else: return result else: - if _always_use_pickle_for(x) or b'__main__' not in result: + if _always_use_pickle_for(x) or b"__main__" not in result: return result else: return cloudpickle.dumps(x, protocol=pickle.HIGHEST_PROTOCOL) diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index a6cfbd6d042..3b0a45c8a6f 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -4,6 +4,7 @@ import dask from dask.base import normalize_token + try: from cytoolz import valmap, get_in except ImportError: @@ -15,15 +16,14 @@ from ..compatibility import PY2 from ..utils import has_keyword from .compression import maybe_compress, decompress -from .utils import (unpack_frames, pack_frames_prelude, frame_split_size, - ensure_bytes) +from .utils import unpack_frames, pack_frames_prelude, frame_split_size, ensure_bytes lazy_registrations = {} -dask_serialize = dask.utils.Dispatch('dask_serialize') -dask_deserialize = dask.utils.Dispatch('dask_deserialize') +dask_serialize = dask.utils.Dispatch("dask_serialize") +dask_deserialize = dask.utils.Dispatch("dask_deserialize") def dask_dumps(x, context=None): @@ -33,34 +33,34 @@ def dask_dumps(x, context=None): dumps = dask_serialize.dispatch(type(x)) except TypeError: raise NotImplementedError(type_name) - if has_keyword(dumps, 'context'): + if has_keyword(dumps, "context"): header, frames = dumps(x, context=context) else: header, frames = dumps(x) - header['type'] = type_name - header['type-serialized'] = pickle.dumps(type(x)) - header['serializer'] = 'dask' + header["type"] = type_name + header["type-serialized"] = pickle.dumps(type(x)) + header["serializer"] = "dask" return header, frames def dask_loads(header, frames): - typ = pickle.loads(header['type-serialized']) + typ = pickle.loads(header["type-serialized"]) loads = dask_deserialize.dispatch(typ) return loads(header, frames) def pickle_dumps(x): - return {'serializer': 'pickle'}, [pickle.dumps(x)] + return {"serializer": "pickle"}, [pickle.dumps(x)] def pickle_loads(header, frames): - return pickle.loads(b''.join(frames)) + return pickle.loads(b"".join(frames)) msgpack_len_opts = { - ('max_%s_len' % x): 2**31 - 1 - for x in ['str', 'bin', 'array', 'map', 'ext']} + ("max_%s_len" % x): 2 ** 31 - 1 for x in ["str", "bin", "array", "map", "ext"] +} def msgpack_dumps(x): @@ -69,16 +69,17 @@ def msgpack_dumps(x): except Exception: raise NotImplementedError() else: - return {'serializer': 'msgpack'}, [frame] + return {"serializer": "msgpack"}, [frame] def msgpack_loads(header, frames): - return msgpack.loads(b''.join(frames), encoding='utf8', use_list=False, - **msgpack_len_opts) + return msgpack.loads( + b"".join(frames), encoding="utf8", use_list=False, **msgpack_len_opts + ) def serialization_error_loads(header, frames): - msg = '\n'.join([ensure_bytes(frame).decode('utf8') for frame in frames]) + msg = "\n".join([ensure_bytes(frame).decode("utf8") for frame in frames]) raise TypeError(msg) @@ -86,16 +87,16 @@ def serialization_error_loads(header, frames): def register_serialization_family(name, dumps, loads): - families[name] = (dumps, loads, dumps and has_keyword(dumps, 'context')) + families[name] = (dumps, loads, dumps and has_keyword(dumps, "context")) -register_serialization_family('dask', dask_dumps, dask_loads) -register_serialization_family('pickle', pickle_dumps, pickle_loads) -register_serialization_family('msgpack', msgpack_dumps, msgpack_loads) -register_serialization_family('error', None, serialization_error_loads) +register_serialization_family("dask", dask_dumps, dask_loads) +register_serialization_family("pickle", pickle_dumps, pickle_loads) +register_serialization_family("msgpack", msgpack_dumps, msgpack_loads) +register_serialization_family("error", None, serialization_error_loads) -def serialize(x, serializers=None, on_error='message', context=None): +def serialize(x, serializers=None, on_error="message", context=None): r""" Convert object to a header and list of bytestrings @@ -132,18 +133,18 @@ def serialize(x, serializers=None, on_error='message', context=None): register_serialization: Register custom serialization functions """ if serializers is None: - serializers = ('dask', 'pickle') # TODO: get from configuration + serializers = ("dask", "pickle") # TODO: get from configuration if isinstance(x, Serialized): return x.header, x.frames - tb = '' + tb = "" for name in serializers: dumps, loads, wants_context = families[name] try: header, frames = dumps(x, context=context) if wants_context else dumps(x) - header['serializer'] = name + header["serializer"] = name return header, frames except NotImplementedError: continue @@ -152,15 +153,15 @@ def serialize(x, serializers=None, on_error='message', context=None): break msg = "Could not serialize object of type %s." % type(x).__name__ - if on_error == 'message': + if on_error == "message": frames = [msg] if tb: frames.append(tb[:100000]) frames = [frame.encode() for frame in frames] - return {'serializer': 'error'}, frames - elif on_error == 'raise': + return {"serializer": "error"}, frames + elif on_error == "raise": raise TypeError(msg, str(x)[:10000]) @@ -180,10 +181,12 @@ def deserialize(header, frames, deserializers=None): -------- serialize """ - name = header.get('serializer') + name = header.get("serializer") if deserializers is not None and name not in deserializers: - raise TypeError("Data serialized with %s but only able to deserialize " - "data with %s" % (name, str(list(deserializers)))) + raise TypeError( + "Data serialized with %s but only able to deserialize " + "data with %s" % (name, str(list(deserializers))) + ) dumps, loads, wants_context = families[name] return loads(header, frames) @@ -209,8 +212,7 @@ def __repr__(self): return "" % str(self.data) def __eq__(self, other): - return (isinstance(other, Serialize) and - other.data == self.data) + return isinstance(other, Serialize) and other.data == self.data def __ne__(self, other): return not (self == other) @@ -237,13 +239,16 @@ def __init__(self, header, frames): def deserialize(self): from .core import decompress + frames = decompress(self.header, self.frames) return deserialize(self.header, frames) def __eq__(self, other): - return (isinstance(other, Serialized) and - other.header == self.header and - other.frames == self.frames) + return ( + isinstance(other, Serialized) + and other.header == self.header + and other.frames == self.frames + ) def __ne__(self, other): return not (self == other) @@ -296,16 +301,24 @@ def _extract_serialize(x, ser, path=()): typ = type(v) if typ is list or typ is dict: _extract_serialize(v, ser, path + (k,)) - elif (typ is Serialize or typ is Serialized - or typ in (bytes, bytearray) and len(v) > 2**16): + elif ( + typ is Serialize + or typ is Serialized + or typ in (bytes, bytearray) + and len(v) > 2 ** 16 + ): ser[path + (k,)] = v elif type(x) is list: for k, v in enumerate(x): typ = type(v) if typ is list or typ is dict: _extract_serialize(v, ser, path + (k,)) - elif (typ is Serialize or typ is Serialized - or typ in (bytes, bytearray) and len(v) > 2**16): + elif ( + typ is Serialize + or typ is Serialized + or typ in (bytes, bytearray) + and len(v) > 2 ** 16 + ): ser[path + (k,)] = v @@ -318,6 +331,7 @@ def nested_deserialize(x): >>> nested_deserialize(msg) {'op': 'update', 'data': 123} """ + def replace_inner(x): if type(x) is dict: x = x.copy() @@ -353,8 +367,8 @@ def serialize_bytelist(x, **kwargs): compression, frames = zip(*map(maybe_compress, frames)) else: compression = [] - header['compression'] = compression - header['count'] = len(frames) + header["compression"] = compression + header["count"] = len(frames) header = msgpack.dumps(header, use_bin_type=True) frames2 = [header] + list(frames) @@ -365,7 +379,7 @@ def serialize_bytes(x, **kwargs): L = serialize_bytelist(x, **kwargs) if PY2: L = [bytes(y) for y in L] - return b''.join(L) + return b"".join(L) def deserialize_bytes(b): @@ -441,7 +455,7 @@ def typename(typ): >>> typename(Scheduler) 'distributed.scheduler.Scheduler' """ - return typ.__module__ + '.' + typ.__name__ + return typ.__module__ + "." + typ.__name__ @partial(normalize_token.register, Serialized) @@ -469,18 +483,24 @@ def _deserialize_bytes(header, frames): def _is_msgpack_serializable(v): typ = type(v) - return (typ is str or typ is int or typ is float or - isinstance(v, dict) and all(map(_is_msgpack_serializable, v.values())) - and all(typ is str for x in v.keys()) or - isinstance(v, (list, tuple)) and all(map(_is_msgpack_serializable, v))) + return ( + typ is str + or typ is int + or typ is float + or isinstance(v, dict) + and all(map(_is_msgpack_serializable, v.values())) + and all(typ is str for x in v.keys()) + or isinstance(v, (list, tuple)) + and all(map(_is_msgpack_serializable, v)) + ) def serialize_object_with_dict(est): header = { - 'serializer': 'dask', - 'type-serialized': pickle.dumps(type(est)), - 'simple': {}, - 'complex': {} + "serializer": "dask", + "type-serialized": pickle.dumps(type(est)), + "simple": {}, + "complex": {}, } frames = [] @@ -491,30 +511,32 @@ def serialize_object_with_dict(est): for k, v in d.items(): if _is_msgpack_serializable(v): - header['simple'][k] = v + header["simple"][k] = v else: if isinstance(v, dict): h, f = serialize_object_with_dict(v) else: h, f = serialize(v) - header['complex'][k] = {'header': h, - 'start': len(frames), - 'stop': len(frames) + len(f)} + header["complex"][k] = { + "header": h, + "start": len(frames), + "stop": len(frames) + len(f), + } frames += f return header, frames def deserialize_object_with_dict(header, frames): - cls = pickle.loads(header['type-serialized']) + cls = pickle.loads(header["type-serialized"]) if issubclass(cls, dict): dd = obj = {} else: obj = object.__new__(cls) dd = obj.__dict__ - dd.update(header['simple']) - for k, d in header['complex'].items(): - h = d['header'] - f = frames[d['start']: d['stop']] + dd.update(header["simple"]) + for k, d in header["complex"].items(): + h = d["header"] + f = frames[d["start"] : d["stop"]] v = deserialize(h, f) dd[k] = v diff --git a/distributed/protocol/sparse.py b/distributed/protocol/sparse.py index ca0c6f38a79..b5a437a32a4 100644 --- a/distributed/protocol/sparse.py +++ b/distributed/protocol/sparse.py @@ -10,22 +10,24 @@ def serialize_sparse(x): coords_header, coords_frames = serialize(x.coords) data_header, data_frames = serialize(x.data) - header = {'coords-header': coords_header, - 'data-header': data_header, - 'shape': x.shape, - 'nframes': [len(coords_frames), len(data_frames)]} + header = { + "coords-header": coords_header, + "data-header": data_header, + "shape": x.shape, + "nframes": [len(coords_frames), len(data_frames)], + } return header, coords_frames + data_frames @dask_deserialize.register(sparse.COO) def deserialize_sparse(header, frames): - coords_frames = frames[:header['nframes'][0]] - data_frames = frames[header['nframes'][0]:] + coords_frames = frames[: header["nframes"][0]] + data_frames = frames[header["nframes"][0] :] - coords = deserialize(header['coords-header'], coords_frames) - data = deserialize(header['data-header'], data_frames) + coords = deserialize(header["coords-header"], coords_frames) + data = deserialize(header["data-header"], data_frames) - shape = header['shape'] + shape = header["shape"] return sparse.COO(coords, data, shape=shape) diff --git a/distributed/protocol/tests/test_arrow.py b/distributed/protocol/tests/test_arrow.py index eca8de9f1a3..a363ee9511e 100644 --- a/distributed/protocol/tests/test_arrow.py +++ b/distributed/protocol/tests/test_arrow.py @@ -1,18 +1,18 @@ import pandas as pd import pytest -pa = pytest.importorskip('pyarrow') +pa = pytest.importorskip("pyarrow") from distributed.utils_test import gen_cluster from distributed.protocol import deserialize, serialize -df = pd.DataFrame({'A': list('abc'), 'B': [1,2,3]}) +df = pd.DataFrame({"A": list("abc"), "B": [1, 2, 3]}) tbl = pa.Table.from_pandas(df, preserve_index=False) batch = pa.RecordBatch.from_pandas(df, preserve_index=False) -@pytest.mark.parametrize('obj', [batch, tbl], ids=["RecordBatch", "Table"]) +@pytest.mark.parametrize("obj", [batch, tbl], ids=["RecordBatch", "Table"]) def test_roundtrip(obj): # Test that the serialize/deserialize functions actually # work independent of distributed @@ -25,7 +25,7 @@ def echo(arg): return arg -@pytest.mark.parametrize('obj', [batch, tbl], ids=["RecordBatch", "Table"]) +@pytest.mark.parametrize("obj", [batch, tbl], ids=["RecordBatch", "Table"]) def test_scatter(obj): @gen_cluster(client=True) def run_test(client, scheduler, worker1, worker2): @@ -33,4 +33,5 @@ def run_test(client, scheduler, worker1, worker2): fut = client.submit(echo, obj_fut) result = yield fut assert obj.equals(result) + run_test() diff --git a/distributed/protocol/tests/test_h5py.py b/distributed/protocol/tests/test_h5py.py index 3fffa9fecd4..f2f9a6625cb 100644 --- a/distributed/protocol/tests/test_h5py.py +++ b/distributed/protocol/tests/test_h5py.py @@ -3,7 +3,7 @@ import pytest -h5py = pytest.importorskip('h5py') +h5py = pytest.importorskip("h5py") from distributed.protocol import deserialize, serialize @@ -33,39 +33,39 @@ def wrapper(): @silence_h5py_issue775 def test_serialize_deserialize_file(): with tmpfile() as fn: - with h5py.File(fn, mode='a') as f: - f.create_dataset('/x', shape=(2, 2), dtype='i4') - with h5py.File(fn, mode='r') as f: + with h5py.File(fn, mode="a") as f: + f.create_dataset("/x", shape=(2, 2), dtype="i4") + with h5py.File(fn, mode="r") as f: g = deserialize(*serialize(f)) assert f.filename == g.filename assert isinstance(g, h5py.File) assert f.mode == g.mode - assert g['x'].shape == (2, 2) + assert g["x"].shape == (2, 2) @silence_h5py_issue775 def test_serialize_deserialize_group(): with tmpfile() as fn: - with h5py.File(fn, mode='a') as f: - f.create_dataset('/group1/group2/x', shape=(2, 2), dtype='i4') - with h5py.File(fn, mode='r') as f: - group = f['/group1/group2'] + with h5py.File(fn, mode="a") as f: + f.create_dataset("/group1/group2/x", shape=(2, 2), dtype="i4") + with h5py.File(fn, mode="r") as f: + group = f["/group1/group2"] group2 = deserialize(*serialize(group)) assert isinstance(group2, h5py.Group) assert group.file.filename == group2.file.filename - assert group2['x'].shape == (2, 2) + assert group2["x"].shape == (2, 2) @silence_h5py_issue775 def test_serialize_deserialize_dataset(): with tmpfile() as fn: - with h5py.File(fn, mode='a') as f: - x = f.create_dataset('/group1/group2/x', shape=(2, 2), dtype='i4') - with h5py.File(fn, mode='r') as f: - x = f['group1/group2/x'] + with h5py.File(fn, mode="a") as f: + x = f.create_dataset("/group1/group2/x", shape=(2, 2), dtype="i4") + with h5py.File(fn, mode="r") as f: + x = f["group1/group2/x"] y = deserialize(*serialize(x)) assert isinstance(y, h5py.Dataset) assert x.name == y.name @@ -76,8 +76,8 @@ def test_serialize_deserialize_dataset(): @silence_h5py_issue775 def test_raise_error_on_serialize_write_permissions(): with tmpfile() as fn: - with h5py.File(fn, mode='a') as f: - x = f.create_dataset('/x', shape=(2, 2), dtype='i4') + with h5py.File(fn, mode="a") as f: + x = f.create_dataset("/x", shape=(2, 2), dtype="i4") f.flush() with pytest.raises(TypeError): deserialize(*serialize(x)) @@ -95,14 +95,14 @@ def test_raise_error_on_serialize_write_permissions(): @gen_cluster(client=True) def test_h5py_serialize(c, s, a, b): from dask.utils import SerializableLock - lock = SerializableLock('hdf5') + + lock = SerializableLock("hdf5") with tmpfile() as fn: - with h5py.File(fn, mode='a') as f: - x = f.create_dataset('/group/x', shape=(4,), dtype='i4', - chunks=(2,)) + with h5py.File(fn, mode="a") as f: + x = f.create_dataset("/group/x", shape=(4,), dtype="i4", chunks=(2,)) x[:] = [1, 2, 3, 4] - with h5py.File(fn, mode='r') as f: - dset = f['/group/x'] + with h5py.File(fn, mode="r") as f: + dset = f["/group/x"] x = da.from_array(dset, chunks=dset.chunks, lock=lock) y = c.compute(x) y = yield y @@ -112,12 +112,11 @@ def test_h5py_serialize(c, s, a, b): @gen_cluster(client=True) def test_h5py_serialize_2(c, s, a, b): with tmpfile() as fn: - with h5py.File(fn, mode='a') as f: - x = f.create_dataset('/group/x', shape=(12,), dtype='i4', - chunks=(4,)) + with h5py.File(fn, mode="a") as f: + x = f.create_dataset("/group/x", shape=(12,), dtype="i4", chunks=(4,)) x[:] = [1, 2, 3, 4] * 3 - with h5py.File(fn, mode='r') as f: - dset = f['/group/x'] + with h5py.File(fn, mode="r") as f: + dset = f["/group/x"] x = da.from_array(dset, chunks=(3,)) y = c.compute(x.sum()) y = yield y diff --git a/distributed/protocol/tests/test_keras.py b/distributed/protocol/tests/test_keras.py index b246b33bfe8..da8cdf6374a 100644 --- a/distributed/protocol/tests/test_keras.py +++ b/distributed/protocol/tests/test_keras.py @@ -1,9 +1,8 @@ - import numpy as np from numpy.testing import assert_allclose import pytest -keras = pytest.importorskip('keras') +keras = pytest.importorskip("keras") from distributed.protocol import serialize, deserialize, dumps, loads, to_serialize @@ -12,7 +11,7 @@ def test_serialize_deserialize_model(): model = keras.models.Sequential() model.add(keras.layers.Dense(5, input_dim=3)) model.add(keras.layers.Dense(2)) - model.compile(optimizer='sgd', loss='mse') + model.compile(optimizer="sgd", loss="mse") x = np.random.random((1, 3)) y = np.random.random((1, 2)) model.train_on_batch(x, y) @@ -20,7 +19,7 @@ def test_serialize_deserialize_model(): loaded = deserialize(*serialize(model)) assert_allclose(loaded.predict(x), model.predict(x)) - data = {'model': to_serialize(model)} + data = {"model": to_serialize(model)} frames = dumps(data) result = loads(frames) - assert_allclose(result['model'].predict(x), model.predict(x)) + assert_allclose(result["model"].predict(x), model.predict(x)) diff --git a/distributed/protocol/tests/test_netcdf4.py b/distributed/protocol/tests/test_netcdf4.py index 381f6468182..f1ddcead3ef 100644 --- a/distributed/protocol/tests/test_netcdf4.py +++ b/distributed/protocol/tests/test_netcdf4.py @@ -1,7 +1,7 @@ import pytest -netCDF4 = pytest.importorskip('netCDF4') -np = pytest.importorskip('numpy') +netCDF4 = pytest.importorskip("netCDF4") +np = pytest.importorskip("numpy") from distributed.protocol import deserialize, serialize @@ -9,51 +9,51 @@ def create_test_dataset(fn): - with netCDF4.Dataset(fn, mode='w') as ds: - ds.createDimension('x', 3) - v = ds.createVariable('x', np.int32, ('x',)) + with netCDF4.Dataset(fn, mode="w") as ds: + ds.createDimension("x", 3) + v = ds.createVariable("x", np.int32, ("x",)) v[:] = np.arange(3) - g = ds.createGroup('group') - g2 = ds.createGroup('group/group1') + g = ds.createGroup("group") + g2 = ds.createGroup("group/group1") - v2 = ds.createVariable('group/y', np.int32, ('x',)) + v2 = ds.createVariable("group/y", np.int32, ("x",)) v2[:] = np.arange(3) + 1 - v3 = ds.createVariable('group/group1/z', np.int32, ('x',)) + v3 = ds.createVariable("group/group1/z", np.int32, ("x",)) v3[:] = np.arange(3) + 2 def test_serialize_deserialize_dataset(): with tmpfile() as fn: create_test_dataset(fn) - with netCDF4.Dataset(fn, mode='r') as f: + with netCDF4.Dataset(fn, mode="r") as f: g = deserialize(*serialize(f)) assert f.filepath() == g.filepath() assert isinstance(g, netCDF4.Dataset) - assert g.variables['x'].dimensions == ('x',) - assert g.variables['x'].dtype == np.int32 - assert (g.variables['x'][:] == np.arange(3)).all() + assert g.variables["x"].dimensions == ("x",) + assert g.variables["x"].dtype == np.int32 + assert (g.variables["x"][:] == np.arange(3)).all() def test_serialize_deserialize_variable(): with tmpfile() as fn: create_test_dataset(fn) - with netCDF4.Dataset(fn, mode='r') as f: - x = f.variables['x'] + with netCDF4.Dataset(fn, mode="r") as f: + x = f.variables["x"] y = deserialize(*serialize(x)) assert isinstance(y, netCDF4.Variable) - assert y.dimensions == ('x',) - assert (x.dtype == y.dtype) + assert y.dimensions == ("x",) + assert x.dtype == y.dtype assert (x[:] == y[:]).all() def test_serialize_deserialize_group(): with tmpfile() as fn: create_test_dataset(fn) - with netCDF4.Dataset(fn, mode='r') as f: - for path in ['group', 'group/group1']: + with netCDF4.Dataset(fn, mode="r") as f: + for path in ["group", "group/group1"]: g = f[path] h = deserialize(*serialize(g)) assert isinstance(h, netCDF4.Group) @@ -61,15 +61,17 @@ def test_serialize_deserialize_group(): assert list(g.groups) == list(h.groups) assert list(g.variables) == list(h.variables) - vars = [f.variables['x'], - f['group'].variables['y'], - f['group/group1'].variables['z']] + vars = [ + f.variables["x"], + f["group"].variables["y"], + f["group/group1"].variables["z"], + ] for x in vars: y = deserialize(*serialize(x)) assert isinstance(y, netCDF4.Variable) - assert y.dimensions == ('x',) - assert (x.dtype == y.dtype) + assert y.dimensions == ("x",) + assert x.dtype == y.dtype assert (x[:] == y[:]).all() @@ -83,8 +85,8 @@ def test_serialize_deserialize_group(): def test_netcdf4_serialize(c, s, a, b): with tmpfile() as fn: create_test_dataset(fn) - with netCDF4.Dataset(fn, mode='r') as f: - dset = f.variables['x'] + with netCDF4.Dataset(fn, mode="r") as f: + dset = f.variables["x"] x = da.from_array(dset, chunks=2) y = c.compute(x) y = yield y diff --git a/distributed/protocol/tests/test_numpy.py b/distributed/protocol/tests/test_numpy.py index c86a5f03199..849c2964fd6 100644 --- a/distributed/protocol/tests/test_numpy.py +++ b/distributed/protocol/tests/test_numpy.py @@ -7,8 +7,15 @@ import pytest from distributed.compatibility import PY2 -from distributed.protocol import (serialize, deserialize, decompress, dumps, - loads, to_serialize, msgpack) +from distributed.protocol import ( + serialize, + deserialize, + decompress, + dumps, + loads, + to_serialize, + msgpack, +) from distributed.protocol.utils import BIG_BYTES_SHARD_SIZE from distributed.utils import tmpfile, nbytes from distributed.utils_test import slow, gen_cluster @@ -19,55 +26,58 @@ def test_serialize(): x = np.ones((5, 5)) header, frames = serialize(x) - assert header['type'] + assert header["type"] assert len(frames) == 1 - if 'compression' in header: + if "compression" in header: frames = decompress(header, frames) result = deserialize(header, frames) assert (result == x).all() -@pytest.mark.parametrize('x', [ - np.ones(5), - np.array(5), - np.random.random((5, 5)), - np.random.random((5, 5))[::2, :], - np.random.random((5, 5))[:, ::2], - np.asfortranarray(np.random.random((5, 5))), - np.asfortranarray(np.random.random((5, 5)))[::2, :], - np.asfortranarray(np.random.random((5, 5)))[:, ::2], - np.random.random(5).astype('f4'), - np.random.random(5).astype('>i8'), - np.random.random(5).astype('i8"), + np.random.random(5).astype(" 2 result = loads(frames) - assert result == {'x': 1, 'data': 123} + assert result == {"x": 1, "data": 123} result2 = loads(frames, deserialize=False) - assert result2['x'] == 1 - assert isinstance(result2['data'], Serialized) - assert any(a is b - for a in result2['data'].frames - for b in frames) + assert result2["x"] == 1 + assert isinstance(result2["data"], Serialized) + assert any(a is b for a in result2["data"].frames for b in frames) frames2 = dumps(result2) assert all(map(eq_frames, frames, frames2)) @@ -210,13 +207,11 @@ def test_dumps_loads_Serialize(): def test_dumps_loads_Serialized(): - msg = {'x': 1, - 'data': Serialized(*serialize(123)), - } + msg = {"x": 1, "data": Serialized(*serialize(123))} frames = dumps(msg) assert len(frames) > 2 result = loads(frames) - assert result == {'x': 1, 'data': 123} + assert result == {"x": 1, "data": 123} result2 = loads(frames, deserialize=False) assert result2 == msg @@ -228,18 +223,17 @@ def test_dumps_loads_Serialized(): assert result == result3 -@pytest.mark.skipif(sys.version_info[0] < 3, - reason='NumPy doesnt use memoryviews') +@pytest.mark.skipif(sys.version_info[0] < 3, reason="NumPy doesnt use memoryviews") def test_maybe_compress_memoryviews(): - np = pytest.importorskip('numpy') - pytest.importorskip('lz4') - x = np.arange(1000000, dtype='int64') + np = pytest.importorskip("numpy") + pytest.importorskip("lz4") + x = np.arange(1000000, dtype="int64") compression, payload = maybe_compress(x.data) try: import blosc # noqa: F401 except ImportError: - assert compression == 'lz4' + assert compression == "lz4" assert len(payload) < x.nbytes * 0.75 else: - assert compression == 'blosc' + assert compression == "blosc" assert len(payload) < x.nbytes / 10 diff --git a/distributed/protocol/tests/test_protocol_utils.py b/distributed/protocol/tests/test_protocol_utils.py index 0db79ccaafa..f4b98ab0e1d 100644 --- a/distributed/protocol/tests/test_protocol_utils.py +++ b/distributed/protocol/tests/test_protocol_utils.py @@ -5,20 +5,20 @@ def test_merge_frames(): - result = merge_frames({'lengths': [3, 4]}, [b'12', b'34', b'567']) - expected = [b'123', b'4567'] + result = merge_frames({"lengths": [3, 4]}, [b"12", b"34", b"567"]) + expected = [b"123", b"4567"] assert list(map(ensure_bytes, result)) == expected - b = b'123' - assert merge_frames({'lengths': [3]}, [b])[0] is b + b = b"123" + assert merge_frames({"lengths": [3]}, [b])[0] is b - L = [b'123', b'456'] - assert merge_frames({'lengths': [3, 3]}, L) is L + L = [b"123", b"456"] + assert merge_frames({"lengths": [3, 3]}, L) is L def test_pack_frames(): - frames = [b'123', b'asdf'] + frames = [b"123", b"asdf"] b = pack_frames(frames) assert isinstance(b, bytes) frames2 = unpack_frames(b) diff --git a/distributed/protocol/tests/test_serialize.py b/distributed/protocol/tests/test_serialize.py index 64d7adc8d41..da43021d550 100644 --- a/distributed/protocol/tests/test_serialize.py +++ b/distributed/protocol/tests/test_serialize.py @@ -9,11 +9,20 @@ from toolz import identity from distributed import wait -from distributed.protocol import (register_serialization, serialize, - deserialize, nested_deserialize, Serialize, - Serialized, to_serialize, serialize_bytes, - deserialize_bytes, serialize_bytelist, - register_serialization_family, dask_serialize) +from distributed.protocol import ( + register_serialization, + serialize, + deserialize, + nested_deserialize, + Serialize, + Serialized, + to_serialize, + serialize_bytes, + deserialize_bytes, + serialize_bytelist, + register_serialization_family, + dask_serialize, +) from distributed.utils import nbytes from distributed.utils_test import inc, gen_test from distributed.comm.utils import to_frames, from_frames @@ -24,7 +33,7 @@ def __init__(self, data): self.data = data def __getstate__(self): - raise Exception('Not picklable') + raise Exception("Not picklable") def serialize_myobj(x): @@ -41,7 +50,7 @@ def deserialize_myobj(header, frames): def test_dumps_serialize(): for x in [123, [1, 2, 3]]: header, frames = serialize(x) - assert header['serializer'] == 'pickle' + assert header["serializer"] == "pickle" assert len(frames) == 1 result = deserialize(header, frames) @@ -49,7 +58,7 @@ def test_dumps_serialize(): x = MyObj(123) header, frames = serialize(x) - assert header['type'] + assert header["type"] assert len(frames) == 1 result = deserialize(header, frames) @@ -57,7 +66,7 @@ def test_dumps_serialize(): def test_serialize_bytestrings(): - for b in (b'123', bytearray(b'4567')): + for b in (b"123", bytearray(b"4567")): header, frames = serialize(b) assert frames[0] is b bb = deserialize(header, frames) @@ -66,7 +75,7 @@ def test_serialize_bytestrings(): def test_Serialize(): s = Serialize(123) - assert '123' in str(s) + assert "123" in str(s) assert s.data == 123 t = Serialize((1, 2)) @@ -92,18 +101,18 @@ def test_Serialized(): def test_nested_deserialize(): - x = {'op': 'update', - 'x': [to_serialize(123), to_serialize(456), 789], - 'y': {'a': ['abc', Serialized(*serialize('def'))], - 'b': b'ghi'} - } + x = { + "op": "update", + "x": [to_serialize(123), to_serialize(456), 789], + "y": {"a": ["abc", Serialized(*serialize("def"))], "b": b"ghi"}, + } x_orig = copy.deepcopy(x) - assert nested_deserialize(x) == {'op': 'update', - 'x': [123, 456, 789], - 'y': {'a': ['abc', 'def'], - 'b': b'ghi'} - } + assert nested_deserialize(x) == { + "op": "update", + "x": [123, 456, 789], + "y": {"a": ["abc", "def"], "b": b"ghi"}, + } assert x == x_orig # x wasn't mutated @@ -146,7 +155,7 @@ def test_inter_worker_comms(c, s, a, b): class Empty(object): def __getstate__(self): - raise Exception('Not picklable') + raise Exception("Not picklable") def serialize_empty(x): @@ -168,6 +177,7 @@ def test_empty(): def test_empty_loads(): from distributed.protocol import loads, dumps + e = Empty() e2 = loads(dumps([to_serialize(e)])) assert isinstance(e2[0], Empty) @@ -175,13 +185,14 @@ def test_empty_loads(): def test_empty_loads_deep(): from distributed.protocol import loads, dumps + e = Empty() e2 = loads(dumps([[[to_serialize(e)]]])) assert isinstance(e2[0][0][0], Empty) def test_serialize_bytes(): - for x in [1, 'abc', np.arange(5)]: + for x in [1, "abc", np.arange(5)]: b = serialize_bytes(x) assert isinstance(b, bytes) y = deserialize_bytes(b) @@ -189,12 +200,12 @@ def test_serialize_bytes(): def test_serialize_list_compress(): - pytest.importorskip('lz4') + pytest.importorskip("lz4") x = np.ones(1000000) L = serialize_bytelist(x) assert sum(map(nbytes, L)) < x.nbytes / 2 - b = b''.join(L) + b = b"".join(L) y = deserialize_bytes(b) assert (x == y).all() @@ -217,7 +228,7 @@ def __getstate__(self): assert "Sneaky" not in str(info.value) assert "MyClass" in str(info.value) - header, frames = serialize(obj, serializers=['pickle']) + header, frames = serialize(obj, serializers=["pickle"]) with pytest.raises(Exception) as info: deserialize(header, frames) @@ -226,28 +237,27 @@ def __getstate__(self): def test_errors(): - msg = {'data': {'foo': to_serialize(inc)}} + msg = {"data": {"foo": to_serialize(inc)}} - header, frames = serialize(msg, serializers=['msgpack', 'pickle']) - assert header['serializer'] == 'pickle' + header, frames = serialize(msg, serializers=["msgpack", "pickle"]) + assert header["serializer"] == "pickle" - header, frames = serialize(msg, serializers=['msgpack']) - assert header['serializer'] == 'error' + header, frames = serialize(msg, serializers=["msgpack"]) + assert header["serializer"] == "error" with pytest.raises(TypeError): - serialize(msg, serializers=['msgpack'], on_error='raise') + serialize(msg, serializers=["msgpack"], on_error="raise") @gen_test() def test_err_on_bad_deserializer(): - frames = yield to_frames({'x': to_serialize(1234)}, - serializers=['pickle']) + frames = yield to_frames({"x": to_serialize(1234)}, serializers=["pickle"]) - result = yield from_frames(frames, deserializers=['pickle', 'foo']) - assert result == {'x': 1234} + result = yield from_frames(frames, deserializers=["pickle", "foo"]) + assert result == {"x": 1234} with pytest.raises(TypeError) as info: - yield from_frames(frames, deserializers=['msgpack']) + yield from_frames(frames, deserializers=["msgpack"]) class MyObject(object): @@ -256,10 +266,12 @@ def __init__(self, **kwargs): def my_dumps(obj, context=None): - if type(obj).__name__ == 'MyObject': - header = {'serializer': 'my-ser'} - frames = [msgpack.dumps(obj.__dict__, use_bin_type=True), - msgpack.dumps(context, use_bin_type=True)] + if type(obj).__name__ == "MyObject": + header = {"serializer": "my-ser"} + frames = [ + msgpack.dumps(obj.__dict__, use_bin_type=True), + msgpack.dumps(context, use_bin_type=True), + ] return header, frames else: raise NotImplementedError() @@ -274,11 +286,13 @@ def my_loads(header, frames): return obj -@gen_cluster(client=True, - client_kwargs={'serializers': ['my-ser', 'pickle']}, - worker_kwargs={'serializers': ['my-ser', 'pickle']}) +@gen_cluster( + client=True, + client_kwargs={"serializers": ["my-ser", "pickle"]}, + worker_kwargs={"serializers": ["my-ser", "pickle"]}, +) def test_context_specific_serialization(c, s, a, b): - register_serialization_family('my-ser', my_dumps, my_loads) + register_serialization_family("my-ser", my_dumps, my_loads) try: # Create the object on A, force communication to B @@ -295,16 +309,17 @@ def check(dask_worker): return my_obj.context result = yield c.run(check, workers=[b.address]) - expected = {'sender': a.address, 'recipient': b.address} - assert result[b.address]['sender'] == a.address # see origin worker + expected = {"sender": a.address, "recipient": b.address} + assert result[b.address]["sender"] == a.address # see origin worker z = yield y # bring object to local process assert z.x == 1 and z.y == 2 - assert z.context['sender'] == b.address + assert z.context["sender"] == b.address finally: from distributed.protocol.serialize import families - del families['my-ser'] + + del families["my-ser"] @gen_cluster(client=True) @@ -325,13 +340,13 @@ def check(dask_worker): return my_obj.context result = yield c.run(check, workers=[b.address]) - expected = {'sender': a.address, 'recipient': b.address} - assert result[b.address]['sender'] == a.address # see origin worker + expected = {"sender": a.address, "recipient": b.address} + assert result[b.address]["sender"] == a.address # see origin worker z = yield y # bring object to local process assert z.x == 1 and z.y == 2 - assert z.context['sender'] == b.address + assert z.context["sender"] == b.address def test_serialize_raises(): @@ -345,4 +360,4 @@ def dumps(f): with pytest.raises(Exception) as info: deserialize(*serialize(Foo())) - assert 'Hello-123' in str(info.value) + assert "Hello-123" in str(info.value) diff --git a/distributed/protocol/tests/test_sklearn.py b/distributed/protocol/tests/test_sklearn.py index 4fa8aeb5369..051a0440f3a 100644 --- a/distributed/protocol/tests/test_sklearn.py +++ b/distributed/protocol/tests/test_sklearn.py @@ -1,5 +1,6 @@ import pytest -pytest.importorskip('sklearn') + +pytest.importorskip("sklearn") import sklearn.linear_model @@ -11,7 +12,7 @@ def test_basic(): est.fit([[0, 0], [1, 1], [2, 2]], [0, 1, 2]) header, frames = serialize(est) - assert header['serializer'] == 'dask' + assert header["serializer"] == "dask" est2 = deserialize(header, frames) diff --git a/distributed/protocol/tests/test_sparse.py b/distributed/protocol/tests/test_sparse.py index 2ff97c143a1..89f9da09bc2 100644 --- a/distributed/protocol/tests/test_sparse.py +++ b/distributed/protocol/tests/test_sparse.py @@ -1,9 +1,8 @@ - import numpy as np from numpy.testing import assert_allclose import pytest -sparse = pytest.importorskip('sparse') +sparse = pytest.importorskip("sparse") from distributed.protocol import deserialize, serialize @@ -14,7 +13,7 @@ def test_serialize_deserialize_sparse(): y = sparse.COO(x) header, frames = serialize(y) - assert 'sparse' in header['type'] + assert "sparse" in header["type"] z = deserialize(*serialize(y)) assert_allclose(y.data, z.data) diff --git a/distributed/protocol/tests/test_torch.py b/distributed/protocol/tests/test_torch.py index cac8fa05d66..6cc8bb20986 100644 --- a/distributed/protocol/tests/test_torch.py +++ b/distributed/protocol/tests/test_torch.py @@ -1,15 +1,15 @@ from distributed.protocol import serialize, deserialize import pytest -np = pytest.importorskip('numpy') -torch = pytest.importorskip('torch') +np = pytest.importorskip("numpy") +torch = pytest.importorskip("torch") def test_tensor(): x = np.arange(10) t = torch.Tensor(x) header, frames = serialize(t) - assert header['serializer'] == 'dask' + assert header["serializer"] == "dask" t2 = deserialize(header, frames) assert (x == t2.numpy()).all() @@ -25,7 +25,7 @@ def test_grad(): def test_resnet(): - torchvision = pytest.importorskip('torchvision') + torchvision = pytest.importorskip("torchvision") model = torchvision.models.resnet.resnet18() header, frames = serialize(model) diff --git a/distributed/protocol/torch.py b/distributed/protocol/torch.py index 9a171b6d84f..e69be68b0c1 100644 --- a/distributed/protocol/torch.py +++ b/distributed/protocol/torch.py @@ -1,5 +1,4 @@ -from .serialize import (serialize, dask_serialize, dask_deserialize, - register_generic) +from .serialize import serialize, dask_serialize, dask_deserialize, register_generic import torch import numpy as np @@ -11,32 +10,33 @@ def serialize_torch_Tensor(t): header, frames = serialize(t.detach_().numpy()) if t.grad is not None: grad_header, grad_frames = serialize(t.grad.numpy()) - header['grad'] = {'header': grad_header, 'start': len(frames)} + header["grad"] = {"header": grad_header, "start": len(frames)} frames += grad_frames - header['requires_grad'] = requires_grad_ - header['device'] = t.device.type + header["requires_grad"] = requires_grad_ + header["device"] = t.device.type return header, frames @dask_deserialize.register(torch.Tensor) def deserialize_torch_Tensor(header, frames): - if header.get('grad', False): - i = header['grad']['start'] + if header.get("grad", False): + i = header["grad"]["start"] frames, grad_frames = frames[:i], frames[i:] - grad = dask_deserialize.dispatch(np.ndarray)(header['grad']['header'], - grad_frames) + grad = dask_deserialize.dispatch(np.ndarray)( + header["grad"]["header"], grad_frames + ) else: grad = None x = dask_deserialize.dispatch(np.ndarray)(header, frames) - if header['device'] == 'cpu': + if header["device"] == "cpu": t = torch.from_numpy(x) - if header['requires_grad']: + if header["requires_grad"]: t = t.requires_grad_(True) else: - t = torch.tensor(data=x, - device=header['device'], - requires_grad=header['requires_grad']) + t = torch.tensor( + data=x, device=header["device"], requires_grad=header["requires_grad"] + ) if grad is not None: t.grad = torch.from_numpy(grad) return t @@ -45,14 +45,14 @@ def deserialize_torch_Tensor(header, frames): @dask_serialize.register(torch.nn.Parameter) def serialize_torch_Parameters(p): header, frames = serialize(p.detach()) - header['requires_grad'] = p.requires_grad + header["requires_grad"] = p.requires_grad return header, frames @dask_deserialize.register(torch.nn.Parameter) def deserialize_torch_Parameters(header, frames): t = dask_deserialize.dispatch(torch.Tensor)(header, frames) - return torch.nn.Parameter(data=t, requires_grad=header['requires_grad']) + return torch.nn.Parameter(data=t, requires_grad=header["requires_grad"]) register_generic(torch.nn.Module) diff --git a/distributed/protocol/utils.py b/distributed/protocol/utils.py index 2d4258e9383..90d30342951 100644 --- a/distributed/protocol/utils.py +++ b/distributed/protocol/utils.py @@ -4,7 +4,7 @@ from ..utils import ensure_bytes, nbytes -BIG_BYTES_SHARD_SIZE = 2**26 +BIG_BYTES_SHARD_SIZE = 2 ** 26 def frame_split_size(frames, n=BIG_BYTES_SHARD_SIZE): @@ -34,7 +34,7 @@ def frame_split_size(frames, n=BIG_BYTES_SHARD_SIZE): except AttributeError: itemsize = 1 for i in range(0, nbytes(frame) // itemsize, n // itemsize): - out.append(frame[i: i + n // itemsize]) + out.append(frame[i : i + n // itemsize]) else: out.append(frame) return out @@ -50,7 +50,7 @@ def merge_frames(header, frames): >>> merge_frames({'lengths': [6]}, [b'123', b'456']) [b'123456'] """ - lengths = list(header['lengths']) + lengths = list(header["lengths"]) if not frames: return frames @@ -77,15 +77,16 @@ def merge_frames(header, frames): L.append(mv[:l]) frames.append(mv[l:]) l = 0 - out.append(b''.join(map(ensure_bytes, L))) + out.append(b"".join(map(ensure_bytes, L))) return out def pack_frames_prelude(frames): lengths = [len(f) for f in frames] - lengths = ([struct.pack('Q', len(frames))] + - [struct.pack('Q', nbytes(frame)) for frame in frames]) - return b''.join(lengths) + lengths = [struct.pack("Q", len(frames))] + [ + struct.pack("Q", nbytes(frame)) for frame in frames + ] + return b"".join(lengths) def pack_frames(frames): @@ -102,7 +103,7 @@ def pack_frames(frames): if not isinstance(frames, list): frames = list(frames) - return b''.join(prelude + frames) + return b"".join(prelude + frames) def unpack_frames(b): @@ -115,13 +116,13 @@ def unpack_frames(b): -------- pack_frames """ - (n_frames,) = struct.unpack('Q', b[:8]) + (n_frames,) = struct.unpack("Q", b[:8]) frames = [] start = 8 + n_frames * 8 for i in range(n_frames): - (length,) = struct.unpack('Q', b[(i + 1) * 8: (i + 2) * 8]) - frame = b[start: start + length] + (length,) = struct.unpack("Q", b[(i + 1) * 8 : (i + 2) * 8]) + frame = b[start : start + length] frames.append(frame) start += length diff --git a/distributed/publish.py b/distributed/publish.py index cf38b7d9490..a21f5ef37ed 100644 --- a/distributed/publish.py +++ b/distributed/publish.py @@ -15,26 +15,30 @@ def __init__(self, scheduler): self.scheduler = scheduler self.datasets = dict() - handlers = {'publish_list': self.list, - 'publish_put': self.put, - 'publish_get': self.get, - 'publish_delete': self.delete} + handlers = { + "publish_list": self.list, + "publish_put": self.put, + "publish_get": self.get, + "publish_delete": self.delete, + } self.scheduler.handlers.update(handlers) - self.scheduler.extensions['publish'] = self + self.scheduler.extensions["publish"] = self def put(self, stream=None, keys=None, data=None, name=None, client=None): with log_errors(): if name in self.datasets: raise KeyError("Dataset %s already exists" % name) - self.scheduler.client_desires_keys(keys, 'published-%s' % tokey(name)) - self.datasets[name] = {'data': data, 'keys': keys} - return {'status': 'OK', 'name': name} + self.scheduler.client_desires_keys(keys, "published-%s" % tokey(name)) + self.datasets[name] = {"data": data, "keys": keys} + return {"status": "OK", "name": name} def delete(self, stream=None, name=None): with log_errors(): - out = self.datasets.pop(name, {'keys': []}) - self.scheduler.client_releases_keys(out['keys'], 'published-%s' % tokey(name)) + out = self.datasets.pop(name, {"keys": []}) + self.scheduler.client_releases_keys( + out["keys"], "published-%s" % tokey(name) + ) def list(self, *args): with log_errors(): @@ -53,6 +57,7 @@ class Datasets(MutableMapping): client : distributed.client.Client """ + def __init__(self, client): self.__client = client diff --git a/distributed/pubsub.py b/distributed/pubsub.py index cdbe9e95a7a..5e086492923 100644 --- a/distributed/pubsub.py +++ b/distributed/pubsub.py @@ -16,46 +16,51 @@ class PubSubSchedulerExtension(object): """ Extend Dask's scheduler with routes to handle PubSub machinery """ + def __init__(self, scheduler): self.scheduler = scheduler self.publishers = defaultdict(set) self.subscribers = defaultdict(set) self.client_subscribers = defaultdict(set) - self.scheduler.handlers.update({ - 'pubsub_add_publisher': self.add_publisher, - }) + self.scheduler.handlers.update({"pubsub_add_publisher": self.add_publisher}) - self.scheduler.stream_handlers.update({ - 'pubsub-add-subscriber': self.add_subscriber, - 'pubsub-remove-publisher': self.remove_publisher, - 'pubsub-remove-subscriber': self.remove_subscriber, - 'pubsub-msg': self.handle_message, - }) + self.scheduler.stream_handlers.update( + { + "pubsub-add-subscriber": self.add_subscriber, + "pubsub-remove-publisher": self.remove_publisher, + "pubsub-remove-subscriber": self.remove_subscriber, + "pubsub-msg": self.handle_message, + } + ) - self.scheduler.extensions['pubsub'] = self + self.scheduler.extensions["pubsub"] = self def add_publisher(self, comm=None, name=None, worker=None): logger.debug("Add publisher: %s %s", name, worker) self.publishers[name].add(worker) - return {'subscribers': {addr: {} for addr in self.subscribers[name]}, - 'publish-scheduler': name in self.client_subscribers and - len(self.client_subscribers[name]) > 0} + return { + "subscribers": {addr: {} for addr in self.subscribers[name]}, + "publish-scheduler": name in self.client_subscribers + and len(self.client_subscribers[name]) > 0, + } def add_subscriber(self, comm=None, name=None, worker=None, client=None): if worker: logger.debug("Add worker subscriber: %s %s", name, worker) self.subscribers[name].add(worker) for pub in self.publishers[name]: - self.scheduler.worker_send(pub, {'op': 'pubsub-add-subscriber', - 'address': worker, - 'name': name}) + self.scheduler.worker_send( + pub, + {"op": "pubsub-add-subscriber", "address": worker, "name": name}, + ) elif client: logger.debug("Add client subscriber: %s %s", name, client) for pub in self.publishers[name]: - self.scheduler.worker_send(pub, {'op': 'pubsub-publish-scheduler', - 'name': name, - 'publish': True}) + self.scheduler.worker_send( + pub, + {"op": "pubsub-publish-scheduler", "name": name, "publish": True}, + ) self.client_subscribers[name].add(client) def remove_publisher(self, comm=None, name=None, worker=None): @@ -72,18 +77,24 @@ def remove_subscriber(self, comm=None, name=None, worker=None, client=None): logger.debug("Add worker subscriber: %s %s", name, worker) self.subscribers[name].remove(worker) for pub in self.publishers[name]: - self.scheduler.worker_send(pub, {'op': 'pubsub-remove-subscriber', - 'address': worker, - 'name': name}) + self.scheduler.worker_send( + pub, + {"op": "pubsub-remove-subscriber", "address": worker, "name": name}, + ) elif client: logger.debug("Add client subscriber: %s %s", name, client) self.client_subscribers[name].remove(client) if not self.client_subscribers[name]: del self.client_subscribers[name] for pub in self.publishers[name]: - self.scheduler.worker_send(pub, {'op': 'pubsub-publish-scheduler', - 'name': name, - 'publish': False}) + self.scheduler.worker_send( + pub, + { + "op": "pubsub-publish-scheduler", + "name": name, + "publish": False, + }, + ) if not self.subscribers[name] and not self.publishers[name]: logger.debug("Remove PubSub topic %s", name) @@ -93,35 +104,38 @@ def remove_subscriber(self, comm=None, name=None, worker=None, client=None): def handle_message(self, name=None, msg=None, worker=None, client=None): for c in list(self.client_subscribers[name]): try: - self.scheduler.client_comms[c].send({'op': 'pubsub-msg', - 'name': name, - 'msg': msg}) + self.scheduler.client_comms[c].send( + {"op": "pubsub-msg", "name": name, "msg": msg} + ) except (KeyError, CommClosedError): self.remove_subscriber(name=name, client=c) if client: for sub in self.subscribers[name]: - self.scheduler.worker_send(sub, {'op': 'pubsub-msg', - 'name': name, - 'msg': msg}) + self.scheduler.worker_send( + sub, {"op": "pubsub-msg", "name": name, "msg": msg} + ) class PubSubWorkerExtension(object): """ Extend Dask's Worker with routes to handle PubSub machinery """ + def __init__(self, worker): self.worker = worker - self.worker.stream_handlers.update({ - 'pubsub-add-subscriber': self.add_subscriber, - 'pubsub-remove-subscriber': self.remove_subscriber, - 'pubsub-msg': self.handle_message, - 'pubsub-publish-scheduler': self.publish_scheduler, - }) + self.worker.stream_handlers.update( + { + "pubsub-add-subscriber": self.add_subscriber, + "pubsub-remove-subscriber": self.remove_subscriber, + "pubsub-msg": self.handle_message, + "pubsub-publish-scheduler": self.publish_scheduler, + } + ) self.subscribers = defaultdict(weakref.WeakSet) self.publishers = defaultdict(weakref.WeakSet) self.publish_to_scheduler = defaultdict(lambda: False) - self.worker.extensions['pubsub'] = self # circular reference + self.worker.extensions["pubsub"] = self # circular reference def add_subscriber(self, name=None, address=None, **info): for pub in self.publishers[name]: @@ -144,15 +158,13 @@ def trigger_cleanup(self): def cleanup(self): for name, s in dict(self.subscribers).items(): if not len(s): - msg = {'op': 'pubsub-remove-subscriber', - 'name': name} + msg = {"op": "pubsub-remove-subscriber", "name": name} self.worker.batched_stream.send(msg) del self.subscribers[name] for name, p in dict(self.publishers).items(): if not len(p): - msg = {'op': 'pubsub-remove-publisher', - 'name': name} + msg = {"op": "pubsub-remove-publisher", "name": name} self.worker.batched_stream.send(msg) del self.publishers[name] del self.publish_to_scheduler[name] @@ -160,22 +172,22 @@ def cleanup(self): class PubSubClientExtension(object): """ Extend Dask's Client with handlers to handle PubSub machinery """ + def __init__(self, client): self.client = client - self.client._stream_handlers.update({ - 'pubsub-msg': self.handle_message - }) + self.client._stream_handlers.update({"pubsub-msg": self.handle_message}) self.subscribers = defaultdict(weakref.WeakSet) - self.client.extensions['pubsub'] = self # TODO: circular reference + self.client.extensions["pubsub"] = self # TODO: circular reference def handle_message(self, name=None, msg=None): for sub in self.subscribers[name]: sub._put(msg) if not self.subscribers[name]: - self.client.scheduler_comm.send({'op': 'pubsub-remove-subscribers', - 'name': name}) + self.client.scheduler_comm.send( + {"op": "pubsub-remove-subscribers", "name": name} + ) def trigger_cleanup(self): self.client.loop.add_callback(self.cleanup) @@ -183,8 +195,7 @@ def trigger_cleanup(self): def cleanup(self): for name, s in self.subscribers.items(): if not s: - msg = {'op': 'pubsub-remove-subscriber', - 'name': name} + msg = {"op": "pubsub-remove-subscriber", "name": name} self.client.scheduler_comm.send(msg) @@ -265,9 +276,11 @@ class Pub(object): -------- Sub """ + def __init__(self, name, worker=None, client=None): if worker is None and client is None: from distributed import get_worker, get_client + try: worker = get_worker() except Exception: @@ -291,7 +304,7 @@ def __init__(self, name, worker=None, client=None): self.loop.add_callback(self._start) if self.worker: - pubsub = self.worker.extensions['pubsub'] + pubsub = self.worker.extensions["pubsub"] self.loop.add_callback(pubsub.publishers[name].add, self) finalize(self, pubsub.trigger_cleanup) @@ -299,12 +312,11 @@ def __init__(self, name, worker=None, client=None): def _start(self): if self.worker: result = yield self.scheduler.pubsub_add_publisher( - name=self.name, - worker=self.worker.address + name=self.name, worker=self.worker.address ) - pubsub = self.worker.extensions['pubsub'] - self.subscribers.update(result['subscribers']) - pubsub.publish_to_scheduler[self.name] = result['publish-scheduler'] + pubsub = self.worker.extensions["pubsub"] + self.subscribers.update(result["subscribers"]) + pubsub.publish_to_scheduler[self.name] = result["publish-scheduler"] self._started = True @@ -317,13 +329,13 @@ def _put(self, msg): self._buffer.append(msg) return - data = {'op': 'pubsub-msg', 'name': self.name, 'msg': to_serialize(msg)} + data = {"op": "pubsub-msg", "name": self.name, "msg": to_serialize(msg)} if self.worker: for sub in self.subscribers: self.worker.send_to_worker(sub, data) - if self.worker.extensions['pubsub'].publish_to_scheduler[self.name]: + if self.worker.extensions["pubsub"].publish_to_scheduler[self.name]: self.worker.batched_stream.send(data) elif self.client: self.client.scheduler_comm.send(data) @@ -340,9 +352,11 @@ class Sub(object): -------- Pub: for full docstring """ + def __init__(self, name, worker=None, client=None): if worker is None and client is None: from distributed.worker import get_worker, get_client + try: worker = get_worker() except Exception: @@ -359,12 +373,12 @@ def __init__(self, name, worker=None, client=None): self.condition = tornado.locks.Condition() if self.worker: - pubsub = self.worker.extensions['pubsub'] + pubsub = self.worker.extensions["pubsub"] elif self.client: - pubsub = self.client.extensions['pubsub'] + pubsub = self.client.extensions["pubsub"] self.loop.add_callback(pubsub.subscribers[name].add, self) - msg = {'op': 'pubsub-add-subscriber', 'name': self.name} + msg = {"op": "pubsub-add-subscriber", "name": self.name} if self.worker: self.loop.add_callback(self.worker.batched_stream.send, msg) elif self.client: diff --git a/distributed/pytest_resourceleaks.py b/distributed/pytest_resourceleaks.py index 29331e1b404..bb62d3916d0 100644 --- a/distributed/pytest_resourceleaks.py +++ b/distributed/pytest_resourceleaks.py @@ -15,66 +15,70 @@ def pytest_addoption(parser): - group = parser.getgroup('resource leaks') + group = parser.getgroup("resource leaks") group.addoption( - '-L', '--leaks', - action='store', - dest='leaks', - help='''\ + "-L", + "--leaks", + action="store", + dest="leaks", + help="""\ List of resources to monitor for leaks before and after each test. Can be 'all' or a comma-separated list of resource names (possible values: {known_checkers}). -'''.format(known_checkers=', '.join(sorted("'%s'" % s for s in all_checkers))) +""".format( + known_checkers=", ".join(sorted("'%s'" % s for s in all_checkers)) + ), ) group.addoption( - '--leaks-timeout', - action='store', - type='float', - dest='leaks_timeout', + "--leaks-timeout", + action="store", + type="float", + dest="leaks_timeout", default=0.5, - help='''\ + help="""\ Wait at most this number of seconds to mark a test leaking (default: %(default)s). -''' +""", ) group.addoption( - '--leaks-fail', - action='store_true', - dest='leaks_mark_failed', + "--leaks-fail", + action="store_true", + dest="leaks_mark_failed", default=False, - help='''Mark leaked tests failed.''' + help="""Mark leaked tests failed.""", ) group.addoption( - '--leak-retries', - action='store', + "--leak-retries", + action="store", type=int, - dest='leak_retries', + dest="leak_retries", default=1, - help='''\ + help="""\ Max number of times to retry a test when it leaks, to ignore warmup-related issues (default: 1). -''' +""", ) def pytest_configure(config): - leaks = config.getvalue('leaks') + leaks = config.getvalue("leaks") if leaks: - if leaks == 'all': + if leaks == "all": leaks = sorted(all_checkers) else: - leaks = leaks.split(',') + leaks = leaks.split(",") unknown = sorted(set(leaks) - set(all_checkers)) if unknown: raise ValueError("unknown resources: %r" % (unknown,)) checkers = [all_checkers[leak]() for leak in leaks] - checker = LeakChecker(checkers=checkers, - grace_delay=config.getvalue('leaks_timeout'), - mark_failed=config.getvalue('leaks_mark_failed'), - max_retries=config.getvalue('leak_retries'), - ) - config.pluginmanager.register(checker, 'leaks_checker') + checker = LeakChecker( + checkers=checkers, + grace_delay=config.getvalue("leaks_timeout"), + mark_failed=config.getvalue("leaks_mark_failed"), + max_retries=config.getvalue("leak_retries"), + ) + config.pluginmanager.register(checker, "leaks_checker") all_checkers = {} @@ -91,7 +95,6 @@ def decorate(cls): class ResourceChecker(object): - def on_start_test(self): pass @@ -111,12 +114,12 @@ def format(self, before, after): raise NotImplementedError -@register_checker('fds') +@register_checker("fds") class FDChecker(ResourceChecker): - def measure(self): - if os.name == 'posix': + if os.name == "posix": import psutil + return psutil.Process().num_fds() else: return 0 @@ -128,11 +131,11 @@ def format(self, before, after): return "leaked %d file descriptor(s)" % (after - before) -@register_checker('memory') +@register_checker("memory") class RSSMemoryChecker(ResourceChecker): - def measure(self): import psutil + return psutil.Process().memory_info().rss def has_leak(self, before, after): @@ -142,9 +145,8 @@ def format(self, before, after): return "leaked %d MB of RSS memory" % ((after - before) / 1e6) -@register_checker('threads') +@register_checker("threads") class ActiveThreadsChecker(ResourceChecker): - def measure(self): return set(threading.enumerate()) @@ -154,23 +156,22 @@ def has_leak(self, before, after): def format(self, before, after): leaked = after - before assert leaked - return ("leaked %d Python threads: %s" - % (len(leaked), sorted(leaked, key=str))) - + return "leaked %d Python threads: %s" % (len(leaked), sorted(leaked, key=str)) -class _ChildProcess(collections.namedtuple('_ChildProcess', - ('pid', 'name', 'cmdline'))): +class _ChildProcess( + collections.namedtuple("_ChildProcess", ("pid", "name", "cmdline")) +): @classmethod def from_process(cls, p): return cls(p.pid, p.name(), p.cmdline()) -@register_checker('processes') +@register_checker("processes") class ChildProcessesChecker(ResourceChecker): - def measure(self): import psutil + # We use pid and creation time as keys to disambiguate between # processes (and protect against pid reuse) # Other properties such as cmdline may change for a given process @@ -181,12 +182,18 @@ def measure(self): with c.oneshot(): if c.ppid() == p.pid and os.path.samefile(c.exe(), sys.executable): cmdline = c.cmdline() - if any(a.startswith('from multiprocessing.semaphore_tracker import main') - for a in cmdline): + if any( + a.startswith( + "from multiprocessing.semaphore_tracker import main" + ) + for a in cmdline + ): # Skip multiprocessing semaphore tracker continue - if any(a.startswith('from multiprocessing.forkserver import main') - for a in cmdline): + if any( + a.startswith("from multiprocessing.forkserver import main") + for a in cmdline + ): # Skip forkserver process, the forkserver's children # however will be recorded normally continue @@ -204,15 +211,14 @@ def format(self, before, after): formatted = [] for key in sorted(leaked): p = after[key] - formatted.append(' - pid={p.pid}, name={p.name!r}, cmdline={p.cmdline!r}' - .format(p=p)) - return ("leaked %d processes:\n%s" - % (len(leaked), '\n'.join(formatted))) + formatted.append( + " - pid={p.pid}, name={p.name!r}, cmdline={p.cmdline!r}".format(p=p) + ) + return "leaked %d processes:\n%s" % (len(leaked), "\n".join(formatted)) -@register_checker('tracemalloc') +@register_checker("tracemalloc") class TracemallocMemoryChecker(ResourceChecker): - def __init__(self): global tracemalloc import tracemalloc @@ -225,6 +231,7 @@ def on_stop_test(self): def measure(self): import tracemalloc + current, peak = tracemalloc.get_traced_memory() snap = tracemalloc.take_snapshot() return current, snap @@ -235,13 +242,15 @@ def has_leak(self, before, after): def format(self, before, after): bytes_before, snap_before = before bytes_after, snap_after = after - diff = snap_after.compare_to(snap_before, 'traceback') + diff = snap_after.compare_to(snap_before, "traceback") ndiff = 5 min_size_diff = 2e5 lines = [] - lines += ["leaked %.1f MB of traced Python memory" - % ((bytes_after - bytes_before) / 1e6)] + lines += [ + "leaked %.1f MB of traced Python memory" + % ((bytes_after - bytes_before) / 1e6) + ] for stat in diff[:ndiff]: size_diff = stat.size_diff or stat.size if size_diff < min_size_diff: @@ -276,8 +285,7 @@ def cleanup(self): gc.collect() def checks_for_item(self, nodeid): - return [c for c in self.checkers - if c not in self.skip_checkers.get(nodeid, ())] + return [c for c in self.checkers if c not in self.skip_checkers.get(nodeid, ())] def measure(self, nodeid): # Return items in order @@ -293,7 +301,7 @@ def measure_before_test(self, nodeid): def measure_after_test(self, nodeid): outcomes = self.outcomes[nodeid] assert outcomes - if outcomes != {'passed'}: + if outcomes != {"passed"}: # Test failed or skipped return @@ -337,6 +345,7 @@ def run_test_again(): # This invokes our setup/teardown hooks again # Inspired by https://pypi.python.org/pypi/pytest-rerunfailures from _pytest.runner import runtestprotocol + item._initrequest() # Re-init fixtures reports = runtestprotocol(item, nextitem=nextitem, log=False) @@ -350,6 +359,7 @@ def run_test_again(): except Exception as e: print("--- Exception when re-running test ---") import traceback + traceback.print_exc() else: leaks = self.leaks.get(nodeid) @@ -376,15 +386,17 @@ def pytest_runtest_protocol(self, item, nextitem): assert nodeid not in self.counters self.counters[nodeid] = {c: [] for c in self.checkers} - leaking = item.get_marker('leaking') + leaking = item.get_marker("leaking") if leaking is not None: unknown = sorted(set(leaking.args) - set(all_checkers)) if unknown: - raise ValueError("pytest.mark.leaking: unknown resources %r" - % (unknown,)) + raise ValueError( + "pytest.mark.leaking: unknown resources %r" % (unknown,) + ) classes = tuple(all_checkers[a] for a in leaking.args) - self.skip_checkers[nodeid] = {c for c in self.checkers - if isinstance(c, classes)} + self.skip_checkers[nodeid] = { + c for c in self.checkers if isinstance(c, classes) + } yield @@ -410,29 +422,31 @@ def pytest_report_teststatus(self, report): outcomes.add(report.outcome) outcome = yield if not self._retrying: - if report.when == 'teardown': + if report.when == "teardown": leaks = self.leaks.get(report.nodeid) if leaks: if self.mark_failed: - outcome.force_result(('failed', 'L', 'LEAKED')) - report.outcome = 'failed' + outcome.force_result(("failed", "L", "LEAKED")) + report.outcome = "failed" report.longrepr = "\n".join( - ["%s %s" % (nodeid, checker.format(before, after)) - for checker, before, after in leaks]) + [ + "%s %s" % (nodeid, checker.format(before, after)) + for checker, before, after in leaks + ] + ) else: - outcome.force_result(('leaked', 'L', 'LEAKED')) + outcome.force_result(("leaked", "L", "LEAKED")) # XXX should we log retried tests @pytest.hookimpl def pytest_terminal_summary(self, terminalreporter, exitstatus): tr = terminalreporter - leaked = tr.getreports('leaked') + leaked = tr.getreports("leaked") if leaked: # If mark_failed is False, leaks are output as a separate # results section - tr.write_sep("=", 'RESOURCE LEAKS') + tr.write_sep("=", "RESOURCE LEAKS") for rep in leaked: nodeid = rep.nodeid for checker, before, after in self.leaks[nodeid]: - tr.line("%s %s" - % (rep.nodeid, checker.format(before, after))) + tr.line("%s %s" % (rep.nodeid, checker.format(before, after))) diff --git a/distributed/queues.py b/distributed/queues.py index fda6daaae5c..72f0f9fe52c 100644 --- a/distributed/queues.py +++ b/distributed/queues.py @@ -33,19 +33,20 @@ def __init__(self, scheduler): self.client_refcount = dict() self.future_refcount = defaultdict(lambda: 0) - self.scheduler.handlers.update({ - 'queue_create': self.create, - 'queue_put': self.put, - 'queue_get': self.get, - 'queue_qsize': self.qsize} + self.scheduler.handlers.update( + { + "queue_create": self.create, + "queue_put": self.put, + "queue_get": self.get, + "queue_qsize": self.qsize, + } ) - self.scheduler.stream_handlers.update({ - 'queue-future-release': self.future_release, - 'queue_release': self.release, - }) + self.scheduler.stream_handlers.update( + {"queue-future-release": self.future_release, "queue_release": self.release} + ) - self.scheduler.extensions['queues'] = self + self.scheduler.extensions["queues"] = self def create(self, stream=None, name=None, client=None, maxsize=0): if name not in self.queues: @@ -64,18 +65,20 @@ def release(self, stream=None, name=None, client=None): futures = self.queues[name]._queue del self.queues[name] self.scheduler.client_releases_keys( - keys=[d['value'] for d in futures if d['type'] == 'Future'], - client='queue-%s' % name + keys=[d["value"] for d in futures if d["type"] == "Future"], + client="queue-%s" % name, ) @gen.coroutine - def put(self, stream=None, name=None, key=None, data=None, client=None, timeout=None): + def put( + self, stream=None, name=None, key=None, data=None, client=None, timeout=None + ): if key is not None: - record = {'type': 'Future', 'value': key} + record = {"type": "Future", "value": key} self.future_refcount[name, key] += 1 - self.scheduler.client_desires_keys(keys=[key], client='queue-%s' % name) + self.scheduler.client_desires_keys(keys=[key], client="queue-%s" % name) else: - record = {'type': 'msgpack', 'value': data} + record = {"type": "msgpack", "value": data} if timeout is not None: timeout = datetime.timedelta(seconds=(timeout)) yield self.queues[name].put(record, timeout=timeout) @@ -83,25 +86,23 @@ def put(self, stream=None, name=None, key=None, data=None, client=None, timeout= def future_release(self, name=None, key=None, client=None): self.future_refcount[name, key] -= 1 if self.future_refcount[name, key] == 0: - self.scheduler.client_releases_keys(keys=[key], - client='queue-%s' % name) + self.scheduler.client_releases_keys(keys=[key], client="queue-%s" % name) del self.future_refcount[name, key] @gen.coroutine - def get(self, stream=None, name=None, client=None, timeout=None, - batch=False): + def get(self, stream=None, name=None, client=None, timeout=None, batch=False): def process(record): """ Add task status if known """ - if record['type'] == 'Future': + if record["type"] == "Future": record = record.copy() - key = record['value'] + key = record["value"] ts = self.scheduler.tasks.get(key) - state = ts.state if ts is not None else 'lost' + state = ts.state if ts is not None else "lost" - record['state'] = state - if state == 'erred': - record['exception'] = ts.exception_blame.exception - record['traceback'] = ts.exception_blame.traceback + record["state"] = state + if state == "erred": + record["exception"] = ts.exception_blame.exception + record["traceback"] = ts.exception_blame.traceback return record @@ -114,8 +115,10 @@ def process(record): out.append(record) else: if timeout is not None: - msg = ("Dask queues don't support simultaneous use of " - "integer batch sizes and timeouts") + msg = ( + "Dask queues don't support simultaneous use of " + "integer batch sizes and timeouts" + ) raise NotImplementedError(msg) for i in range(batch): record = yield q.get() @@ -164,13 +167,20 @@ class Queue(object): def __init__(self, name=None, client=None, maxsize=0): self.client = client or _get_global_client() - self.name = name or 'queue-' + uuid.uuid4().hex - if self.client.asynchronous or getattr(thread_state, 'on_event_loop_thread', False): - self._started = self.client.scheduler.queue_create(name=self.name, - maxsize=maxsize) + self.name = name or "queue-" + uuid.uuid4().hex + if self.client.asynchronous or getattr( + thread_state, "on_event_loop_thread", False + ): + self._started = self.client.scheduler.queue_create( + name=self.name, maxsize=maxsize + ) else: - sync(self.client.loop, self.client.scheduler.queue_create, - name=self.name, maxsize=maxsize) + sync( + self.client.loop, + self.client.scheduler.queue_create, + name=self.name, + maxsize=maxsize, + ) self._started = gen.moment def __await__(self): @@ -178,18 +188,19 @@ def __await__(self): def _(): yield self._started raise gen.Return(self) + return _().__await__() @gen.coroutine def _put(self, value, timeout=None): if isinstance(value, Future): - yield self.client.scheduler.queue_put(key=tokey(value.key), - timeout=timeout, - name=self.name) + yield self.client.scheduler.queue_put( + key=tokey(value.key), timeout=timeout, name=self.name + ) else: - yield self.client.scheduler.queue_put(data=value, - timeout=timeout, - name=self.name) + yield self.client.scheduler.queue_put( + data=value, timeout=timeout, name=self.name + ) def put(self, value, timeout=None, **kwargs): """ Put data into the queue """ @@ -207,8 +218,7 @@ def get(self, timeout=None, batch=False, **kwargs): If an integer than return that many elements from the queue If False (default) then return one item at a time """ - return self.client.sync(self._get, timeout=timeout, batch=batch, - **kwargs) + return self.client.sync(self._get, timeout=timeout, batch=batch, **kwargs) def qsize(self, **kwargs): """ Current number of elements in the queue """ @@ -216,21 +226,20 @@ def qsize(self, **kwargs): @gen.coroutine def _get(self, timeout=None, batch=False): - resp = yield self.client.scheduler.queue_get(timeout=timeout, - name=self.name, - batch=batch) + resp = yield self.client.scheduler.queue_get( + timeout=timeout, name=self.name, batch=batch + ) def process(d): - if d['type'] == 'Future': - value = Future(d['value'], self.client, inform=True, - state=d['state']) - if d['state'] == 'erred': - value._state.set_error(d['exception'], d['traceback']) - self.client._send_to_scheduler({'op': 'queue-future-release', - 'name': self.name, - 'key': d['value']}) + if d["type"] == "Future": + value = Future(d["value"], self.client, inform=True, state=d["state"]) + if d["state"] == "erred": + value._state.set_error(d["exception"], d["traceback"]) + self.client._send_to_scheduler( + {"op": "queue-future-release", "name": self.name, "key": d["value"]} + ) else: - value = d['value'] + value = d["value"] return value @@ -247,9 +256,8 @@ def _qsize(self): raise gen.Return(result) def close(self): - if self.client.status == 'running': # TODO: can leave zombie futures - self.client._send_to_scheduler({'op': 'queue_release', - 'name': self.name}) + if self.client.status == "running": # TODO: can leave zombie futures + self.client._send_to_scheduler({"op": "queue_release", "name": self.name}) def __getstate__(self): return (self.name, self.client.scheduler.address) diff --git a/distributed/recreate_exceptions.py b/distributed/recreate_exceptions.py index cd252af8ac6..d5351bb4d59 100644 --- a/distributed/recreate_exceptions.py +++ b/distributed/recreate_exceptions.py @@ -20,8 +20,8 @@ class ReplayExceptionScheduler(object): def __init__(self, scheduler): self.scheduler = scheduler - self.scheduler.handlers['cause_of_failure'] = self.cause_of_failure - self.scheduler.extensions['exceptions'] = self + self.scheduler.handlers["cause_of_failure"] = self.cause_of_failure + self.scheduler.extensions["exceptions"] = self def cause_of_failure(self, *args, **kwargs): """ @@ -39,7 +39,7 @@ def cause_of_failure(self, *args, **kwargs): deps: keys that the task depends on """ - keys = kwargs.pop('keys', []) + keys = kwargs.pop("keys", []) for key in keys: if isinstance(key, list): key = tuple(key) # ensure not a list from msgpack @@ -48,9 +48,11 @@ def cause_of_failure(self, *args, **kwargs): if ts is not None and ts.exception_blame is not None: cause = ts.exception_blame # NOTE: cannot serialize sets - return {'deps': [dts.key for dts in cause.dependencies], - 'cause': cause.key, - 'task': cause.run_spec} + return { + "deps": [dts.key for dts in cause.dependencies], + "cause": cause.key, + "task": cause.run_spec, + } class ReplayExceptionClient(object): @@ -66,7 +68,7 @@ class ReplayExceptionClient(object): def __init__(self, client): self.client = client - self.client.extensions['exceptions'] = self + self.client.extensions["exceptions"] = self # monkey patch self.client.recreate_error_locally = self.recreate_error_locally self.client._recreate_error_locally = self._recreate_error_locally @@ -80,12 +82,11 @@ def scheduler(self): @gen.coroutine def _get_futures_error(self, future): # only get errors for futures that errored. - futures = [f for f in futures_of(future) if f.status == 'error'] + futures = [f for f in futures_of(future) if f.status == "error"] if not futures: raise ValueError("No errored futures passed") - out = yield self.scheduler.cause_of_failure( - keys=[f.key for f in futures]) - deps, task = out['deps'], out['task'] + out = yield self.scheduler.cause_of_failure(keys=[f.key for f in futures]) + deps, task = out["deps"], out["task"] if isinstance(task, dict): function, args, kwargs = _deserialize(**task) raise gen.Return((function, args, kwargs, deps)) @@ -177,6 +178,7 @@ def recreate_error_locally(self, future): Nothing; the function runs and should raise an exception, allowing the debugger to run. """ - func, args, kwargs = sync(self.client.loop, - self._recreate_error_locally, future) + func, args, kwargs = sync( + self.client.loop, self._recreate_error_locally, future + ) func(*args, **kwargs) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index cf26d0c955d..5e5e2843c2c 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -16,6 +16,7 @@ import psutil import sortedcontainers + try: from cytoolz import frequencies, merge, pluck, merge_sorted, first except ImportError: @@ -28,20 +29,34 @@ import dask from .batched import BatchedSend -from .comm import (normalize_address, resolve_address, - get_address_host, unparse_host_port) +from .comm import ( + normalize_address, + resolve_address, + get_address_host, + unparse_host_port, +) from .compatibility import finalize, unicode, Mapping, Set -from .core import (rpc, connect, send_recv, - clean_exception, CommClosedError) +from .core import rpc, connect, send_recv, clean_exception, CommClosedError from . import profile from .metrics import time from .node import ServerNode from .proctitle import setproctitle from .security import Security -from .utils import (All, ignoring, get_ip, get_fileno_limit, log_errors, - key_split, validate_key, no_default, DequeHandler, - parse_timedelta, PeriodicCallback, shutting_down) -from .utils_comm import (scatter_to_workers, gather_from_workers) +from .utils import ( + All, + ignoring, + get_ip, + get_fileno_limit, + log_errors, + key_split, + validate_key, + no_default, + DequeHandler, + parse_timedelta, + PeriodicCallback, + shutting_down, +) +from .utils_comm import scatter_to_workers, gather_from_workers from .utils_perf import enable_gc_diagnosis, disable_gc_diagnosis from .publish import PublishExtension @@ -56,11 +71,11 @@ logger = logging.getLogger(__name__) -BANDWIDTH = dask.config.get('distributed.scheduler.bandwidth') -ALLOWED_FAILURES = dask.config.get('distributed.scheduler.allowed-failures') +BANDWIDTH = dask.config.get("distributed.scheduler.bandwidth") +ALLOWED_FAILURES = dask.config.get("distributed.scheduler.allowed-failures") -LOG_PDB = dask.config.get('distributed.admin.pdb-on-err') -DEFAULT_DATA_SIZE = dask.config.get('distributed.scheduler.default-data-size') +LOG_PDB = dask.config.get("distributed.admin.pdb-on-err") +DEFAULT_DATA_SIZE = dask.config.get("distributed.scheduler.default-data-size") DEFAULT_EXTENSIONS = [ LockExtension, @@ -71,10 +86,10 @@ PubSubSchedulerExtension, ] -if dask.config.get('distributed.scheduler.work-stealing'): +if dask.config.get("distributed.scheduler.work-stealing"): DEFAULT_EXTENSIONS.append(WorkStealing) -ALL_TASK_STATES = {'released', 'waiting', 'no-worker', 'processing', 'erred', 'memory'} +ALL_TASK_STATES = {"released", "waiting", "no-worker", "processing", "erred", "memory"} class ClientState(object): @@ -97,11 +112,8 @@ class ClientState(object): collection) gets garbage-collected. """ - __slots__ = ( - 'client_key', - 'wants_what', - 'last_seen', - ) + + __slots__ = ("client_key", "wants_what", "last_seen") def __init__(self, client): self.client_key = client @@ -189,31 +201,40 @@ class WorkerState(object): actors to which this worker has a reference. """ + # XXX need a state field to signal active/removed? __slots__ = ( - 'actors', - 'address', - 'has_what', - 'last_seen', - 'local_directory', - 'memory_limit', - 'metrics', - 'name', - 'nbytes', - 'ncores', - 'occupancy', - 'pid', - 'processing', - 'resources', - 'services', - 'status', - 'time_delay', - 'used_resources', + "actors", + "address", + "has_what", + "last_seen", + "local_directory", + "memory_limit", + "metrics", + "name", + "nbytes", + "ncores", + "occupancy", + "pid", + "processing", + "resources", + "services", + "status", + "time_delay", + "used_resources", ) - def __init__(self, address=None, pid=0, name=None, ncores=0, memory_limit=0, - local_directory=None, services=None): + def __init__( + self, + address=None, + pid=0, + name=None, + ncores=0, + memory_limit=0, + local_directory=None, + services=None, + ): self.address = address self.pid = pid self.name = name @@ -222,7 +243,7 @@ def __init__(self, address=None, pid=0, name=None, ncores=0, memory_limit=0, self.local_directory = local_directory self.services = services or {} - self.status = 'running' + self.status = "running" self.nbytes = 0 self.occupancy = 0 self.metrics = {} @@ -240,25 +261,28 @@ def host(self): return get_address_host(self.address) def __repr__(self): - return "" % (self.address, - len(self.has_what), len(self.processing)) + return "" % ( + self.address, + len(self.has_what), + len(self.processing), + ) def __str__(self): return self.address def identity(self): return { - 'type': 'Worker', - 'id': self.name, - 'host': self.host, - 'resources': self.resources, - 'local_directory': self.local_directory, - 'name': self.name, - 'ncores': self.ncores, - 'memory_limit': self.memory_limit, - 'last_seen': self.last_seen, - 'services': self.services, - 'metrics': self.metrics + "type": "Worker", + "id": self.name, + "host": self.host, + "resources": self.resources, + "local_directory": self.local_directory, + "name": self.name, + "ncores": self.ncores, + "memory_limit": self.memory_limit, + "last_seen": self.last_seen, + "services": self.services, + "metrics": self.metrics, } @@ -486,47 +510,48 @@ class TaskState(object): Whether or not this task is an Actor. """ + __slots__ = ( # === General description === - 'actor', + "actor", # Key name - 'key', + "key", # Key prefix (see key_split()) - 'prefix', + "prefix", # How to run the task (None if pure data) - 'run_spec', + "run_spec", # Alive dependents and dependencies - 'dependencies', - 'dependents', + "dependencies", + "dependents", # Compute priority - 'priority', + "priority", # Restrictions - 'host_restrictions', - 'worker_restrictions', # not WorkerStates but addresses - 'resource_restrictions', - 'loose_restrictions', + "host_restrictions", + "worker_restrictions", # not WorkerStates but addresses + "resource_restrictions", + "loose_restrictions", # === Task state === - 'state', + "state", # Whether some dependencies were forgotten - 'has_lost_dependencies', + "has_lost_dependencies", # If in 'waiting' state, which tasks need to complete # before we can run - 'waiting_on', + "waiting_on", # If in 'waiting' or 'processing' state, which tasks needs us # to complete before they can run - 'waiters', + "waiters", # In in 'processing' state, which worker we are processing on - 'processing_on', + "processing_on", # If in 'memory' state, Which workers have us - 'who_has', + "who_has", # Which clients want us - 'who_wants', - 'exception', - 'traceback', - 'exception_blame', - 'suspicious', - 'retries', - 'nbytes', + "who_wants", + "exception", + "traceback", + "exception_blame", + "suspicious", + "retries", + "nbytes", ) def __init__(self, key, run_spec): @@ -581,6 +606,7 @@ def validate(self): logger.exception(e) if LOG_PDB: import pdb + pdb.set_trace() @@ -611,6 +637,7 @@ class _OptionalStateLegacyMapping(_StateLegacyMapping): Similar to _StateLegacyMapping, but a false-y value is interpreted as a missing key. """ + # For tasks etc. def __iter__(self): @@ -636,6 +663,7 @@ class _StateLegacySet(Set): Similar to _StateLegacyMapping, but exposes a set containing all values with a true value. """ + # For loose_restrictions def __init__(self, states, accessor): @@ -759,21 +787,23 @@ class Scheduler(ServerNode): * **coroutines:** ``[Futures]``: A list of active futures that control operation """ + default_port = 8786 def __init__( - self, - loop=None, - delete_interval='500ms', - synchronize_worker_interval='60s', - services=None, - allowed_failures=ALLOWED_FAILURES, - extensions=None, - validate=False, - scheduler_file=None, - security=None, - worker_ttl=None, - **kwargs): + self, + loop=None, + delete_interval="500ms", + synchronize_worker_interval="60s", + services=None, + allowed_failures=ALLOWED_FAILURES, + extensions=None, + validate=False, + scheduler_file=None, + security=None, + worker_ttl=None, + **kwargs + ): self._setup_logging() @@ -782,19 +812,21 @@ def __init__( self.validate = validate self.status = None self.proc = psutil.Process() - self.delete_interval = parse_timedelta(delete_interval, default='ms') - self.synchronize_worker_interval = parse_timedelta(synchronize_worker_interval, default='ms') + self.delete_interval = parse_timedelta(delete_interval, default="ms") + self.synchronize_worker_interval = parse_timedelta( + synchronize_worker_interval, default="ms" + ) self.digests = None self.service_specs = services or {} self.services = {} self.scheduler_file = scheduler_file - worker_ttl = worker_ttl or dask.config.get('distributed.scheduler.worker-ttl') + worker_ttl = worker_ttl or dask.config.get("distributed.scheduler.worker-ttl") self.worker_ttl = parse_timedelta(worker_ttl) if worker_ttl else None self.security = security or Security() assert isinstance(self.security, Security) - self.connection_args = self.security.get_connection_args('scheduler') - self.listen_args = self.security.get_listen_args('scheduler') + self.connection_args = self.security.get_connection_args("scheduler") + self.listen_args = self.security.get_listen_args("scheduler") # Communication state self.loop = loop or IOLoop.current() @@ -807,42 +839,43 @@ def __init__( # Task state self.tasks = dict() for old_attr, new_attr, wrap in [ - ('priority', 'priority', None), - ('dependencies', 'dependencies', _legacy_task_key_set), - ('dependents', 'dependents', _legacy_task_key_set), - ('retries', 'retries', None)]: + ("priority", "priority", None), + ("dependencies", "dependencies", _legacy_task_key_set), + ("dependents", "dependents", _legacy_task_key_set), + ("retries", "retries", None), + ]: func = operator.attrgetter(new_attr) if wrap is not None: func = compose(wrap, func) - setattr(self, old_attr, - _StateLegacyMapping(self.tasks, func)) + setattr(self, old_attr, _StateLegacyMapping(self.tasks, func)) for old_attr, new_attr, wrap in [ - ('nbytes', 'nbytes', None), - ('who_wants', 'who_wants', _legacy_client_key_set), - ('who_has', 'who_has', _legacy_worker_key_set), - ('waiting', 'waiting_on', _legacy_task_key_set), - ('waiting_data', 'waiters', _legacy_task_key_set), - ('rprocessing', 'processing_on', None), - ('host_restrictions', 'host_restrictions', None), - ('worker_restrictions', 'worker_restrictions', None), - ('resource_restrictions', 'resource_restrictions', None), - ('suspicious_tasks', 'suspicious', None), - ('exceptions', 'exception', None), - ('tracebacks', 'traceback', None), - ('exceptions_blame', 'exception_blame', _task_key_or_none)]: + ("nbytes", "nbytes", None), + ("who_wants", "who_wants", _legacy_client_key_set), + ("who_has", "who_has", _legacy_worker_key_set), + ("waiting", "waiting_on", _legacy_task_key_set), + ("waiting_data", "waiters", _legacy_task_key_set), + ("rprocessing", "processing_on", None), + ("host_restrictions", "host_restrictions", None), + ("worker_restrictions", "worker_restrictions", None), + ("resource_restrictions", "resource_restrictions", None), + ("suspicious_tasks", "suspicious", None), + ("exceptions", "exception", None), + ("tracebacks", "traceback", None), + ("exceptions_blame", "exception_blame", _task_key_or_none), + ]: func = operator.attrgetter(new_attr) if wrap is not None: func = compose(wrap, func) - setattr(self, old_attr, - _OptionalStateLegacyMapping(self.tasks, func)) + setattr(self, old_attr, _OptionalStateLegacyMapping(self.tasks, func)) - for old_attr, new_attr, wrap in [('loose_restrictions', 'loose_restrictions', None)]: + for old_attr, new_attr, wrap in [ + ("loose_restrictions", "loose_restrictions", None) + ]: func = operator.attrgetter(new_attr) if wrap is not None: func = compose(wrap, func) - setattr(self, old_attr, - _StateLegacySet(self.tasks, func)) + setattr(self, old_attr, _StateLegacySet(self.tasks, func)) self.generation = 0 self._last_client = None @@ -859,32 +892,33 @@ def __init__( # Client state self.clients = dict() - for old_attr, new_attr, wrap in [('wants_what', 'wants_what', _legacy_task_key_set)]: + for old_attr, new_attr, wrap in [ + ("wants_what", "wants_what", _legacy_task_key_set) + ]: func = operator.attrgetter(new_attr) if wrap is not None: func = compose(wrap, func) - setattr(self, old_attr, - _StateLegacyMapping(self.clients, func)) - self.clients['fire-and-forget'] = ClientState('fire-and-forget') + setattr(self, old_attr, _StateLegacyMapping(self.clients, func)) + self.clients["fire-and-forget"] = ClientState("fire-and-forget") # Worker state self.workers = sortedcontainers.SortedDict() for old_attr, new_attr, wrap in [ - ('ncores', 'ncores', None), - ('worker_bytes', 'nbytes', None), - ('worker_resources', 'resources', None), - ('used_resources', 'used_resources', None), - ('occupancy', 'occupancy', None), - ('worker_info', 'metrics', None), - ('processing', 'processing', _legacy_task_key_dict), - ('has_what', 'has_what', _legacy_task_key_set)]: + ("ncores", "ncores", None), + ("worker_bytes", "nbytes", None), + ("worker_resources", "resources", None), + ("used_resources", "used_resources", None), + ("occupancy", "occupancy", None), + ("worker_info", "metrics", None), + ("processing", "processing", _legacy_task_key_dict), + ("has_what", "has_what", _legacy_task_key_set), + ]: func = operator.attrgetter(new_attr) if wrap is not None: func = compose(wrap, func) - setattr(self, old_attr, - _StateLegacyMapping(self.workers, func)) + setattr(self, old_attr, _StateLegacyMapping(self.workers, func)) - self.idle = sortedcontainers.SortedSet(key=operator.attrgetter('address')) + self.idle = sortedcontainers.SortedSet(key=operator.attrgetter("address")) self.saturated = set() self.total_ncores = 0 @@ -895,91 +929,99 @@ def __init__( self._task_state_collections = [self.unrunnable] - self._worker_collections = [self.workers, self.host_info, - self.resources, self.aliases] + self._worker_collections = [ + self.workers, + self.host_info, + self.resources, + self.aliases, + ] self.extensions = {} self.plugins = [] - self.transition_log = deque(maxlen=dask.config.get('distributed.scheduler.transition-log-length')) - self.log = deque(maxlen=dask.config.get('distributed.scheduler.transition-log-length')) + self.transition_log = deque( + maxlen=dask.config.get("distributed.scheduler.transition-log-length") + ) + self.log = deque( + maxlen=dask.config.get("distributed.scheduler.transition-log-length") + ) self.worker_setups = [] worker_handlers = { - 'task-finished': self.handle_task_finished, - 'task-erred': self.handle_task_erred, - 'release': self.handle_release_data, - 'release-worker-data': self.release_worker_data, - 'add-keys': self.add_keys, - 'missing-data': self.handle_missing_data, - 'long-running': self.handle_long_running, - 'reschedule': self.reschedule + "task-finished": self.handle_task_finished, + "task-erred": self.handle_task_erred, + "release": self.handle_release_data, + "release-worker-data": self.release_worker_data, + "add-keys": self.add_keys, + "missing-data": self.handle_missing_data, + "long-running": self.handle_long_running, + "reschedule": self.reschedule, } client_handlers = { - 'update-graph': self.update_graph, - 'client-desires-keys': self.client_desires_keys, - 'update-data': self.update_data, - 'report-key': self.report_on_key, - 'client-releases-keys': self.client_releases_keys, - 'heartbeat-client': self.client_heartbeat, - 'close-client': self.remove_client, - 'restart': self.restart + "update-graph": self.update_graph, + "client-desires-keys": self.client_desires_keys, + "update-data": self.update_data, + "report-key": self.report_on_key, + "client-releases-keys": self.client_releases_keys, + "heartbeat-client": self.client_heartbeat, + "close-client": self.remove_client, + "restart": self.restart, } self.handlers = { - 'register-client': self.add_client, - 'scatter': self.scatter, - 'register-worker': self.add_worker, - 'unregister': self.remove_worker, - 'gather': self.gather, - 'cancel': self.stimulus_cancel, - 'retry': self.stimulus_retry, - 'feed': self.feed, - 'terminate': self.close, - 'broadcast': self.broadcast, - 'proxy': self.proxy, - 'ncores': self.get_ncores, - 'has_what': self.get_has_what, - 'who_has': self.get_who_has, - 'processing': self.get_processing, - 'call_stack': self.get_call_stack, - 'profile': self.get_profile, - 'logs': self.get_logs, - 'worker_logs': self.get_worker_logs, - 'nbytes': self.get_nbytes, - 'versions': self.versions, - 'add_keys': self.add_keys, - 'rebalance': self.rebalance, - 'replicate': self.replicate, - 'start_ipython': self.start_ipython, - 'run_function': self.run_function, - 'update_data': self.update_data, - 'set_resources': self.add_resources, - 'retire_workers': self.retire_workers, - 'get_metadata': self.get_metadata, - 'set_metadata': self.set_metadata, - 'heartbeat_worker': self.heartbeat_worker, - 'get_task_status': self.get_task_status, - 'get_task_stream': self.get_task_stream, - 'register_worker_callbacks': self.register_worker_callbacks + "register-client": self.add_client, + "scatter": self.scatter, + "register-worker": self.add_worker, + "unregister": self.remove_worker, + "gather": self.gather, + "cancel": self.stimulus_cancel, + "retry": self.stimulus_retry, + "feed": self.feed, + "terminate": self.close, + "broadcast": self.broadcast, + "proxy": self.proxy, + "ncores": self.get_ncores, + "has_what": self.get_has_what, + "who_has": self.get_who_has, + "processing": self.get_processing, + "call_stack": self.get_call_stack, + "profile": self.get_profile, + "logs": self.get_logs, + "worker_logs": self.get_worker_logs, + "nbytes": self.get_nbytes, + "versions": self.versions, + "add_keys": self.add_keys, + "rebalance": self.rebalance, + "replicate": self.replicate, + "start_ipython": self.start_ipython, + "run_function": self.run_function, + "update_data": self.update_data, + "set_resources": self.add_resources, + "retire_workers": self.retire_workers, + "get_metadata": self.get_metadata, + "set_metadata": self.set_metadata, + "heartbeat_worker": self.heartbeat_worker, + "get_task_status": self.get_task_status, + "get_task_stream": self.get_task_stream, + "register_worker_callbacks": self.register_worker_callbacks, } self._transitions = { - ('released', 'waiting'): self.transition_released_waiting, - ('waiting', 'released'): self.transition_waiting_released, - ('waiting', 'processing'): self.transition_waiting_processing, - ('waiting', 'memory'): self.transition_waiting_memory, - ('processing', 'released'): self.transition_processing_released, - ('processing', 'memory'): self.transition_processing_memory, - ('processing', 'erred'): self.transition_processing_erred, - ('no-worker', 'released'): self.transition_no_worker_released, - ('no-worker', 'waiting'): self.transition_no_worker_waiting, - ('released', 'forgotten'): self.transition_released_forgotten, - ('memory', 'forgotten'): self.transition_memory_forgotten, - ('erred', 'forgotten'): self.transition_released_forgotten, - ('erred', 'released'): self.transition_erred_released, - ('memory', 'released'): self.transition_memory_released, - ('released', 'erred'): self.transition_released_erred + ("released", "waiting"): self.transition_released_waiting, + ("waiting", "released"): self.transition_waiting_released, + ("waiting", "processing"): self.transition_waiting_processing, + ("waiting", "memory"): self.transition_waiting_memory, + ("processing", "released"): self.transition_processing_released, + ("processing", "memory"): self.transition_processing_memory, + ("processing", "erred"): self.transition_processing_erred, + ("no-worker", "released"): self.transition_no_worker_released, + ("no-worker", "waiting"): self.transition_no_worker_waiting, + ("released", "forgotten"): self.transition_released_forgotten, + ("memory", "forgotten"): self.transition_memory_forgotten, + ("erred", "forgotten"): self.transition_released_forgotten, + ("erred", "released"): self.transition_erred_released, + ("memory", "released"): self.transition_memory_released, + ("released", "erred"): self.transition_released_erred, } connection_limit = get_fileno_limit() / 2 @@ -988,15 +1030,15 @@ def __init__( handlers=self.handlers, stream_handlers=merge(worker_handlers, client_handlers), io_loop=self.loop, - connection_limit=connection_limit, deserialize=False, + connection_limit=connection_limit, + deserialize=False, connection_args=self.connection_args, - **kwargs) + **kwargs + ) if self.worker_ttl: - pc = PeriodicCallback(self.check_worker_ttl, - self.worker_ttl, - io_loop=loop) - self.periodic_callbacks['worker-ttl'] = pc + pc = PeriodicCallback(self.check_worker_ttl, self.worker_ttl, io_loop=loop) + self.periodic_callbacks["worker-ttl"] = pc if extensions is None: extensions = DEFAULT_EXTENSIONS @@ -1011,16 +1053,22 @@ def __init__( def __repr__(self): return '' % ( - self.address, len(self.workers), self.total_ncores) + self.address, + len(self.workers), + self.total_ncores, + ) def identity(self, comm=None): """ Basic information about ourselves and our cluster """ - d = {'type': type(self).__name__, - 'id': str(self.id), - 'address': self.address, - 'services': {key: v.port for (key, v) in self.services.items()}, - 'workers': {worker.address: worker.identity() - for worker in self.workers.values()}} + d = { + "type": type(self).__name__, + "id": str(self.id), + "address": self.address, + "services": {key: v.port for (key, v) in self.services.items()}, + "workers": { + worker.address: worker.identity() for worker in self.workers.values() + }, + } return d def get_worker_service_addr(self, worker, service_name, protocol=False): @@ -1042,17 +1090,17 @@ def get_worker_service_addr(self, worker, service_name, protocol=False): if port is None: return None elif protocol: - return '%(protocol)s://%(host)s:%(port)d' % { - 'protocol': ws.address.split('://')[0], - 'host': ws.host, - 'port': port + return "%(protocol)s://%(host)s:%(port)d" % { + "protocol": ws.address.split("://")[0], + "host": ws.host, + "port": port, } else: return ws.host, port def start_services(self, default_listen_ip): - if default_listen_ip == '0.0.0.0': - default_listen_ip = '' # for IPV6 + if default_listen_ip == "0.0.0.0": + default_listen_ip = "" # for IPV6 for k, v in self.service_specs.items(): listen_ip = None @@ -1062,7 +1110,7 @@ def start_services(self, default_listen_ip): port = 0 if isinstance(port, (str, unicode)): - port = port.split(':') + port = port.split(":") if isinstance(port, (tuple, list)): listen_ip, port = (port[0], int(port[1])) @@ -1074,12 +1122,17 @@ def start_services(self, default_listen_ip): try: service = v(self, io_loop=self.loop, **kwargs) - service.listen((listen_ip if listen_ip is not None else default_listen_ip, port)) + service.listen( + (listen_ip if listen_ip is not None else default_listen_ip, port) + ) self.services[k] = service except Exception as e: - warnings.warn("\nCould not launch service '%s' on port %s. " % (k, port) + - "Got the following message:\n\n" + str(e), - stacklevel=3) + warnings.warn( + "\nCould not launch service '%s' on port %s. " % (k, port) + + "Got the following message:\n\n" + + str(e), + stacklevel=3, + ) def stop_services(self): for service in self.services.values(): @@ -1101,36 +1154,36 @@ def start(self, addr_or_port=8786, start_queues=True): if exc: raise exc - if self.status != 'running': + if self.status != "running": if isinstance(addr_or_port, int): # Listen on all interfaces. `get_ip()` is not suitable # as it would prevent connecting via 127.0.0.1. - self.listen(('', addr_or_port), listen_args=self.listen_args) + self.listen(("", addr_or_port), listen_args=self.listen_args) self.ip = get_ip() - listen_ip = '' + listen_ip = "" else: self.listen(addr_or_port, listen_args=self.listen_args) self.ip = get_address_host(self.listen_address) listen_ip = self.ip - if listen_ip == '0.0.0.0': - listen_ip = '' + if listen_ip == "0.0.0.0": + listen_ip = "" - if isinstance(addr_or_port, str) and addr_or_port.startswith('inproc://'): - listen_ip = 'localhost' + if isinstance(addr_or_port, str) and addr_or_port.startswith("inproc://"): + listen_ip = "localhost" # Services listen on all addresses self.start_services(listen_ip) - self.status = 'running' + self.status = "running" logger.info(" Scheduler at: %25s", self.address) for k, v in self.services.items(): - logger.info("%11s at: %25s", k, '%s:%d' % (listen_ip, v.port)) + logger.info("%11s at: %25s", k, "%s:%d" % (listen_ip, v.port)) self.loop.add_callback(self.reevaluate_occupancy) if self.scheduler_file: - with open(self.scheduler_file, 'w') as f: + with open(self.scheduler_file, "w") as f: json.dump(self.identity(), f, indent=2) fn = self.scheduler_file # remove file when we close the process @@ -1161,9 +1214,9 @@ def close(self, comm=None, fast=False): -------- Scheduler.cleanup """ - if self.status.startswith('clos'): + if self.status.startswith("clos"): return - self.status = 'closing' + self.status = "closing" logger.info("Scheduler closing...") setproctitle("dask-scheduler [closing]") @@ -1181,8 +1234,8 @@ def close(self, comm=None, fast=False): futures = [] for w, comm in list(self.stream_comms.items()): if not comm.closed(): - comm.send({'op': 'close', 'report': False}) - comm.send({'op': 'close-stream'}) + comm.send({"op": "close", "report": False}) + comm.send({"op": "close-stream"}) with ignoring(AttributeError): futures.append(comm.close()) @@ -1197,7 +1250,7 @@ def close(self, comm=None, fast=False): self.rpc.close() - self.status = 'closed' + self.status = "closed" self.stop() yield super(Scheduler, self).close() @@ -1214,16 +1267,20 @@ def close_worker(self, stream=None, worker=None, safe=None): """ logger.info("Closing worker %s", worker) with log_errors(): - self.log_event(worker, {'action': 'close-worker'}) - nanny_addr = self.get_worker_service_addr(worker, 'nanny', protocol=True) + self.log_event(worker, {"action": "close-worker"}) + nanny_addr = self.get_worker_service_addr(worker, "nanny", protocol=True) address = nanny_addr or worker - self.worker_send(worker, {'op': 'close', 'report': False}) + self.worker_send(worker, {"op": "close", "report": False}) self.remove_worker(address=worker, safe=safe) def _setup_logging(self): - self._deque_handler = DequeHandler(n=dask.config.get('distributed.admin.log-length')) - self._deque_handler.setFormatter(logging.Formatter(dask.config.get('distributed.admin.log-format'))) + self._deque_handler = DequeHandler( + n=dask.config.get("distributed.admin.log-length") + ) + self._deque_handler.setFormatter( + logging.Formatter(dask.config.get("distributed.admin.log-format")) + ) logger.addHandler(self._deque_handler) finalize(self, logger.removeHandler, self._deque_handler) @@ -1232,8 +1289,16 @@ def _setup_logging(self): ########### @gen.coroutine - def heartbeat_worker(self, comm=None, address=None, resolve_address=True, - now=None, resources=None, host_info=None, metrics=None): + def heartbeat_worker( + self, + comm=None, + address=None, + resolve_address=True, + now=None, + resources=None, + host_info=None, + metrics=None, + ): address = self.coerce_address(address, resolve_address) address = normalize_address(address) host = get_address_host(address) @@ -1243,11 +1308,11 @@ def heartbeat_worker(self, comm=None, address=None, resolve_address=True, metrics = metrics or {} host_info = host_info or {} - self.host_info[host]['last-seen'] = local_now + self.host_info[host]["last-seen"] = local_now ws = self.workers.get(address) if not ws: - return {'status': 'missing'} + return {"status": "missing"} ws.last_seen = time() @@ -1263,17 +1328,33 @@ def heartbeat_worker(self, comm=None, address=None, resolve_address=True, if resources: self.add_resources(worker=address, resources=resources) - self.log_event(address, merge({'action': 'heartbeat'}, metrics)) + self.log_event(address, merge({"action": "heartbeat"}, metrics)) - return {'status': 'OK', - 'time': time(), - 'heartbeat-interval': heartbeat_interval(len(self.workers))} + return { + "status": "OK", + "time": time(), + "heartbeat-interval": heartbeat_interval(len(self.workers)), + } @gen.coroutine - def add_worker(self, comm=None, address=None, keys=(), ncores=None, - name=None, resolve_address=True, nbytes=None, now=None, - resources=None, host_info=None, memory_limit=None, - metrics=None, pid=0, services=None, local_directory=None): + def add_worker( + self, + comm=None, + address=None, + keys=(), + ncores=None, + name=None, + resolve_address=True, + nbytes=None, + now=None, + resources=None, + host_info=None, + memory_limit=None, + metrics=None, + pid=0, + services=None, + local_directory=None, + ): """ Add a new worker to the cluster """ with log_errors(): address = self.coerce_address(address, resolve_address) @@ -1285,36 +1366,41 @@ def add_worker(self, comm=None, address=None, keys=(), ncores=None, raise ValueError("Worker already exists %s" % address) self.workers[address] = ws = WorkerState( - address=address, - pid=pid, - ncores=ncores, - memory_limit=memory_limit, - name=name, - local_directory=local_directory, - services=services + address=address, + pid=pid, + ncores=ncores, + memory_limit=memory_limit, + name=name, + local_directory=local_directory, + services=services, ) if name in self.aliases: - msg = {'status': 'error', - 'message': 'name taken, %s' % name, - 'time': time()} + msg = { + "status": "error", + "message": "name taken, %s" % name, + "time": time(), + } yield comm.write(msg) return - if 'addresses' not in self.host_info[host]: - self.host_info[host].update({'addresses': set(), 'cores': 0}) + if "addresses" not in self.host_info[host]: + self.host_info[host].update({"addresses": set(), "cores": 0}) - self.host_info[host]['addresses'].add(address) - self.host_info[host]['cores'] += ncores + self.host_info[host]["addresses"].add(address) + self.host_info[host]["cores"] += ncores self.total_ncores += ncores self.aliases[name] = address - response = self.heartbeat_worker(address=address, - resolve_address=resolve_address, - now=now, resources=resources, - host_info=host_info, - metrics=metrics) + response = self.heartbeat_worker( + address=address, + resolve_address=resolve_address, + now=now, + resources=resources, + host_info=host_info, + metrics=metrics, + ) # Do not need to adjust self.total_occupancy as self.occupancy[ws] cannot exist before this. self.check_idle_saturated(ws) @@ -1322,7 +1408,7 @@ def add_worker(self, comm=None, address=None, keys=(), ncores=None, # for key in keys: # TODO # self.mark_key_in_memory(key, [address]) - self.stream_comms[address] = BatchedSend(interval='5ms', loop=self.loop) + self.stream_comms[address] = BatchedSend(interval="5ms", loop=self.loop) if ws.ncores > len(ws.processing): self.idle.add(ws) @@ -1336,37 +1422,51 @@ def add_worker(self, comm=None, address=None, keys=(), ncores=None, if nbytes: for key in nbytes: ts = self.tasks.get(key) - if ts is not None and ts.state in ('processing', 'waiting'): - recommendations = self.transition(key, 'memory', - worker=address, - nbytes=nbytes[key]) + if ts is not None and ts.state in ("processing", "waiting"): + recommendations = self.transition( + key, "memory", worker=address, nbytes=nbytes[key] + ) self.transitions(recommendations) recommendations = {} for ts in list(self.unrunnable): valid = self.valid_workers(ts) if valid is True or ws in valid: - recommendations[ts.key] = 'waiting' + recommendations[ts.key] = "waiting" if recommendations: self.transitions(recommendations) - self.log_event(address, {'action': 'add-worker'}) - self.log_event('all', {'action': 'add-worker', - 'worker': address}) + self.log_event(address, {"action": "add-worker"}) + self.log_event("all", {"action": "add-worker", "worker": address}) logger.info("Register %s", str(address)) - yield comm.write({'status': 'OK', - 'time': time(), - 'heartbeat-interval': heartbeat_interval(len(self.workers)), - 'worker-setups': self.worker_setups}) + yield comm.write( + { + "status": "OK", + "time": time(), + "heartbeat-interval": heartbeat_interval(len(self.workers)), + "worker-setups": self.worker_setups, + } + ) yield self.handle_worker(comm=comm, worker=address) - def update_graph(self, client=None, tasks=None, keys=None, - dependencies=None, restrictions=None, priority=None, - loose_restrictions=None, resources=None, - submitting_task=None, retries=None, user_priority=0, - actors=None, fifo_timeout=0): + def update_graph( + self, + client=None, + tasks=None, + keys=None, + dependencies=None, + restrictions=None, + priority=None, + loose_restrictions=None, + resources=None, + submitting_task=None, + retries=None, + user_priority=0, + actors=None, + fifo_timeout=0, + ): """ Add new computations to the internal dask graph @@ -1376,8 +1476,9 @@ def update_graph(self, client=None, tasks=None, keys=None, fifo_timeout = parse_timedelta(fifo_timeout) keys = set(keys) if len(tasks) > 1: - self.log_event(['all', client], {'action': 'update_graph', - 'count': len(tasks)}) + self.log_event( + ["all", client], {"action": "update_graph", "count": len(tasks)} + ) # Remove aliases for k in list(tasks): @@ -1390,14 +1491,15 @@ def update_graph(self, client=None, tasks=None, keys=None, while len(tasks) != n: # walk through new tasks, cancel any bad deps n = len(tasks) for k, deps in list(dependencies.items()): - if any(dep not in self.tasks and dep not in tasks - for dep in deps): # bad key - logger.info('User asked for computation on lost data, %s', k) + if any( + dep not in self.tasks and dep not in tasks for dep in deps + ): # bad key + logger.info("User asked for computation on lost data, %s", k) del tasks[k] del dependencies[k] if k in keys: keys.remove(k) - self.report({'op': 'cancelled-key', 'key': k}, client=client) + self.report({"op": "cancelled-key", "key": k}, client=client) self.client_releases_keys(keys=[k], client=client) # Remove any self-dependencies (happens on test_publish_bag() and others) @@ -1410,7 +1512,7 @@ def update_graph(self, client=None, tasks=None, keys=None, # Avoid computation that is already finished already_in_memory = set() # tasks that are already done for k, v in dependencies.items(): - if v and k in self.tasks and self.tasks[k].state in ('memory', 'erred'): + if v and k in self.tasks and self.tasks[k].state in ("memory", "erred"): already_in_memory.add(k) if already_in_memory: @@ -1450,7 +1552,7 @@ def update_graph(self, client=None, tasks=None, keys=None, ts = self.tasks.get(k) if ts is None: ts = self.tasks[k] = TaskState(k, tasks.get(k)) - ts.state = 'released' + ts.state = "released" elif not ts.run_spec: ts.run_spec = tasks.get(k) @@ -1480,7 +1582,9 @@ def update_graph(self, client=None, tasks=None, keys=None, for actor in actors or []: self.tasks[actor].actor = True - priority = priority or dask.order.order(tasks) # TODO: define order wrt old graph + priority = priority or dask.order.order( + tasks + ) # TODO: define order wrt old graph if submitting_task: # sub-tasks get better priority than parent tasks ts = self.tasks.get(submitting_task) @@ -1501,8 +1605,7 @@ def update_graph(self, client=None, tasks=None, keys=None, ts.priority = (-user_priority.get(key, 0), generation, priority[key]) # Ensure all runnables have a priority - runnables = [ts for ts in touched_tasks - if ts.run_spec] + runnables = [ts for ts in touched_tasks if ts.run_spec] for ts in runnables: if ts.priority is None and ts.run_spec: ts.priority = (self.generation, 0) @@ -1553,38 +1656,42 @@ def update_graph(self, client=None, tasks=None, keys=None, # Compute recommendations recommendations = OrderedDict() - for ts in sorted(runnables, key=operator.attrgetter('priority'), - reverse=True): - if ts.state == 'released' and ts.run_spec: - recommendations[ts.key] = 'waiting' + for ts in sorted(runnables, key=operator.attrgetter("priority"), reverse=True): + if ts.state == "released" and ts.run_spec: + recommendations[ts.key] = "waiting" for ts in touched_tasks: for dts in ts.dependencies: if dts.exception_blame: ts.exception_blame = dts.exception_blame - recommendations[ts.key] = 'erred' + recommendations[ts.key] = "erred" break for plugin in self.plugins[:]: try: - plugin.update_graph(self, client=client, tasks=tasks, - keys=keys, restrictions=restrictions or {}, - dependencies=dependencies, - priority=priority, - loose_restrictions=loose_restrictions, - resources=resources) + plugin.update_graph( + self, + client=client, + tasks=tasks, + keys=keys, + restrictions=restrictions or {}, + dependencies=dependencies, + priority=priority, + loose_restrictions=loose_restrictions, + resources=resources, + ) except Exception as e: logger.exception(e) self.transitions(recommendations) for ts in touched_tasks: - if ts.state in ('memory', 'erred'): + if ts.state in ("memory", "erred"): self.report_on_key(ts.key, client=client) end = time() if self.digests is not None: - self.digests['update-graph-duration'].add(end - start) + self.digests["update-graph-duration"].add(end - start) # TODO: balance workers @@ -1597,24 +1704,29 @@ def stimulus_task_finished(self, key=None, worker=None, **kwargs): return {} ws = self.workers[worker] - if ts.state == 'processing': - recommendations = self.transition(key, 'memory', worker=worker, - **kwargs) + if ts.state == "processing": + recommendations = self.transition(key, "memory", worker=worker, **kwargs) - if ts.state == 'memory': + if ts.state == "memory": assert ws in ts.who_has else: - logger.debug("Received already computed task, worker: %s, state: %s" - ", key: %s, who_has: %s", - worker, ts.state, key, ts.who_has) + logger.debug( + "Received already computed task, worker: %s, state: %s" + ", key: %s, who_has: %s", + worker, + ts.state, + key, + ts.who_has, + ) if ws not in ts.who_has: - self.worker_send(worker, {'op': 'release-task', 'key': key}) + self.worker_send(worker, {"op": "release-task", "key": key}) recommendations = {} return recommendations - def stimulus_task_erred(self, key=None, worker=None, - exception=None, traceback=None, **kwargs): + def stimulus_task_erred( + self, key=None, worker=None, exception=None, traceback=None, **kwargs + ): """ Mark that a task has erred on a particular worker """ logger.debug("Stimulus task erred %s, %s", key, worker) @@ -1622,45 +1734,49 @@ def stimulus_task_erred(self, key=None, worker=None, if ts is None: return {} - if ts.state == 'processing': + if ts.state == "processing": retries = ts.retries if retries > 0: ts.retries = retries - 1 - recommendations = self.transition(key, 'waiting') + recommendations = self.transition(key, "waiting") else: - recommendations = self.transition(key, 'erred', - cause=key, - exception=exception, - traceback=traceback, - worker=worker, - **kwargs) + recommendations = self.transition( + key, + "erred", + cause=key, + exception=exception, + traceback=traceback, + worker=worker, + **kwargs + ) else: recommendations = {} return recommendations - def stimulus_missing_data(self, cause=None, key=None, worker=None, - ensure=True, **kwargs): + def stimulus_missing_data( + self, cause=None, key=None, worker=None, ensure=True, **kwargs + ): """ Mark that certain keys have gone missing. Recover. """ with log_errors(): logger.debug("Stimulus missing data %s, %s", key, worker) ts = self.tasks.get(key) - if ts is None or ts.state == 'memory': + if ts is None or ts.state == "memory": return {} cts = self.tasks.get(cause) recommendations = OrderedDict() - if cts is not None and cts.state == 'memory': # couldn't find this + if cts is not None and cts.state == "memory": # couldn't find this for ws in cts.who_has: # TODO: this behavior is extreme ws.has_what.remove(cts) ws.nbytes -= cts.get_nbytes() cts.who_has.clear() - recommendations[cause] = 'released' + recommendations[cause] = "released" if key: - recommendations[key] = 'released' + recommendations[key] = "released" self.transitions(recommendations) @@ -1672,7 +1788,7 @@ def stimulus_missing_data(self, cause=None, key=None, worker=None, def stimulus_retry(self, comm=None, keys=None, client=None): logger.info("Client %s requests to retry %d keys", client, len(keys)) if client: - self.log_event(client, {'action': 'retry', 'count': len(keys)}) + self.log_event(client, {"action": "retry", "count": len(keys)}) stack = list(keys) seen = set() @@ -1680,14 +1796,15 @@ def stimulus_retry(self, comm=None, keys=None, client=None): while stack: key = stack.pop() seen.add(key) - erred_deps = [dts.key for dts in self.tasks[key].dependencies - if dts.state == 'erred'] + erred_deps = [ + dts.key for dts in self.tasks[key].dependencies if dts.state == "erred" + ] if erred_deps: stack.extend(erred_deps) else: roots.append(key) - recommendations = {key: 'waiting' for key in roots} + recommendations = {key: "waiting" for key in roots} self.transitions(recommendations) if self.validate: @@ -1705,31 +1822,36 @@ def remove_worker(self, comm=None, address=None, safe=False, close=True): state. """ with log_errors(): - if self.status == 'closed': + if self.status == "closed": return if address not in self.workers: - return 'already-removed' + return "already-removed" address = self.coerce_address(address) host = get_address_host(address) ws = self.workers[address] - self.log_event(['all', address], {'action': 'remove-worker', - 'worker': address, - 'processing-tasks': dict(ws.processing)}) + self.log_event( + ["all", address], + { + "action": "remove-worker", + "worker": address, + "processing-tasks": dict(ws.processing), + }, + ) logger.info("Remove worker %s", address) if close: with ignoring(AttributeError, CommClosedError): - self.stream_comms[address].send({'op': 'close', 'report': False}) + self.stream_comms[address].send({"op": "close", "report": False}) self.remove_resources(address) - self.host_info[host]['cores'] -= ws.ncores - self.host_info[host]['addresses'].remove(address) + self.host_info[host]["cores"] -= ws.ncores + self.host_info[host]["addresses"].remove(address) self.total_ncores -= ws.ncores - if not self.host_info[host]['addresses']: + if not self.host_info[host]["addresses"]: del self.host_info[host] self.rpc.remove(address) @@ -1738,29 +1860,29 @@ def remove_worker(self, comm=None, address=None, safe=False, close=True): self.idle.discard(ws) self.saturated.discard(ws) del self.workers[address] - ws.status = 'closed' + ws.status = "closed" self.total_occupancy -= ws.occupancy recommendations = OrderedDict() for ts in list(ws.processing): k = ts.key - recommendations[k] = 'released' + recommendations[k] = "released" if not safe: ts.suspicious += 1 if ts.suspicious > self.allowed_failures: del recommendations[k] e = pickle.dumps(KilledWorker(k, address)) - r = self.transition(k, 'erred', exception=e, cause=k) + r = self.transition(k, "erred", exception=e, cause=k) recommendations.update(r) for ts in ws.has_what: ts.who_has.remove(ws) if not ts.who_has: if ts.run_spec: - recommendations[ts.key] = 'released' + recommendations[ts.key] = "released" else: # pure data - recommendations[ts.key] = 'forgotten' + recommendations[ts.key] = "forgotten" ws.has_what.clear() self.transitions(recommendations) @@ -1780,21 +1902,21 @@ def remove_worker_from_events(): if address not in self.workers and address in self.events: del self.events[address] - cleanup_delay = parse_timedelta(dask.config.get('distributed.scheduler.events-cleanup-delay')) - self.loop.call_later( - cleanup_delay, - remove_worker_from_events + cleanup_delay = parse_timedelta( + dask.config.get("distributed.scheduler.events-cleanup-delay") ) + self.loop.call_later(cleanup_delay, remove_worker_from_events) logger.debug("Removed worker %s", address) - return 'OK' + return "OK" def stimulus_cancel(self, comm, keys=None, client=None, force=False): """ Stop execution on a list of keys """ logger.info("Client %s requests to cancel %d keys", client, len(keys)) if client: - self.log_event(client, {'action': 'cancel', 'count': len(keys), - 'force': force}) + self.log_event( + client, {"action": "cancel", "count": len(keys), "force": force} + ) for key in keys: self.cancel_key(key, client, force=force) @@ -1805,14 +1927,15 @@ def cancel_key(self, key, client, retries=5, force=False): cs = self.clients[client] if ts is None or not ts.who_wants: # no key yet, lets try again in a moment if retries: - self.loop.add_future(gen.sleep(0.2), - lambda _: self.cancel_key(key, client, retries - 1)) + self.loop.add_future( + gen.sleep(0.2), lambda _: self.cancel_key(key, client, retries - 1) + ) return if force or ts.who_wants == {cs}: # no one else wants this key for dts in list(ts.dependents): self.cancel_key(dts.key, client, force=force) logger.info("Scheduler cancels key %s. Force=%s", key, force) - self.report({'op': 'cancelled-key', 'key': key}) + self.report({"op": "cancelled-key", "key": key}) clients = list(ts.who_wants) if force else [cs] for c in clients: self.client_releases_keys(keys=[key], client=c.client_key) @@ -1827,11 +1950,11 @@ def client_desires_keys(self, keys=None, client=None): if ts is None: # For publish, queues etc. ts = self.tasks[k] = TaskState(k, None) - ts.state = 'released' + ts.state = "released" ts.who_wants.add(cs) cs.wants_what.add(ts) - if ts.state in ('memory', 'erred'): + if ts.state in ("memory", "erred"): self.report_on_key(k, client=client) def client_releases_keys(self, keys=None, client=None): @@ -1852,9 +1975,9 @@ def client_releases_keys(self, keys=None, client=None): for ts in tasks2: if not ts.dependents: # No live dependents, can forget - recommendations[ts.key] = 'forgotten' - elif ts.state != 'erred' and not ts.waiters: - recommendations[ts.key] = 'released' + recommendations[ts.key] = "forgotten" + elif ts.state != "erred" and not ts.waiters: + recommendations[ts.key] = "released" self.transitions(recommendations) @@ -1868,13 +1991,12 @@ def client_heartbeat(self, client=None): def validate_released(self, key): ts = self.tasks[key] - assert ts.state == 'released' + assert ts.state == "released" assert not ts.waiters assert not ts.waiting_on assert not ts.who_has assert not ts.processing_on - assert not any(ts in dts.waiters - for dts in ts.dependencies) + assert not any(ts in dts.waiters for dts in ts.dependencies) assert ts not in self.unrunnable def validate_waiting(self, key): @@ -1906,7 +2028,7 @@ def validate_memory(self, key): assert not ts.waiting_on assert ts not in self.unrunnable for dts in ts.dependents: - assert (dts in ts.waiters) == (dts.state in ('waiting', 'processing')) + assert (dts in ts.waiters) == (dts.state in ("waiting", "processing")) assert ts not in dts.waiting_on def validate_no_worker(self, key): @@ -1933,16 +2055,18 @@ def validate_key(self, key, ts=None): else: ts.validate() try: - func = getattr(self, 'validate_' + ts.state.replace('-', '_')) + func = getattr(self, "validate_" + ts.state.replace("-", "_")) except AttributeError: - logger.error("self.validate_%s not found", - ts.state.replace('-', '_')) + logger.error( + "self.validate_%s not found", ts.state.replace("-", "_") + ) else: func(key) except Exception as e: logger.exception(e) if LOG_PDB: import pdb + pdb.set_trace() raise @@ -1972,8 +2096,10 @@ def validate_state(self, allow_overlap=False): assert cs.client_key == c a = {w: ws.nbytes for w, ws in self.workers.items()} - b = {w: sum(ts.get_nbytes() for ts in ws.has_what) - for w, ws in self.workers.items()} + b = { + w: sum(ts.get_nbytes() for ts in ws.has_what) + for w, ws in self.workers.items() + } assert a == b, (a, b) actual_total_occupancy = 0 @@ -1981,8 +2107,10 @@ def validate_state(self, allow_overlap=False): assert abs(sum(ws.processing.values()) - ws.occupancy) < 1e-8 actual_total_occupancy += ws.occupancy - assert abs(actual_total_occupancy - self.total_occupancy) < 1e-8, \ - (actual_total_occupancy, self.total_occupancy) + assert abs(actual_total_occupancy - self.total_occupancy) < 1e-8, ( + actual_total_occupancy, + self.total_occupancy, + ) ################### # Manage Messages # @@ -2000,27 +2128,29 @@ def report(self, msg, ts=None, client=None): comm = self.client_comms[client] comm.send(msg) except CommClosedError: - if self.status == 'running': + if self.status == "running": logger.critical("Tried writing to closed comm: %s", msg) except KeyError: pass - if ts is None and 'key' in msg: - ts = self.tasks.get(msg['key']) + if ts is None and "key" in msg: + ts = self.tasks.get(msg["key"]) if ts is None: # Notify all clients comms = self.client_comms.values() else: # Notify clients interested in key - comms = [self.client_comms[c.client_key] - for c in ts.who_wants - if c.client_key in self.client_comms] + comms = [ + self.client_comms[c.client_key] + for c in ts.who_wants + if c.client_key in self.client_comms + ] for c in comms: try: c.send(msg) # logger.debug("Scheduler sends message to client %s", msg) except CommClosedError: - if self.status == 'running': + if self.status == "running": logger.critical("Tried writing to closed comm: %s", msg) @gen.coroutine @@ -2031,46 +2161,45 @@ def add_client(self, comm, client=None): """ assert client is not None logger.info("Receive client connection: %s", client) - self.log_event(['all', client], {'action': 'add-client', - 'client': client}) + self.log_event(["all", client], {"action": "add-client", "client": client}) self.clients[client] = ClientState(client) try: - bcomm = BatchedSend(interval='2ms', loop=self.loop) + bcomm = BatchedSend(interval="2ms", loop=self.loop) bcomm.start(comm) self.client_comms[client] = bcomm - bcomm.send({'op': 'stream-start'}) + bcomm.send({"op": "stream-start"}) try: - yield self.handle_stream(comm=comm, extra={'client': client}) + yield self.handle_stream(comm=comm, extra={"client": client}) finally: self.remove_client(client=client) - logger.debug('Finished handling client %s', client) + logger.debug("Finished handling client %s", client) finally: if not comm.closed(): - self.client_comms[client].send({'op': 'stream-closed'}) + self.client_comms[client].send({"op": "stream-closed"}) try: if not shutting_down(): yield self.client_comms[client].close() del self.client_comms[client] - if self.status == 'running': + if self.status == "running": logger.info("Close client connection: %s", client) except TypeError: # comm becomes None during GC pass def remove_client(self, client=None): """ Remove client from network """ - if self.status == 'running': + if self.status == "running": logger.info("Remove client %s", client) - self.log_event(['all', client], {'action': 'remove-client', - 'client': client}) + self.log_event(["all", client], {"action": "remove-client", "client": client}) try: cs = self.clients[client] except KeyError: # XXX is this a legitimate condition? pass else: - self.client_releases_keys(keys=[ts.key for ts in cs.wants_what], - client=cs.client_key) + self.client_releases_keys( + keys=[ts.key for ts in cs.wants_what], client=cs.client_key + ) del self.clients[client] @gen.coroutine @@ -2079,46 +2208,49 @@ def remove_client_from_events(): if client not in self.clients and client in self.events: del self.events[client] - cleanup_delay = parse_timedelta(dask.config.get('distributed.scheduler.events-cleanup-delay')) - self.loop.call_later( - cleanup_delay, - remove_client_from_events + cleanup_delay = parse_timedelta( + dask.config.get("distributed.scheduler.events-cleanup-delay") ) + self.loop.call_later(cleanup_delay, remove_client_from_events) def send_task_to_worker(self, worker, key): """ Send a single computational task to a worker """ try: ts = self.tasks[key] - msg = {'op': 'compute-task', - 'key': key, - 'priority': ts.priority, - 'duration': self.get_task_duration(ts)} + msg = { + "op": "compute-task", + "key": key, + "priority": ts.priority, + "duration": self.get_task_duration(ts), + } if ts.resource_restrictions: - msg['resource_restrictions'] = ts.resource_restrictions + msg["resource_restrictions"] = ts.resource_restrictions if ts.actor: - msg['actor'] = True + msg["actor"] = True deps = ts.dependencies if deps: - msg['who_has'] = {dep.key: [ws.address for ws in dep.who_has] - for dep in deps} - msg['nbytes'] = {dep.key: dep.nbytes for dep in deps} + msg["who_has"] = { + dep.key: [ws.address for ws in dep.who_has] for dep in deps + } + msg["nbytes"] = {dep.key: dep.nbytes for dep in deps} if self.validate and deps: - assert all(msg['who_has'].values()) + assert all(msg["who_has"].values()) task = ts.run_spec if type(task) is dict: msg.update(task) else: - msg['task'] = task + msg["task"] = task self.worker_send(worker, msg) except Exception as e: logger.exception(e) if LOG_PDB: import pdb + pdb.set_trace() raise @@ -2148,7 +2280,7 @@ def handle_release_data(self, key=None, worker=None, client=None, **msg): def handle_missing_data(self, key=None, errant_worker=None, **kwargs): logger.debug("handle missing data key=%s worker=%s", key, errant_worker) - self.log.append(('missing', key, errant_worker)) + self.log.append(("missing", key, errant_worker)) ts = self.tasks.get(key) if ts is None or not ts.who_has: @@ -2161,9 +2293,9 @@ def handle_missing_data(self, key=None, errant_worker=None, **kwargs): ws.nbytes -= ts.get_nbytes() if not ts.who_has: if ts.run_spec: - self.transitions({key: 'released'}) + self.transitions({key: "released"}) else: - self.transitions({key: 'forgotten'}) + self.transitions({key: "forgotten"}) def release_worker_data(self, stream=None, keys=None, worker=None): ws = self.workers[worker] @@ -2177,7 +2309,7 @@ def release_worker_data(self, stream=None, keys=None, worker=None): wh = ts.who_has wh.remove(ws) if not wh: - recommendations[ts.key] = 'released' + recommendations[ts.key] = "released" if recommendations: self.transitions(recommendations) @@ -2188,13 +2320,14 @@ def handle_long_running(self, key=None, worker=None, compute_duration=None): duration accounting as if the task has stopped. """ ts = self.tasks[key] - if 'stealing' in self.extensions: - self.extensions['stealing'].remove_key_from_stealable(ts) + if "stealing" in self.extensions: + self.extensions["stealing"].remove_key_from_stealable(ts) ws = ts.processing_on if ws is None: - logger.debug("Received long-running signal from duplicate task. " - "Ignoring.") + logger.debug( + "Received long-running signal from duplicate task. " "Ignoring." + ) return if compute_duration: @@ -2204,8 +2337,7 @@ def handle_long_running(self, key=None, worker=None, compute_duration=None): if not old_duration: avg_duration = new_duration else: - avg_duration = (0.5 * old_duration - + 0.5 * new_duration) + avg_duration = 0.5 * old_duration + 0.5 * new_duration self.task_duration[prefix] = avg_duration @@ -2229,7 +2361,7 @@ def handle_worker(self, comm=None, worker=None): worker_comm.start(comm) logger.info("Starting worker compute stream, %s", worker) try: - yield self.handle_stream(comm=comm, extra={'worker': worker}) + yield self.handle_stream(comm=comm, extra={"worker": worker}) finally: if worker in self.stream_comms: worker_comm.abort() @@ -2269,8 +2401,15 @@ def worker_send(self, worker, msg): ############################ @gen.coroutine - def scatter(self, comm=None, data=None, workers=None, client=None, - broadcast=False, timeout=2): + def scatter( + self, + comm=None, + data=None, + workers=None, + client=None, + broadcast=False, + timeout=2, + ): """ Send data out to workers See also @@ -2291,9 +2430,9 @@ def scatter(self, comm=None, data=None, workers=None, client=None, assert isinstance(data, dict) - keys, who_has, nbytes = yield scatter_to_workers(ncores, data, - rpc=self.rpc, - report=False) + keys, who_has, nbytes = yield scatter_to_workers( + ncores, data, rpc=self.rpc, report=False + ) self.update_data(who_has=who_has, nbytes=nbytes, client=client) @@ -2304,9 +2443,9 @@ def scatter(self, comm=None, data=None, workers=None, client=None, n = broadcast yield self.replicate(keys=keys, workers=workers, n=n) - self.log_event([client, 'all'], {'action': 'scatter', - 'client': client, - 'count': len(data)}) + self.log_event( + [client, "all"], {"action": "scatter", "client": client, "count": len(data)} + ) raise gen.Return(keys) @gen.coroutine @@ -2322,16 +2461,22 @@ def gather(self, comm=None, keys=None, serializers=None): who_has[key] = [] data, missing_keys, missing_workers = yield gather_from_workers( - who_has, rpc=self.rpc, close=False, serializers=serializers) + who_has, rpc=self.rpc, close=False, serializers=serializers + ) if not missing_keys: - result = {'status': 'OK', 'data': data} + result = {"status": "OK", "data": data} else: - missing_states = [(self.tasks[key].state - if key in self.tasks else None) - for key in missing_keys] - logger.debug("Couldn't gather keys %s state: %s workers: %s", - missing_keys, missing_states, missing_workers) - result = {'status': 'error', 'keys': missing_keys} + missing_states = [ + (self.tasks[key].state if key in self.tasks else None) + for key in missing_keys + ] + logger.debug( + "Couldn't gather keys %s state: %s workers: %s", + missing_keys, + missing_states, + missing_workers, + ) + result = {"status": "error", "keys": missing_keys} with log_errors(): for worker in missing_workers: self.remove_worker(address=worker) # this is extreme @@ -2339,18 +2484,20 @@ def gather(self, comm=None, keys=None, serializers=None): if not workers: continue ts = self.tasks[key] - logger.exception("Workers don't have promised key: %s, %s", - str(workers), str(key)) + logger.exception( + "Workers don't have promised key: %s, %s", + str(workers), + str(key), + ) for worker in workers: ws = self.workers.get(worker) if ws is not None and ts in ws.has_what: ws.has_what.remove(ts) ts.who_has.remove(ws) ws.nbytes -= ts.get_nbytes() - self.transitions({key: 'released'}) + self.transitions({key: "released"}) - self.log_event('all', {'action': 'gather', - 'count': len(keys)}) + self.log_event("all", {"action": "gather", "count": len(keys)}) raise gen.Return(result) def clear_task_state(self): @@ -2369,11 +2516,14 @@ def restart(self, client=None, timeout=3): logger.info("Send lost future signal to clients") for cs in self.clients.values(): - self.client_releases_keys(keys=[ts.key for ts in cs.wants_what], - client=cs.client_key) + self.client_releases_keys( + keys=[ts.key for ts in cs.wants_what], client=cs.client_key + ) - nannies = {addr: self.get_worker_service_addr(addr, 'nanny', protocol=True) - for addr in self.workers} + nannies = { + addr: self.get_worker_service_addr(addr, "nanny", protocol=True) + for addr in self.workers + } for addr in list(self.workers): try: @@ -2381,8 +2531,9 @@ def restart(self, client=None, timeout=3): # otherwise the nanny will kill it anyway self.remove_worker(address=addr, close=addr not in nannies) except Exception as e: - logger.info("Exception while restarting. This is normal", - exc_info=True) + logger.info( + "Exception while restarting. This is normal", exc_info=True + ) self.clear_task_state() @@ -2394,38 +2545,54 @@ def restart(self, client=None, timeout=3): logger.debug("Send kill signal to nannies: %s", nannies) - nannies = [rpc(nanny_address, connection_args=self.connection_args) - for nanny_address in nannies.values() - if nanny_address is not None] + nannies = [ + rpc(nanny_address, connection_args=self.connection_args) + for nanny_address in nannies.values() + if nanny_address is not None + ] try: - resps = All([nanny.restart(close=True, timeout=timeout * 0.8, - executor_wait=False) - for nanny in nannies]) + resps = All( + [ + nanny.restart( + close=True, timeout=timeout * 0.8, executor_wait=False + ) + for nanny in nannies + ] + ) resps = yield gen.with_timeout(timedelta(seconds=timeout), resps) - if not all(resp == 'OK' for resp in resps): - logger.error("Not all workers responded positively: %s", - resps, exc_info=True) + if not all(resp == "OK" for resp in resps): + logger.error( + "Not all workers responded positively: %s", resps, exc_info=True + ) except gen.TimeoutError: - logger.error("Nannies didn't report back restarted within " - "timeout. Continuuing with restart process") + logger.error( + "Nannies didn't report back restarted within " + "timeout. Continuuing with restart process" + ) finally: for nanny in nannies: nanny.close_rpc() self.start() - self.log_event([client, 'all'], {'action': 'restart', - 'client': client}) + self.log_event([client, "all"], {"action": "restart", "client": client}) start = time() while time() < start + 10 and len(self.workers) < n_workers: yield gen.sleep(0.01) - self.report({'op': 'restart'}) + self.report({"op": "restart"}) @gen.coroutine - def broadcast(self, comm=None, msg=None, workers=None, hosts=None, - nanny=False, serializers=None): + def broadcast( + self, + comm=None, + msg=None, + workers=None, + hosts=None, + nanny=False, + serializers=None, + ): """ Broadcast message to workers, return all results """ if workers is None: if hosts is None: @@ -2435,33 +2602,36 @@ def broadcast(self, comm=None, msg=None, workers=None, hosts=None, if hosts is not None: for host in hosts: if host in self.host_info: - workers.extend(self.host_info[host]['addresses']) + workers.extend(self.host_info[host]["addresses"]) # TODO replace with worker_list if nanny: - addresses = [self.get_worker_service_addr(w, 'nanny', protocol=True) - for w in workers] + addresses = [ + self.get_worker_service_addr(w, "nanny", protocol=True) for w in workers + ] else: addresses = workers @gen.coroutine def send_message(addr): - comm = yield connect(addr, deserialize=self.deserialize, - connection_args=self.connection_args) + comm = yield connect( + addr, deserialize=self.deserialize, connection_args=self.connection_args + ) resp = yield send_recv(comm, close=True, serializers=serializers, **msg) raise gen.Return(resp) - results = yield All([send_message(address) - for address in addresses - if address is not None]) + results = yield All( + [send_message(address) for address in addresses if address is not None] + ) raise Return(dict(zip(workers, results))) @gen.coroutine def proxy(self, comm=None, msg=None, worker=None, serializers=None): """ Proxy a communication through the scheduler to some other worker """ - d = yield self.broadcast(comm=comm, msg=msg, workers=[worker], - serializers=serializers) + d = yield self.broadcast( + comm=comm, msg=msg, workers=[worker], serializers=serializers + ) raise gen.Return(d[worker]) @gen.coroutine @@ -2481,8 +2651,7 @@ def rebalance(self, comm=None, keys=None, workers=None): tasks = {self.tasks[k] for k in keys} missing_data = [ts.key for ts in tasks if not ts.who_has] if missing_data: - raise Return({'status': 'missing-data', - 'keys': missing_data}) + raise Return({"status": "missing-data", "keys": missing_data}) else: tasks = set(self.tasks.values()) @@ -2499,27 +2668,31 @@ def rebalance(self, comm=None, keys=None, workers=None): for vv in v: tasks_by_worker[vv].add(k) - worker_bytes = {ws: sum(ts.get_nbytes() for ts in v) - for ws, v in tasks_by_worker.items()} + worker_bytes = { + ws: sum(ts.get_nbytes() for ts in v) + for ws, v in tasks_by_worker.items() + } avg = sum(worker_bytes.values()) / len(worker_bytes) - sorted_workers = list(map(first, sorted(worker_bytes.items(), - key=second, reverse=True))) + sorted_workers = list( + map(first, sorted(worker_bytes.items(), key=second, reverse=True)) + ) recipients = iter(reversed(sorted_workers)) recipient = next(recipients) msgs = [] # (sender, recipient, key) - for sender in sorted_workers[:len(workers) // 2]: - sender_keys = {ts: ts.get_nbytes() - for ts in tasks_by_worker[sender]} - sender_keys = iter(sorted(sender_keys.items(), - key=second, reverse=True)) + for sender in sorted_workers[: len(workers) // 2]: + sender_keys = {ts: ts.get_nbytes() for ts in tasks_by_worker[sender]} + sender_keys = iter( + sorted(sender_keys.items(), key=second, reverse=True) + ) try: while worker_bytes[sender] > avg: - while (worker_bytes[recipient] < avg and - worker_bytes[sender] > avg): + while ( + worker_bytes[recipient] < avg and worker_bytes[sender] > avg + ): ts, nb = next(sender_keys) if ts not in tasks_by_worker[recipient]: tasks_by_worker[recipient].add(ts) @@ -2538,44 +2711,62 @@ def rebalance(self, comm=None, keys=None, workers=None): to_recipients[recipient.address][ts.key].append(sender.address) to_senders[sender.address].append(ts.key) - result = yield {r: self.rpc(addr=r).gather(who_has=v) - for r, v in to_recipients.items()} + result = yield { + r: self.rpc(addr=r).gather(who_has=v) for r, v in to_recipients.items() + } for r, v in to_recipients.items(): - self.log_event(r, {'action': 'rebalance', - 'who_has': v}) - - self.log_event('all', {'action': 'rebalance', - 'total-keys': len(tasks), - 'senders': valmap(len, to_senders), - 'recipients': valmap(len, to_recipients), - 'moved_keys': len(msgs)}) + self.log_event(r, {"action": "rebalance", "who_has": v}) + + self.log_event( + "all", + { + "action": "rebalance", + "total-keys": len(tasks), + "senders": valmap(len, to_senders), + "recipients": valmap(len, to_recipients), + "moved_keys": len(msgs), + }, + ) - if not all(r['status'] == 'OK' for r in result.values()): - raise Return({'status': 'missing-data', - 'keys': sum([r['keys'] for r in result - if 'keys' in r], [])}) + if not all(r["status"] == "OK" for r in result.values()): + raise Return( + { + "status": "missing-data", + "keys": sum([r["keys"] for r in result if "keys" in r], []), + } + ) for sender, recipient, ts in msgs: - assert ts.state == 'memory' + assert ts.state == "memory" ts.who_has.add(recipient) recipient.has_what.add(ts) recipient.nbytes += ts.get_nbytes() - self.log.append(('rebalance', ts.key, time(), - sender.address, recipient.address)) + self.log.append( + ("rebalance", ts.key, time(), sender.address, recipient.address) + ) - result = yield {r: self.rpc(addr=r).delete_data(keys=v, report=False) - for r, v in to_senders.items()} + result = yield { + r: self.rpc(addr=r).delete_data(keys=v, report=False) + for r, v in to_senders.items() + } for sender, recipient, ts in msgs: ts.who_has.remove(sender) sender.has_what.remove(ts) sender.nbytes -= ts.get_nbytes() - raise Return({'status': 'OK'}) + raise Return({"status": "OK"}) @gen.coroutine - def replicate(self, comm=None, keys=None, n=None, workers=None, - branching_factor=2, delete=True): + def replicate( + self, + comm=None, + keys=None, + n=None, + workers=None, + branching_factor=2, + delete=True, + ): """ Replicate data throughout cluster This performs a tree copy of the data throughout the network @@ -2610,8 +2801,7 @@ def replicate(self, comm=None, keys=None, n=None, workers=None, tasks = {self.tasks[k] for k in keys} missing_data = [ts.key for ts in tasks if not ts.who_has] if missing_data: - raise Return({'status': 'missing-data', - 'keys': missing_data}) + raise Return({"status": "missing-data", "keys": missing_data}) # Delete extraneous data if delete: @@ -2619,22 +2809,25 @@ def replicate(self, comm=None, keys=None, n=None, workers=None, for ts in tasks: del_candidates = ts.who_has & workers if len(del_candidates) > n: - for ws in random.sample(del_candidates, - len(del_candidates) - n): + for ws in random.sample(del_candidates, len(del_candidates) - n): del_worker_tasks[ws].add(ts) - yield [self.rpc(addr=ws.address) - .delete_data(keys=[ts.key for ts in tasks], report=False) - for ws, tasks in del_worker_tasks.items()] + yield [ + self.rpc(addr=ws.address).delete_data( + keys=[ts.key for ts in tasks], report=False + ) + for ws, tasks in del_worker_tasks.items() + ] for ws, tasks in del_worker_tasks.items(): ws.has_what -= tasks for ts in tasks: ts.who_has.remove(ws) ws.nbytes -= ts.get_nbytes() - self.log_event(ws.address, - {'action': 'replicate-remove', - 'keys': [ts.key for ts in tasks]}) + self.log_event( + ws.address, + {"action": "replicate-remove", "keys": [ts.key for ts in tasks]}, + ) # Copy not-yet-filled data while tasks: @@ -2646,33 +2839,35 @@ def replicate(self, comm=None, keys=None, n=None, workers=None, tasks.remove(ts) continue - count = min(n_missing, - branching_factor * len(ts.who_has)) + count = min(n_missing, branching_factor * len(ts.who_has)) assert count > 0 for ws in random.sample(workers - ts.who_has, count): - gathers[ws.address][ts.key] = [wws.address - for wws in ts.who_has] + gathers[ws.address][ts.key] = [wws.address for wws in ts.who_has] - results = yield {w: self.rpc(addr=w).gather(who_has=who_has) - for w, who_has in gathers.items()} + results = yield { + w: self.rpc(addr=w).gather(who_has=who_has) + for w, who_has in gathers.items() + } for w, v in results.items(): - if v['status'] == 'OK': + if v["status"] == "OK": self.add_keys(worker=w, keys=list(gathers[w])) else: - logger.warning("Communication failed during replication: %s", - v) - - self.log_event(w, {'action': 'replicate-add', - 'keys': gathers[w]}) - - self.log_event('all', {'action': 'replicate', - 'workers': list(workers), - 'key-count': len(keys), - 'branching-factor': branching_factor}) + logger.warning("Communication failed during replication: %s", v) + + self.log_event(w, {"action": "replicate-add", "keys": gathers[w]}) + + self.log_event( + "all", + { + "action": "replicate", + "workers": list(workers), + "key-count": len(keys), + "branching-factor": branching_factor, + }, + ) - def workers_to_close(self, memory_ratio=None, n=None, key=None, - minimum=None): + def workers_to_close(self, memory_ratio=None, n=None, key=None, minimum=None): """ Find workers that we can close with low cost @@ -2737,10 +2932,10 @@ def workers_to_close(self, memory_ratio=None, n=None, key=None, groups = groupby(key, self.workers.values()) - limit_bytes = {k: sum(ws.memory_limit for ws in v) - for k, v in groups.items()} - group_bytes = {k: sum(ws.nbytes for ws in v) - for k, v in groups.items()} + limit_bytes = { + k: sum(ws.memory_limit for ws in v) for k, v in groups.items() + } + group_bytes = {k: sum(ws.nbytes for ws in v) for k, v in groups.items()} limit = sum(limit_bytes.values()) total = sum(group_bytes.values()) @@ -2765,8 +2960,9 @@ def key(group): limit -= limit_bytes[group] - if ((n is not None and len(to_close) < n) or - (memory_ratio is not None and limit >= memory_ratio * total)): + if (n is not None and len(to_close) < n) or ( + memory_ratio is not None and limit >= memory_ratio * total + ): to_close.append(group) n_remain -= len(groups[group]) @@ -2780,8 +2976,9 @@ def key(group): return result @gen.coroutine - def retire_workers(self, comm=None, workers=None, remove=True, - close_workers=False, **kwargs): + def retire_workers( + self, comm=None, workers=None, remove=True, close_workers=False, **kwargs + ): """ Gracefully retire workers from cluster Parameters @@ -2815,9 +3012,11 @@ def retire_workers(self, comm=None, workers=None, remove=True, try: workers = self.workers_to_close(**kwargs) if workers: - workers = yield self.retire_workers(workers=workers, - remove=remove, - close_workers=close_workers) + workers = yield self.retire_workers( + workers=workers, + remove=remove, + close_workers=close_workers, + ) raise gen.Return(workers) except KeyError: # keys left during replicate pass @@ -2833,24 +3032,31 @@ def retire_workers(self, comm=None, workers=None, remove=True, other_workers = set(self.workers.values()) - workers if keys: if other_workers: - yield self.replicate(keys=keys, - workers=[ws.address for ws in other_workers], - n=1, delete=False) + yield self.replicate( + keys=keys, + workers=[ws.address for ws in other_workers], + n=1, + delete=False, + ) else: raise gen.Return([]) worker_keys = {ws.address: ws.identity() for ws in workers} if close_workers and worker_keys: - yield [self.close_worker(worker=w, safe=True) - for w in worker_keys] + yield [self.close_worker(worker=w, safe=True) for w in worker_keys] if remove: for w in worker_keys: self.remove_worker(address=w, safe=True) - self.log_event('all', {'action': 'retire-workers', - 'workers': worker_keys, - 'moved-keys': len(keys)}) - self.log_event(list(worker_keys), {'action': 'retired'}) + self.log_event( + "all", + { + "action": "retire-workers", + "workers": worker_keys, + "moved-keys": len(keys), + }, + ) + self.log_event(list(worker_keys), {"action": "retired"}) raise gen.Return(worker_keys) @@ -2862,24 +3068,25 @@ def add_keys(self, comm=None, worker=None, keys=()): reasons. However, it is sent by workers from time to time. """ if worker not in self.workers: - return 'not found' + return "not found" ws = self.workers[worker] for key in keys: ts = self.tasks.get(key) - if ts is not None and ts.state == 'memory': + if ts is not None and ts.state == "memory": if ts not in ws.has_what: ws.nbytes += ts.get_nbytes() ws.has_what.add(ts) ts.who_has.add(ws) else: - self.worker_send(worker, {'op': 'delete-data', - 'keys': [key], - 'report': False}) + self.worker_send( + worker, {"op": "delete-data", "keys": [key], "report": False} + ) - return 'OK' + return "OK" - def update_data(self, comm=None, who_has=None, nbytes=None, client=None, - serializers=None): + def update_data( + self, comm=None, who_has=None, nbytes=None, client=None, serializers=None + ): """ Learn that new data has entered the network from an external source @@ -2888,15 +3095,16 @@ def update_data(self, comm=None, who_has=None, nbytes=None, client=None, Scheduler.mark_key_in_memory """ with log_errors(): - who_has = {k: [self.coerce_address(vv) for vv in v] - for k, v in who_has.items()} + who_has = { + k: [self.coerce_address(vv) for vv in v] for k, v in who_has.items() + } logger.debug("Update data %s", who_has) for key, workers in who_has.items(): ts = self.tasks.get(key) if ts is None: ts = self.tasks[key] = TaskState(key, None) - ts.state = 'memory' + ts.state = "memory" if key in nbytes: ts.set_nbytes(nbytes[key]) for w in workers: @@ -2905,9 +3113,9 @@ def update_data(self, comm=None, who_has=None, nbytes=None, client=None, ws.nbytes += ts.get_nbytes() ws.has_what.add(ts) ts.who_has.add(ws) - self.report({'op': 'key-in-memory', - 'key': key, - 'workers': list(workers)}) + self.report( + {"op": "key-in-memory", "key": key, "workers": list(workers)} + ) if client: self.client_desires_keys(keys=list(who_has), client=client) @@ -2918,29 +3126,31 @@ def report_on_key(self, key=None, ts=None, client=None): try: ts = self.tasks[key] except KeyError: - self.report({'op': 'cancelled-key', - 'key': key}, - client=client) + self.report({"op": "cancelled-key", "key": key}, client=client) return else: key = ts.key - if ts.state == 'forgotten': - self.report({'op': 'cancelled-key', - 'key': key}, ts=ts, client=client) - elif ts.state == 'memory': - self.report({'op': 'key-in-memory', - 'key': key}, ts=ts, client=client) - elif ts.state == 'erred': + if ts.state == "forgotten": + self.report({"op": "cancelled-key", "key": key}, ts=ts, client=client) + elif ts.state == "memory": + self.report({"op": "key-in-memory", "key": key}, ts=ts, client=client) + elif ts.state == "erred": failing_ts = ts.exception_blame - self.report({'op': 'task-erred', - 'key': key, - 'exception': failing_ts.exception, - 'traceback': failing_ts.traceback}, - ts=ts, client=client) + self.report( + { + "op": "task-erred", + "key": key, + "exception": failing_ts.exception, + "traceback": failing_ts.traceback, + }, + ts=ts, + client=client, + ) @gen.coroutine - def feed(self, comm, function=None, setup=None, teardown=None, - interval='1s', **kwargs): + def feed( + self, comm, function=None, setup=None, teardown=None, interval="1s", **kwargs + ): """ Provides a data Comm to external requester @@ -2948,6 +3158,7 @@ def feed(self, comm, function=None, setup=None, teardown=None, eventually be phased out. It is mostly used by diagnostics. """ import pickle + interval = parse_timedelta(interval) with log_errors(): if function: @@ -2960,7 +3171,7 @@ def feed(self, comm, function=None, setup=None, teardown=None, if isinstance(state, gen.Future): state = yield state try: - while self.status == 'running': + while self.status == "running": if state is None: response = function(self) else: @@ -2976,36 +3187,41 @@ def feed(self, comm, function=None, setup=None, teardown=None, def get_processing(self, comm=None, workers=None): if workers is not None: workers = set(map(self.coerce_address, workers)) - return {w: [ts.key for ts in self.workers[w].processing] - for w in workers} + return {w: [ts.key for ts in self.workers[w].processing] for w in workers} else: - return {w: [ts.key for ts in ws.processing] - for w, ws in self.workers.items()} + return { + w: [ts.key for ts in ws.processing] for w, ws in self.workers.items() + } def get_who_has(self, comm=None, keys=None): if keys is not None: - return {k: [ws.address for ws in self.tasks[k].who_has] - if k in self.tasks else [] - for k in keys} + return { + k: [ws.address for ws in self.tasks[k].who_has] + if k in self.tasks + else [] + for k in keys + } else: - return {key: [ws.address for ws in ts.who_has] - for key, ts in self.tasks.items()} + return { + key: [ws.address for ws in ts.who_has] for key, ts in self.tasks.items() + } def get_has_what(self, comm=None, workers=None): if workers is not None: workers = map(self.coerce_address, workers) - return {w: [ts.key for ts in self.workers[w].has_what] - if w in self.workers else [] - for w in workers} + return { + w: [ts.key for ts in self.workers[w].has_what] + if w in self.workers + else [] + for w in workers + } else: - return {w: [ts.key for ts in ws.has_what] - for w, ws in self.workers.items()} + return {w: [ts.key for ts in ws.has_what] for w, ws in self.workers.items()} def get_ncores(self, comm=None, workers=None): if workers is not None: workers = map(self.coerce_address, workers) - return {w: self.workers[w].ncores - for w in workers if w in self.workers} + return {w: self.workers[w].ncores for w in workers if w in self.workers} else: return {w: ws.ncores for w, ws in self.workers.items()} @@ -3017,9 +3233,9 @@ def get_call_stack(self, comm=None, keys=None): while stack: key = stack.pop() ts = self.tasks[key] - if ts.state == 'waiting': + if ts.state == "waiting": stack.extend(dts.key for dts in ts.dependencies) - elif ts.state == 'processing': + elif ts.state == "processing": processing.add(ts) workers = defaultdict(list) @@ -3033,8 +3249,9 @@ def get_call_stack(self, comm=None, keys=None): raise gen.Return({}) else: - response = yield {w: self.rpc(w).call_stack(keys=v) - for w, v in workers.items()} + response = yield { + w: self.rpc(w).call_stack(keys=v) for w, v in workers.items() + } response = {k: v for k, v in response.items() if v} raise gen.Return(response) @@ -3043,8 +3260,11 @@ def get_nbytes(self, comm=None, keys=None, summary=True): if keys is not None: result = {k: self.tasks[k].nbytes for k in keys} else: - result = {k: ts.nbytes for k, ts in self.tasks.items() - if ts.nbytes is not None} + result = { + k: ts.nbytes + for k, ts in self.tasks.items() + if ts.nbytes is not None + } if summary: out = defaultdict(lambda: 0) @@ -3059,9 +3279,7 @@ def get_comm_cost(self, ts, ws): Get the estimated communication cost (in s.) to compute the task on the given worker. """ - return (sum(dts.nbytes - for dts in ts.dependencies - ws.has_what) - / BANDWIDTH) + return sum(dts.nbytes for dts in ts.dependencies - ws.has_what) / BANDWIDTH def get_task_duration(self, ts, default=0.5): """ @@ -3083,7 +3301,8 @@ def run_function(self, stream, function, args=(), kwargs={}, wait=True): Client.run_on_scheduler: """ from .worker import run - self.log_event('all', {'action': 'run-function', 'function': function}) + + self.log_event("all", {"action": "run-function", "function": function}) return run(self, stream, function=function, args=args, kwargs=kwargs, wait=wait) def set_metadata(self, stream=None, keys=None, value=None): @@ -3095,7 +3314,9 @@ def set_metadata(self, stream=None, keys=None, value=None): metadata = metadata[key] metadata[keys[-1]] = value except Exception as e: - import pdb; pdb.set_trace() + import pdb + + pdb.set_trace() def get_metadata(self, stream=None, keys=None, default=no_default): metadata = self.task_metadata @@ -3110,12 +3331,13 @@ def get_metadata(self, stream=None, keys=None, default=no_default): raise def get_task_status(self, stream=None, keys=None): - return {key: (self.tasks[key].state - if key in self.tasks else None) - for key in keys} + return { + key: (self.tasks[key].state if key in self.tasks else None) for key in keys + } def get_task_stream(self, comm=None, start=None, stop=None, count=None): from distributed.diagnostics.task_stream import TaskStreamPlugin + self.add_plugin(TaskStreamPlugin, idempotent=True) ts = [p for p in self.plugins if isinstance(p, TaskStreamPlugin)][0] return ts.collect(start=start, stop=stop, count=count) @@ -3128,7 +3350,7 @@ def register_worker_callbacks(self, comm, setup=None): self.worker_setups.append(setup) - responses = yield self.broadcast(msg=dict(op='run', function=setup)) + responses = yield self.broadcast(msg=dict(op="run", function=setup)) raise gen.Return(responses) ##################### @@ -3168,35 +3390,33 @@ def _add_to_memory(self, ts, ws, recommendations, type=None, **kwargs): deps = ts.dependents if len(deps) > 1: - deps = sorted(deps, key=operator.attrgetter('priority'), - reverse=True) + deps = sorted(deps, key=operator.attrgetter("priority"), reverse=True) for dts in deps: s = dts.waiting_on if ts in s: s.discard(ts) if not s: # new task ready to run - recommendations[dts.key] = 'processing' + recommendations[dts.key] = "processing" for dts in ts.dependencies: s = dts.waiters s.discard(ts) if not s and not dts.who_wants: - recommendations[dts.key] = 'released' + recommendations[dts.key] = "released" if not ts.waiters and not ts.who_wants: - recommendations[ts.key] = 'released' + recommendations[ts.key] = "released" else: - msg = {'op': 'key-in-memory', - 'key': ts.key} + msg = {"op": "key-in-memory", "key": ts.key} if type is not None: - msg['type'] = type + msg["type"] = type self.report(msg) - ts.state = 'memory' + ts.state = "memory" - cs = self.clients['fire-and-forget'] + cs = self.clients["fire-and-forget"] if ts in cs.wants_what: - self.client_releases_keys(client='fire-and-forget', keys=[ts.key]) + self.client_releases_keys(client="fire-and-forget", keys=[ts.key]) def transition_released_waiting(self, key): try: @@ -3207,45 +3427,45 @@ def transition_released_waiting(self, key): assert not ts.waiting_on assert not ts.who_has assert not ts.processing_on - assert not any(dts.state == 'forgotten' for dts in ts.dependencies) + assert not any(dts.state == "forgotten" for dts in ts.dependencies) if ts.has_lost_dependencies: - return {key: 'forgotten'} + return {key: "forgotten"} - ts.state = 'waiting' + ts.state = "waiting" recommendations = OrderedDict() for dts in ts.dependencies: if dts.exception_blame: ts.exception_blame = dts.exception_blame - recommendations[key] = 'erred' + recommendations[key] = "erred" return recommendations for dts in ts.dependencies: dep = dts.key if not dts.who_has: ts.waiting_on.add(dts) - if dts.state == 'released': - recommendations[dep] = 'waiting' + if dts.state == "released": + recommendations[dep] = "waiting" else: dts.waiters.add(ts) - ts.waiters = {dts for dts in ts.dependents - if dts.state == 'waiting'} + ts.waiters = {dts for dts in ts.dependents if dts.state == "waiting"} if not ts.waiting_on: if self.workers: - recommendations[key] = 'processing' + recommendations[key] = "processing" else: self.unrunnable.add(ts) - ts.state = 'no-worker' + ts.state = "no-worker" return recommendations except Exception as e: logger.exception(e) if LOG_PDB: import pdb + pdb.set_trace() raise @@ -3262,7 +3482,7 @@ def transition_no_worker_waiting(self, key): self.unrunnable.remove(ts) if ts.has_lost_dependencies: - return {key: 'forgotten'} + return {key: "forgotten"} recommendations = OrderedDict() @@ -3270,25 +3490,26 @@ def transition_no_worker_waiting(self, key): dep = dts.key if not dts.who_has: ts.waiting_on.add(dep) - if dts.state == 'released': - recommendations[dep] = 'waiting' + if dts.state == "released": + recommendations[dep] = "waiting" else: dts.waiters.add(ts) - ts.state = 'waiting' + ts.state = "waiting" if not ts.waiting_on: if self.workers: - recommendations[key] = 'processing' + recommendations[key] = "processing" else: self.unrunnable.add(ts) - ts.state = 'no-worker' + ts.state = "no-worker" return recommendations except Exception as e: logger.exception(e) if LOG_PDB: import pdb + pdb.set_trace() raise @@ -3300,27 +3521,34 @@ def decide_worker(self, ts): if not valid_workers and not ts.loose_restrictions and self.workers: self.unrunnable.add(ts) - ts.state = 'no-worker' + ts.state = "no-worker" return None if ts.dependencies or valid_workers is not True: - worker = decide_worker(ts, self.workers.values(), valid_workers, - partial(self.worker_objective, ts)) + worker = decide_worker( + ts, + self.workers.values(), + valid_workers, + partial(self.worker_objective, ts), + ) elif self.idle: if len(self.idle) < 20: # smart but linear in small case - worker = min(self.idle, - key=operator.attrgetter('occupancy')) + worker = min(self.idle, key=operator.attrgetter("occupancy")) else: # dumb but fast in large case worker = self.idle[self.n_tasks % len(self.idle)] else: if len(self.workers) < 20: # smart but linear in small case - worker = min(self.workers.values(), - key=operator.attrgetter('occupancy')) + worker = min( + self.workers.values(), key=operator.attrgetter("occupancy") + ) else: # dumb but fast in large case worker = self.workers.values()[self.n_tasks % len(self.workers)] if self.validate: - assert worker is None or isinstance(worker, WorkerState), (type(worker), worker) + assert worker is None or isinstance(worker, WorkerState), ( + type(worker), + worker, + ) assert worker.address in self.workers return worker @@ -3336,8 +3564,7 @@ def transition_waiting_processing(self, key): assert not ts.processing_on assert not ts.has_lost_dependencies assert ts not in self.unrunnable - assert all(dts.who_has - for dts in ts.dependencies) + assert all(dts.who_has for dts in ts.dependencies) ws = self.decide_worker(ts) if ws is None: @@ -3351,7 +3578,7 @@ def transition_waiting_processing(self, key): ts.processing_on = ws ws.occupancy += duration + comm self.total_occupancy += duration + comm - ts.state = 'processing' + ts.state = "processing" self.consume_resources(ts, ws) self.check_idle_saturated(ws) self.n_tasks += 1 @@ -3368,6 +3595,7 @@ def transition_waiting_processing(self, key): logger.exception(e) if LOG_PDB: import pdb + pdb.set_trace() raise @@ -3379,7 +3607,7 @@ def transition_waiting_memory(self, key, nbytes=None, worker=None, **kwargs): if self.validate: assert not ts.processing_on assert ts.waiting_on - assert ts.state == 'waiting' + assert ts.state == "waiting" ts.waiting_on.clear() @@ -3402,11 +3630,13 @@ def transition_waiting_memory(self, key, nbytes=None, worker=None, **kwargs): logger.exception(e) if LOG_PDB: import pdb + pdb.set_trace() raise - def transition_processing_memory(self, key, nbytes=None, type=None, - worker=None, startstops=None, **kwargs): + def transition_processing_memory( + self, key, nbytes=None, type=None, worker=None, startstops=None, **kwargs + ): try: ts = self.tasks[key] assert worker @@ -3419,20 +3649,24 @@ def transition_processing_memory(self, key, nbytes=None, type=None, assert not ts.waiting_on assert not ts.who_has, (ts, ts.who_has) assert not ts.exception_blame - assert ts.state == 'processing' + assert ts.state == "processing" ws = self.workers.get(worker) if ws is None: - return {key: 'released'} + return {key: "released"} if ws is not ts.processing_on: # someone else has this task - logger.info("Unexpected worker completed task, likely due to" - " work stealing. Expected: %s, Got: %s, Key: %s", - ts.processing_on, ws, key) + logger.info( + "Unexpected worker completed task, likely due to" + " work stealing. Expected: %s, Got: %s, Key: %s", + ts.processing_on, + ws, + key, + ) return {} if startstops: - L = [(b, c) for a, b, c in startstops if a == 'compute'] + L = [(b, c) for a, b, c in startstops if a == "compute"] if L: compute_start, compute_stop = L[0] else: # This is very rare @@ -3451,8 +3685,7 @@ def transition_processing_memory(self, key, nbytes=None, type=None, if not old_duration: avg_duration = new_duration else: - avg_duration = (0.5 * old_duration - + 0.5 * new_duration) + avg_duration = 0.5 * old_duration + 0.5 * new_duration self.task_duration[prefix] = avg_duration @@ -3486,6 +3719,7 @@ def transition_processing_memory(self, key, nbytes=None, type=None, logger.exception(e) if LOG_PDB: import pdb + pdb.set_trace() raise @@ -3505,35 +3739,35 @@ def transition_memory_released(self, key, safe=False): if ts.who_wants: ts.exception_blame = ts ts.exception = "Worker holding Actor was lost" - return {ts.key: 'erred'} # don't try to recreate + return {ts.key: "erred"} # don't try to recreate recommendations = OrderedDict() for dts in ts.waiters: - if dts.state in ('no-worker', 'processing'): - recommendations[dts.key] = 'waiting' - elif dts.state == 'waiting': + if dts.state in ("no-worker", "processing"): + recommendations[dts.key] = "waiting" + elif dts.state == "waiting": dts.waiting_on.add(ts) # XXX factor this out? for ws in ts.who_has: ws.has_what.remove(ts) ws.nbytes -= ts.get_nbytes() - self.worker_send(ws.address, {'op': 'delete-data', - 'keys': [key], - 'report': False}) + self.worker_send( + ws.address, {"op": "delete-data", "keys": [key], "report": False} + ) ts.who_has.clear() - ts.state = 'released' + ts.state = "released" - self.report({'op': 'lost-data', 'key': key}) + self.report({"op": "lost-data", "key": key}) if not ts.run_spec: # pure data - recommendations[key] = 'forgotten' + recommendations[key] = "forgotten" elif ts.has_lost_dependencies: - recommendations[key] = 'forgotten' + recommendations[key] = "forgotten" elif ts.who_wants or ts.waiters: - recommendations[key] = 'waiting' + recommendations[key] = "waiting" if self.validate: assert not ts.waiting_on @@ -3543,6 +3777,7 @@ def transition_memory_released(self, key, safe=False): logger.exception(e) if LOG_PDB: import pdb + pdb.set_trace() raise @@ -3564,14 +3799,18 @@ def transition_released_erred(self, key): for dts in ts.dependents: dts.exception_blame = failing_ts if not dts.who_has: - recommendations[dts.key] = 'erred' - - self.report({'op': 'task-erred', - 'key': key, - 'exception': failing_ts.exception, - 'traceback': failing_ts.traceback}) + recommendations[dts.key] = "erred" + + self.report( + { + "op": "task-erred", + "key": key, + "exception": failing_ts.exception, + "traceback": failing_ts.traceback, + } + ) - ts.state = 'erred' + ts.state = "erred" # TODO: waiting data? return recommendations @@ -3579,6 +3818,7 @@ def transition_released_erred(self, key): logger.exception(e) if LOG_PDB: import pdb + pdb.set_trace() raise @@ -3588,7 +3828,7 @@ def transition_erred_released(self, key): if self.validate: with log_errors(pdb=LOG_PDB): - assert all(dts.state != 'erred' for dts in ts.dependencies) + assert all(dts.state != "erred" for dts in ts.dependencies) assert ts.exception_blame assert not ts.who_has assert not ts.waiting_on @@ -3601,17 +3841,18 @@ def transition_erred_released(self, key): ts.traceback = None for dep in ts.dependents: - if dep.state == 'erred': - recommendations[dep.key] = 'waiting' + if dep.state == "erred": + recommendations[dep.key] = "waiting" - self.report({'op': 'task-retried', 'key': key}) - ts.state = 'released' + self.report({"op": "task-retried", "key": key}) + ts.state = "released" return recommendations except Exception as e: logger.exception(e) if LOG_PDB: import pdb + pdb.set_trace() raise @@ -3630,15 +3871,15 @@ def transition_waiting_released(self, key): if ts in s: s.discard(ts) if not s and not dts.who_wants: - recommendations[dts.key] = 'released' + recommendations[dts.key] = "released" ts.waiting_on.clear() - ts.state = 'released' + ts.state = "released" if ts.has_lost_dependencies: - recommendations[key] = 'forgotten' + recommendations[key] = "forgotten" elif not ts.exception_blame and (ts.who_wants or ts.waiters): - recommendations[key] = 'waiting' + recommendations[key] = "waiting" else: ts.waiters.clear() @@ -3647,6 +3888,7 @@ def transition_waiting_released(self, key): logger.exception(e) if LOG_PDB: import pdb + pdb.set_trace() raise @@ -3658,27 +3900,28 @@ def transition_processing_released(self, key): assert ts.processing_on assert not ts.who_has assert not ts.waiting_on - assert self.tasks[key].state == 'processing' + assert self.tasks[key].state == "processing" - self._remove_from_processing(ts, send_worker_msg={'op': 'release-task', - 'key': key}) + self._remove_from_processing( + ts, send_worker_msg={"op": "release-task", "key": key} + ) - ts.state = 'released' + ts.state = "released" recommendations = OrderedDict() if ts.has_lost_dependencies: - recommendations[key] = 'forgotten' + recommendations[key] = "forgotten" elif ts.waiters or ts.who_wants: - recommendations[key] = 'waiting' + recommendations[key] = "waiting" - if recommendations.get(key) != 'waiting': + if recommendations.get(key) != "waiting": for dts in ts.dependencies: - if dts.state != 'released': + if dts.state != "released": s = dts.waiters s.discard(ts) if not s and not dts.who_wants: - recommendations[dts.key] = 'released' + recommendations[dts.key] = "released" ts.waiters.clear() if self.validate: @@ -3689,11 +3932,13 @@ def transition_processing_released(self, key): logger.exception(e) if LOG_PDB: import pdb + pdb.set_trace() raise - def transition_processing_erred(self, key, cause=None, exception=None, - traceback=None, **kwargs): + def transition_processing_erred( + self, key, cause=None, exception=None, traceback=None, **kwargs + ): try: ts = self.tasks[key] @@ -3723,26 +3968,30 @@ def transition_processing_erred(self, key, cause=None, exception=None, for dts in ts.dependents: dts.exception_blame = failing_ts - recommendations[dts.key] = 'erred' + recommendations[dts.key] = "erred" for dts in ts.dependencies: s = dts.waiters s.discard(ts) if not s and not dts.who_wants: - recommendations[dts.key] = 'released' + recommendations[dts.key] = "released" ts.waiters.clear() # do anything with this? - ts.state = 'erred' + ts.state = "erred" - self.report({'op': 'task-erred', - 'key': key, - 'exception': failing_ts.exception, - 'traceback': failing_ts.traceback}) + self.report( + { + "op": "task-erred", + "key": key, + "exception": failing_ts.exception, + "traceback": failing_ts.traceback, + } + ) - cs = self.clients['fire-and-forget'] + cs = self.clients["fire-and-forget"] if ts in cs.wants_what: - self.client_releases_keys(client='fire-and-forget', keys=[key]) + self.client_releases_keys(client="fire-and-forget", keys=[key]) if self.validate: assert not ts.processing_on @@ -3752,6 +4001,7 @@ def transition_processing_erred(self, key, cause=None, exception=None, logger.exception(e) if LOG_PDB: import pdb + pdb.set_trace() raise @@ -3760,12 +4010,12 @@ def transition_no_worker_released(self, key): ts = self.tasks[key] if self.validate: - assert self.tasks[key].state == 'no-worker' + assert self.tasks[key].state == "no-worker" assert not ts.who_has assert not ts.waiting_on self.unrunnable.remove(ts) - ts.state = 'released' + ts.state = "released" for dts in ts.dependencies: dts.waiters.discard(ts) @@ -3777,12 +4027,13 @@ def transition_no_worker_released(self, key): logger.exception(e) if LOG_PDB: import pdb + pdb.set_trace() raise def remove_key(self, key): ts = self.tasks.pop(key) - assert ts.state == 'forgotten' + assert ts.state == "forgotten" self.unrunnable.discard(ts) for cs in ts.who_wants: cs.wants_what.remove(ts) @@ -3794,15 +4045,15 @@ def remove_key(self, key): del self.task_metadata[key] def _propagate_forgotten(self, ts, recommendations): - ts.state = 'forgotten' + ts.state = "forgotten" key = ts.key for dts in ts.dependents: dts.has_lost_dependencies = True dts.dependencies.remove(ts) dts.waiting_on.discard(ts) - if dts.state not in ('memory', 'erred'): + if dts.state not in ("memory", "erred"): # Cannot compute task anymore - recommendations[dts.key] = 'forgotten' + recommendations[dts.key] = "forgotten" ts.dependents.clear() ts.waiters.clear() @@ -3813,7 +4064,7 @@ def _propagate_forgotten(self, ts, recommendations): if not dts.dependents and not dts.who_wants: # Task not needed anymore assert dts is not ts - recommendations[dts.key] = 'forgotten' + recommendations[dts.key] = "forgotten" ts.dependencies.clear() ts.waiting_on.clear() @@ -3822,9 +4073,9 @@ def _propagate_forgotten(self, ts, recommendations): ws.nbytes -= ts.get_nbytes() w = ws.address if w in self.workers: # in case worker has died - self.worker_send(w, {'op': 'delete-data', - 'keys': [key], - 'report': False}) + self.worker_send( + w, {"op": "delete-data", "keys": [key], "report": False} + ) ts.who_has.clear() def transition_memory_forgotten(self, key): @@ -3832,7 +4083,7 @@ def transition_memory_forgotten(self, key): ts = self.tasks[key] if self.validate: - assert ts.state == 'memory' + assert ts.state == "memory" assert not ts.processing_on assert not ts.waiting_on if not ts.run_spec: @@ -3863,6 +4114,7 @@ def transition_memory_forgotten(self, key): logger.exception(e) if LOG_PDB: import pdb + pdb.set_trace() raise @@ -3871,7 +4123,7 @@ def transition_released_forgotten(self, key): ts = self.tasks[key] if self.validate: - assert ts.state in ('released', 'erred') + assert ts.state in ("released", "erred") assert not ts.who_has assert not ts.processing_on assert not ts.waiting_on, (ts, ts.waiting_on) @@ -3898,6 +4150,7 @@ def transition_released_forgotten(self, key): logger.exception(e) if LOG_PDB: import pdb + pdb.set_trace() raise @@ -3933,30 +4186,36 @@ def transition(self, key, finish, *args, **kwargs): if (start, finish) in self._transitions: func = self._transitions[start, finish] recommendations = func(key, *args, **kwargs) - elif 'released' not in (start, finish): - func = self._transitions['released', finish] + elif "released" not in (start, finish): + func = self._transitions["released", finish] assert not args and not kwargs - a = self.transition(key, 'released') + a = self.transition(key, "released") if key in a: - func = self._transitions['released', a[key]] + func = self._transitions["released", a[key]] b = func(key) a = a.copy() a.update(b) recommendations = a - start = 'released' + start = "released" else: - raise RuntimeError("Impossible transition from %r to %r" - % (start, finish)) + raise RuntimeError( + "Impossible transition from %r to %r" % (start, finish) + ) finish2 = ts.state - self.transition_log.append((key, start, finish2, recommendations, - time())) + self.transition_log.append((key, start, finish2, recommendations, time())) if self.validate: - logger.debug("Transitioned %r %s->%s (actual: %s). Consequence: %s", - key, start, finish2, ts.state, dict(recommendations)) + logger.debug( + "Transitioned %r %s->%s (actual: %s). Consequence: %s", + key, + start, + finish2, + ts.state, + dict(recommendations), + ) if self.plugins: # Temporarily put back forgotten key for plugin to retrieve it - if ts.state == 'forgotten': + if ts.state == "forgotten": try: ts.dependents = dependents ts.dependencies = dependencies @@ -3968,15 +4227,15 @@ def transition(self, key, finish, *args, **kwargs): plugin.transition(key, start, finish2, *args, **kwargs) except Exception: logger.info("Plugin failed with exception", exc_info=True) - if ts.state == 'forgotten': + if ts.state == "forgotten": del self.tasks[ts.key] return recommendations except Exception as e: - logger.exception("Error transitioning %r from %r to %r", - key, start, finish) + logger.exception("Error transitioning %r from %r to %r", key, start, finish) if LOG_PDB: import pdb + pdb.set_trace() raise @@ -4001,8 +4260,9 @@ def transitions(self, recommendations): def story(self, *keys): """ Get all transitions that touch one of the input keys """ keys = set(keys) - return [t for t in self.transition_log - if t[0] in keys or keys.intersection(t[3])] + return [ + t for t in self.transition_log if t[0] in keys or keys.intersection(t[3]) + ] transition_story = story @@ -4013,18 +4273,18 @@ def reschedule(self, key=None, worker=None): elsewhere """ ts = self.tasks[key] - if ts.state != 'processing': + if ts.state != "processing": return if worker and ts.processing_on.address != worker: return - self.transitions({key: 'released'}) + self.transitions({key: "released"}) ############################## # Assigning Tasks to Workers # ############################## def check_idle_saturated(self, ws, occ=None): - if self.total_ncores == 0 or ws.status == 'closed': + if self.total_ncores == 0 or ws.status == "closed": return if occ is None: occ = ws.occupancy @@ -4065,8 +4325,7 @@ def valid_workers(self, ts): # may not be connected when host_restrictions is populated hr = [self.coerce_hostname(h) for h in ts.host_restrictions] # XXX need HostState? - ss = [self.host_info[h]['addresses'] - for h in hr if h in self.host_info] + ss = [self.host_info[h]["addresses"] for h in hr if h in self.host_info] ss = set.union(*ss) if ss else set() if s is True: s = ss @@ -4074,9 +4333,14 @@ def valid_workers(self, ts): s |= ss if ts.resource_restrictions: - w = {resource: {w for w, supplied in self.resources[resource].items() - if supplied >= required} - for resource, required in ts.resource_restrictions.items()} + w = { + resource: { + w + for w, supplied in self.resources[resource].items() + if supplied >= required + } + for resource, required in ts.resource_restrictions.items() + } ww = set.intersection(*w.values()) @@ -4112,7 +4376,7 @@ def add_resources(self, stream=None, worker=None, resources=None): for resource, quantity in ws.resources.items(): ws.used_resources[resource] = 0 self.resources[resource][worker] = quantity - return 'OK' + return "OK" def remove_resources(self, worker): ws = self.workers[worker] @@ -4132,8 +4396,7 @@ def coerce_address(self, addr, resolve=True): if isinstance(addr, tuple): addr = unparse_host_port(*addr) if not isinstance(addr, six.string_types): - raise TypeError("addresses should be strings or tuples, got %r" - % (addr,)) + raise TypeError("addresses should be strings or tuples, got %r" % (addr,)) if resolve: addr = resolve_address(addr) @@ -4163,7 +4426,7 @@ def workers_list(self, workers): out = set() for w in workers: - if ':' in w: + if ":" in w: out.add(w) else: out.update({ww for ww in self.workers if w in ww}) # TODO: quadratic @@ -4175,11 +4438,10 @@ def start_ipython(self, comm=None): Returns Jupyter connection info dictionary. """ from ._ipython_utils import start_ipython + if self._ipython_kernel is None: self._ipython_kernel = start_ipython( - ip=self.ip, - ns={'scheduler': self}, - log=logger, + ip=self.ip, ns={"scheduler": self}, log=logger ) return self._ipython_kernel.get_connection_info() @@ -4189,9 +4451,9 @@ def worker_objective(self, ts, ws): Minimize expected start time. If a tie then break with data storage. """ - comm_bytes = sum([dts.get_nbytes() - for dts in ts.dependencies - if ws not in dts.who_has]) + comm_bytes = sum( + [dts.get_nbytes() for dts in ts.dependencies if ws not in dts.who_has] + ) stack_time = ws.occupancy / ws.ncores start_time = comm_bytes / BANDWIDTH + stack_time @@ -4201,43 +4463,61 @@ def worker_objective(self, ts, ws): return (start_time, ws.nbytes) @gen.coroutine - def get_profile(self, comm=None, workers=None, merge_workers=True, - start=None, stop=None, key=None): + def get_profile( + self, + comm=None, + workers=None, + merge_workers=True, + start=None, + stop=None, + key=None, + ): if workers is None: workers = self.workers else: workers = set(self.workers) & set(workers) - result = yield {w: self.rpc(w).profile(start=start, stop=stop, key=key) - for w in workers} + result = yield { + w: self.rpc(w).profile(start=start, stop=stop, key=key) for w in workers + } if merge_workers: result = profile.merge(*result.values()) raise gen.Return(result) @gen.coroutine - def get_profile_metadata(self, comm=None, workers=None, merge_workers=True, - start=None, stop=None, profile_cycle_interval=None): - dt = profile_cycle_interval or dask.config.get('distributed.worker.profile.cycle') - dt = parse_timedelta(dt, default='ms') + def get_profile_metadata( + self, + comm=None, + workers=None, + merge_workers=True, + start=None, + stop=None, + profile_cycle_interval=None, + ): + dt = profile_cycle_interval or dask.config.get( + "distributed.worker.profile.cycle" + ) + dt = parse_timedelta(dt, default="ms") if workers is None: workers = self.workers else: workers = set(self.workers) & set(workers) - result = yield {w: self.rpc(w).profile_metadata(start=start, stop=stop) - for w in workers} + result = yield { + w: self.rpc(w).profile_metadata(start=start, stop=stop) for w in workers + } - counts = [v['counts'] for v in result.values()] + counts = [v["counts"] for v in result.values()] counts = itertools.groupby(merge_sorted(*counts), lambda t: t[0] // dt * dt) counts = [(time, sum(pluck(1, group))) for time, group in counts] keys = set() for v in result.values(): - for t, d in v['keys']: + for t, d in v["keys"]: for k in d: keys.add(k) keys = {k: [] for k in keys} - groups1 = [v['keys'] for v in result.values()] + groups1 = [v["keys"] for v in result.values()] groups2 = list(merge_sorted(*groups1, key=first)) last = 0 @@ -4250,7 +4530,7 @@ def get_profile_metadata(self, comm=None, workers=None, merge_workers=True, for k, v in d.items(): keys[k][-1][1] += v - raise gen.Return({'counts': counts, 'keys': keys}) + raise gen.Return({"counts": counts, "keys": keys}) def get_logs(self, comm=None, n=None): deque_handler = self._deque_handler @@ -4263,8 +4543,7 @@ def get_logs(self, comm=None, n=None): @gen.coroutine def get_worker_logs(self, comm=None, n=None, workers=None): - results = yield self.broadcast(msg={'op': 'get_logs', 'n': n}, - workers=workers) + results = yield self.broadcast(msg={"op": "get_logs", "n": n}, workers=workers) raise gen.Return(results) ########### @@ -4288,7 +4567,7 @@ def reevaluate_occupancy(self, worker_index=0): """ DELAY = 0.1 try: - if self.status == 'closed': + if self.status == "closed": return last = time() @@ -4311,8 +4590,9 @@ def reevaluate_occupancy(self, worker_index=0): next_time = timedelta(seconds=duration * 5) # 25ms gap break - self.loop.add_timeout(next_time, self.reevaluate_occupancy, - worker_index=worker_index) + self.loop.add_timeout( + next_time, self.reevaluate_occupancy, worker_index=worker_index + ) except Exception: logger.error("Error in reevaluate occupancy", exc_info=True) @@ -4335,8 +4615,8 @@ def _reevaluate_occupancy_worker(self, ws): self.check_idle_saturated(ws) # significant increase in duration - if (new > old * 1.3) and ('stealing' in self.extensions): - steal = self.extensions['stealing'] + if (new > old * 1.3) and ("stealing" in self.extensions): + steal = self.extensions["stealing"] for ts in ws.processing: steal.remove_key_from_stealable(ts) steal.put_key_in_stealable(ts) @@ -4345,8 +4625,11 @@ def check_worker_ttl(self): now = time() for ws in self.workers.values(): if ws.last_seen < now - self.worker_ttl: - logger.warning("Worker failed to heartbeat within %s seconds. " - "Closing: %s", self.worker_ttl, ws) + logger.warning( + "Worker failed to heartbeat within %s seconds. " "Closing: %s", + self.worker_ttl, + ws, + ) self.remove_worker(address=ws.address) @@ -4371,8 +4654,7 @@ def decide_worker(ts, all_workers, valid_workers, objective): if ts.actor: candidates = all_workers else: - candidates = frequencies([ws for dts in deps - for ws in dts.who_has]) + candidates = frequencies([ws for dts in deps for ws in dts.who_has]) if valid_workers is True: if not candidates: candidates = all_workers @@ -4398,71 +4680,107 @@ def validate_task_state(ts): """ Validate the given TaskState. """ - assert ts.state in ALL_TASK_STATES or ts.state == 'forgotten', ts + assert ts.state in ALL_TASK_STATES or ts.state == "forgotten", ts if ts.waiting_on: - assert ts.waiting_on.issubset(ts.dependencies), \ - ("waiting not subset of dependencies", str(ts.waiting_on), str(ts.dependencies)) + assert ts.waiting_on.issubset(ts.dependencies), ( + "waiting not subset of dependencies", + str(ts.waiting_on), + str(ts.dependencies), + ) if ts.waiters: - assert ts.waiters.issubset(ts.dependents), \ - ("waiters not subset of dependents", str(ts.waiters), str(ts.dependents)) + assert ts.waiters.issubset(ts.dependents), ( + "waiters not subset of dependents", + str(ts.waiters), + str(ts.dependents), + ) for dts in ts.waiting_on: - assert not dts.who_has, \ - ("waiting on in-memory dep", str(ts), str(dts)) - assert dts.state != 'released', \ - ("waiting on released dep", str(ts), str(dts)) + assert not dts.who_has, ("waiting on in-memory dep", str(ts), str(dts)) + assert dts.state != "released", ("waiting on released dep", str(ts), str(dts)) for dts in ts.dependencies: - assert ts in dts.dependents, \ - ("not in dependency's dependents", str(ts), str(dts), str(dts.dependents)) - if ts.state in ('waiting', 'processing'): - assert dts in ts.waiting_on or dts.who_has, \ - ("dep missing", str(ts), str(dts)) - assert dts.state != 'forgotten' + assert ts in dts.dependents, ( + "not in dependency's dependents", + str(ts), + str(dts), + str(dts.dependents), + ) + if ts.state in ("waiting", "processing"): + assert dts in ts.waiting_on or dts.who_has, ( + "dep missing", + str(ts), + str(dts), + ) + assert dts.state != "forgotten" for dts in ts.waiters: - assert dts.state in ('waiting', 'processing'), \ - ("waiter not in play", str(ts), str(dts)) + assert dts.state in ("waiting", "processing"), ( + "waiter not in play", + str(ts), + str(dts), + ) for dts in ts.dependents: - assert ts in dts.dependencies, \ - ("not in dependent's dependencies", str(ts), str(dts), str(dts.dependencies)) - assert dts.state != 'forgotten' + assert ts in dts.dependencies, ( + "not in dependent's dependencies", + str(ts), + str(dts), + str(dts.dependencies), + ) + assert dts.state != "forgotten" - assert (ts.processing_on is not None) == (ts.state == 'processing') - assert bool(ts.who_has) == (ts.state == 'memory'), (ts, ts.who_has) + assert (ts.processing_on is not None) == (ts.state == "processing") + assert bool(ts.who_has) == (ts.state == "memory"), (ts, ts.who_has) - if ts.state == 'processing': - assert all(dts.who_has for dts in ts.dependencies), \ - ("task processing without all deps", str(ts), str(ts.dependencies)) + if ts.state == "processing": + assert all(dts.who_has for dts in ts.dependencies), ( + "task processing without all deps", + str(ts), + str(ts.dependencies), + ) assert not ts.waiting_on if ts.who_has: - assert ts.waiters or ts.who_wants, \ - ("unneeded task in memory", str(ts), str(ts.who_has)) + assert ts.waiters or ts.who_wants, ( + "unneeded task in memory", + str(ts), + str(ts.who_has), + ) assert not any(ts in dts.waiting_on for dts in ts.dependents) for ws in ts.who_has: - assert ts in ws.has_what, \ - ("not in who_has' has_what", str(ts), str(ws), str(ws.has_what)) + assert ts in ws.has_what, ( + "not in who_has' has_what", + str(ts), + str(ws), + str(ws.has_what), + ) if ts.who_wants: for cs in ts.who_wants: - assert ts in cs.wants_what, \ - ("not in who_wants' wants_what", str(ts), str(cs), str(cs.wants_what)) + assert ts in cs.wants_what, ( + "not in who_wants' wants_what", + str(ts), + str(cs), + str(cs.wants_what), + ) if ts.actor: - if ts.state == 'memory': + if ts.state == "memory": assert sum([ts in ws.actors for ws in ts.who_has]) == 1 - if ts.state == 'processing': + if ts.state == "processing": assert ts in ts.processing_on.actors def validate_worker_state(ws): for ts in ws.has_what: - assert ws in ts.who_has, \ - ("not in has_what' who_has", str(ws), str(ts), str(ts.who_has)) + assert ws in ts.who_has, ( + "not in has_what' who_has", + str(ws), + str(ts), + str(ts.who_has), + ) for ts in ws.actors: - assert ts.state in ('memory', 'processing') + assert ts.state in ("memory", "processing") def validate_state(tasks, workers, clients): @@ -4480,14 +4798,18 @@ def validate_state(tasks, workers, clients): for cs in clients.values(): for ts in cs.wants_what: - assert cs in ts.who_wants, \ - ("not in wants_what' who_wants", str(cs), str(ts), str(ts.who_wants)) + assert cs in ts.who_wants, ( + "not in wants_what' who_wants", + str(cs), + str(ts), + str(ts.who_wants), + ) _round_robin = [0] -fast_tasks = {'rechunk-split', 'shuffle-split'} +fast_tasks = {"rechunk-split", "shuffle-split"} def heartbeat_interval(n): diff --git a/distributed/security.py b/distributed/security.py index 0a40396a54a..e86c0602860 100644 --- a/distributed/security.py +++ b/distributed/security.py @@ -8,24 +8,23 @@ import dask -_roles = ['client', 'scheduler', 'worker'] +_roles = ["client", "scheduler", "worker"] -_tls_per_role_fields = ['key', 'cert'] +_tls_per_role_fields = ["key", "cert"] -_tls_fields = ['ca_file', 'ciphers'] +_tls_fields = ["ca_file", "ciphers"] -_misc_fields = ['require_encryption'] +_misc_fields = ["require_encryption"] -_fields = set(_misc_fields + - ['tls_%s' % field for field in _tls_fields] + - ['tls_%s_%s' % (role, field) - for role in _roles - for field in _tls_per_role_fields] - ) +_fields = set( + _misc_fields + + ["tls_%s" % field for field in _tls_fields] + + ["tls_%s_%s" % (role, field) for role in _roles for field in _tls_per_role_fields] +) def _field_to_config_key(field): - return field.replace('_', '-') + return field.replace("_", "-") class Security(object): @@ -61,65 +60,65 @@ def _init_from_dict(self, d): """ Initialize Security from nested dict. """ - self._init_fields_from_dict(d, '', _misc_fields, {}) - self._init_fields_from_dict(d, 'tls', _tls_fields, _tls_per_role_fields) + self._init_fields_from_dict(d, "", _misc_fields, {}) + self._init_fields_from_dict(d, "tls", _tls_fields, _tls_per_role_fields) - def _init_fields_from_dict(self, d, category, - fields, per_role_fields): + def _init_fields_from_dict(self, d, category, fields, per_role_fields): if category: d = d.get(category, {}) - category_prefix = category + '_' + category_prefix = category + "_" else: - category_prefix = '' + category_prefix = "" for field in fields: k = _field_to_config_key(field) if k in d: - setattr(self, '%s%s' % (category_prefix, field), d[k]) + setattr(self, "%s%s" % (category_prefix, field), d[k]) for role in _roles: dd = d.get(role, {}) for field in per_role_fields: k = _field_to_config_key(field) if k in dd: - setattr(self, '%s%s_%s' % (category_prefix, role, field), dd[k]) + setattr(self, "%s%s_%s" % (category_prefix, role, field), dd[k]) def __repr__(self): items = sorted((k, getattr(self, k)) for k in _fields) - return ("Security(" + - ", ".join("%s=%r" % (k, v) for k, v in items if v is not None) + - ")") + return ( + "Security(" + + ", ".join("%s=%r" % (k, v) for k, v in items if v is not None) + + ")" + ) def get_tls_config_for_role(self, role): """ Return the TLS configuration for the given role, as a flat dict. """ - return self._get_config_for_role('tls', role, _tls_fields, _tls_per_role_fields) + return self._get_config_for_role("tls", role, _tls_fields, _tls_per_role_fields) def _get_config_for_role(self, category, role, fields, per_role_fields): if role not in _roles: raise ValueError("unknown role %r" % (role,)) d = {} for field in fields: - k = '%s_%s' % (category, field) + k = "%s_%s" % (category, field) d[field] = getattr(self, k) for field in per_role_fields: - k = '%s_%s_%s' % (category, role, field) + k = "%s_%s_%s" % (category, role, field) d[field] = getattr(self, k) return d def _get_tls_context(self, tls, purpose): - if tls.get('ca_file') and tls.get('cert'): + if tls.get("ca_file") and tls.get("cert"): try: - ctx = ssl.create_default_context(purpose=purpose, - cafile=tls['ca_file']) + ctx = ssl.create_default_context(purpose=purpose, cafile=tls["ca_file"]) except AttributeError: raise RuntimeError("TLS functionality requires Python 2.7.9+") ctx.verify_mode = ssl.CERT_REQUIRED # We expect a dedicated CA for the cluster and people using # IP addresses rather than hostnames ctx.check_hostname = False - ctx.load_cert_chain(tls['cert'], tls.get('key')) - if tls.get('ciphers'): - ctx.set_ciphers(tls.get('ciphers')) + ctx.load_cert_chain(tls["cert"], tls.get("key")) + if tls.get("ciphers"): + ctx.set_ciphers(tls.get("ciphers")) return ctx def get_connection_args(self, role): @@ -131,8 +130,8 @@ def get_connection_args(self, role): tls = self.get_tls_config_for_role(role) # Ensure backwards compatibility (ssl.Purpose is Python 2.7.9+ only) purpose = ssl.Purpose.SERVER_AUTH if hasattr(ssl, "Purpose") else None - d['ssl_context'] = self._get_tls_context(tls, purpose) - d['require_encryption'] = self.require_encryption + d["ssl_context"] = self._get_tls_context(tls, purpose) + d["require_encryption"] = self.require_encryption return d def get_listen_args(self, role): @@ -144,6 +143,6 @@ def get_listen_args(self, role): tls = self.get_tls_config_for_role(role) # Ensure backwards compatibility (ssl.Purpose is Python 2.7.9+ only) purpose = ssl.Purpose.CLIENT_AUTH if hasattr(ssl, "Purpose") else None - d['ssl_context'] = self._get_tls_context(tls, purpose) - d['require_encryption'] = self.require_encryption + d["ssl_context"] = self._get_tls_context(tls, purpose) + d["require_encryption"] = self.require_encryption return d diff --git a/distributed/sizeof.py b/distributed/sizeof.py index cbf65638fa7..0bc094e35a7 100644 --- a/distributed/sizeof.py +++ b/distributed/sizeof.py @@ -15,6 +15,5 @@ def safe_sizeof(obj, default_size=1e6): try: return sizeof(obj) except Exception: - logger.warning('Sizeof calculation failed. Defaulting to 1MB', - exc_info=True) + logger.warning("Sizeof calculation failed. Defaulting to 1MB", exc_info=True) return int(default_size) diff --git a/distributed/stealing.py b/distributed/stealing.py index b7a73613a56..d361305b105 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -22,7 +22,7 @@ logger = logging.getLogger(__name__) -LOG_PDB = dask.config.get('distributed.admin.pdb-on-err') +LOG_PDB = dask.config.get("distributed.admin.pdb-on-err") class WorkStealing(SchedulerPlugin): @@ -43,25 +43,25 @@ def __init__(self, scheduler): for worker in scheduler.workers: self.add_worker(worker=worker) - pc = PeriodicCallback(callback=self.balance, - callback_time=100, - io_loop=self.scheduler.loop) + pc = PeriodicCallback( + callback=self.balance, callback_time=100, io_loop=self.scheduler.loop + ) self._pc = pc - self.scheduler.periodic_callbacks['stealing'] = pc + self.scheduler.periodic_callbacks["stealing"] = pc self.scheduler.plugins.append(self) - self.scheduler.extensions['stealing'] = self - self.scheduler.events['stealing'] = deque(maxlen=100000) + self.scheduler.extensions["stealing"] = self + self.scheduler.events["stealing"] = deque(maxlen=100000) self.count = 0 # { task state: } self.in_flight = dict() # { worker state: occupancy } self.in_flight_occupancy = defaultdict(lambda: 0) - self.scheduler.stream_handlers['steal-response'] = self.move_task_confirm + self.scheduler.stream_handlers["steal-response"] = self.move_task_confirm @property def log(self): - return self.scheduler.events['stealing'] + return self.scheduler.events["stealing"] def add_worker(self, scheduler=None, worker=None): self.stealable[worker] = [set() for i in range(15)] @@ -72,17 +72,18 @@ def remove_worker(self, scheduler=None, worker=None): def teardown(self): self._pc.stop() - def transition(self, key, start, finish, compute_start=None, - compute_stop=None, *args, **kwargs): + def transition( + self, key, start, finish, compute_start=None, compute_stop=None, *args, **kwargs + ): ts = self.scheduler.tasks[key] - if finish == 'processing': + if finish == "processing": self.put_key_in_stealable(ts) - if start == 'processing': + if start == "processing": self.remove_key_from_stealable(ts) - if finish == 'memory': + if finish == "memory": for tts in self.stealable_unknown_durations.pop(ts.prefix, ()): - if tts not in self.in_flight and tts.state == 'processing': + if tts not in self.in_flight and tts.state == "processing": self.put_key_in_stealable(tts) else: self.in_flight.pop(ts, None) @@ -91,7 +92,7 @@ def put_key_in_stealable(self, ts): ws = ts.processing_on worker = ws.address cost_multiplier, level = self.steal_time_ratio(ts) - self.log.append(('add-stealable', ts.key, worker, level)) + self.log.append(("add-stealable", ts.key, worker, level)) if cost_multiplier is not None: self.stealable_all[level].add(ts) self.stealable[worker][level].add(ts) @@ -103,7 +104,7 @@ def remove_key_from_stealable(self, ts): return worker, level = result - self.log.append(('remove-stealable', ts.key, worker, level)) + self.log.append(("remove-stealable", ts.key, worker, level)) try: self.stealable[worker][level].remove(ts) except KeyError: @@ -123,9 +124,9 @@ def steal_time_ratio(self, ts): For example a result of zero implies a task without dependencies. level: The location within a stealable list to place this value """ - if (not ts.loose_restrictions - and (ts.host_restrictions or ts.worker_restrictions - or ts.resource_restrictions)): + if not ts.loose_restrictions and ( + ts.host_restrictions or ts.worker_restrictions or ts.resource_restrictions + ): return None, None # don't steal if not ts.dependencies: # no dependencies fast path @@ -158,28 +159,36 @@ def move_task_request(self, ts, victim, thief): if self.scheduler.validate: if victim is not ts.processing_on: import pdb + pdb.set_trace() key = ts.key self.remove_key_from_stealable(ts) - logger.debug("Request move %s, %s: %2f -> %s: %2f", key, - victim, victim.occupancy, - thief, thief.occupancy) + logger.debug( + "Request move %s, %s: %2f -> %s: %2f", + key, + victim, + victim.occupancy, + thief, + thief.occupancy, + ) victim_duration = victim.processing[ts] - thief_duration = ( - self.scheduler.get_task_duration(ts) + - self.scheduler.get_comm_cost(ts, thief) - ) + thief_duration = self.scheduler.get_task_duration( + ts + ) + self.scheduler.get_comm_cost(ts, thief) self.scheduler.stream_comms[victim.address].send( - {'op': 'steal-request', 'key': key}) + {"op": "steal-request", "key": key} + ) - self.in_flight[ts] = {'victim': victim, - 'thief': thief, - 'victim_duration': victim_duration, - 'thief_duration': thief_duration} + self.in_flight[ts] = { + "victim": victim, + "thief": thief, + "victim_duration": victim_duration, + "thief_duration": thief_duration, + } self.in_flight_occupancy[victim] -= victim_duration self.in_flight_occupancy[thief] += thief_duration @@ -189,6 +198,7 @@ def move_task_request(self, ts, victim, thief): logger.exception(e) if LOG_PDB: import pdb + pdb.set_trace() raise @@ -203,42 +213,48 @@ def move_task_confirm(self, key=None, worker=None, state=None): d = self.in_flight.pop(ts) except KeyError: return - thief = d['thief'] - victim = d['victim'] - logger.debug("Confirm move %s, %s -> %s. State: %s", - key, victim, thief, state) + thief = d["thief"] + victim = d["victim"] + logger.debug( + "Confirm move %s, %s -> %s. State: %s", key, victim, thief, state + ) - self.in_flight_occupancy[thief] -= d['thief_duration'] - self.in_flight_occupancy[victim] += d['victim_duration'] + self.in_flight_occupancy[thief] -= d["thief_duration"] + self.in_flight_occupancy[victim] += d["victim_duration"] if not self.in_flight: self.in_flight_occupancy = defaultdict(lambda: 0) - if ts.state != 'processing' or ts.processing_on is not victim: + if ts.state != "processing" or ts.processing_on is not victim: old_thief = thief.occupancy new_thief = sum(thief.processing.values()) old_victim = victim.occupancy new_victim = sum(victim.processing.values()) thief.occupancy = new_thief victim.occupancy = new_victim - self.scheduler.total_occupancy += new_thief - old_thief + new_victim - old_victim + self.scheduler.total_occupancy += ( + new_thief - old_thief + new_victim - old_victim + ) return # One of the pair has left, punt and reschedule - if (thief.address not in self.scheduler.workers or - victim.address not in self.scheduler.workers): + if ( + thief.address not in self.scheduler.workers + or victim.address not in self.scheduler.workers + ): self.scheduler.reschedule(key) return # Victim had already started execution, reverse stealing - if state in ('memory', 'executing', 'long-running', None): - self.log.append(('already-computing', - key, victim.address, thief.address)) + if state in ("memory", "executing", "long-running", None): + self.log.append( + ("already-computing", key, victim.address, thief.address) + ) self.scheduler.check_idle_saturated(thief) self.scheduler.check_idle_saturated(victim) # Victim was waiting, has given up task, enact steal - elif state in ('waiting', 'ready'): + elif state in ("waiting", "ready"): self.remove_key_from_stealable(ts) ts.processing_on = thief duration = victim.processing.pop(ts) @@ -247,23 +263,23 @@ def move_task_confirm(self, key=None, worker=None, state=None): if not victim.processing: self.scheduler.total_occupancy -= victim.occupancy victim.occupancy = 0 - thief.processing[ts] = d['thief_duration'] - thief.occupancy += d['thief_duration'] - self.scheduler.total_occupancy += d['thief_duration'] + thief.processing[ts] = d["thief_duration"] + thief.occupancy += d["thief_duration"] + self.scheduler.total_occupancy += d["thief_duration"] self.put_key_in_stealable(ts) try: self.scheduler.send_task_to_worker(thief.address, key) except CommClosedError: self.scheduler.remove_worker(thief.address) - self.log.append(('confirm', - key, victim.address, thief.address)) + self.log.append(("confirm", key, victim.address, thief.address)) else: raise ValueError("Unexpected task state: %s" % state) except Exception as e: logger.exception(e) if LOG_PDB: import pdb + pdb.set_trace() raise finally: @@ -286,11 +302,20 @@ def maybe_move_task(level, ts, sat, idl, duration, cost_multiplier): occ_idl = combined_occupancy(idl) occ_sat = combined_occupancy(sat) - if (occ_idl + cost_multiplier * duration <= occ_sat - duration / 2): + if occ_idl + cost_multiplier * duration <= occ_sat - duration / 2: self.move_task_request(ts, sat, idl) - log.append((start, level, ts.key, duration, - sat.address, occ_sat, - idl.address, occ_idl)) + log.append( + ( + start, + level, + ts.key, + duration, + sat.address, + occ_sat, + idl.address, + occ_idl, + ) + ) s.check_idle_saturated(sat, occ=occ_sat) s.check_idle_saturated(idl, occ=occ_idl) @@ -306,9 +331,11 @@ def maybe_move_task(level, ts, sat, idl, duration, cost_multiplier): if not s.saturated: saturated = topk(10, s.workers.values(), key=combined_occupancy) - saturated = [ws for ws in saturated - if combined_occupancy(ws) > 0.2 - and len(ws.processing) > ws.ncores] + saturated = [ + ws + for ws in saturated + if combined_occupancy(ws) > 0.2 and len(ws.processing) > ws.ncores + ] elif len(s.saturated) < 20: saturated = sorted(saturated, key=combined_occupancy, reverse=True) if len(idle) < 20: @@ -323,8 +350,7 @@ def maybe_move_task(level, ts, sat, idl, duration, cost_multiplier): continue for ts in list(stealable): - if (ts not in self.key_stealable or - ts.processing_on is not sat): + if ts not in self.key_stealable or ts.processing_on is not sat: stealable.discard(ts) continue i += 1 @@ -337,8 +363,7 @@ def maybe_move_task(level, ts, sat, idl, duration, cost_multiplier): stealable.discard(ts) continue - maybe_move_task(level, ts, sat, idl, - duration, cost_multiplier) + maybe_move_task(level, ts, sat, idl, duration, cost_multiplier) if self.cost_multipliers[level] < 20: # don't steal from public at cost stealable = self.stealable_all[level] @@ -362,15 +387,14 @@ def maybe_move_task(level, ts, sat, idl, duration, cost_multiplier): idl = idle[i % len(idle)] duration = sat.processing[ts] - maybe_move_task(level, ts, sat, idl, - duration, cost_multiplier) + maybe_move_task(level, ts, sat, idl, duration, cost_multiplier) if log: self.log.append(log) self.count += 1 stop = time() if s.digests: - s.digests['steal-duration'].add(stop - start) + s.digests["steal-duration"].add(stop - start) def restart(self, scheduler): for stealable in self.stealable.values(): @@ -394,4 +418,4 @@ def story(self, *keys): return out -fast_tasks = {'shuffle-split'} +fast_tasks = {"shuffle-split"} diff --git a/distributed/submit.py b/distributed/submit.py index 2d7d62ac1f9..bdbe3251a9d 100644 --- a/distributed/submit.py +++ b/distributed/submit.py @@ -18,20 +18,26 @@ from distributed.utils import get_ip -logger = logging.getLogger('distributed.remote') +logger = logging.getLogger("distributed.remote") class RemoteClient(Server): - def __init__(self, ip=None, local_dir=tempfile.mkdtemp(prefix='client-'), - loop=None, security=None, **kwargs): + def __init__( + self, + ip=None, + local_dir=tempfile.mkdtemp(prefix="client-"), + loop=None, + security=None, + **kwargs + ): self.ip = ip or get_ip() self.loop = loop or IOLoop.current() self.local_dir = local_dir - handlers = {'upload_file': self.upload_file, 'execute': self.execute} + handlers = {"upload_file": self.upload_file, "execute": self.execute} self.security = security or Security() assert isinstance(self.security, Security) - self.listen_args = self.security.get_listen_args('scheduler') + self.listen_args = self.security.get_listen_args("scheduler") super(RemoteClient, self).__init__(handlers, io_loop=self.loop, **kwargs) @@ -46,22 +52,21 @@ def start(self, port=0): @gen.coroutine def execute(self, stream=None, filename=None): script_path = os.path.join(self.local_dir, filename) - cmd = '{0} {1}'.format(sys.executable, script_path) - process = subprocess.Popen(cmd, shell=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) + cmd = "{0} {1}".format(sys.executable, script_path) + process = subprocess.Popen( + cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) out, err = process.communicate() return_code = process.returncode - raise gen.Return({'stdout': out, 'stderr': err, - 'returncode': return_code}) + raise gen.Return({"stdout": out, "stderr": err, "returncode": return_code}) def upload_file(self, stream, filename=None, file_payload=None): out_filename = os.path.join(self.local_dir, filename) if isinstance(file_payload, unicode): file_payload = file_payload.encode() - with open(out_filename, 'wb') as f: + with open(out_filename, "wb") as f: f.write(file_payload) - return {'status': 'OK', 'nbytes': len(file_payload)} + return {"status": "OK", "nbytes": len(file_payload)} @gen.coroutine def _close(self): @@ -70,8 +75,8 @@ def _close(self): def _remote(host, port, loop=IOLoop.current(), client=RemoteClient): host = host or get_ip() - if ':' in host and port == 8788: - host, port = host.rsplit(':', 1) + if ":" in host and port == 8788: + host, port = host.rsplit(":", 1) port = int(port) ip = socket.gethostbyname(host) remote_client = client(ip=ip, loop=loop) @@ -86,8 +91,8 @@ def _remote(host, port, loop=IOLoop.current(), client=RemoteClient): def _submit(remote_client_address, filepath, connection_args=None): rc = rpc(remote_client_address, connection_args=connection_args) remote_file = os.path.basename(filepath) - with open(filepath, 'rb') as f: + with open(filepath, "rb") as f: bytes_read = f.read() yield rc.upload_file(filename=remote_file, file_payload=bytes_read) result = yield rc.execute(filename=remote_file) - raise gen.Return((result['stdout'], result['stderr'])) + raise gen.Return((result["stdout"], result["stderr"])) diff --git a/distributed/system_monitor.py b/distributed/system_monitor.py index 3c68e114b7f..30efc3ceb87 100644 --- a/distributed/system_monitor.py +++ b/distributed/system_monitor.py @@ -16,9 +16,7 @@ def __init__(self, n=10000): self.memory = deque(maxlen=n) self.count = 0 - self.quantities = {'cpu': self.cpu, - 'memory': self.memory, - 'time': self.time} + self.quantities = {"cpu": self.cpu, "memory": self.memory, "time": self.time} try: ioc = psutil.net_io_counters() @@ -28,14 +26,14 @@ def __init__(self, n=10000): self.last_time = time() self.read_bytes = deque(maxlen=n) self.write_bytes = deque(maxlen=n) - self.quantities['read_bytes'] = self.read_bytes - self.quantities['write_bytes'] = self.write_bytes + self.quantities["read_bytes"] = self.read_bytes + self.quantities["write_bytes"] = self.write_bytes self._last_io_counters = ioc self._collect_net_io_counters = True if not WINDOWS: self.num_fds = deque(maxlen=n) - self.quantities['num_fds'] = self.num_fds + self.quantities["num_fds"] = self.num_fds self.update() @@ -56,10 +54,7 @@ def update(self): self.time.append(now) self.count += 1 - result = {'cpu': cpu, - 'memory': memory, - 'time': now, - 'count': self.count} + result = {"cpu": cpu, "memory": memory, "time": now, "count": self.count} if self._collect_net_io_counters: try: @@ -75,20 +70,22 @@ def update(self): self._last_io_counters = ioc self.read_bytes.append(read_bytes) self.write_bytes.append(write_bytes) - result['read_bytes'] = read_bytes - result['write_bytes'] = write_bytes + result["read_bytes"] = read_bytes + result["write_bytes"] = write_bytes if not WINDOWS: num_fds = self.proc.num_fds() self.num_fds.append(num_fds) - result['num_fds'] = num_fds + result["num_fds"] = num_fds return result def __repr__(self): - return '' % ( - self.cpu[-1], self.memory[-1] / 1e6, - -1 if WINDOWS else self.num_fds[-1]) + return "" % ( + self.cpu[-1], + self.memory[-1] / 1e6, + -1 if WINDOWS else self.num_fds[-1], + ) def range_query(self, start): if start == self.count: diff --git a/distributed/tests/make_tls_certs.py b/distributed/tests/make_tls_certs.py index 8ffd62e876d..0c1c5876134 100644 --- a/distributed/tests/make_tls_certs.py +++ b/distributed/tests/make_tls_certs.py @@ -78,29 +78,51 @@ def make_cert_key(hostname, sign=False): req_file, cert_file, key_file = tempnames try: req = req_template.format(hostname=hostname) - with open(req_file, 'w') as f: + with open(req_file, "w") as f: f.write(req) - args = ['req', '-new', '-days', '3650', '-nodes', - '-newkey', 'rsa:2048', '-keyout', key_file, - '-config', req_file] + args = [ + "req", + "-new", + "-days", + "3650", + "-nodes", + "-newkey", + "rsa:2048", + "-keyout", + key_file, + "-config", + req_file, + ] if sign: with tempfile.NamedTemporaryFile(delete=False) as f: tempnames.append(f.name) reqfile = f.name - args += ['-out', reqfile] + args += ["-out", reqfile] else: - args += ['-x509', '-out', cert_file] - subprocess.check_call(['openssl'] + args) + args += ["-x509", "-out", cert_file] + subprocess.check_call(["openssl"] + args) if sign: - args = ['ca', '-config', req_file, '-out', cert_file, '-outdir', 'cadir', - '-policy', 'policy_anything', '-batch', '-infiles', reqfile] - subprocess.check_call(['openssl'] + args) - - with open(cert_file, 'r') as f: + args = [ + "ca", + "-config", + req_file, + "-out", + cert_file, + "-outdir", + "cadir", + "-policy", + "policy_anything", + "-batch", + "-infiles", + reqfile, + ] + subprocess.check_call(["openssl"] + args) + + with open(cert_file, "r") as f: cert = f.read() - with open(key_file, 'r') as f: + with open(key_file, "r") as f: key = f.read() return cert, key finally: @@ -108,7 +130,7 @@ def make_cert_key(hostname, sign=False): os.remove(name) -TMP_CADIR = 'cadir' +TMP_CADIR = "cadir" def unmake_ca(): @@ -117,53 +139,82 @@ def unmake_ca(): def make_ca(): os.mkdir(TMP_CADIR) - with open(os.path.join('cadir', 'index.txt'), 'a+') as f: + with open(os.path.join("cadir", "index.txt"), "a+") as f: pass # empty file # with open(os.path.join('cadir','crl.txt'),'a+') as f: - # f.write("00") - with open(os.path.join('cadir', 'index.txt.attr'), 'w+') as f: - f.write('unique_subject = no') + # f.write("00") + with open(os.path.join("cadir", "index.txt.attr"), "w+") as f: + f.write("unique_subject = no") with tempfile.NamedTemporaryFile("w") as t: - t.write(req_template.format(hostname='our-ca-server')) + t.write(req_template.format(hostname="our-ca-server")) t.flush() with tempfile.NamedTemporaryFile() as f: - args = ['req', '-new', '-days', '3650', '-extensions', 'v3_ca', '-nodes', - '-newkey', 'rsa:2048', '-keyout', 'tls-ca-key.pem', - '-out', f.name, - '-subj', '/C=XY/L=Dask-distributed/O=Dask CA/CN=our-ca-server'] - subprocess.check_call(['openssl'] + args) - args = ['ca', '-config', t.name, '-create_serial', - '-out', 'tls-ca-cert.pem', '-batch', '-outdir', TMP_CADIR, - '-keyfile', 'tls-ca-key.pem', '-days', '3650', - '-selfsign', '-extensions', 'v3_ca', '-infiles', f.name] - subprocess.check_call(['openssl'] + args) - #args = ['ca', '-config', t.name, '-gencrl', '-out', 'revocation.crl'] - #subprocess.check_call(['openssl'] + args) - - -if __name__ == '__main__': + args = [ + "req", + "-new", + "-days", + "3650", + "-extensions", + "v3_ca", + "-nodes", + "-newkey", + "rsa:2048", + "-keyout", + "tls-ca-key.pem", + "-out", + f.name, + "-subj", + "/C=XY/L=Dask-distributed/O=Dask CA/CN=our-ca-server", + ] + subprocess.check_call(["openssl"] + args) + args = [ + "ca", + "-config", + t.name, + "-create_serial", + "-out", + "tls-ca-cert.pem", + "-batch", + "-outdir", + TMP_CADIR, + "-keyfile", + "tls-ca-key.pem", + "-days", + "3650", + "-selfsign", + "-extensions", + "v3_ca", + "-infiles", + f.name, + ] + subprocess.check_call(["openssl"] + args) + # args = ['ca', '-config', t.name, '-gencrl', '-out', 'revocation.crl'] + # subprocess.check_call(['openssl'] + args) + + +if __name__ == "__main__": os.chdir(here) - cert, key = make_cert_key('localhost') - with open('tls-self-signed-cert.pem', 'w') as f: + cert, key = make_cert_key("localhost") + with open("tls-self-signed-cert.pem", "w") as f: f.write(cert) - with open('tls-self-signed-key.pem', 'w') as f: + with open("tls-self-signed-key.pem", "w") as f: f.write(key) # For certificate matching tests make_ca() - with open('tls-ca-cert.pem', 'r') as f: + with open("tls-ca-cert.pem", "r") as f: ca_cert = f.read() - cert, key = make_cert_key('localhost', sign=True) - with open('tls-cert.pem', 'w') as f: + cert, key = make_cert_key("localhost", sign=True) + with open("tls-cert.pem", "w") as f: f.write(cert) - with open('tls-cert-chain.pem', 'w') as f: + with open("tls-cert-chain.pem", "w") as f: f.write(cert) f.write(ca_cert) - with open('tls-key.pem', 'w') as f: + with open("tls-key.pem", "w") as f: f.write(key) - with open('tls-key-cert.pem', 'w') as f: + with open("tls-key-cert.pem", "w") as f: f.write(key) f.write(cert) diff --git a/distributed/tests/py3_test_asyncio.py b/distributed/tests/py3_test_asyncio.py index 3c8629c2ba3..90e20268617 100644 --- a/distributed/tests/py3_test_asyncio.py +++ b/distributed/tests/py3_test_asyncio.py @@ -1,7 +1,7 @@ # flake8: noqa import pytest -asyncio = pytest.importorskip('asyncio') +asyncio = pytest.importorskip("asyncio") import functools from time import time @@ -50,7 +50,7 @@ async def test_coro_test(): @coro_test async def test_asyncio_start_close(): async with AioClient(processes=False, dashboard_address=False) as c: - assert c.status == 'running' + assert c.status == "running" # AioClient has installed its AioLoop shim. assert isinstance(IOLoop.current(instance=False), BaseAsyncIOLoop) @@ -58,7 +58,7 @@ async def test_asyncio_start_close(): assert result == 11 await c.close() - assert c.status == 'closed' + assert c.status == "closed" # assert IOLoop.current(instance=False) is None @@ -153,25 +153,23 @@ async def test_asyncio_gather(): assert result == 11 result = await c.gather([x]) assert result == [11] - result = await c.gather({'x': x, 'y': [y]}) - assert result == {'x': 11, 'y': [12]} + result = await c.gather({"x": x, "y": [y]}) + assert result == {"x": 11, "y": [12]} @coro_test async def test_asyncio_get(): async with AioClient(processes=False) as c: - result = await c.get({'x': (inc, 1)}, 'x') + result = await c.get({"x": (inc, 1)}, "x") assert result == 2 - result = await c.get({'x': (inc, 1)}, ['x']) + result = await c.get({"x": (inc, 1)}, ["x"]) assert result == [2] result = await c.get({}, []) assert result == [] - result = await c.get({('x', 1): (inc, 1), - ('x', 2): (inc, ('x', 1))}, - ('x', 2)) + result = await c.get({("x", 1): (inc, 1), ("x", 2): (inc, ("x", 1))}, ("x", 2)) assert result == 3 @@ -228,7 +226,7 @@ async def test_asyncio_cancel(): await c.cancel([x]) assert x.cancelled() - assert 'cancel' in str(x) + assert "cancel" in str(x) s.validate_state() start = time() @@ -244,7 +242,7 @@ async def test_asyncio_cancel(): @coro_test async def test_asyncio_cancel_tuple_key(): async with AioClient(processes=False) as c: - x = c.submit(inc, 1, key=('x', 0, 1)) + x = c.submit(inc, 1, key=("x", 0, 1)) await x await c.cancel(x) with pytest.raises(CancelledError): @@ -298,7 +296,7 @@ async def aioinc(x, delay=0.02): async def aiothrows(x, delay=0.02): await asyncio.sleep(delay) - raise RuntimeError('hello') + raise RuntimeError("hello") async with AioClient(processes=False) as c: results = await c.run_coroutine(aioinc, 1, delay=0.05) @@ -317,7 +315,7 @@ async def aiothrows(x, delay=0.02): @coro_test async def test_asyncio_restart(): async with AioClient(processes=False) as c: - assert c.status == 'running' + assert c.status == "running" x = c.submit(inc, 1) assert x.key in c.refcount @@ -327,6 +325,7 @@ async def test_asyncio_restart(): key = x.key del x import gc + gc.collect() assert key not in c.refcount @@ -343,8 +342,8 @@ async def test_asyncio_variable(): async with AioClient(processes=False) as c: s = c.cluster.scheduler - x = Variable('x') - xx = Variable('x') + x = Variable("x") + xx = Variable("x") assert x.client is c future = c.submit(inc, 1) diff --git a/distributed/tests/py3_test_client.py b/distributed/tests/py3_test_client.py index d75b2bd0801..b5d10f8d553 100644 --- a/distributed/tests/py3_test_client.py +++ b/distributed/tests/py3_test_client.py @@ -60,7 +60,7 @@ async def f(): yield f() assert set(results) == set(range(1, 11)) - assert not s.counters['op'].components[0]['gather'] + assert not s.counters["op"].components[0]["gather"] @gen_cluster(client=True) @@ -102,16 +102,16 @@ async def f(): loop.run_sync(f) assert result == 11 - assert client.status == 'closed' - assert cluster.status == 'closed' + assert client.status == "closed" + assert cluster.status == "closed" def test_locks(loop): async def f(): async with Client(processes=False, asynchronous=True) as c: assert c.asynchronous - async with Lock('x'): - lock2 = Lock('x') + async with Lock("x"): + lock2 = Lock("x") result = await lock2.acquire(timeout=0.1) assert result is False @@ -124,7 +124,7 @@ async def ff(): return 1 with cluster() as (s, [a, b]): - with Client(s['address'], loop=loop) as c: + with Client(s["address"], loop=loop) as c: assert sync(loop, ff) == 1 assert c.sync(ff) == 1 @@ -132,8 +132,8 @@ async def ff(): @pytest.mark.xfail(reason="known intermittent failure") @gen_cluster(client=True) async def test_dont_hold_on_to_large_messages(c, s, a, b): - np = pytest.importorskip('numpy') - da = pytest.importorskip('dask.array') + np = pytest.importorskip("numpy") + da = pytest.importorskip("dask.array") x = np.random.random(1000000) xr = weakref.ref(x) @@ -146,6 +146,7 @@ async def test_dont_hold_on_to_large_messages(c, s, a, b): if time() > start + 5: # Help diagnosing from types import FrameType + x = xr() if x is not None: del x @@ -154,8 +155,12 @@ async def test_dont_hold_on_to_large_messages(c, s, a, b): print("refs to x:", rc, refs, gc.isenabled()) frames = [r for r in refs if isinstance(r, FrameType)] for i, f in enumerate(frames): - print("frames #%d:" % i, - f.f_code.co_name, f.f_code.co_filename, sorted(f.f_locals)) + print( + "frames #%d:" % i, + f.f_code.co_name, + f.f_code.co_filename, + sorted(f.f_locals), + ) pytest.fail("array should have been destroyed") await gen.sleep(0.200) @@ -165,41 +170,41 @@ async def test_dont_hold_on_to_large_messages(c, s, a, b): async def test_run_scheduler_async_def(c, s, a, b): async def f(dask_scheduler): await gen.sleep(0.01) - dask_scheduler.foo = 'bar' + dask_scheduler.foo = "bar" await c.run_on_scheduler(f) - assert s.foo == 'bar' + assert s.foo == "bar" async def f(dask_worker): await gen.sleep(0.01) - dask_worker.foo = 'bar' + dask_worker.foo = "bar" await c.run(f) - assert a.foo == 'bar' - assert b.foo == 'bar' + assert a.foo == "bar" + assert b.foo == "bar" @gen_cluster(client=True) async def test_run_scheduler_async_def_wait(c, s, a, b): async def f(dask_scheduler): await gen.sleep(0.01) - dask_scheduler.foo = 'bar' + dask_scheduler.foo = "bar" await c.run_on_scheduler(f, wait=False) - while not hasattr(s, 'foo'): + while not hasattr(s, "foo"): await gen.sleep(0.01) - assert s.foo == 'bar' + assert s.foo == "bar" async def f(dask_worker): await gen.sleep(0.01) - dask_worker.foo = 'bar' + dask_worker.foo = "bar" await c.run(f, wait=False) - while not hasattr(a, 'foo') or not hasattr(b, 'foo'): + while not hasattr(a, "foo") or not hasattr(b, "foo"): await gen.sleep(0.01) - assert a.foo == 'bar' - assert b.foo == 'bar' + assert a.foo == "bar" + assert b.foo == "bar" diff --git a/distributed/tests/py3_test_pubsub.py b/distributed/tests/py3_test_pubsub.py index b7cde193d37..172c8734819 100644 --- a/distributed/tests/py3_test_pubsub.py +++ b/distributed/tests/py3_test_pubsub.py @@ -10,7 +10,7 @@ @gen_cluster(client=True) def test_basic(c, s, a, b): async def publish(): - pub = Pub('a') + pub = Pub("a") i = 0 while True: @@ -19,7 +19,7 @@ async def publish(): i += 1 def f(_): - sub = Sub('a') + sub = Sub("a") return list(toolz.take(5, sub)) c.run_coroutine(publish, workers=[a.address]) diff --git a/distributed/tests/test_actor.py b/distributed/tests/test_actor.py index 942d3f1e761..fba0f50cbfe 100644 --- a/distributed/tests/test_actor.py +++ b/distributed/tests/test_actor.py @@ -47,21 +47,21 @@ def get(self, key): return self.data[key] -@pytest.mark.parametrize('direct_to_workers', [True, False]) +@pytest.mark.parametrize("direct_to_workers", [True, False]) def test_client_actions(direct_to_workers): - @gen_cluster(client=True) def test(c, s, a, b): - c = yield Client(s.address, asynchronous=True, - direct_to_workers=direct_to_workers) + c = yield Client( + s.address, asynchronous=True, direct_to_workers=direct_to_workers + ) counter = c.submit(Counter, workers=[a.address], actor=True) assert isinstance(counter, Future) counter = yield counter assert counter._address - assert hasattr(counter, 'increment') - assert hasattr(counter, 'add') - assert hasattr(counter, 'n') + assert hasattr(counter, "increment") + assert hasattr(counter, "add") + assert hasattr(counter, "n") n = yield counter.n assert n == 0 @@ -86,9 +86,8 @@ def test(c, s, a, b): test() -@pytest.mark.parametrize('separate_thread', [False, True]) +@pytest.mark.parametrize("separate_thread", [False, True]) def test_worker_actions(separate_thread): - @gen_cluster(client=True) def test(c, s, a, b): counter = c.submit(Counter, workers=[a.address], actor=True) @@ -121,15 +120,17 @@ def test_Actor(c, s, a, b): assert counter._cls == Counter - assert hasattr(counter, 'n') - assert hasattr(counter, 'increment') - assert hasattr(counter, 'add') + assert hasattr(counter, "n") + assert hasattr(counter, "increment") + assert hasattr(counter, "add") - assert not hasattr(counter, 'abc') + assert not hasattr(counter, "abc") -@pytest.mark.xfail(reason="Tornado can pass things out of order" + - "Should rely on sending small messages rather than rpc") +@pytest.mark.xfail( + reason="Tornado can pass things out of order" + + "Should rely on sending small messages rather than rpc" +) @gen_cluster(client=True) def test_linear_access(c, s, a, b): start = time() @@ -159,7 +160,7 @@ class Foo(object): x = 0 def __init__(self): - raise ValueError('bar') + raise ValueError("bar") with pytest.raises(ValueError) as info: future = yield c.submit(Foo, actor=True) @@ -250,11 +251,11 @@ def test_sync(client): assert future.result() == future.result() - assert 'ActorFuture' in repr(future) - assert 'distributed.actor' not in repr(future) + assert "ActorFuture" in repr(future) + assert "distributed.actor" not in repr(future) -@gen_cluster(client=True, config={'distributed.comm.timeouts.connect': '1s'}) +@gen_cluster(client=True, config={"distributed.comm.timeouts.connect": "1s"}) def test_failed_worker(c, s, a, b): future = c.submit(Counter, actor=True, workers=[a.address]) yield wait(future) @@ -280,21 +281,21 @@ def bench(c, s, a, b): @gen_cluster(client=True) def test_numpy_roundtrip(c, s, a, b): - np = pytest.importorskip('numpy') + np = pytest.importorskip("numpy") server = yield c.submit(ParameterServer, actor=True) x = np.random.random(1000) - yield server.put('x', x) + yield server.put("x", x) - y = yield server.get('x') + y = yield server.get("x") assert (x == y).all() @gen_cluster(client=True) def test_numpy_roundtrip_getattr(c, s, a, b): - np = pytest.importorskip('numpy') + np = pytest.importorskip("numpy") counter = yield c.submit(Counter, actor=True) @@ -311,10 +312,10 @@ def test_numpy_roundtrip_getattr(c, s, a, b): def test_repr(c, s, a, b): counter = yield c.submit(Counter, actor=True) - assert 'Counter' in repr(counter) - assert 'Actor' in repr(counter) + assert "Counter" in repr(counter) + assert "Actor" in repr(counter) assert counter.key in repr(counter) - assert 'distributed.actor' not in repr(counter) + assert "distributed.actor" not in repr(counter) @gen_cluster(client=True) @@ -324,7 +325,7 @@ def test_dir(c, s, a, b): d = set(dir(counter)) for attr in dir(Counter): - if not attr.startswith('_'): + if not attr.startswith("_"): assert attr in d @@ -346,7 +347,7 @@ def add(n, counter): yield done -@gen_cluster(client=True, ncores=[('127.0.0.1', 5)] * 2) +@gen_cluster(client=True, ncores=[("127.0.0.1", 5)] * 2) def test_thread_safety(c, s, a, b): class Unsafe(object): def __init__(self): @@ -381,7 +382,7 @@ class Foo(object): def __init__(self, x): pass - b = c.submit(operator.mul, 'b', 1000000) + b = c.submit(operator.mul, "b", 1000000) yield wait(b) [ws] = s.tasks[b.key].who_has @@ -393,13 +394,13 @@ def __init__(self, x): assert s.tasks[x.key].who_has != s.tasks[y.key].who_has # second load balanced -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 5) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 5) def test_load_balance_map(c, s, *workers): class Foo(object): def __init__(self, x, y=None): pass - b = c.submit(operator.mul, 'b', 1000000) + b = c.submit(operator.mul, "b", 1000000) yield wait(b) actors = c.map(Foo, range(10), y=b, actor=True) @@ -408,10 +409,11 @@ def __init__(self, x, y=None): assert all(len(w.actors) == 2 for w in workers) -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 4, Worker=Nanny) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 4, Worker=Nanny) def bench_param_server(c, s, *workers): import dask.array as da import numpy as np + x = da.random.random((500000, 1000), chunks=(1000, 1000)) x = x.persist() yield wait(x) @@ -439,6 +441,7 @@ def f(block, ps=None): return np.array([[stop - start]]) from distributed.utils import format_time + start = time() ps = yield c.submit(ParameterServer, x.shape[1], actor=True) y = x.map_blocks(f, ps=ps, dtype=x.dtype) @@ -448,10 +451,9 @@ def f(block, ps=None): print(format_time(end - start)) -@pytest.mark.xfail(reason='unknown') +@pytest.mark.xfail(reason="unknown") @gen_cluster(client=True) def test_compute(c, s, a, b): - @dask.delayed def f(n, counter): assert isinstance(counter, Actor) @@ -502,8 +504,11 @@ def check(dask_worker): assert time() < start + 2 -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)], - config={'distributed.worker.profile.interval': '1ms'}) +@gen_cluster( + client=True, + ncores=[("127.0.0.1", 1)], + config={"distributed.worker.profile.interval": "1ms"}, +) def test_actors_in_profile(c, s, a): class Sleeper(object): def sleep(self, time): @@ -513,8 +518,10 @@ def sleep(self, time): for i in range(5): yield sleeper.sleep(0.200) - if (list(a.profile_recent['children'])[0].startswith('sleep') or - 'Sleeper.sleep' in a.profile_keys): + if ( + list(a.profile_recent["children"])[0].startswith("sleep") + or "Sleeper.sleep" in a.profile_keys + ): return assert False, list(a.profile_keys) diff --git a/distributed/tests/test_as_completed.py b/distributed/tests/test_as_completed.py index aa2bfaca765..8e66b58dd4e 100644 --- a/distributed/tests/test_as_completed.py +++ b/distributed/tests/test_as_completed.py @@ -168,8 +168,8 @@ def test_as_completed_error_async(c, s, a, b): result = {first, second} assert result == {x, y} - assert x.status == 'error' - assert y.status == 'finished' + assert x.status == "error" + assert y.status == "finished" def test_as_completed_error(client): @@ -180,8 +180,8 @@ def test_as_completed_error(client): result = set(ac) assert result == {x, y} - assert x.status == 'error' - assert y.status == 'finished' + assert x.status == "error" + assert y.status == "finished" def test_as_completed_with_results(client): @@ -193,7 +193,7 @@ def test_as_completed_with_results(client): y.cancel() with pytest.raises(RuntimeError) as exc: res = list(ac) - assert str(exc.value) == 'hello!' + assert str(exc.value) == "hello!" @gen_cluster(client=True) @@ -208,7 +208,7 @@ def test_as_completed_with_results_async(c, s, a, b): first = yield ac.__anext__() second = yield ac.__anext__() third = yield ac.__anext__() - assert str(exc.value) == 'hello!' + assert str(exc.value) == "hello!" def test_as_completed_with_results_no_raise(client): @@ -222,9 +222,9 @@ def test_as_completed_with_results_no_raise(client): dd = {r[0]: r[1:] for r in res} assert set(dd.keys()) == {y, x, z} - assert x.status == 'error' - assert y.status == 'cancelled' - assert z.status == 'finished' + assert x.status == "error" + assert y.status == "cancelled" + assert z.status == "finished" assert isinstance(dd[y][0], CancelledError) assert isinstance(dd[x][0][1], RuntimeError) @@ -246,9 +246,9 @@ def test_as_completed_with_results_no_raise_async(c, s, a, b): dd = {r[0]: r[1:] for r in res} assert set(dd.keys()) == {y, x, z} - assert x.status == 'error' - assert y.status == 'cancelled' - assert z.status == 'finished' + assert x.status == "error" + assert y.status == "cancelled" + assert z.status == "finished" assert isinstance(dd[y][0], CancelledError) assert isinstance(dd[x][0][1], RuntimeError) diff --git a/distributed/tests/test_asyncprocess.py b/distributed/tests/test_asyncprocess.py index a30f7654a20..1e7a5d2804f 100644 --- a/distributed/tests/test_asyncprocess.py +++ b/distributed/tests/test_asyncprocess.py @@ -116,6 +116,7 @@ def test_simple(): if wr1() is not None: # Help diagnosing from types import FrameType + p = wr1() if p is not None: rc = sys.getrefcount(p) @@ -124,8 +125,12 @@ def test_simple(): print("refs to proc:", rc, refs) frames = [r for r in refs if isinstance(r, FrameType)] for i, f in enumerate(frames): - print("frames #%d:" % i, - f.f_code.co_name, f.f_code.co_filename, sorted(f.f_locals)) + print( + "frames #%d:" % i, + f.f_code.co_name, + f.f_code.co_filename, + sorted(f.f_locals), + ) pytest.fail("AsyncProcess should have been destroyed") t1 = time() while wr2() is not None: @@ -139,7 +144,7 @@ def test_simple(): def test_exitcode(): q = mp_context.Queue() - proc = AsyncProcess(target=exit, kwargs={'q': q}) + proc = AsyncProcess(target=exit, kwargs={"q": q}) proc.daemon = True assert not proc.is_alive() assert proc.exitcode is None @@ -154,7 +159,7 @@ def test_exitcode(): assert proc.exitcode == 5 -@pytest.mark.skipif(os.name == 'nt', reason="POSIX only") +@pytest.mark.skipif(os.name == "nt", reason="POSIX only") @gen_test() def test_signal(): proc = AsyncProcess(target=exit_with_signal, args=(signal.SIGINT,)) @@ -274,11 +279,12 @@ def test_child_main_thread(): q._writer.close() -@pytest.mark.skipif(sys.platform.startswith('win'), - reason="num_fds not supported on windows") +@pytest.mark.skipif( + sys.platform.startswith("win"), reason="num_fds not supported on windows" +) @gen_test() def test_num_fds(): - psutil = pytest.importorskip('psutil') + psutil = pytest.importorskip("psutil") # Warm up proc = AsyncProcess(target=exit_now) @@ -324,7 +330,7 @@ def _worker_process(worker_ready, child_pipe): # The parent exiting should cause this process to os._exit from a monitor # thread. This sleep should never return. - shorter_timeout = 2.5 # timeout shorter than that in the spawning test. + shorter_timeout = 2.5 # timeout shorter than that in the spawning test. sleep(shorter_timeout) # Unreachable if functioning correctly. @@ -336,11 +342,11 @@ def _parent_process(child_pipe): The child_alive pipe is held open for as long as the child is alive, and can be used to determine if it exited correctly. """ + def parent_process_coroutine(): worker_ready = mp_context.Event() - worker = AsyncProcess(target=_worker_process, - args=(worker_ready, child_pipe)) + worker = AsyncProcess(target=_worker_process, args=(worker_ready, child_pipe)) yield worker.start() @@ -394,7 +400,7 @@ def test_asyncprocess_child_teardown_on_parent_exit(): # when the child is ready to enter the sleep, so all of the slow things # (process startup, etc) should have happened by now, even on a busy # system. A short timeout should therefore be appropriate. - short_timeout = 5. + short_timeout = 5.0 # Poll is used to allow other tests to proceed after this one in case of # test failure. try: @@ -402,7 +408,7 @@ def test_asyncprocess_child_teardown_on_parent_exit(): except EnvironmentError: # Windows can raise BrokenPipeError. EnvironmentError is caught for # Python2/3 portability. - assert sys.platform.startswith('win'), "should only raise on windows" + assert sys.platform.startswith("win"), "should only raise on windows" # Broken pipe implies closed, which is readable. readable = True @@ -415,11 +421,11 @@ def test_asyncprocess_child_teardown_on_parent_exit(): # This won't block due to the above 'assert readable'. result = children_alive.recv() except EOFError: - pass # Test passes. + pass # Test passes. except EnvironmentError: # Windows can raise BrokenPipeError. EnvironmentError is caught for # Python2/3 portability. - assert sys.platform.startswith('win'), "should only raise on windows" + assert sys.platform.startswith("win"), "should only raise on windows" # Test passes. else: # Oops, children_alive read something. It should be closed. If diff --git a/distributed/tests/test_batched.py b/distributed/tests/test_batched.py index 386de5957cb..2f22134f7ae 100644 --- a/distributed/tests/test_batched.py +++ b/distributed/tests/test_batched.py @@ -28,7 +28,7 @@ def handle_comm(self, comm): return def listen(self): - listener = listen('', self.handle_comm) + listener = listen("", self.handle_comm) listener.start() self.address = listener.contact_address self.stop = listener.stop @@ -57,17 +57,17 @@ def test_BatchedSend(): yield gen.sleep(0.020) - b.send('hello') - b.send('hello') - b.send('world') + b.send("hello") + b.send("hello") + b.send("world") yield gen.sleep(0.020) - b.send('HELLO') - b.send('HELLO') + b.send("HELLO") + b.send("HELLO") result = yield comm.read() - assert result == ('hello', 'hello', 'world') + assert result == ("hello", "hello", "world") result = yield comm.read() - assert result == ('HELLO', 'HELLO') + assert result == ("HELLO", "HELLO") assert b.byte_count > 1 @@ -79,12 +79,12 @@ def test_send_before_start(): b = BatchedSend(interval=10) - b.send('hello') - b.send('world') + b.send("hello") + b.send("world") b.start(comm) result = yield comm.read() - assert result == ('hello', 'world') + assert result == ("hello", "world") @gen_test() @@ -95,12 +95,12 @@ def test_send_after_stream_start(): b = BatchedSend(interval=10) b.start(comm) - b.send('hello') - b.send('world') + b.send("hello") + b.send("world") result = yield comm.read() if len(result) < 2: result += yield comm.read() - assert result == ('hello', 'world') + assert result == ("hello", "world") @gen_test() @@ -112,8 +112,8 @@ def test_send_before_close(): b.start(comm) cnt = int(e.count) - b.send('hello') - yield b.close() # close immediately after sending + b.send("hello") + yield b.close() # close immediately after sending assert not b.buffer start = time() @@ -122,7 +122,7 @@ def test_send_before_close(): assert time() < start + 5 with pytest.raises(CommClosedError): - b.send('123') + b.send("123") @gen_test() @@ -137,8 +137,8 @@ def test_close_closed(): comm.close() # external closing yield b.close() - assert 'closed' in repr(b) - assert 'closed' in str(b) + assert "closed" in repr(b) + assert "closed" in str(b) @gen_test() @@ -191,18 +191,19 @@ def recv(): @gen.coroutine def run_traffic_jam(nsends, nbytes): # This test eats `nsends * nbytes` bytes in RAM - np = pytest.importorskip('numpy') + np = pytest.importorskip("numpy") from distributed.protocol import to_serialize - data = bytes(np.random.randint(0, 255, size=(nbytes,)).astype('u1').data) + + data = bytes(np.random.randint(0, 255, size=(nbytes,)).astype("u1").data) with echo_server() as e: comm = yield connect(e.address) b = BatchedSend(interval=0.01) b.start(comm) - msg = {'x': to_serialize(data)} + msg = {"x": to_serialize(data)} for i in range(nsends): - b.send(assoc(msg, 'i', i)) + b.send(assoc(msg, "i", i)) if np.random.random() > 0.5: yield gen.sleep(0.001) @@ -214,7 +215,7 @@ def run_traffic_jam(nsends, nbytes): # loses some of our messages L = yield gen.with_timeout(timedelta(seconds=5), comm.read()) count += 1 - results.extend(r['i'] for r in L) + results.extend(r["i"] for r in L) assert count == b.batch_count == e.count assert b.message_count == nsends @@ -241,25 +242,25 @@ def test_serializers(): with echo_server() as e: comm = yield connect(e.address) - b = BatchedSend(interval='10ms', serializers=['msgpack']) + b = BatchedSend(interval="10ms", serializers=["msgpack"]) b.start(comm) - b.send({'x': to_serialize(123)}) - b.send({'x': to_serialize('hello')}) + b.send({"x": to_serialize(123)}) + b.send({"x": to_serialize("hello")}) yield gen.sleep(0.100) - b.send({'x': to_serialize(lambda x: x + 1)}) + b.send({"x": to_serialize(lambda x: x + 1)}) - with captured_logger('distributed.protocol') as sio: + with captured_logger("distributed.protocol") as sio: yield gen.sleep(0.100) value = sio.getvalue() - assert 'serialize' in value - assert 'type' in value - assert 'function' in value + assert "serialize" in value + assert "type" in value + assert "function" in value msg = yield comm.read() - assert list(msg) == [{'x': 123}, {'x': 'hello'}] + assert list(msg) == [{"x": 123}, {"x": "hello"}] with pytest.raises(gen.TimeoutError): msg = yield gen.with_timeout(timedelta(milliseconds=100), comm.read()) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index f9594a672b9..634834bf671 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -20,8 +20,7 @@ import zipfile import pytest -from toolz import (identity, isdistinct, concat, pluck, valmap, - partial, first, merge) +from toolz import identity, isdistinct, concat, pluck, valmap, partial, first, merge from tornado import gen from tornado.ioloop import IOLoop @@ -29,29 +28,71 @@ from dask import delayed from dask.optimization import SubgraphCallable import dask.bag as db -from distributed import (Worker, Nanny, fire_and_forget, LocalCluster, - get_client, secede, get_worker, Executor, profile, - TimeoutError) +from distributed import ( + Worker, + Nanny, + fire_and_forget, + LocalCluster, + get_client, + secede, + get_worker, + Executor, + profile, + TimeoutError, +) from distributed.comm import CommClosedError -from distributed.client import (Client, Future, wait, as_completed, tokenize, - _get_global_client, default_client, - futures_of, - temp_default_client) +from distributed.client import ( + Client, + Future, + wait, + as_completed, + tokenize, + _get_global_client, + default_client, + futures_of, + temp_default_client, +) from distributed.compatibility import PY3, Iterator from distributed.metrics import time from distributed.scheduler import Scheduler, KilledWorker from distributed.sizeof import sizeof -from distributed.utils import (ignoring, mp_context, sync, tmp_text, tokey, - tmpfile) -from distributed.utils_test import (cluster, slow, slowinc, slowadd, slowdec, - randominc, inc, dec, div, throws, geninc, asyncinc, - gen_cluster, gen_test, double, popen, - captured_logger, varying, map_varying, - wait_for, async_wait_for, pristine_loop) -from distributed.utils_test import (client as c, client_secondary as c2,# noqa F401 - cluster_fixture, loop, loop_in_thread,# noqa F401 - nodebug, s, a, b) # noqa F401 +from distributed.utils import ignoring, mp_context, sync, tmp_text, tokey, tmpfile +from distributed.utils_test import ( + cluster, + slow, + slowinc, + slowadd, + slowdec, + randominc, + inc, + dec, + div, + throws, + geninc, + asyncinc, + gen_cluster, + gen_test, + double, + popen, + captured_logger, + varying, + map_varying, + wait_for, + async_wait_for, + pristine_loop, +) +from distributed.utils_test import ( # noqa: F401 + client as c, + client_secondary as c2, + cluster_fixture, + loop, + loop_in_thread, + nodebug, + s, + a, + b, +) @gen_cluster(client=True, timeout=None) @@ -129,23 +170,25 @@ def test_map_empty(c, s, a, b): @gen_cluster(client=True) def test_map_keynames(c, s, a, b): - futures = c.map(inc, range(4), key='INC') - assert all(f.key.startswith('INC') for f in futures) + futures = c.map(inc, range(4), key="INC") + assert all(f.key.startswith("INC") for f in futures) assert isdistinct(f.key for f in futures) - futures2 = c.map(inc, [5, 6, 7, 8], key='INC') + futures2 = c.map(inc, [5, 6, 7, 8], key="INC") assert [f.key for f in futures] != [f.key for f in futures2] - keys = ['inc-1', 'inc-2', 'inc-3', 'inc-4'] + keys = ["inc-1", "inc-2", "inc-3", "inc-4"] futures = c.map(inc, range(4), key=keys) assert [f.key for f in futures] == keys @gen_cluster(client=True) def test_map_retries(c, s, a, b): - args = [[ZeroDivisionError("one"), 2, 3], - [4, 5, 6], - [ZeroDivisionError("seven"), ZeroDivisionError("eight"), 9]] + args = [ + [ZeroDivisionError("one"), 2, 3], + [4, 5, 6], + [ZeroDivisionError("seven"), ZeroDivisionError("eight"), 9], + ] x, y, z = c.map(*map_varying(args), retries=2) assert (yield x) == 2 @@ -287,7 +330,7 @@ def test_persist_retries(c, s, a, b): @gen_cluster(client=True) def test_retries_dask_array(c, s, a, b): - da = pytest.importorskip('dask.array') + da = pytest.importorskip("dask.array") x = da.ones((10, 10), chunks=(3, 3)) future = c.compute(x.sum(), retries=2) y = yield future @@ -305,7 +348,7 @@ def test_future_repr(c, s, a, b): @gen_cluster(client=True) def test_future_tuple_repr(c, s, a, b): - da = pytest.importorskip('dask.array') + da = pytest.importorskip("dask.array") y = da.arange(10, chunks=(5,)).persist() f = futures_of(y)[0] for func in [repr, lambda x: x._repr_html_()]: @@ -371,7 +414,7 @@ def test_Future_release_sync(c): def test_short_tracebacks(loop, c): - tblib = pytest.importorskip('tblib') + tblib = pytest.importorskip("tblib") future = c.submit(div, 1, 0) try: future.result() @@ -382,7 +425,7 @@ def test_short_tracebacks(loop, c): while tb is not None: n += 1 - tb = tb['tb_next'] + tb = tb["tb_next"] assert n < 5 @@ -435,7 +478,9 @@ def test_gc(s, a, b): yield x assert s.tasks[x.key].who_has x.__del__() - yield async_wait_for(lambda: x.key not in s.tasks or not s.tasks[x.key].who_has, timeout=0.3) + yield async_wait_for( + lambda: x.key not in s.tasks or not s.tasks[x.key].who_has, timeout=0.3 + ) yield c.close() @@ -474,8 +519,8 @@ def test_gather(c, s, a, b): assert result == 11 result = yield c.gather([x]) assert result == [11] - result = yield c.gather({'x': x, 'y': [y]}) - assert result == {'x': 11, 'y': [12]} + result = yield c.gather({"x": x, "y": [y]}) + assert result == {"x": 11, "y": [12]} @gen_cluster(client=True) @@ -498,7 +543,7 @@ def test_gather_sync(c): with pytest.raises(ZeroDivisionError): c.gather([x, y]) - [xx] = c.gather([x, y], errors='skip') + [xx] = c.gather([x, y], errors="skip") assert xx == 2 @@ -510,18 +555,18 @@ def test_gather_strict(c, s, a, b): with pytest.raises(ZeroDivisionError): yield c.gather([x, y]) - [xx] = yield c.gather([x, y], errors='skip') + [xx] = yield c.gather([x, y], errors="skip") assert xx == 2 -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)]) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)]) def test_gather_skip(c, s, a): x = c.submit(div, 1, 0, priority=10) y = c.submit(slowinc, 1, delay=0.5) - with captured_logger(logging.getLogger('distributed.scheduler')) as sched: - with captured_logger(logging.getLogger('distributed.client')) as client: - L = yield c.gather([x, y], errors='skip') + with captured_logger(logging.getLogger("distributed.scheduler")) as sched: + with captured_logger(logging.getLogger("distributed.client")) as client: + L = yield c.gather([x, y], errors="skip") assert L == [2] assert not client.getvalue() @@ -537,12 +582,12 @@ def test_limit_concurrent_gathering(c, s, a, b): @gen_cluster(client=True, timeout=None) def test_get(c, s, a, b): - future = c.get({'x': (inc, 1)}, 'x', sync=False) + future = c.get({"x": (inc, 1)}, "x", sync=False) assert isinstance(future, Future) result = yield future assert result == 2 - futures = c.get({'x': (inc, 1)}, ['x'], sync=False) + futures = c.get({"x": (inc, 1)}, ["x"], sync=False) assert isinstance(futures[0], Future) result = yield futures assert result == [2] @@ -550,22 +595,25 @@ def test_get(c, s, a, b): result = yield c.get({}, [], sync=False) assert result == [] - result = yield c.get({('x', 1): (inc, 1), ('x', 2): (inc, ('x', 1))}, - ('x', 2), sync=False) + result = yield c.get( + {("x", 1): (inc, 1), ("x", 2): (inc, ("x", 1))}, ("x", 2), sync=False + ) assert result == 3 def test_get_sync(c): - assert c.get({'x': (inc, 1)}, 'x') == 2 + assert c.get({"x": (inc, 1)}, "x") == 2 def test_no_future_references(c): from weakref import WeakSet + ws = WeakSet() futures = c.map(inc, range(10)) ws.update(futures) del futures import gc + gc.collect() start = time() while list(ws): @@ -606,7 +654,7 @@ def test_wait(c, s, a, b): assert done == {x, y, z} assert not_done == set() - assert x.status == y.status == 'finished' + assert x.status == y.status == "finished" @gen_cluster(client=True) @@ -615,13 +663,13 @@ def test_wait_first_completed(c, s, a, b): y = c.submit(slowinc, 1) z = c.submit(inc, 2) - done, not_done = yield wait([x, y, z], return_when='FIRST_COMPLETED') + done, not_done = yield wait([x, y, z], return_when="FIRST_COMPLETED") assert done == {z} assert not_done == {x, y} - assert z.status == 'finished' - assert x.status == 'pending' - assert y.status == 'pending' + assert z.status == "finished" + assert x.status == "pending" + assert y.status == "pending" @gen_cluster(client=True, timeout=2) @@ -638,7 +686,7 @@ def test_wait_sync(c): done, not_done = wait([x, y]) assert done == {x, y} assert not_done == set() - assert x.status == y.status == 'finished' + assert x.status == y.status == "finished" future = c.submit(sleep, 0.3) with pytest.raises(gen.TimeoutError): @@ -683,7 +731,7 @@ def test_garbage_collection(c, s, a, b): def test_garbage_collection_with_scatter(c, s, a, b): [future] = yield c.scatter([1]) assert future.key in c.futures - assert future.status == 'finished' + assert future.status == "finished" assert s.who_wants[future.key] == {c.id} key = future.key @@ -708,6 +756,7 @@ def test_recompute_released_key(c, s, a, b): xkey = x.key del x import gc + gc.collect() yield gen.moment assert c.refcount[xkey] == 0 @@ -726,6 +775,7 @@ def test_recompute_released_key(c, s, a, b): @gen_cluster(client=True) def test_long_tasks_dont_trigger_timeout(c, s, a, b): from time import sleep + x = c.submit(sleep, 3) yield x @@ -812,9 +862,10 @@ def test_tokenize_on_futures(c, s, a, b): assert tok == tokenize(y) -@pytest.mark.skipif(not sys.platform.startswith('linux'), - reason="Need 127.0.0.2 to mean localhost") -@gen_cluster([('127.0.0.1', 1), ('127.0.0.2', 2)], client=True) +@pytest.mark.skipif( + not sys.platform.startswith("linux"), reason="Need 127.0.0.2 to mean localhost" +) +@gen_cluster([("127.0.0.1", 1), ("127.0.0.2", 2)], client=True) def test_restrictions_submit(c, s, a, b): x = c.submit(inc, 1, workers={a.ip}) y = c.submit(inc, x, workers={b.ip}) @@ -840,9 +891,10 @@ def test_restrictions_ip_port(c, s, a, b): assert y.key in b.data -@pytest.mark.skipif(not sys.platform.startswith('linux'), - reason="Need 127.0.0.2 to mean localhost") -@gen_cluster([('127.0.0.1', 1), ('127.0.0.2', 2)], client=True) +@pytest.mark.skipif( + not sys.platform.startswith("linux"), reason="Need 127.0.0.2 to mean localhost" +) +@gen_cluster([("127.0.0.1", 1), ("127.0.0.2", 2)], client=True) def test_restrictions_map(c, s, a, b): L = c.map(inc, range(5), workers={a.ip}) yield wait(L) @@ -852,9 +904,7 @@ def test_restrictions_map(c, s, a, b): for x in L: assert s.host_restrictions[x.key] == {a.ip} - L = c.map(inc, [10, 11, 12], workers=[{a.ip}, - {a.ip, b.ip}, - {b.ip}]) + L = c.map(inc, [10, 11, 12], workers=[{a.ip}, {a.ip, b.ip}, {b.ip}]) yield wait(L) assert s.host_restrictions[L[0].key] == {a.ip} @@ -865,28 +915,29 @@ def test_restrictions_map(c, s, a, b): c.map(inc, [10, 11, 12], workers=[{a.ip}]) -@pytest.mark.skipif(not sys.platform.startswith('linux'), - reason="Need 127.0.0.2 to mean localhost") -@gen_cluster([('127.0.0.1', 1), ('127.0.0.2', 2)], client=True) +@pytest.mark.skipif( + not sys.platform.startswith("linux"), reason="Need 127.0.0.2 to mean localhost" +) +@gen_cluster([("127.0.0.1", 1), ("127.0.0.2", 2)], client=True) def test_restrictions_get(c, s, a, b): - dsk = {'x': 1, 'y': (inc, 'x'), 'z': (inc, 'y')} - restrictions = {'y': {a.ip}, 'z': {b.ip}} + dsk = {"x": 1, "y": (inc, "x"), "z": (inc, "y")} + restrictions = {"y": {a.ip}, "z": {b.ip}} - futures = c.get(dsk, ['y', 'z'], restrictions, sync=False) + futures = c.get(dsk, ["y", "z"], restrictions, sync=False) result = yield futures assert result == [2, 3] - assert 'y' in a.data - assert 'z' in b.data + assert "y" in a.data + assert "z" in b.data @gen_cluster(client=True) def dont_test_bad_restrictions_raise_exception(c, s, a, b): - z = c.submit(inc, 2, workers={'bad-address'}) + z = c.submit(inc, 2, workers={"bad-address"}) try: yield z assert False except ValueError as e: - assert 'bad-address' in str(e) + assert "bad-address" in str(e) assert z.key in str(e) @@ -903,15 +954,12 @@ def test_remove_worker(c, s, a, b): assert result == list(map(inc, range(20))) -@gen_cluster(ncores=[('127.0.0.1', 1)], client=True) +@gen_cluster(ncores=[("127.0.0.1", 1)], client=True) def test_errors_dont_block(c, s, w): - L = [c.submit(inc, 1), - c.submit(throws, 1), - c.submit(inc, 2), - c.submit(throws, 2)] + L = [c.submit(inc, 1), c.submit(throws, 1), c.submit(inc, 2), c.submit(throws, 2)] start = time() - while not (L[0].status == L[2].status == 'finished'): + while not (L[0].status == L[2].status == "finished"): assert time() < start + 5 yield gen.sleep(0.01) @@ -978,9 +1026,9 @@ def test_two_consecutive_clients_share_results(s, a, b): @gen_cluster(client=True) def test_submit_then_get_with_Future(c, s, a, b): x = c.submit(slowinc, 1) - dsk = {'y': (inc, x)} + dsk = {"y": (inc, x)} - result = yield c.get(dsk, 'y', sync=False) + result = yield c.get(dsk, "y", sync=False) assert result == 3 @@ -988,17 +1036,18 @@ def test_submit_then_get_with_Future(c, s, a, b): def test_aliases(c, s, a, b): x = c.submit(inc, 1) - dsk = {'y': x} - result = yield c.get(dsk, 'y', sync=False) + dsk = {"y": x} + result = yield c.get(dsk, "y", sync=False) assert result == 2 @gen_cluster(client=True) def test_aliases_2(c, s, a, b): dsk_keys = [ - ({'x': (inc, 1), 'y': 'x', 'z': 'x', 'w': (add, 'y', 'z')}, ['y', 'w']), - ({'x': 'y', 'y': 1}, ['x']), - ({'x': 1, 'y': 'x', 'z': 'y', 'w': (inc, 'z')}, ['w'])] + ({"x": (inc, 1), "y": "x", "z": "x", "w": (add, "y", "z")}, ["y", "w"]), + ({"x": "y", "y": 1}, ["x"]), + ({"x": 1, "y": "x", "z": "y", "w": (inc, "z")}, ["w"]), + ] for dsk, keys in dsk_keys: result = yield c.get(dsk, keys, sync=False) assert list(result) == list(dask.get(dsk, keys)) @@ -1007,14 +1056,13 @@ def test_aliases_2(c, s, a, b): @gen_cluster(client=True) def test__scatter(c, s, a, b): - d = yield c.scatter({'y': 20}) - assert isinstance(d['y'], Future) - assert a.data.get('y') == 20 or b.data.get('y') == 20 - y_who_has = s.get_who_has(keys=['y'])['y'] - assert (a.address in y_who_has or - b.address in y_who_has) - assert s.get_nbytes(summary=False) == {'y': sizeof(20)} - yy = yield c.gather([d['y']]) + d = yield c.scatter({"y": 20}) + assert isinstance(d["y"], Future) + assert a.data.get("y") == 20 or b.data.get("y") == 20 + y_who_has = s.get_who_has(keys=["y"])["y"] + assert a.address in y_who_has or b.address in y_who_has + assert s.get_nbytes(summary=False) == {"y": sizeof(20)} + yy = yield c.gather([d["y"]]) assert yy == [20] [x] = yield c.scatter([10]) @@ -1023,12 +1071,14 @@ def test__scatter(c, s, a, b): xx = yield c.gather([x]) x_who_has = s.get_who_has(keys=[x.key])[x.key] assert s.tasks[x.key].who_has - assert (s.workers[a.address] in s.tasks[x.key].who_has or - s.workers[b.address] in s.tasks[x.key].who_has) - assert s.get_nbytes(summary=False) == {'y': sizeof(20), x.key: sizeof(10)} + assert ( + s.workers[a.address] in s.tasks[x.key].who_has + or s.workers[b.address] in s.tasks[x.key].who_has + ) + assert s.get_nbytes(summary=False) == {"y": sizeof(20), x.key: sizeof(10)} assert xx == [10] - z = c.submit(add, x, d['y']) # submit works on Future + z = c.submit(add, x, d["y"]) # submit works on Future result = yield z assert result == 10 + 20 result = yield c.gather([z, x]) @@ -1037,9 +1087,9 @@ def test__scatter(c, s, a, b): @gen_cluster(client=True) def test__scatter_types(c, s, a, b): - d = yield c.scatter({'x': 1}) + d = yield c.scatter({"x": 1}) assert isinstance(d, dict) - assert list(d) == ['x'] + assert list(d) == ["x"] for seq in [[1], (1,), {1}, frozenset([1])]: L = yield c.scatter(seq) @@ -1082,7 +1132,7 @@ class MyObj(object): @normalize_token.register(MyObj) def f(x): L.append(x) - return 'x' + return "x" obj = MyObj() @@ -1092,9 +1142,9 @@ def f(x): @gen_cluster(client=True) def test_scatter_singletons(c, s, a, b): - np = pytest.importorskip('numpy') - pd = pytest.importorskip('pandas') - for x in [1, np.ones(5), pd.DataFrame({'x': [1, 2, 3]})]: + np = pytest.importorskip("numpy") + pd = pytest.importorskip("pandas") + for x in [1, np.ones(5), pd.DataFrame({"x": [1, 2, 3]})]: future = yield c.scatter(x) result = yield future assert str(result) == str(x) @@ -1103,7 +1153,7 @@ def test_scatter_singletons(c, s, a, b): @gen_cluster(client=True) def test_scatter_typename(c, s, a, b): future = yield c.scatter(123) - assert future.key.startswith('int') + assert future.key.startswith("int") @gen_cluster(client=True) @@ -1118,22 +1168,23 @@ def test_scatter_hash(c, s, a, b): @gen_cluster(client=True) def test_get_releases_data(c, s, a, b): - [x] = yield c.get({'x': (inc, 1)}, ['x'], sync=False) + [x] = yield c.get({"x": (inc, 1)}, ["x"], sync=False) import gc + gc.collect() start = time() - while c.refcount['x']: + while c.refcount["x"]: yield gen.sleep(0.01) assert time() < start + 2 def test_Current(s, a, b): - with Client(s['address']) as c: + with Client(s["address"]) as c: assert Client.current() is c with pytest.raises(ValueError): Client.current() - with Client(s['address']) as c: + with Client(s["address"]) as c: assert Client.current() is c @@ -1142,10 +1193,10 @@ def test_global_clients(loop): with pytest.raises(ValueError): default_client() with cluster() as (s, [a, b]): - with Client(s['address'], loop=loop) as c: + with Client(s["address"], loop=loop) as c: assert _get_global_client() is c assert default_client() is c - with Client(s['address'], loop=loop) as f: + with Client(s["address"], loop=loop) as f: assert _get_global_client() is f assert default_client() is f assert default_client(c) is c @@ -1176,13 +1227,13 @@ def test_get_nbytes(c, s, a, b): y = c.submit(inc, x) yield y - assert s.get_nbytes(summary=False) == {x.key: sizeof(1), - y.key: sizeof(2)} + assert s.get_nbytes(summary=False) == {x.key: sizeof(1), y.key: sizeof(2)} -@pytest.mark.skipif(not sys.platform.startswith('linux'), - reason="Need 127.0.0.2 to mean localhost") -@gen_cluster([('127.0.0.1', 1), ('127.0.0.2', 2)], client=True) +@pytest.mark.skipif( + not sys.platform.startswith("linux"), reason="Need 127.0.0.2 to mean localhost" +) +@gen_cluster([("127.0.0.1", 1), ("127.0.0.2", 2)], client=True) def test_nbytes_determines_worker(c, s, a, b): x = c.submit(identity, 1, workers=[a.ip]) y = c.submit(identity, tuple(range(100)), workers=[b.ip]) @@ -1207,29 +1258,35 @@ def test_if_intermediates_clear_on_error(c, s, a, b): @gen_cluster(client=True) def test_pragmatic_move_small_data_to_large_data(c, s, a, b): - np = pytest.importorskip('numpy') + np = pytest.importorskip("numpy") lists = c.map(np.ones, [10000] * 10, pure=False) sums = c.map(np.sum, lists) total = c.submit(sum, sums) def f(x, y): return None - s.task_duration['f'] = 0.001 + + s.task_duration["f"] = 0.001 results = c.map(f, lists, [total] * 10) yield wait([total]) yield wait(results) - assert sum(s.tasks[r.key].who_has.issubset(s.tasks[l.key].who_has) - for l, r in zip(lists, results)) >= 9 + assert ( + sum( + s.tasks[r.key].who_has.issubset(s.tasks[l.key].who_has) + for l, r in zip(lists, results) + ) + >= 9 + ) @gen_cluster(client=True) def test_get_with_non_list_key(c, s, a, b): - dsk = {('x', 0): (inc, 1), 5: (inc, 2)} + dsk = {("x", 0): (inc, 1), 5: (inc, 2)} - x = yield c.get(dsk, ('x', 0), sync=False) + x = yield c.get(dsk, ("x", 0), sync=False) y = yield c.get(dsk, 5, sync=False) assert x == 2 assert y == 3 @@ -1237,15 +1294,15 @@ def test_get_with_non_list_key(c, s, a, b): @gen_cluster(client=True) def test_get_with_error(c, s, a, b): - dsk = {'x': (div, 1, 0), 'y': (inc, 'x')} + dsk = {"x": (div, 1, 0), "y": (inc, "x")} with pytest.raises(ZeroDivisionError): - yield c.get(dsk, 'y', sync=False) + yield c.get(dsk, "y", sync=False) def test_get_with_error_sync(c): - dsk = {'x': (div, 1, 0), 'y': (inc, 'x')} + dsk = {"x": (div, 1, 0), "y": (inc, "x")} with pytest.raises(ZeroDivisionError): - c.get(dsk, 'y') + c.get(dsk, "y") @gen_cluster(client=True) @@ -1259,10 +1316,10 @@ def test_directed_scatter(c, s, a, b): def test_directed_scatter_sync(c, s, a, b, loop): - futures = c.scatter([1, 2, 3], workers=[b['address']]) + futures = c.scatter([1, 2, 3], workers=[b["address"]]) has_what = sync(loop, c.scheduler.has_what) - assert len(has_what[b['address']]) == len(futures) - assert len(has_what[a['address']]) == 0 + assert len(has_what[b["address"]]) == len(futures) + assert len(has_what[a["address"]]) == 0 def test_iterator_scatter(c): @@ -1283,6 +1340,7 @@ def test_iterator_scatter(c): def test_queue_scatter(c): from distributed.compatibility import Queue + q = Queue() for d in range(10): q.put(d) @@ -1295,6 +1353,7 @@ def test_queue_scatter(c): def test_queue_scatter_gather_maxsize(c): from distributed.compatibility import Queue + q = Queue(maxsize=3) out = c.scatter(q, maxsize=10) assert out.maxsize == 10 @@ -1314,6 +1373,7 @@ def test_queue_scatter_gather_maxsize(c): def test_queue_gather(c): from distributed.compatibility import Queue + q = Queue() qin = list(range(10)) @@ -1346,7 +1406,7 @@ def test_iterator_gather(c, c2): i_out = list(ff) assert i_out == i_in - i_in = ['a', 'b', 'c', StopIteration('f'), StopIteration, 'd', 'c'] + i_in = ["a", "b", "c", StopIteration("f"), StopIteration, "d", "c"] g = (d for d in i_in) futures = c.scatter(g) @@ -1365,26 +1425,26 @@ def test_scatter_direct(c, s, a, b): future = yield c.scatter(123, direct=True) assert future.key in a.data or future.key in b.data assert s.tasks[future.key].who_has - assert future.status == 'finished' + assert future.status == "finished" result = yield future assert result == 123 - assert not s.counters['op'].components[0]['scatter'] + assert not s.counters["op"].components[0]["scatter"] result = yield future - assert not s.counters['op'].components[0]['gather'] + assert not s.counters["op"].components[0]["gather"] result = yield c.gather(future) - assert not s.counters['op'].components[0]['gather'] + assert not s.counters["op"].components[0]["gather"] @gen_cluster(client=True) def test_scatter_direct_numpy(c, s, a, b): - np = pytest.importorskip('numpy') + np = pytest.importorskip("numpy") x = np.ones(5) future = yield c.scatter(x, direct=True) result = yield future assert np.allclose(x, result) - assert not s.counters['op'].components[0]['scatter'] + assert not s.counters["op"].components[0]["scatter"] @gen_cluster(client=True) @@ -1392,31 +1452,35 @@ def test_scatter_direct_broadcast(c, s, a, b): future2 = yield c.scatter(456, direct=True, broadcast=True) assert future2.key in a.data assert future2.key in b.data - assert s.tasks[future2.key].who_has == {s.workers[a.address], - s.workers[b.address]} + assert s.tasks[future2.key].who_has == {s.workers[a.address], s.workers[b.address]} result = yield future2 assert result == 456 - assert not s.counters['op'].components[0]['scatter'] + assert not s.counters["op"].components[0]["scatter"] -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 4) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 4) def test_scatter_direct_balanced(c, s, *workers): futures = yield c.scatter([1, 2, 3], direct=True) assert sorted([len(w.data) for w in workers]) == [0, 1, 1, 1] -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 4) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 4) def test_scatter_direct_broadcast_target(c, s, *workers): - futures = yield c.scatter([123, 456], direct=True, - workers=workers[0].address) + futures = yield c.scatter([123, 456], direct=True, workers=workers[0].address) assert futures[0].key in workers[0].data assert futures[1].key in workers[0].data - futures = yield c.scatter([123, 456], direct=True, broadcast=True, - workers=[w.address for w in workers[:3]]) - assert (f.key in w.data and w.address in s.tasks[f.key].who_has - for f in futures - for w in workers[:3]) + futures = yield c.scatter( + [123, 456], + direct=True, + broadcast=True, + workers=[w.address for w in workers[:3]], + ) + assert ( + f.key in w.data and w.address in s.tasks[f.key].who_has + for f in futures + for w in workers[:3] + ) @gen_cluster(client=True, ncores=[]) @@ -1425,7 +1489,7 @@ def test_scatter_direct_empty(c, s): yield c.scatter(123, direct=True, timeout=0.1) -@gen_cluster(client=True, timeout=None, ncores=[('127.0.0.1', 1)] * 5) +@gen_cluster(client=True, timeout=None, ncores=[("127.0.0.1", 1)] * 5) def test_scatter_direct_spread_evenly(c, s, *workers): futures = [] for i in range(10): @@ -1435,8 +1499,8 @@ def test_scatter_direct_spread_evenly(c, s, *workers): assert all(w.data for w in workers) -@pytest.mark.parametrize('direct', [True, False]) -@pytest.mark.parametrize('broadcast', [True, False]) +@pytest.mark.parametrize("direct", [True, False]) +@pytest.mark.parametrize("broadcast", [True, False]) def test_scatter_gather_sync(c, direct, broadcast): futures = c.scatter([1, 2, 3], direct=direct, broadcast=broadcast) results = c.gather(futures, direct=direct) @@ -1467,18 +1531,17 @@ def test_traceback(c, s, a, b): tb = yield x.traceback() if sys.version_info[0] >= 3: - assert any('x / y' in line - for line in pluck(3, traceback.extract_tb(tb))) + assert any("x / y" in line for line in pluck(3, traceback.extract_tb(tb))) @gen_cluster(client=True) def test_get_traceback(c, s, a, b): try: - yield c.get({'x': (div, 1, 0)}, 'x', sync=False) + yield c.get({"x": (div, 1, 0)}, "x", sync=False) except ZeroDivisionError: exc_type, exc_value, exc_traceback = sys.exc_info() L = traceback.format_tb(exc_traceback) - assert any('x / y' in line for line in L) + assert any("x / y" in line for line in L) @gen_cluster(client=True) @@ -1489,22 +1552,25 @@ def test_gather_traceback(c, s, a, b): except ZeroDivisionError: exc_type, exc_value, exc_traceback = sys.exc_info() L = traceback.format_tb(exc_traceback) - assert any('x / y' in line for line in L) + assert any("x / y" in line for line in L) def test_traceback_sync(c): x = c.submit(div, 1, 0) tb = x.traceback() if sys.version_info[0] >= 3: - assert any('x / y' in line - for line in concat(traceback.extract_tb(tb)) - if isinstance(line, str)) + assert any( + "x / y" in line + for line in concat(traceback.extract_tb(tb)) + if isinstance(line, str) + ) y = c.submit(inc, x) tb2 = y.traceback() assert set(pluck(3, traceback.extract_tb(tb2))).issuperset( - set(pluck(3, traceback.extract_tb(tb)))) + set(pluck(3, traceback.extract_tb(tb))) + ) z = c.submit(div, 1, 2) tb = z.traceback() @@ -1515,11 +1581,12 @@ def test_traceback_sync(c): def test_upload_file(c, s, a, b): def g(): import myfile + return myfile.f() try: for value in [123, 456]: - with tmp_text('myfile.py', 'def f():\n return {}'.format(value)) as fn: + with tmp_text("myfile.py", "def f():\n return {}".format(value)) as fn: yield c.upload_file(fn) x = c.submit(g, pure=False) @@ -1527,13 +1594,13 @@ def g(): assert result == value finally: # Ensure that this test won't impact the others - if 'myfile' in sys.modules: - del sys.modules['myfile'] + if "myfile" in sys.modules: + del sys.modules["myfile"] @gen_cluster(client=True) def test_upload_file_no_extension(c, s, a, b): - with tmp_text('myfile', '') as fn: + with tmp_text("myfile", "") as fn: yield c.upload_file(fn) @@ -1541,26 +1608,29 @@ def test_upload_file_no_extension(c, s, a, b): def test_upload_file_zip(c, s, a, b): def g(): import myfile + return myfile.f() try: for value in [123, 456]: - with tmp_text('myfile.py', 'def f():\n return {}'.format(value)) as fn_my_file: - with zipfile.ZipFile('myfile.zip', 'w') as z: + with tmp_text( + "myfile.py", "def f():\n return {}".format(value) + ) as fn_my_file: + with zipfile.ZipFile("myfile.zip", "w") as z: z.write(fn_my_file, arcname=os.path.basename(fn_my_file)) - yield c.upload_file('myfile.zip') + yield c.upload_file("myfile.zip") x = c.submit(g, pure=False) result = yield x assert result == value finally: # Ensure that this test won't impact the others - if os.path.exists('myfile.zip'): - os.remove('myfile.zip') - if 'myfile' in sys.modules: - del sys.modules['myfile'] + if os.path.exists("myfile.zip"): + os.remove("myfile.zip") + if "myfile" in sys.modules: + del sys.modules["myfile"] for path in sys.path: - if os.path.basename(path) == 'myfile.zip': + if os.path.basename(path) == "myfile.zip": sys.path.remove(path) break @@ -1569,26 +1639,27 @@ def g(): def test_upload_large_file(c, s, a, b): assert a.local_dir assert b.local_dir - with tmp_text('myfile', 'abc') as fn: - with tmp_text('myfile2', 'def') as fn2: - yield c._upload_large_file(fn, remote_filename='x') + with tmp_text("myfile", "abc") as fn: + with tmp_text("myfile2", "def") as fn2: + yield c._upload_large_file(fn, remote_filename="x") yield c._upload_large_file(fn2) for w in [a, b]: - assert os.path.exists(os.path.join(w.local_dir, 'x')) - assert os.path.exists(os.path.join(w.local_dir, 'myfile2')) - with open(os.path.join(w.local_dir, 'x')) as f: - assert f.read() == 'abc' - with open(os.path.join(w.local_dir, 'myfile2')) as f: - assert f.read() == 'def' + assert os.path.exists(os.path.join(w.local_dir, "x")) + assert os.path.exists(os.path.join(w.local_dir, "myfile2")) + with open(os.path.join(w.local_dir, "x")) as f: + assert f.read() == "abc" + with open(os.path.join(w.local_dir, "myfile2")) as f: + assert f.read() == "def" def test_upload_file_sync(c): def g(): import myfile + return myfile.x - with tmp_text('myfile.py', 'x = 123') as fn: + with tmp_text("myfile.py", "x = 123") as fn: c.upload_file(fn) x = c.submit(g) assert x.result() == 123 @@ -1596,13 +1667,13 @@ def g(): @gen_cluster(client=True) def test_upload_file_exception(c, s, a, b): - with tmp_text('myfile.py', 'syntax-error!') as fn: + with tmp_text("myfile.py", "syntax-error!") as fn: with pytest.raises(SyntaxError): yield c.upload_file(fn) def test_upload_file_exception_sync(c): - with tmp_text('myfile.py', 'syntax-error!') as fn: + with tmp_text("myfile.py", "syntax-error!") as fn: with pytest.raises(SyntaxError): c.upload_file(fn) @@ -1633,6 +1704,7 @@ def test_multiple_clients(s, a, b): @gen_cluster(client=True) def test_async_compute(c, s, a, b): from dask.delayed import delayed + x = delayed(1) y = delayed(inc)(x) z = delayed(dec)(x) @@ -1651,10 +1723,11 @@ def test_async_compute(c, s, a, b): @gen_cluster(client=True) def test_async_compute_with_scatter(c, s, a, b): - d = yield c.scatter({('x', 1): 1, ('y', 1): 2}) - x, y = d[('x', 1)], d[('y', 1)] + d = yield c.scatter({("x", 1): 1, ("y", 1): 2}) + x, y = d[("x", 1)], d[("y", 1)] from dask.delayed import delayed + z = delayed(add)(delayed(inc)(x), delayed(inc)(y)) zz = c.compute(z) @@ -1716,13 +1789,14 @@ def test_client_with_scheduler(c, s, a, b): AA, BB, xx = yield c.gather([A, B, x]) assert (AA, BB, xx) == (1, 2, 2) - result = yield c.get({'x': (inc, 1), 'y': (add, 'x', 10)}, 'y', sync=False) + result = yield c.get({"x": (inc, 1), "y": (add, "x", 10)}, "y", sync=False) assert result == 12 -@pytest.mark.skipif(not sys.platform.startswith('linux'), - reason="Need 127.0.0.2 to mean localhost") -@gen_cluster([('127.0.0.1', 1), ('127.0.0.2', 2)], client=True) +@pytest.mark.skipif( + not sys.platform.startswith("linux"), reason="Need 127.0.0.2 to mean localhost" +) +@gen_cluster([("127.0.0.1", 1), ("127.0.0.2", 2)], client=True) def test_allow_restrictions(c, s, a, b): aws = s.workers[a.address] bws = s.workers[a.address] @@ -1742,13 +1816,13 @@ def test_allow_restrictions(c, s, a, b): assert all(s.tasks[f.key].who_has == {aws} for f in L) assert {f.key for f in L}.issubset(s.loose_restrictions) - x = c.submit(inc, 15, workers='127.0.0.3', allow_other_workers=True) + x = c.submit(inc, 15, workers="127.0.0.3", allow_other_workers=True) yield x assert s.tasks[x.key].who_has assert x.key in s.loose_restrictions - L = c.map(inc, range(15, 25), workers='127.0.0.3', allow_other_workers=True) + L = c.map(inc, range(15, 25), workers="127.0.0.3", allow_other_workers=True) yield wait(L) assert all(s.tasks[f.key].who_has for f in L) assert {f.key for f in L}.issubset(s.loose_restrictions) @@ -1760,21 +1834,21 @@ def test_allow_restrictions(c, s, a, b): c.map(inc, [1], allow_other_workers=True) with pytest.raises(TypeError): - c.submit(inc, 20, workers='127.0.0.1', allow_other_workers='Hello!') + c.submit(inc, 20, workers="127.0.0.1", allow_other_workers="Hello!") with pytest.raises(TypeError): - c.map(inc, [20], workers='127.0.0.1', allow_other_workers='Hello!') + c.map(inc, [20], workers="127.0.0.1", allow_other_workers="Hello!") -@pytest.mark.skipif('True', reason='because') +@pytest.mark.skipif("True", reason="because") def test_bad_address(): try: - Client('123.123.123.123:1234', timeout=0.1) + Client("123.123.123.123:1234", timeout=0.1) except (IOError, gen.TimeoutError) as e: assert "connect" in str(e).lower() try: - Client('127.0.0.1:1234', timeout=0.1) + Client("127.0.0.1:1234", timeout=0.1) except (IOError, gen.TimeoutError) as e: assert "connect" in str(e).lower() @@ -1782,7 +1856,7 @@ def test_bad_address(): @gen_cluster(client=True) def test_long_error(c, s, a, b): def bad(x): - raise ValueError('a' * 100000) + raise ValueError("a" * 100000) x = c.submit(bad, 10) @@ -1792,9 +1866,11 @@ def bad(x): assert len(str(e)) < 100000 tb = yield x.traceback() - assert all(len(line) < 100000 - for line in concat(traceback.extract_tb(tb)) - if isinstance(line, str)) + assert all( + len(line) < 100000 + for line in concat(traceback.extract_tb(tb)) + if isinstance(line, str) + ) @gen_cluster(client=True) @@ -1828,6 +1904,7 @@ def __getstate__(self): def __setstate__(self, state): print("This should never have been deserialized, closing") import sys + sys.exit(0) @@ -1840,10 +1917,10 @@ def test_badly_serialized_input(c, s, a, b): L = yield c.gather(futures) assert list(L) == list(map(inc, range(10))) - assert future.status == 'error' + assert future.status == "error" -@pytest.mark.skipif('True', reason="") +@pytest.mark.skipif("True", reason="") def test_badly_serialized_input_stderr(capsys, c): o = BadlySerializedObject() future = c.submit(inc, o) @@ -1852,24 +1929,24 @@ def test_badly_serialized_input_stderr(capsys, c): while True: sleep(0.01) out, err = capsys.readouterr() - if 'hello!' in err: + if "hello!" in err: break assert time() - start < 20 - assert future.status == 'error' + assert future.status == "error" def test_repr(loop): funcs = [str, repr, lambda x: x._repr_html_()] with cluster(nworkers=3) as (s, [a, b, c]): - with Client(s['address'], loop=loop) as c: + with Client(s["address"], loop=loop) as c: for func in funcs: text = func(c) assert c.scheduler.address in text - assert '3' in text + assert "3" in text for func in funcs: text = func(c) - assert 'not connected' in text + assert "not connected" in text @gen_cluster(client=True) @@ -1879,8 +1956,9 @@ def test_repr_async(c, s, a, b): @gen_test() def test_repr_localcluster(): - cluster = yield LocalCluster(processes=False, dashboard_address=None, - asynchronous=True) + cluster = yield LocalCluster( + processes=False, dashboard_address=None, asynchronous=True + ) client = yield Client(cluster, asynchronous=True) try: text = client._repr_html_() @@ -2008,7 +2086,7 @@ def test_repr_sync(c): assert c.scheduler.address in s assert c.scheduler.address in r assert str(2) in s # nworkers - assert 'cores' in s + assert "cores" in s @gen_cluster(client=True) @@ -2038,9 +2116,11 @@ def test_multi_client(s, a, b): yield wait([x, y]) - assert s.wants_what == {c.id: {x.key, y.key}, - f.id: {y.key}, - 'fire-and-forget': set()} + assert s.wants_what == { + c.id: {x.key, y.key}, + f.id: {y.key}, + "fire-and-forget": set(), + } assert s.who_wants == {x.key: {c.id}, y.key: {c.id, f.id}} yield c.close() @@ -2072,8 +2152,7 @@ def long_running_client_connection(address): @gen_cluster() def test_cleanup_after_broken_client_connection(s, a, b): - proc = mp_context.Process(target=long_running_client_connection, - args=(s.address,)) + proc = mp_context.Process(target=long_running_client_connection, args=(s.address,)) proc.daemon = True proc.start() @@ -2110,9 +2189,7 @@ def test_multi_garbage_collection(s, a, b): yield gen.sleep(0.01) assert time() < start + 5 - assert s.wants_what == {c.id: {y.key}, - f.id: {y.key}, - 'fire-and-forget': set()} + assert s.wants_what == {c.id: {y.key}, f.id: {y.key}, "fire-and-forget": set()} assert s.who_wants == {y.key: {c.id, f.id}} y.__del__() @@ -2123,9 +2200,7 @@ def test_multi_garbage_collection(s, a, b): yield gen.sleep(0.1) assert y.key in a.data or y.key in b.data - assert s.wants_what == {c.id: {y.key}, - f.id: set(), - 'fire-and-forget': set()} + assert s.wants_what == {c.id: {y.key}, f.id: set(), "fire-and-forget": set()} assert s.who_wants == {y.key: {c.id}} y2.__del__() @@ -2147,7 +2222,7 @@ def test__broadcast(c, s, a, b): assert a.data == b.data == {x.key: 1, y.key: 2} -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 4) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 4) def test__broadcast_integer(c, s, *workers): x, y = yield c.scatter([1, 2], broadcast=2) assert len(s.tasks[x.key].who_has) == 2 @@ -2156,8 +2231,8 @@ def test__broadcast_integer(c, s, *workers): @gen_cluster(client=True) def test__broadcast_dict(c, s, a, b): - d = yield c.scatter({'x': 1}, broadcast=True) - assert a.data == b.data == {'x': 1} + d = yield c.scatter({"x": 1}, broadcast=True) + assert a.data == b.data == {"x": 1} def test_broadcast(c, s, a, b): @@ -2166,21 +2241,23 @@ def test_broadcast(c, s, a, b): has_what = sync(c.loop, c.scheduler.has_what) assert {k: set(v) for k, v in has_what.items()} == { - a['address']: {x.key, y.key}, - b['address']: {x.key, y.key}} + a["address"]: {x.key, y.key}, + b["address"]: {x.key, y.key}, + } - [z] = c.scatter([3], broadcast=True, workers=[a['address']]) + [z] = c.scatter([3], broadcast=True, workers=[a["address"]]) has_what = sync(c.loop, c.scheduler.has_what) assert {k: set(v) for k, v in has_what.items()} == { - a['address']: {x.key, y.key, z.key}, - b['address']: {x.key, y.key}} + a["address"]: {x.key, y.key, z.key}, + b["address"]: {x.key, y.key}, + } @gen_cluster(client=True) def test_proxy(c, s, a, b): - msg = yield c.scheduler.proxy(msg={'op': 'identity'}, worker=a.address) - assert msg['id'] == a.identity()['id'] + msg = yield c.scheduler.proxy(msg={"op": "identity"}, worker=a.address) + assert msg["id"] == a.identity()["id"] @gen_cluster(client=True) @@ -2194,7 +2271,7 @@ def test__cancel(c, s, a, b): yield c.cancel([x]) assert x.cancelled() - assert 'cancel' in str(x) + assert "cancel" in str(x) s.validate_state() start = time() @@ -2208,7 +2285,7 @@ def test__cancel(c, s, a, b): @gen_cluster(client=True) def test__cancel_tuple_key(c, s, a, b): - x = c.submit(inc, 1, key=('x', 0, 1)) + x = c.submit(inc, 1, key=("x", 0, 1)) result = yield x yield c.cancel(x) @@ -2249,7 +2326,7 @@ def test__cancel_multi_client(s, a, b): @gen_cluster(client=True) def test__cancel_collection(c, s, a, b): L = c.map(double, [[1], [2], [3]]) - x = db.Bag({('b', i): f for i, f in enumerate(L)}, 'b', 3) + x = db.Bag({("b", i): f for i, f in enumerate(L)}, "b", 3) yield c.cancel(x) yield c.cancel([x]) @@ -2258,9 +2335,9 @@ def test__cancel_collection(c, s, a, b): def test_cancel(c): - x = c.submit(slowinc, 1, key='x') - y = c.submit(slowinc, x, key='y') - z = c.submit(slowinc, y, key='z') + x = c.submit(slowinc, 1, key="x") + y = c.submit(slowinc, x, key="y") + z = c.submit(slowinc, y, key="z") c.cancel([y]) @@ -2280,7 +2357,7 @@ def test_future_type(c, s, a, b): x = c.submit(inc, 1) yield wait([x]) assert x.type == int - assert 'int' in str(x) + assert "int" in str(x) @gen_cluster(client=True) @@ -2292,14 +2369,15 @@ def test_traceback_clean(c, s, a, b): f = e exc_type, exc_value, tb = sys.exc_info() while tb: - assert 'scheduler' not in tb.tb_frame.f_code.co_filename - assert 'worker' not in tb.tb_frame.f_code.co_filename + assert "scheduler" not in tb.tb_frame.f_code.co_filename + assert "worker" not in tb.tb_frame.f_code.co_filename tb = tb.tb_next @gen_cluster(client=True) def test_map_queue(c, s, a, b): from distributed.compatibility import Queue, isqueue + q_1 = Queue(maxsize=2) q_2 = c.map(inc, q_1) assert isqueue(q_2) @@ -2318,14 +2396,16 @@ def test_map_queue(c, s, a, b): assert result == (1 + 1) * 2 -@pytest.mark.skipif(sys.version_info >= (3, 7), - reason="replace StopIteration with return") +@pytest.mark.skipif( + sys.version_info >= (3, 7), reason="replace StopIteration with return" +) @gen_cluster(client=True) def test_map_iterator_with_return(c, s, a, b): def g(): yield 1 yield 2 raise StopIteration(3) # py2.7 compat. + f1 = c.map(lambda x: x, g()) assert isinstance(f1, Iterator) @@ -2404,11 +2484,11 @@ def test_map_differnet_lengths(c, s, a, b): def test_Future_exception_sync_2(loop, capsys): with cluster() as (s, [a, b]): - with Client(s['address'], loop=loop) as c: + with Client(s["address"], loop=loop) as c: assert dask.base.get_scheduler() == c.get out, err = capsys.readouterr() - assert len(out.strip().split('\n')) == 1 + assert len(out.strip().split("\n")) == 1 assert dask.base.get_scheduler() != c.get @@ -2416,6 +2496,7 @@ def test_Future_exception_sync_2(loop, capsys): @gen_cluster(timeout=60, client=True) def test_async_persist(c, s, a, b): from dask.delayed import delayed, Delayed + x = delayed(1) y = delayed(inc)(x) z = delayed(dec)(x) @@ -2447,7 +2528,7 @@ def test_async_persist(c, s, a, b): @gen_cluster(client=True) def test__persist(c, s, a, b): - pytest.importorskip('dask.array') + pytest.importorskip("dask.array") import dask.array as da x = da.ones((10, 10), chunks=(5, 10)) @@ -2467,8 +2548,9 @@ def test__persist(c, s, a, b): def test_persist(c): - pytest.importorskip('dask.array') + pytest.importorskip("dask.array") import dask.array as da + x = da.ones((10, 10), chunks=(5, 10)) y = 2 * (x + 1) assert len(y.dask) == 6 @@ -2502,7 +2584,7 @@ def deep(n): @gen_cluster(client=True) def test_wait_on_collections(c, s, a, b): L = c.map(double, [[1], [2], [3]]) - x = db.Bag({('b', i): f for i, f in enumerate(L)}, 'b', 3) + x = db.Bag({("b", i): f for i, f in enumerate(L)}, "b", 3) yield wait(x) assert all(f.key in a.data or f.key in b.data for f in L) @@ -2516,19 +2598,21 @@ def test_futures_of_get(c, s, a, b): assert set(futures_of(x)) == {x} assert set(futures_of([x, y, z])) == {x, y, z} assert set(futures_of([x, [y], [[z]]])) == {x, y, z} - assert set(futures_of({'x': x, 'y': [y]})) == {x, y} + assert set(futures_of({"x": x, "y": [y]})) == {x, y} - b = db.Bag({('b', i): f for i, f in enumerate([x, y, z])}, 'b', 3) + b = db.Bag({("b", i): f for i, f in enumerate([x, y, z])}, "b", 3) assert set(futures_of(b)) == {x, y, z} - sg = SubgraphCallable({'x': x, 'y': y, 'z': z, - 'out': (add, (add, (add, x, y), z), 'in')}, - 'out', ('in',)) + sg = SubgraphCallable( + {"x": x, "y": y, "z": z, "out": (add, (add, (add, x, y), z), "in")}, + "out", + ("in",), + ) assert set(futures_of(sg)) == {x, y, z} def test_futures_of_class(): - da = pytest.importorskip('dask.array') + da = pytest.importorskip("dask.array") assert futures_of([da.Array]) == [] @@ -2541,7 +2625,7 @@ def test_futures_of_cancelled_raises(c, s, a, b): yield x with pytest.raises(CancelledError): - yield c.get({'x': (inc, x), 'y': (inc, 2)}, ['x', 'y'], sync=False) + yield c.get({"x": (inc, x), "y": (inc, 2)}, ["x", "y"], sync=False) with pytest.raises(CancelledError): c.submit(inc, x) @@ -2552,20 +2636,20 @@ def test_futures_of_cancelled_raises(c, s, a, b): with pytest.raises(CancelledError): c.map(add, [1], y=x) - assert 'y' not in s.tasks + assert "y" not in s.tasks @pytest.mark.skip -@gen_cluster(ncores=[('127.0.0.1', 1)], client=True) +@gen_cluster(ncores=[("127.0.0.1", 1)], client=True) def test_dont_delete_recomputed_results(c, s, w): - x = c.submit(inc, 1) # compute first time + x = c.submit(inc, 1) # compute first time yield wait([x]) - x.__del__() # trigger garbage collection + x.__del__() # trigger garbage collection yield gen.moment - xx = c.submit(inc, 1) # compute second time + xx = c.submit(inc, 1) # compute second time start = time() - while xx.key not in w.data: # data shows up + while xx.key not in w.data: # data shows up yield gen.sleep(0.01) assert time() < start + 1 @@ -2584,7 +2668,7 @@ def test_fatally_serialized_input(c, s): yield gen.sleep(0.01) -@pytest.mark.skip(reason='Use fast random selection now') +@pytest.mark.skip(reason="Use fast random selection now") @gen_cluster(client=True) def test_balance_tasks_by_stacks(c, s, a, b): x = c.submit(inc, 1) @@ -2614,7 +2698,7 @@ def test_run_handles_picklable_data(c, s, a, b): yield wait(futures) def func(): - return {}, set(), [], (), 1, 'hello', b'100' + return {}, set(), [], (), 1, "hello", b"100" results = yield c.run_on_scheduler(func) assert results == func() @@ -2628,11 +2712,10 @@ def func(x, y=10): return x + y result = c.run(func, 1, y=2) - assert result == {a['address']: 3, - b['address']: 3} + assert result == {a["address"]: 3, b["address"]: 3} - result = c.run(func, 1, y=2, workers=[a['address']]) - assert result == {a['address']: 3} + result = c.run(func, 1, y=2, workers=[a["address"]]) + assert result == {a["address"]: 3} @gen_cluster(client=True) @@ -2657,11 +2740,10 @@ def test_run_coroutine(c, s, a, b): def test_run_coroutine_sync(c, s, a, b): result = c.run(geninc, 2, delay=0.01) - assert result == {a['address']: 3, - b['address']: 3} + assert result == {a["address"]: 3, b["address"]: 3} - result = c.run(geninc, 2, workers=[a['address']]) - assert result == {a['address']: 3} + result = c.run(geninc, 2, workers=[a["address"]]) + assert result == {a["address"]: 3} t1 = time() result = c.run(geninc, 2, delay=10, wait=False) @@ -2676,15 +2758,15 @@ def raise_exception(exc_type, exc_msg): for exc_type in [ValueError, RuntimeError]: with pytest.raises(exc_type) as excinfo: - c.run(raise_exception, exc_type, 'informative message') - assert 'informative message' in str(excinfo.value) + c.run(raise_exception, exc_type, "informative message") + assert "informative message" in str(excinfo.value) def test_diagnostic_ui(loop): with cluster() as (s, [a, b]): - a_addr = a['address'] - b_addr = b['address'] - with Client(s['address'], loop=loop) as c: + a_addr = a["address"] + b_addr = b["address"] + with Client(s["address"], loop=loop) as c: d = c.ncores() assert d == {a_addr: 1, b_addr: 1} @@ -2692,7 +2774,7 @@ def test_diagnostic_ui(loop): assert d == {a_addr: 1} d = c.ncores(a_addr) assert d == {a_addr: 1} - d = c.ncores(a['address']) + d = c.ncores(a["address"]) assert d == {a_addr: 1} x = c.submit(inc, 1) @@ -2726,10 +2808,8 @@ def test_diagnostic_nbytes_sync(c): doubles = c.map(double, [1, 2, 3]) wait(incs + doubles) - assert c.nbytes(summary=False) == {k.key: sizeof(1) - for k in incs + doubles} - assert c.nbytes(summary=True) == {'inc': sizeof(1) * 3, - 'double': sizeof(1) * 3} + assert c.nbytes(summary=False) == {k.key: sizeof(1) for k in incs + doubles} + assert c.nbytes(summary=True) == {"inc": sizeof(1) * 3, "double": sizeof(1) * 3} @gen_cluster(client=True) @@ -2738,31 +2818,29 @@ def test_diagnostic_nbytes(c, s, a, b): doubles = c.map(double, [1, 2, 3]) yield wait(incs + doubles) - assert s.get_nbytes(summary=False) == {k.key: sizeof(1) - for k in incs + doubles} - assert s.get_nbytes(summary=True) == {'inc': sizeof(1) * 3, - 'double': sizeof(1) * 3} + assert s.get_nbytes(summary=False) == {k.key: sizeof(1) for k in incs + doubles} + assert s.get_nbytes(summary=True) == {"inc": sizeof(1) * 3, "double": sizeof(1) * 3} @gen_test() def test_worker_aliases(): s = Scheduler(validate=True) s.start(0) - a = Worker(s.ip, s.port, name='alice') - b = Worker(s.ip, s.port, name='bob') + a = Worker(s.ip, s.port, name="alice") + b = Worker(s.ip, s.port, name="bob") w = Worker(s.ip, s.port, name=3) yield [a, b, w] c = yield Client((s.ip, s.port), asynchronous=True) - L = c.map(inc, range(10), workers='alice') + L = c.map(inc, range(10), workers="alice") future = yield c.scatter(123, workers=3) yield wait(L) assert len(a.data) == 10 assert len(b.data) == 0 assert dict(w.data) == {future.key: 123} - for i, alias in enumerate([3, [3], 'alice']): + for i, alias in enumerate([3, [3], "alice"]): result = yield c.submit(lambda x: x + 1, i, workers=alias) assert result == i + 1 @@ -2809,16 +2887,17 @@ def test_persist_get(c, s, a, b): assert result == ((1 + 1) + (2 + 2)) + 10 -@pytest.mark.skipif(sys.platform.startswith('win'), - reason="num_fds not supported on windows") +@pytest.mark.skipif( + sys.platform.startswith("win"), reason="num_fds not supported on windows" +) def test_client_num_fds(loop): - psutil = pytest.importorskip('psutil') + psutil = pytest.importorskip("psutil") with cluster() as (s, [a, b]): proc = psutil.Process() - with Client(s['address'], loop=loop) as c: # first client to start loop - before = proc.num_fds() # measure + with Client(s["address"], loop=loop) as c: # first client to start loop + before = proc.num_fds() # measure for i in range(4): - with Client(s['address'], loop=loop): # start more clients + with Client(s["address"], loop=loop): # start more clients pass start = time() while proc.num_fds() > before: @@ -2837,14 +2916,14 @@ def test_startup_close_startup(s, a, b): def test_startup_close_startup_sync(loop): with cluster() as (s, [a, b]): - with Client(s['address'], loop=loop) as c: + with Client(s["address"], loop=loop) as c: sleep(0.1) - with Client(s['address']) as c: + with Client(s["address"]) as c: pass - with Client(s['address']) as c: + with Client(s["address"]) as c: pass sleep(0.1) - with Client(s['address']) as c: + with Client(s["address"]) as c: pass @@ -2854,14 +2933,15 @@ def f(): class BadlySerializedException(Exception): def __reduce__(self): raise TypeError() - raise BadlySerializedException('hello world') + + raise BadlySerializedException("hello world") x = c.submit(f) try: result = yield x except Exception as e: - assert 'hello world' in str(e) + assert "hello world" in str(e) else: assert False @@ -2885,11 +2965,10 @@ def test_rebalance(c, s, a, b): assert len(a.data) == 1 assert {ts.key for ts in aws.has_what} == set(a.data) - assert (aws not in s.tasks[x.key].who_has or - aws not in s.tasks[y.key].who_has) + assert aws not in s.tasks[x.key].who_has or aws not in s.tasks[y.key].who_has -@gen_cluster(ncores=[('127.0.0.1', 1)] * 4, client=True) +@gen_cluster(ncores=[("127.0.0.1", 1)] * 4, client=True) def test_rebalance_workers(e, s, a, b, c, d): w, x, y, z = yield e.scatter([1, 2, 3, 4], workers=[a.address]) assert len(a.data) == 4 @@ -2921,7 +3000,7 @@ def test_rebalance_execution(c, s, a, b): def test_rebalance_sync(c, s, a, b): - futures = c.map(inc, range(10), workers=[a['address']]) + futures = c.map(inc, range(10), workers=[a["address"]]) c.rebalance(futures) has_what = c.has_what() @@ -2944,31 +3023,32 @@ def test_receive_lost_key(c, s, a, b): yield a._close() start = time() - while x.status == 'finished': + while x.status == "finished": assert time() < start + 5 yield gen.sleep(0.01) -@pytest.mark.skipif(not sys.platform.startswith('linux'), - reason="Need 127.0.0.2 to mean localhost") -@gen_cluster([('127.0.0.1', 1), ('127.0.0.2', 2)], client=True) +@pytest.mark.skipif( + not sys.platform.startswith("linux"), reason="Need 127.0.0.2 to mean localhost" +) +@gen_cluster([("127.0.0.1", 1), ("127.0.0.2", 2)], client=True) def test_unrunnable_task_runs(c, s, a, b): x = c.submit(inc, 1, workers=[a.ip]) result = yield x yield a._close() start = time() - while x.status == 'finished': + while x.status == "finished": assert time() < start + 5 yield gen.sleep(0.01) assert s.tasks[x.key] in s.unrunnable - assert s.get_task_status(keys=[x.key]) == {x.key: 'no-worker'} + assert s.get_task_status(keys=[x.key]) == {x.key: "no-worker"} w = yield Worker(s.ip, s.port, loop=s.loop) start = time() - while x.status != 'finished': + while x.status != "finished": assert time() < start + 2 yield gen.sleep(0.01) @@ -2990,16 +3070,16 @@ def test_add_worker_after_tasks(c, s): yield n._close() -@pytest.mark.skipif(not sys.platform.startswith('linux'), - reason="Need 127.0.0.2 to mean localhost") -@gen_cluster([('127.0.0.1', 1), ('127.0.0.2', 2)], client=True) +@pytest.mark.skipif( + not sys.platform.startswith("linux"), reason="Need 127.0.0.2 to mean localhost" +) +@gen_cluster([("127.0.0.1", 1), ("127.0.0.2", 2)], client=True) def test_workers_register_indirect_data(c, s, a, b): [x] = yield c.scatter([1], workers=a.address) y = c.submit(inc, x, workers=b.ip) yield y assert b.data[x.key] == 1 - assert s.tasks[x.key].who_has == {s.workers[a.address], - s.workers[b.address]} + assert s.tasks[x.key].who_has == {s.workers[a.address], s.workers[b.address]} assert s.workers[b.address].has_what == {s.tasks[x.key], s.tasks[y.key]} s.validate_state() @@ -3015,7 +3095,7 @@ def test_submit_on_cancelled_future(c, s, a, b): y = c.submit(inc, x) -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 10) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 10) def test_replicate(c, s, *workers): [a, b] = yield c.scatter([1, 2]) yield s.replicate(keys=[a.key, b.key], n=5) @@ -3030,7 +3110,7 @@ def test_replicate(c, s, *workers): @gen_cluster(client=True) def test_replicate_tuple_keys(c, s, a, b): - x = delayed(inc)(1, dask_key_name=('x', 1)) + x = delayed(inc)(1, dask_key_name=("x", 1)) f = c.persist(x) yield c.replicate(f, n=5) s.validate_state() @@ -3040,12 +3120,13 @@ def test_replicate_tuple_keys(c, s, a, b): s.validate_state() -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 10) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 10) def test_replicate_workers(c, s, *workers): [a, b] = yield c.scatter([1, 2], workers=[workers[0].address]) - yield s.replicate(keys=[a.key, b.key], n=5, - workers=[w.address for w in workers[:5]]) + yield s.replicate( + keys=[a.key, b.key], n=5, workers=[w.address for w in workers[:5]] + ) assert len(s.tasks[a.key].who_has) == 5 assert len(s.tasks[b.key].who_has) == 5 @@ -3069,8 +3150,9 @@ def test_replicate_workers(c, s, *workers): assert len(s.tasks[b.key].who_has) == 10 s.validate_state() - yield s.replicate(keys=[a.key, b.key], n=1, - workers=[w.address for w in workers[:5]]) + yield s.replicate( + keys=[a.key, b.key], n=1, workers=[w.address for w in workers[:5]] + ) assert sum(a.key in w.data for w in workers[:5]) == 1 assert sum(b.key in w.data for w in workers[:5]) == 1 assert sum(a.key in w.data for w in workers[5:]) == 5 @@ -3089,7 +3171,7 @@ def __getstate__(self): return self.n -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 10) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 10) def test_replicate_tree_branching(c, s, *workers): obj = CountSerialization() [future] = yield c.scatter([obj]) @@ -3099,7 +3181,7 @@ def test_replicate_tree_branching(c, s, *workers): assert max_count > 1 -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 10) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 10) def test_client_replicate(c, s, *workers): x = c.submit(inc, 1) y = c.submit(inc, 2) @@ -3120,25 +3202,27 @@ def test_client_replicate(c, s, *workers): assert len(s.tasks[y.key].who_has) == 10 -@pytest.mark.skipif(not sys.platform.startswith('linux'), - reason="Need 127.0.0.2 to mean localhost") -@gen_cluster(client=True, ncores=[('127.0.0.1', 1), - ('127.0.0.2', 1), - ('127.0.0.2', 1)], timeout=None) +@pytest.mark.skipif( + not sys.platform.startswith("linux"), reason="Need 127.0.0.2 to mean localhost" +) +@gen_cluster( + client=True, + ncores=[("127.0.0.1", 1), ("127.0.0.2", 1), ("127.0.0.2", 1)], + timeout=None, +) def test_client_replicate_host(client, s, a, b, c): aws = s.workers[a.address] bws = s.workers[b.address] cws = s.workers[c.address] - x = client.submit(inc, 1, workers='127.0.0.2') + x = client.submit(inc, 1, workers="127.0.0.2") yield wait([x]) - assert (s.tasks[x.key].who_has == {bws} or - s.tasks[x.key].who_has == {cws}) + assert s.tasks[x.key].who_has == {bws} or s.tasks[x.key].who_has == {cws} - yield client.replicate([x], workers=['127.0.0.2']) + yield client.replicate([x], workers=["127.0.0.2"]) assert s.tasks[x.key].who_has == {bws, cws} - yield client.replicate([x], workers=['127.0.0.1']) + yield client.replicate([x], workers=["127.0.0.1"]) assert s.tasks[x.key].who_has == {aws, bws, cws} @@ -3156,21 +3240,22 @@ def test_client_replicate_sync(c): assert y.result() == 3 -@pytest.mark.skipif(sys.platform.startswith('win'), - reason="Windows timer too coarse-grained") -@gen_cluster(client=True, ncores=[('127.0.0.1', 4)] * 1) +@pytest.mark.skipif( + sys.platform.startswith("win"), reason="Windows timer too coarse-grained" +) +@gen_cluster(client=True, ncores=[("127.0.0.1", 4)] * 1) def test_task_load_adapts_quickly(c, s, a): future = c.submit(slowinc, 1, delay=0.2) # slow yield wait(future) - assert 0.15 < s.task_duration['slowinc'] < 0.4 + assert 0.15 < s.task_duration["slowinc"] < 0.4 futures = c.map(slowinc, range(10), delay=0) # very fast yield wait(futures) - assert 0 < s.task_duration['slowinc'] < 0.1 + assert 0 < s.task_duration["slowinc"] < 0.1 -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 2) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 2) def test_even_load_after_fast_functions(c, s, a, b): x = c.submit(inc, 1, workers=a.address) # very fast y = c.submit(inc, 2, workers=b.address) # very fast @@ -3184,7 +3269,7 @@ def test_even_load_after_fast_functions(c, s, a, b): # assert abs(len(a.data) - len(b.data)) <= 3 -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 2) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 2) def test_even_load_on_startup(c, s, a, b): x, y = c.map(inc, [1, 2]) yield wait([x, y]) @@ -3192,7 +3277,7 @@ def test_even_load_on_startup(c, s, a, b): @pytest.mark.skip -@gen_cluster(client=True, ncores=[('127.0.0.1', 2)] * 2) +@gen_cluster(client=True, ncores=[("127.0.0.1", 2)] * 2) def test_contiguous_load(c, s, a, b): w, x, y, z = c.map(inc, [1, 2, 3, 4]) yield wait([w, x, y, z]) @@ -3202,7 +3287,7 @@ def test_contiguous_load(c, s, a, b): assert {y.key, z.key} in groups -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 4) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 4) def test_balanced_with_submit(c, s, *workers): L = [c.submit(slowinc, i) for i in range(4)] yield wait(L) @@ -3210,7 +3295,7 @@ def test_balanced_with_submit(c, s, *workers): assert len(w.data) == 1 -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 4) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 4) def test_balanced_with_submit_and_resident_data(c, s, *workers): [x] = yield c.scatter([10], broadcast=True) L = [c.submit(slowinc, x, pure=False) for i in range(4)] @@ -3219,34 +3304,38 @@ def test_balanced_with_submit_and_resident_data(c, s, *workers): assert len(w.data) == 2 -@gen_cluster(client=True, ncores=[('127.0.0.1', 20)] * 2) +@gen_cluster(client=True, ncores=[("127.0.0.1", 20)] * 2) def test_scheduler_saturates_cores(c, s, a, b): for delay in [0, 0.01, 0.1]: futures = c.map(slowinc, range(100), delay=delay) futures = c.map(slowinc, futures, delay=delay / 10) while not s.tasks: if s.tasks: - assert all(len(p) >= 20 - for w in s.workers.values() - for p in w.processing.values()) + assert all( + len(p) >= 20 + for w in s.workers.values() + for p in w.processing.values() + ) yield gen.sleep(0.01) -@gen_cluster(client=True, ncores=[('127.0.0.1', 20)] * 2) +@gen_cluster(client=True, ncores=[("127.0.0.1", 20)] * 2) def test_scheduler_saturates_cores_random(c, s, a, b): for delay in [0, 0.01, 0.1]: futures = c.map(randominc, range(100), scale=0.1) while not s.tasks: if s.tasks: - assert all(len(p) >= 20 - for w in s.workers.values() - for p in w.processing.values()) + assert all( + len(p) >= 20 + for w in s.workers.values() + for p in w.processing.values() + ) yield gen.sleep(0.01) -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 4) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 4) def test_cancel_clears_processing(c, s, *workers): - da = pytest.importorskip('dask.array') + da = pytest.importorskip("dask.array") x = c.submit(slowinc, 1, delay=0.2) while not s.tasks: yield gen.sleep(0.01) @@ -3263,36 +3352,36 @@ def test_cancel_clears_processing(c, s, *workers): def test_default_get(): with cluster() as (s, [a, b]): pre_get = dask.base.get_scheduler() - pytest.raises(KeyError, dask.config.get, 'shuffle') - with Client(s['address'], set_as_default=True) as c: + pytest.raises(KeyError, dask.config.get, "shuffle") + with Client(s["address"], set_as_default=True) as c: assert dask.base.get_scheduler() == c.get - assert dask.config.get('shuffle') == 'tasks' + assert dask.config.get("shuffle") == "tasks" assert dask.base.get_scheduler() == pre_get - pytest.raises(KeyError, dask.config.get, 'shuffle') + pytest.raises(KeyError, dask.config.get, "shuffle") - c = Client(s['address'], set_as_default=False) + c = Client(s["address"], set_as_default=False) assert dask.base.get_scheduler() == pre_get - pytest.raises(KeyError, dask.config.get, 'shuffle') + pytest.raises(KeyError, dask.config.get, "shuffle") c.close() - c = Client(s['address'], set_as_default=True) - assert dask.config.get('shuffle') == 'tasks' + c = Client(s["address"], set_as_default=True) + assert dask.config.get("shuffle") == "tasks" assert dask.base.get_scheduler() == c.get c.close() assert dask.base.get_scheduler() == pre_get - pytest.raises(KeyError, dask.config.get, 'shuffle') + pytest.raises(KeyError, dask.config.get, "shuffle") - with Client(s['address']) as c: + with Client(s["address"]) as c: assert dask.base.get_scheduler() == c.get - with Client(s['address'], set_as_default=False) as c: + with Client(s["address"], set_as_default=False) as c: assert dask.base.get_scheduler() != c.get assert dask.base.get_scheduler() != c.get - with Client(s['address'], set_as_default=True) as c1: + with Client(s["address"], set_as_default=True) as c1: assert dask.base.get_scheduler() == c1.get - with Client(s['address'], set_as_default=True) as c2: + with Client(s["address"], set_as_default=True) as c2: assert dask.base.get_scheduler() == c2.get assert dask.base.get_scheduler() == c1.get assert dask.base.get_scheduler() == pre_get @@ -3303,8 +3392,9 @@ def test_get_processing(c, s, a, b): processing = yield c.processing() assert processing == valmap(tuple, s.processing) - futures = c.map(slowinc, range(10), delay=0.1, workers=[a.address], - allow_other_workers=True) + futures = c.map( + slowinc, range(10), delay=0.1, workers=[a.address], allow_other_workers=True + ) yield gen.sleep(0.2) @@ -3353,7 +3443,7 @@ def assert_dict_key_equal(expected, actual): assert list(ev) == list(av) -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 3) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 3) def test_get_foo_lost_keys(c, s, u, v, w): x = c.submit(inc, 1, workers=[u.address]) y = yield c.scatter(3, workers=[v.address]) @@ -3366,7 +3456,7 @@ def test_get_foo_lost_keys(c, s, u, v, w): d = yield c.scheduler.has_what(workers=[ua, va]) assert_dict_key_equal(d, {ua: [x.key], va: [y.key]}) d = yield c.scheduler.who_has() - assert_dict_key_equal(d, {x.key: [ua], y.key: [va]}) + assert_dict_key_equal(d, {x.key: [ua], y.key: [va]}) d = yield c.scheduler.who_has(keys=[x.key, y.key]) assert_dict_key_equal(d, {x.key: [ua], y.key: [va]}) @@ -3397,14 +3487,14 @@ def test_get_processing_sync(c, s, a, b): processing = c.processing() assert not any(v for v in processing.values()) - futures = c.map(slowinc, range(10), delay=0.1, - workers=[a['address']], - allow_other_workers=False) + futures = c.map( + slowinc, range(10), delay=0.1, workers=[a["address"]], allow_other_workers=False + ) sleep(0.2) - aa = a['address'] - bb = b['address'] + aa = a["address"] + bb = b["address"] processing = c.processing() assert set(c.processing(aa)) == {aa} @@ -3423,7 +3513,7 @@ def test_close_idempotent(c): def test_get_returns_early(c): start = time() with ignoring(RuntimeError): - result = c.get({'x': (throws, 1), 'y': (sleep, 1)}, ['x', 'y']) + result = c.get({"x": (throws, 1), "y": (sleep, 1)}, ["x", "y"]) assert time() < start + 0.5 # Futures should be released and forgotten wait_for(lambda: not c.futures, timeout=0.1) @@ -3434,7 +3524,7 @@ def test_get_returns_early(c): x.result() with ignoring(RuntimeError): - result = c.get({'x': (throws, 1), x.key: (inc, 1)}, ['x', x.key]) + result = c.get({"x": (throws, 1), x.key: (inc, 1)}, ["x", x.key]) assert x.key in c.futures @@ -3450,6 +3540,7 @@ def test_Client_clears_references_after_restart(c, s, a, b): key = x.key del x import gc + gc.collect() yield gen.moment @@ -3458,7 +3549,7 @@ def test_Client_clears_references_after_restart(c, s, a, b): def test_get_stops_work_after_error(c): with pytest.raises(RuntimeError): - c.get({'x': (throws, 1), 'y': (sleep, 1.5)}, ['x', 'y']) + c.get({"x": (throws, 1), "y": (sleep, 1.5)}, ["x", "y"]) start = time() while any(c.processing().values()): @@ -3479,7 +3570,7 @@ def test_as_completed_results(c): assert set(pluck(0, seq2)) == set(seq) -@pytest.mark.parametrize('with_results', [True, False]) +@pytest.mark.parametrize("with_results", [True, False]) def test_as_completed_batches(c, with_results): n = 50 futures = c.map(slowinc, range(n), delay=0.01) @@ -3509,11 +3600,11 @@ def test_status(): s.start(0) c = yield Client((s.ip, s.port), asynchronous=True) - assert c.status == 'running' + assert c.status == "running" x = c.submit(inc, 1) yield c.close() - assert c.status == 'closed' + assert c.status == "closed" yield s.close() @@ -3551,13 +3642,19 @@ def test_scatter_raises_if_no_workers(c, s): @slow def test_reconnect(loop): - w = Worker('127.0.0.1', 9393, loop=loop) + w = Worker("127.0.0.1", 9393, loop=loop) w.start() - scheduler_cli = ['dask-scheduler', '--host', '127.0.0.1', - '--port', '9393', '--no-bokeh'] + scheduler_cli = [ + "dask-scheduler", + "--host", + "127.0.0.1", + "--port", + "9393", + "--no-bokeh", + ] with popen(scheduler_cli) as s: - c = Client('127.0.0.1:9393', loop=loop) + c = Client("127.0.0.1:9393", loop=loop) start = time() while len(c.ncores()) != 1: sleep(0.1) @@ -3567,20 +3664,20 @@ def test_reconnect(loop): assert x.result() == 2 start = time() - while c.status != 'connecting': + while c.status != "connecting": assert time() < start + 5 sleep(0.01) with pytest.raises(Exception): c.ncores() - assert x.status == 'cancelled' + assert x.status == "cancelled" with pytest.raises(CancelledError): x.result() with popen(scheduler_cli) as s: start = time() - while c.status != 'running': + while c.status != "running": sleep(0.1) assert time() < start + 5 start = time() @@ -3608,17 +3705,16 @@ def test_reconnect(loop): @slow -@pytest.mark.skipif(sys.platform.startswith('win'), - reason="num_fds not supported on windows") -@pytest.mark.skipif(sys.version_info[0] == 2, - reason="Semaphore.acquire doesn't support timeout option") -@pytest.mark.xfail(reason='TODO: intermittent failures') -@pytest.mark.parametrize("worker,count,repeat", [ - (Worker, 100, 5), - (Nanny, 10, 20) -]) +@pytest.mark.skipif( + sys.platform.startswith("win"), reason="num_fds not supported on windows" +) +@pytest.mark.skipif( + sys.version_info[0] == 2, reason="Semaphore.acquire doesn't support timeout option" +) +@pytest.mark.xfail(reason="TODO: intermittent failures") +@pytest.mark.parametrize("worker,count,repeat", [(Worker, 100, 5), (Nanny, 10, 20)]) def test_open_close_many_workers(loop, worker, count, repeat): - psutil = pytest.importorskip('psutil') + psutil = pytest.importorskip("psutil") proc = psutil.Process() with cluster(nworkers=0, active_rpc_timeout=20) as (s, _): @@ -3631,7 +3727,7 @@ def test_open_close_many_workers(loop, worker, count, repeat): def start_worker(sleep, duration, repeat=1): for i in range(repeat): yield gen.sleep(sleep) - w = worker(s['address'], loop=loop) + w = worker(s["address"], loop=loop) running[w] = None yield w addr = w.worker_address @@ -3643,10 +3739,11 @@ def start_worker(sleep, duration, repeat=1): done.release() for i in range(count): - loop.add_callback(start_worker, random.random() / 5, random.random() / 5, - repeat=repeat) + loop.add_callback( + start_worker, random.random() / 5, random.random() / 5, repeat=repeat + ) - with Client(s['address'], loop=loop) as c: + with Client(s["address"], loop=loop) as c: sleep(1) for i in range(count): @@ -3689,7 +3786,7 @@ def test_idempotence(s, a, b): # Error a = c.submit(div, 1, 0) yield wait(a) - assert a.status == 'error' + assert a.status == "error" log = list(s.transition_log) b = f.submit(div, 1, 0) @@ -3714,12 +3811,12 @@ def test_idempotence(s, a, b): def test_scheduler_info(c): info = c.scheduler_info() assert isinstance(info, dict) - assert len(info['workers']) == 2 + assert len(info["workers"]) == 2 def test_write_scheduler_file(c): info = c.scheduler_info() - with tmpfile('json') as scheduler_file: + with tmpfile("json") as scheduler_file: c.write_scheduler_file(scheduler_file) with Client(scheduler_file=scheduler_file) as c2: info2 = c2.scheduler_info() @@ -3732,30 +3829,30 @@ def test_write_scheduler_file(c): def test_get_versions(c): - requests = pytest.importorskip('requests') + requests = pytest.importorskip("requests") v = c.get_versions() - assert v['scheduler'] is not None - assert v['client'] is not None - assert len(v['workers']) == 2 - for k, v in v['workers'].items(): + assert v["scheduler"] is not None + assert v["client"] is not None + assert len(v["workers"]) == 2 + for k, v in v["workers"].items(): assert v is not None c.get_versions(check=True) # smoke test for versions # that this does not raise - v = c.get_versions(packages=['requests']) - assert dict(v['client']['packages']['optional'])['requests'] == requests.__version__ + v = c.get_versions(packages=["requests"]) + assert dict(v["client"]["packages"]["optional"])["requests"] == requests.__version__ def test_threaded_get_within_distributed(c): import dask.multiprocessing - for get in [dask.local.get_sync, - dask.multiprocessing.get, - dask.threaded.get]: + + for get in [dask.local.get_sync, dask.multiprocessing.get, dask.threaded.get]: + def f(): - return get({'x': (lambda: 1,)}, 'x') + return get({"x": (lambda: 1,)}, "x") future = c.submit(f) assert future.result() == 1 @@ -3768,11 +3865,11 @@ def test_lose_scattered_data(c, s, a, b): yield a._close() yield gen.sleep(0.1) - assert x.status == 'cancelled' + assert x.status == "cancelled" assert x.key not in s.tasks -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 3) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 3) def test_partially_lose_scattered_data(e, s, a, b, c): [x] = yield e.scatter([1], workers=a.address) yield e.replicate(x, n=2) @@ -3780,8 +3877,8 @@ def test_partially_lose_scattered_data(e, s, a, b, c): yield a._close() yield gen.sleep(0.1) - assert x.status == 'finished' - assert s.get_task_status(keys=[x.key]) == {x.key: 'memory'} + assert x.status == "finished" + assert s.get_task_status(keys=[x.key]) == {x.key: "memory"} @gen_cluster(client=True) @@ -3797,9 +3894,9 @@ def test_scatter_compute_lose(c, s, a, b): with pytest.raises(CancelledError): yield wait(z) - assert x.status == 'cancelled' - assert y.status == 'finished' - assert z.status == 'cancelled' + assert x.status == "cancelled" + assert y.status == "finished" + assert z.status == "cancelled" @gen_cluster(client=True) @@ -3821,13 +3918,13 @@ def test_scatter_compute_store_lose(c, s, a, b): yield a._close() start = time() - while x.status == 'finished': + while x.status == "finished": yield gen.sleep(0.01) assert time() < start + 2 # assert xx.status == 'finished' - assert y.status == 'finished' - assert z.status == 'finished' + assert y.status == "finished" + assert z.status == "finished" zz = c.submit(inc, z) yield wait(zz) @@ -3836,7 +3933,7 @@ def test_scatter_compute_store_lose(c, s, a, b): del z start = time() - while s.get_task_status(keys=[zkey]) != {zkey: 'released'}: + while s.get_task_status(keys=[zkey]) != {zkey: "released"}: yield gen.sleep(0.01) assert time() < start + 2 @@ -3844,9 +3941,7 @@ def test_scatter_compute_store_lose(c, s, a, b): del xx start = time() - while (x.key in s.tasks and - zkey not in s.tasks and - xxkey not in s.tasks): + while x.key in s.tasks and zkey not in s.tasks and xxkey not in s.tasks: yield gen.sleep(0.01) assert time() < start + 2 @@ -3868,12 +3963,12 @@ def test_scatter_compute_store_lose_processing(c, s, a, b): yield a._close() start = time() - while x.status == 'finished': + while x.status == "finished": yield gen.sleep(0.01) assert time() < start + 2 - assert y.status == 'cancelled' - assert z.status == 'cancelled' + assert y.status == "cancelled" + assert z.status == "cancelled" @gen_cluster(client=False) @@ -3913,19 +4008,23 @@ def test_temp_client(s, a, b): @nodebug # test timing is fragile -@gen_cluster(ncores=[('127.0.0.1', 1)] * 3, client=True) +@gen_cluster(ncores=[("127.0.0.1", 1)] * 3, client=True) def test_persist_workers(e, s, a, b, c): L1 = [delayed(inc)(i) for i in range(4)] total = delayed(sum)(L1) L2 = [delayed(add)(i, total) for i in L1] total2 = delayed(sum)(L2) - out = e.persist(L1 + L2 + [total, total2], - workers={tuple(L1): a.address, - total: b.address, - tuple(L2): [c.address], - total2: b.address}, - allow_other_workers=L2 + [total2]) + out = e.persist( + L1 + L2 + [total, total2], + workers={ + tuple(L1): a.address, + total: b.address, + tuple(L2): [c.address], + total2: b.address, + }, + allow_other_workers=L2 + [total2], + ) yield wait(out) assert all(v.key in a.data for v in L1) @@ -3934,17 +4033,17 @@ def test_persist_workers(e, s, a, b, c): assert s.loose_restrictions == {total2.key} | {v.key for v in L2} -@gen_cluster(ncores=[('127.0.0.1', 1)] * 3, client=True) +@gen_cluster(ncores=[("127.0.0.1", 1)] * 3, client=True) def test_compute_workers(e, s, a, b, c): L1 = [delayed(inc)(i) for i in range(4)] total = delayed(sum)(L1) L2 = [delayed(add)(i, total) for i in L1] - out = e.compute(L1 + L2 + [total], - workers={tuple(L1): a.address, - total: b.address, - tuple(L2): [c.address]}, - allow_other_workers=L1 + [total]) + out = e.compute( + L1 + L2 + [total], + workers={tuple(L1): a.address, total: b.address, tuple(L2): [c.address]}, + allow_other_workers=L1 + [total], + ) yield wait(out) for v in L1: @@ -3958,16 +4057,16 @@ def test_compute_workers(e, s, a, b, c): @gen_cluster(client=True) def test_compute_nested_containers(c, s, a, b): - da = pytest.importorskip('dask.array') - np = pytest.importorskip('numpy') + da = pytest.importorskip("dask.array") + np = pytest.importorskip("numpy") x = da.ones(10, chunks=(5,)) + 1 - future = c.compute({'x': [x], 'y': 123}) + future = c.compute({"x": [x], "y": 123}) result = yield future assert isinstance(result, dict) - assert (result['x'][0] == np.ones(10) + 1).all() - assert result['y'] == 123 + assert (result["x"][0] == np.ones(10) + 1).all() + assert result["y"] == 123 def test_get_restrictions(): @@ -3975,20 +4074,20 @@ def test_get_restrictions(): total = delayed(sum)(L1) L2 = [delayed(add)(i, total) for i in L1] - r1, loose = Client.get_restrictions(L2, '127.0.0.1', False) - assert r1 == {d.key: ['127.0.0.1'] for d in L2} + r1, loose = Client.get_restrictions(L2, "127.0.0.1", False) + assert r1 == {d.key: ["127.0.0.1"] for d in L2} assert not loose - r1, loose = Client.get_restrictions(L2, ['127.0.0.1'], True) - assert r1 == {d.key: ['127.0.0.1'] for d in L2} + r1, loose = Client.get_restrictions(L2, ["127.0.0.1"], True) + assert r1 == {d.key: ["127.0.0.1"] for d in L2} assert set(loose) == {d.key for d in L2} - r1, loose = Client.get_restrictions(L2, {total: '127.0.0.1'}, True) - assert r1 == {total.key: ['127.0.0.1']} + r1, loose = Client.get_restrictions(L2, {total: "127.0.0.1"}, True) + assert r1 == {total.key: ["127.0.0.1"]} assert loose == [total.key] - r1, loose = Client.get_restrictions(L2, {(total,): '127.0.0.1'}, True) - assert r1 == {total.key: ['127.0.0.1']} + r1, loose = Client.get_restrictions(L2, {(total,): "127.0.0.1"}, True) + assert r1 == {total.key: ["127.0.0.1"]} assert loose == [total.key] @@ -3997,8 +4096,8 @@ def test_scatter_type(c, s, a, b): [future] = yield c.scatter([1]) assert future.type == int - d = yield c.scatter({'x': 1.0}) - assert d['x'].type == float + d = yield c.scatter({"x": 1.0}) + assert d["x"].type == float @gen_cluster(client=True) @@ -4013,7 +4112,7 @@ def test_retire_workers_2(c, s, a, b): assert a.address not in s.workers -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 10) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 10) def test_retire_many_workers(c, s, *workers): futures = yield c.scatter(list(range(100))) @@ -4024,16 +4123,15 @@ def test_retire_many_workers(c, s, *workers): assert len(s.has_what) == len(s.ncores) == 3 assert all(future.done() for future in futures) - assert all(s.tasks[future.key].state == 'memory' for future in futures) + assert all(s.tasks[future.key].state == "memory" for future in futures) for w, keys in s.has_what.items(): assert 15 < len(keys) < 50 -@gen_cluster(client=True, - ncores=[('127.0.0.1', 3)] * 2) +@gen_cluster(client=True, ncores=[("127.0.0.1", 3)] * 2) def test_weight_occupancy_against_data_movement(c, s, a, b): - s.extensions['stealing']._pc.callback_time = 1000000 - s.task_duration['f'] = 0.01 + s.extensions["stealing"]._pc.callback_time = 1000000 + s.task_duration["f"] = 0.01 def f(x, y=0, z=0): sleep(0.01) @@ -4050,11 +4148,10 @@ def f(x, y=0, z=0): assert sum(f.key in b.data for f in futures) >= 1 -@gen_cluster(client=True, - ncores=[('127.0.0.1', 1), ('127.0.0.1', 10)]) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1), ("127.0.0.1", 10)]) def test_distribute_tasks_by_ncores(c, s, a, b): - s.task_duration['f'] = 0.01 - s.extensions['stealing']._pc.callback_time = 1000000 + s.task_duration["f"] = 0.01 + s.extensions["stealing"]._pc.callback_time = 1000000 def f(x, y=0): sleep(0.01) @@ -4079,10 +4176,10 @@ def f(future): def g(future): S.add((future.key, future.status)) - u = c.submit(inc, 1, key='u') - v = c.submit(throws, "hello", key='v') - w = c.submit(slowinc, 2, delay=0.3, key='w') - x = c.submit(inc, 3, key='x') + u = c.submit(inc, 1, key="u") + v = c.submit(throws, "hello", key="v") + w = c.submit(slowinc, 2, delay=0.3, key="w") + x = c.submit(inc, 3, key="x") u.add_done_callback(f) v.add_done_callback(f) w.add_done_callback(f) @@ -4115,7 +4212,7 @@ def test_normalize_collection(c, s, a, b): @gen_cluster(client=True) def test_normalize_collection_dask_array(c, s, a, b): - da = pytest.importorskip('dask.array') + da = pytest.importorskip("dask.array") x = da.ones(10, chunks=(5,)) y = x + 1 @@ -4139,9 +4236,9 @@ def test_normalize_collection_dask_array(c, s, a, b): @slow def test_normalize_collection_with_released_futures(c): - da = pytest.importorskip('dask.array') + da = pytest.importorskip("dask.array") - x = da.arange(2**20, chunks=2**10) + x = da.arange(2 ** 20, chunks=2 ** 10) y = x.persist() wait(y) sol = y.sum().compute() @@ -4157,7 +4254,7 @@ def test_normalize_collection_with_released_futures(c): @gen_cluster(client=True) def test_auto_normalize_collection(c, s, a, b): - da = pytest.importorskip('dask.array') + da = pytest.importorskip("dask.array") x = da.ones(10, chunks=5) assert len(x.dask) == 2 @@ -4182,7 +4279,7 @@ def test_auto_normalize_collection(c, s, a, b): def test_auto_normalize_collection_sync(c): - da = pytest.importorskip('dask.array') + da = pytest.importorskip("dask.array") x = da.ones(10, chunks=5) y = x.map_blocks(slowinc, delay=1, dtype=x.dtype) @@ -4199,14 +4296,15 @@ def test_auto_normalize_collection_sync(c): def assert_no_data_loss(scheduler): for key, start, finish, recommendations, _ in scheduler.transition_log: - if start == 'memory' and finish == 'released': + if start == "memory" and finish == "released": for k, v in recommendations.items(): - assert not (k == key and v == 'waiting') + assert not (k == key and v == "waiting") @gen_cluster(client=True, timeout=None) def test_interleave_computations(c, s, a, b): import distributed + distributed.g = s xs = [delayed(slowinc)(i, delay=0.02) for i in range(30)] ys = [delayed(slowdec)(x, delay=0.02) for x in xs] @@ -4216,7 +4314,7 @@ def test_interleave_computations(c, s, a, b): future = c.compute(total) - done = ('memory', 'released') + done = ("memory", "released") yield gen.sleep(0.1) @@ -4226,12 +4324,9 @@ def test_interleave_computations(c, s, a, b): while not s.tasks or any(w.processing for w in s.workers.values()): yield gen.sleep(0.05) - x_done = sum(state in done - for state in s.get_task_status(keys=x_keys).values()) - y_done = sum(state in done - for state in s.get_task_status(keys=y_keys).values()) - z_done = sum(state in done - for state in s.get_task_status(keys=z_keys).values()) + x_done = sum(state in done for state in s.get_task_status(keys=x_keys).values()) + y_done = sum(state in done for state in s.get_task_status(keys=y_keys).values()) + z_done = sum(state in done for state in s.get_task_status(keys=z_keys).values()) assert x_done >= y_done >= z_done assert x_done < y_done + 10 @@ -4247,7 +4342,7 @@ def test_interleave_computations_map(c, s, a, b): ys = c.map(slowdec, xs, delay=0.02) zs = c.map(slowadd, xs, ys, delay=0.02) - done = ('memory', 'released') + done = ("memory", "released") x_keys = [x.key for x in xs] y_keys = [y.key for y in ys] @@ -4255,12 +4350,9 @@ def test_interleave_computations_map(c, s, a, b): while not s.tasks or any(w.processing for w in s.workers.values()): yield gen.sleep(0.05) - x_done = sum(state in done - for state in s.get_task_status(keys=x_keys).values()) - y_done = sum(state in done - for state in s.get_task_status(keys=y_keys).values()) - z_done = sum(state in done - for state in s.get_task_status(keys=z_keys).values()) + x_done = sum(state in done for state in s.get_task_status(keys=x_keys).values()) + y_done = sum(state in done for state in s.get_task_status(keys=y_keys).values()) + z_done = sum(state in done for state in s.get_task_status(keys=z_keys).values()) assert x_done >= y_done >= z_done assert x_done < y_done + 10 @@ -4269,20 +4361,20 @@ def test_interleave_computations_map(c, s, a, b): @gen_cluster(client=True) def test_scatter_dict_workers(c, s, a, b): - yield c.scatter({'a': 10}, workers=[a.address, b.address]) - assert 'a' in a.data or 'a' in b.data + yield c.scatter({"a": 10}, workers=[a.address, b.address]) + assert "a" in a.data or "a" in b.data @slow @gen_test() def test_client_timeout(): loop = IOLoop.current() - c = Client('127.0.0.1:57484', asynchronous=True) + c = Client("127.0.0.1:57484", asynchronous=True) s = Scheduler(loop=loop) yield gen.sleep(4) try: - s.start(('127.0.0.1', 57484)) + s.start(("127.0.0.1", 57484)) except EnvironmentError: # port in use return @@ -4337,29 +4429,29 @@ def test_dont_clear_waiting_data(c, s, a, b): def test_get_future_error_simple(c, s, a, b): f = c.submit(div, 1, 0) yield wait(f) - assert f.status == 'error' + assert f.status == "error" function, args, kwargs, deps = yield c._get_futures_error(f) # args contains only solid values, not keys - assert function.__name__ == 'div' + assert function.__name__ == "div" with pytest.raises(ZeroDivisionError): function(*args, **kwargs) @gen_cluster(client=True) def test_get_futures_error(c, s, a, b): - x0 = delayed(dec)(2, dask_key_name='x0') - y0 = delayed(dec)(1, dask_key_name='y0') - x = delayed(div)(1, x0, dask_key_name='x') - y = delayed(div)(1, y0, dask_key_name='y') - tot = delayed(sum)(x, y, dask_key_name='tot') + x0 = delayed(dec)(2, dask_key_name="x0") + y0 = delayed(dec)(1, dask_key_name="y0") + x = delayed(div)(1, x0, dask_key_name="x") + y = delayed(div)(1, y0, dask_key_name="y") + tot = delayed(sum)(x, y, dask_key_name="tot") f = c.compute(tot) yield wait(f) - assert f.status == 'error' + assert f.status == "error" function, args, kwargs, deps = yield c._get_futures_error(f) - assert function.__name__ == 'div' + assert function.__name__ == "div" assert args == (1, y0.key) @@ -4373,11 +4465,11 @@ def test_recreate_error_delayed(c, s, a, b): f = c.compute(tot) - assert f.status == 'pending' + assert f.status == "pending" function, args, kwargs = yield c._recreate_error_locally(f) - assert f.status == 'error' - assert function.__name__ == 'div' + assert f.status == "error" + assert function.__name__ == "div" assert args == (1, 0) with pytest.raises(ZeroDivisionError): function(*args, **kwargs) @@ -4392,11 +4484,11 @@ def test_recreate_error_futures(c, s, a, b): tot = c.submit(sum, x, y) f = c.compute(tot) - assert f.status == 'pending' + assert f.status == "pending" function, args, kwargs = yield c._recreate_error_locally(f) - assert f.status == 'error' - assert function.__name__ == 'div' + assert f.status == "error" + assert function.__name__ == "div" assert args == (1, 0) with pytest.raises(ZeroDivisionError): function(*args, **kwargs) @@ -4413,15 +4505,17 @@ def test_recreate_error_collection(c, s, a, b): with pytest.raises(ZeroDivisionError): function(*args, **kwargs) - dd = pytest.importorskip('dask.dataframe') + dd = pytest.importorskip("dask.dataframe") import pandas as pd - df = dd.from_pandas(pd.DataFrame({'a': [0, 1, 2, 3, 4]}), chunksize=2) + + df = dd.from_pandas(pd.DataFrame({"a": [0, 1, 2, 3, 4]}), chunksize=2) def make_err(x): # because pandas would happily work with NaN if x == 0: raise ValueError return x + df2 = df.a.map(make_err) f = c.compute(df2) function, args, kwargs = yield c._recreate_error_locally(f) @@ -4437,12 +4531,12 @@ def make_err(x): @gen_cluster(client=True) def test_recreate_error_array(c, s, a, b): - da = pytest.importorskip('dask.array') - pytest.importorskip('scipy') + da = pytest.importorskip("dask.array") + pytest.importorskip("scipy") z = (da.linalg.inv(da.zeros((10, 10), chunks=10)) + 1).sum() zz = z.persist() func, args, kwargs = yield c._recreate_error_locally(zz) - assert '0.,0.,0.' in str(args).replace(' ', '') # args contain actual arrays + assert "0.,0.,0." in str(args).replace(" ", "") # args contain actual arrays def test_recreate_error_sync(c): @@ -4455,7 +4549,7 @@ def test_recreate_error_sync(c): with pytest.raises(ZeroDivisionError) as e: c.recreate_error_locally(f) - assert f.status == 'error' + assert f.status == "error" def test_recreate_error_not_error(c): @@ -4472,7 +4566,7 @@ def test_retire_workers(c, s, a, b): assert set(s.workers) == {b.address} start = time() - while a.status != 'closed': + while a.status != "closed": yield gen.sleep(0.01) assert time() < start + 5 @@ -4504,7 +4598,7 @@ def __getstate__(self): return 1 def __setstate__(self, state): - raise MyException('hello') + raise MyException("hello") future = c.submit(identity, Foo()) with pytest.raises(MyException): @@ -4524,7 +4618,7 @@ def __getstate__(self): return 1 def __setstate__(self, state): - raise MyException('hello') + raise MyException("hello") def __call__(self, *args): return 1 @@ -4552,7 +4646,7 @@ def f(x): fire_and_forget(c.submit(f, future)) start = time() - while not hasattr(distributed, 'foo'): + while not hasattr(distributed, "foo"): yield gen.sleep(0.01) assert time() < start + 2 assert distributed.foo == 123 @@ -4581,21 +4675,21 @@ def test_fire_and_forget_err(c, s, a, b): def test_quiet_client_close(loop): - with captured_logger(logging.getLogger('distributed')) as logger: + with captured_logger(logging.getLogger("distributed")) as logger: with Client(loop=loop, processes=False, threads_per_worker=4) as c: futures = c.map(slowinc, range(1000), delay=0.01) sleep(0.200) # stop part-way - sleep(.1) # let things settle + sleep(0.1) # let things settle out = logger.getvalue() - lines = out.strip().split('\n') + lines = out.strip().split("\n") assert len(lines) <= 2 for line in lines: assert ( - not line or - 'Reconnecting' in line or - 'garbage' in line or - set(line) == {'-'} + not line + or "Reconnecting" in line + or "garbage" in line + or set(line) == {"-"} ), line @@ -4605,14 +4699,14 @@ def test_quiet_client_close_when_cluster_is_closed_before_client(loop): # fix in #2477 and with 5 attempts, this test passes by chance in about 10% # of the cases. for _ in range(n_attempts): - with captured_logger(logging.getLogger('tornado.application')) as logger: + with captured_logger(logging.getLogger("tornado.application")) as logger: cluster = LocalCluster(loop=loop) client = Client(cluster, loop=loop) cluster.close() client.close() out = logger.getvalue() - assert 'CancelledError' not in out + assert "CancelledError" not in out @gen_cluster() @@ -4641,6 +4735,7 @@ def f(_): return total.result() from concurrent.futures import ThreadPoolExecutor + with ThreadPoolExecutor(20) as e: results = list(e.map(f, range(20))) assert results and all(results) @@ -4649,7 +4744,7 @@ def f(_): @slow def test_threadsafe_get(c): - da = pytest.importorskip('dask.array') + da = pytest.importorskip("dask.array") x = da.arange(100, chunks=(10,)) def f(_): @@ -4660,6 +4755,7 @@ def f(_): return total from concurrent.futures import ThreadPoolExecutor + e = ThreadPoolExecutor(30) results = list(e.map(f, range(30))) assert results and all(results) @@ -4667,7 +4763,7 @@ def f(_): @slow def test_threadsafe_compute(c): - da = pytest.importorskip('dask.array') + da = pytest.importorskip("dask.array") x = da.arange(100, chunks=(10,)) def f(_): @@ -4679,6 +4775,7 @@ def f(_): return total from concurrent.futures import ThreadPoolExecutor + e = ThreadPoolExecutor(30) results = list(e.map(f, range(30))) assert results and all(results) @@ -4686,13 +4783,13 @@ def f(_): @gen_cluster(client=True) def test_identity(c, s, a, b): - assert c.id.lower().startswith('client') - assert a.id.lower().startswith('worker') - assert b.id.lower().startswith('worker') - assert s.id.lower().startswith('scheduler') + assert c.id.lower().startswith("client") + assert a.id.lower().startswith("worker") + assert b.id.lower().startswith("worker") + assert s.id.lower().startswith("scheduler") -@gen_cluster(client=True, ncores=[('127.0.0.1', 4)] * 2) +@gen_cluster(client=True, ncores=[("127.0.0.1", 4)] * 2) def test_get_client(c, s, a, b): assert get_client() is c assert c.asynchronous @@ -4701,11 +4798,13 @@ def f(x): client = get_client() future = client.submit(inc, x) import distributed + assert not client.asynchronous assert client is distributed.tmp_client return future.result() import distributed + distributed.tmp_client = c try: futures = c.map(f, range(5)) @@ -4719,16 +4818,17 @@ def test_get_client_no_cluster(): # Clean up any global workers added by other tests. This test requires that # there are no global workers. from distributed.worker import _global_workers + del _global_workers[:] - msg = 'No global client found and no address provided' - with pytest.raises(ValueError, match=r'^{}$'.format(msg)): + msg = "No global client found and no address provided" + with pytest.raises(ValueError, match=r"^{}$".format(msg)): get_client() @gen_cluster(client=True) def test_serialize_collections(c, s, a, b): - da = pytest.importorskip('dask.array') + da = pytest.importorskip("dask.array") x = da.arange(10, chunks=(5,)).persist() def f(x): @@ -4740,7 +4840,7 @@ def f(x): assert result == sum(range(10)) -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 1, timeout=100) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 1, timeout=100) def test_secede_simple(c, s, a): def f(): client = get_client() @@ -4752,7 +4852,7 @@ def f(): @slow -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 2, timeout=60) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 2, timeout=60) def test_secede_balances(c, s, a, b): count = threading.active_count() @@ -4766,7 +4866,7 @@ def f(x): futures = c.map(f, range(100)) start = time() - while not all(f.status == 'finished' for f in futures): + while not all(f.status == "finished" for f in futures): yield gen.sleep(0.01) assert threading.active_count() < count + 50 @@ -4782,29 +4882,31 @@ def f(x): def test_sub_submit_priority(c, s, a, b): def f(): client = get_client() - client.submit(slowinc, 1, delay=0.2, key='slowinc') + client.submit(slowinc, 1, delay=0.2, key="slowinc") - future = c.submit(f, key='f') + future = c.submit(f, key="f") yield gen.sleep(0.1) if len(s.tasks) == 2: - assert s.priorities['f'] > s.priorities['slowinc'] # lower values schedule first + assert ( + s.priorities["f"] > s.priorities["slowinc"] + ) # lower values schedule first def test_get_client_sync(c, s, a, b): results = c.run(lambda: get_worker().scheduler.address) - assert results == {w['address']: s['address'] for w in [a, b]} + assert results == {w["address"]: s["address"] for w in [a, b]} results = c.run(lambda: get_client().scheduler.address) - assert results == {w['address']: s['address'] for w in [a, b]} + assert results == {w["address"]: s["address"] for w in [a, b]} @gen_cluster(client=True) def test_serialize_collections_of_futures(c, s, a, b): - pd = pytest.importorskip('pandas') - dd = pytest.importorskip('dask.dataframe') + pd = pytest.importorskip("pandas") + dd = pytest.importorskip("dask.dataframe") from dask.dataframe.utils import assert_eq - df = pd.DataFrame({'x': [1, 2, 3]}) + df = pd.DataFrame({"x": [1, 2, 3]}) ddf = dd.from_pandas(df, npartitions=2).persist() future = yield c.scatter(ddf) @@ -4815,11 +4917,11 @@ def test_serialize_collections_of_futures(c, s, a, b): def test_serialize_collections_of_futures_sync(c): - pd = pytest.importorskip('pandas') - dd = pytest.importorskip('dask.dataframe') + pd = pytest.importorskip("pandas") + dd = pytest.importorskip("dask.dataframe") from dask.dataframe.utils import assert_eq - df = pd.DataFrame({'x': [1, 2, 3]}) + df = pd.DataFrame({"x": [1, 2, 3]}) ddf = dd.from_pandas(df, npartitions=2).persist() future = c.scatter(ddf) @@ -4827,11 +4929,11 @@ def test_serialize_collections_of_futures_sync(c): assert_eq(result.compute(), df) assert future.type == dd.DataFrame - assert c.submit(lambda x, y: assert_eq(x.compute(), y), future, df).result() + assert c.submit(lambda x, y: assert_eq(x.compute(), y), future, df).result() def _dynamic_workload(x, delay=0.01): - if delay == 'random': + if delay == "random": sleep(random.random() / 2) else: sleep(delay) @@ -4839,8 +4941,9 @@ def _dynamic_workload(x, delay=0.01): return 4 secede() client = get_client() - futures = client.map(_dynamic_workload, [x + i + 1 for i in range(2)], - pure=False, delay=delay) + futures = client.map( + _dynamic_workload, [x + i + 1 for i in range(2)], pure=False, delay=delay + ) total = client.submit(sum, futures) return total.result() @@ -4856,12 +4959,12 @@ def test_dynamic_workloads_sync(c): @slow def test_dynamic_workloads_sync_random(c): - _test_dynamic_workloads_sync(c, delay='random') + _test_dynamic_workloads_sync(c, delay="random") @gen_cluster(client=True) def test_bytes_keys(c, s, a, b): - key = b'inc-123' + key = b"inc-123" future = c.submit(inc, 1, key=key) result = yield future assert type(future.key) is bytes @@ -4874,7 +4977,7 @@ def test_bytes_keys(c, s, a, b): def test_unicode_ascii_keys(c, s, a, b): # cross-version unicode type (py2: unicode, py3: str) uni_type = type(u"") - key = u'inc-123' + key = u"inc-123" future = c.submit(inc, 1, key=key) result = yield future assert type(future.key) is uni_type @@ -4887,7 +4990,7 @@ def test_unicode_ascii_keys(c, s, a, b): def test_unicode_keys(c, s, a, b): # cross-version unicode type (py2: unicode, py3: str) uni_type = type(u"") - key = u'inc-123\u03bc' + key = u"inc-123\u03bc" future = c.submit(inc, 1, key=key) result = yield future assert type(future.key) is uni_type @@ -4899,8 +5002,8 @@ def test_unicode_keys(c, s, a, b): result2 = yield future2 assert result2 == 3 - future3 = yield c.scatter({u'data-123': 123}) - result3 = yield future3[u'data-123'] + future3 = yield c.scatter({u"data-123": 123}) + result3 = yield future3[u"data-123"] assert result3 == 123 @@ -4918,9 +5021,10 @@ def f(): def test_quiet_quit_when_cluster_leaves(loop_in_thread): loop = loop_in_thread - with LocalCluster(loop=loop, scheduler_port=0, dashboard_address=None, - silence_logs=False) as cluster: - with captured_logger('distributed.comm') as sio: + with LocalCluster( + loop=loop, scheduler_port=0, dashboard_address=None, silence_logs=False + ) as cluster: + with captured_logger("distributed.comm") as sio: with Client(cluster, loop=loop) as client: futures = client.map(lambda x: x + 1, range(10)) sleep(0.05) @@ -4933,13 +5037,13 @@ def test_quiet_quit_when_cluster_leaves(loop_in_thread): def test_warn_executor(loop, s, a, b): with warnings.catch_warnings(record=True) as record: - with Executor(s['address'], loop=loop) as c: + with Executor(s["address"], loop=loop) as c: pass - assert any('Client' in str(r.message) for r in record) + assert any("Client" in str(r.message) for r in record) -@gen_cluster([('127.0.0.1', 4)] * 2, client=True) +@gen_cluster([("127.0.0.1", 4)] * 2, client=True) def test_call_stack_future(c, s, a, b): x = c.submit(slowdec, 1, delay=0.5) future = c.submit(slowinc, 1, delay=0.5) @@ -4951,11 +5055,11 @@ def test_call_stack_future(c, s, a, b): w = a if future.key in a.executing else b assert list(result) == [w.address] assert list(result[w.address]) == [future.key] - assert 'slowinc' in str(result) - assert 'slowdec' not in str(result) + assert "slowinc" in str(result) + assert "slowdec" not in str(result) -@gen_cluster([('127.0.0.1', 4)] * 2, client=True) +@gen_cluster([("127.0.0.1", 4)] * 2, client=True) def test_call_stack_all(c, s, a, b): future = c.submit(slowinc, 1, delay=0.5) yield gen.sleep(0.1) @@ -4963,12 +5067,12 @@ def test_call_stack_all(c, s, a, b): w = a if a.executing else b assert list(result) == [w.address] assert list(result[w.address]) == [future.key] - assert 'slowinc' in str(result) + assert "slowinc" in str(result) -@gen_cluster([('127.0.0.1', 4)] * 2, client=True) +@gen_cluster([("127.0.0.1", 4)] * 2, client=True) def test_call_stack_collections(c, s, a, b): - da = pytest.importorskip('dask.array') + da = pytest.importorskip("dask.array") x = da.random.random(100, chunks=(10,)).map_blocks(slowinc, delay=0.5).persist() while not a.executing and not b.executing: yield gen.sleep(0.001) @@ -4976,9 +5080,9 @@ def test_call_stack_collections(c, s, a, b): assert result -@gen_cluster([('127.0.0.1', 4)] * 2, client=True) +@gen_cluster([("127.0.0.1", 4)] * 2, client=True) def test_call_stack_collections_all(c, s, a, b): - da = pytest.importorskip('dask.array') + da = pytest.importorskip("dask.array") x = da.random.random(100, chunks=(10,)).map_blocks(slowinc, delay=0.5).persist() while not a.executing and not b.executing: yield gen.sleep(0.001) @@ -4986,39 +5090,42 @@ def test_call_stack_collections_all(c, s, a, b): assert result -@gen_cluster(client=True, worker_kwargs={'profile_cycle_interval': 100}) +@gen_cluster(client=True, worker_kwargs={"profile_cycle_interval": 100}) def test_profile(c, s, a, b): futures = c.map(slowinc, range(10), delay=0.05, workers=a.address) yield wait(futures) x = yield c.profile(start=time() + 10, stop=time() + 20) - assert not x['count'] + assert not x["count"] x = yield c.profile(start=0, stop=time()) - assert x['count'] == sum(p['count'] for _, p in a.profile_history) + a.profile_recent['count'] + assert ( + x["count"] + == sum(p["count"] for _, p in a.profile_history) + a.profile_recent["count"] + ) y = yield c.profile(start=time() - 0.300, stop=time()) - assert 0 < y['count'] < x['count'] + assert 0 < y["count"] < x["count"] - assert not any(p['count'] for _, p in b.profile_history) + assert not any(p["count"] for _, p in b.profile_history) result = yield c.profile(workers=b.address) - assert not result['count'] + assert not result["count"] -@gen_cluster(client=True, worker_kwargs={'profile_cycle_interval': 100}) +@gen_cluster(client=True, worker_kwargs={"profile_cycle_interval": 100}) def test_profile_keys(c, s, a, b): x = c.map(slowinc, range(10), delay=0.05, workers=a.address) y = c.map(slowdec, range(10), delay=0.05, workers=a.address) yield wait(x + y) - xp = yield c.profile('slowinc') - yp = yield c.profile('slowdec') + xp = yield c.profile("slowinc") + yp = yield c.profile("slowdec") p = yield c.profile() - assert p['count'] == xp['count'] + yp['count'] + assert p["count"] == xp["count"] + yp["count"] - with captured_logger(logging.getLogger('distributed')) as logger: - prof = yield c.profile('does-not-exist') + with captured_logger(logging.getLogger("distributed")) as logger: + prof = yield c.profile("does-not-exist") assert prof == profile.create() out = logger.getvalue() assert not out @@ -5026,14 +5133,15 @@ def test_profile_keys(c, s, a, b): @gen_cluster() def test_client_with_name(s, a, b): - with captured_logger('distributed.scheduler') as sio: - client = yield Client(s.address, asynchronous=True, name='foo', - silence_logs=False) - assert 'foo' in client.id + with captured_logger("distributed.scheduler") as sio: + client = yield Client( + s.address, asynchronous=True, name="foo", silence_logs=False + ) + assert "foo" in client.id yield client.close() text = sio.getvalue() - assert 'foo' in text + assert "foo" in text @gen_cluster(client=True) @@ -5054,7 +5162,7 @@ def test_future_auto_inform(c, s, a, b): future = Future(x.key, client) start = time() - while future.status != 'finished': + while future.status != "finished": yield gen.sleep(0.01) assert time() < start + 1 @@ -5069,8 +5177,12 @@ def test_client_async_before_loop_starts(): @slow -@gen_cluster(client=True, Worker=Nanny if PY3 else Worker, timeout=60, - ncores=[('127.0.0.1', 3)] * 2) +@gen_cluster( + client=True, + Worker=Nanny if PY3 else Worker, + timeout=60, + ncores=[("127.0.0.1", 3)] * 2, +) def test_nested_compute(c, s, a, b): def fib(x): assert get_worker().get_current_task() @@ -5089,8 +5201,8 @@ def fib(x): @gen_cluster(client=True) def test_task_metadata(c, s, a, b): - yield c.set_metadata('x', 1) - result = yield c.get_metadata('x') + yield c.set_metadata("x", 1) + result = yield c.get_metadata("x") assert result == 1 future = c.submit(inc, 1) @@ -5111,18 +5223,18 @@ def test_task_metadata(c, s, a, b): result = yield c.get_metadata(key, None) assert result is None - yield c.set_metadata(['x', 'a'], 1) - result = yield c.get_metadata('x') - assert result == {'a': 1} - yield c.set_metadata(['x', 'b'], 2) - result = yield c.get_metadata('x') - assert result == {'a': 1, 'b': 2} - result = yield c.get_metadata(['x', 'a']) + yield c.set_metadata(["x", "a"], 1) + result = yield c.get_metadata("x") + assert result == {"a": 1} + yield c.set_metadata(["x", "b"], 2) + result = yield c.get_metadata("x") + assert result == {"a": 1, "b": 2} + result = yield c.get_metadata(["x", "a"]) assert result == 1 - yield c.set_metadata(['x', 'a', 'c', 'd'], 1) - result = yield c.get_metadata('x') - assert result == {'a': {'c': {'d': 1}}, 'b': 2} + yield c.set_metadata(["x", "a", "c", "d"], 1) + result = yield c.get_metadata("x") + assert result == {"a": {"c": {"d": 1}}, "b": 2} @gen_cluster(client=True) @@ -5132,13 +5244,13 @@ def test_logs(c, s, a, b): assert logs for _, msg in logs: - assert 'distributed.scheduler' in msg + assert "distributed.scheduler" in msg w_logs = yield c.get_worker_logs(n=5) assert set(w_logs.keys()) == {a.address, b.address} for log in w_logs.values(): for _, msg in log: - assert 'distributed.worker' in msg + assert "distributed.worker" in msg @gen_cluster(client=True) @@ -5152,8 +5264,8 @@ def test_avoid_delayed_finalize(c, s, a, b): @gen_cluster() def test_config_scheduler_address(s, a, b): - with dask.config.set({'scheduler-address': s.address}): - with captured_logger('distributed.client') as sio: + with dask.config.set({"scheduler-address": s.address}): + with captured_logger("distributed.client") as sio: c = yield Client(asynchronous=True) assert c.scheduler.address == s.address @@ -5166,18 +5278,18 @@ def test_config_scheduler_address(s, a, b): @gen_cluster(client=True) def test_warn_when_submitting_large_values(c, s, a, b): with warnings.catch_warnings(record=True) as record: - future = c.submit(lambda x: x + 1, b'0' * 2000000) + future = c.submit(lambda x: x + 1, b"0" * 2000000) text = str(record[0].message) - assert '2.00 MB' in text - assert 'large' in text - assert '...' in text + assert "2.00 MB" in text + assert "large" in text + assert "..." in text assert "'000" in text assert "000'" in text assert len(text) < 2000 with warnings.catch_warnings(record=True) as record: - data = b'0' * 2000000 + data = b"0" * 2000000 for i in range(10): future = c.submit(lambda x, y: x, data, i) @@ -5201,14 +5313,14 @@ def test_scatter_direct(s, a, b): @pytest.mark.skipif(sys.version_info[0] < 3, reason="cloudpickle Py27 issue") @gen_cluster(client=True) def test_unhashable_function(c, s, a, b): - d = {'a': 1} - result = yield c.submit(d.get, 'a') + d = {"a": 1} + result = yield c.submit(d.get, "a") assert result == 1 @gen_cluster() def test_client_name(s, a, b): - with dask.config.set({'client-name': 'hello-world'}): + with dask.config.set({"client-name": "hello-world"}): c = yield Client(s.address, asynchronous=True) assert any("hello-world" in name for name in list(s.clients)) @@ -5216,39 +5328,45 @@ def test_client_name(s, a, b): def test_client_doesnt_close_given_loop(loop, s, a, b): - with Client(s['address'], loop=loop) as c: + with Client(s["address"], loop=loop) as c: assert c.submit(inc, 1).result() == 2 - with Client(s['address'], loop=loop) as c: + with Client(s["address"], loop=loop) as c: assert c.submit(inc, 2).result() == 3 @gen_cluster(client=True, ncores=[]) def test_quiet_scheduler_loss(c, s): - c._periodic_callbacks['scheduler-info'].interval = 10 - with captured_logger(logging.getLogger('distributed.client')) as logger: + c._periodic_callbacks["scheduler-info"].interval = 10 + with captured_logger(logging.getLogger("distributed.client")) as logger: yield s.close() yield c._update_scheduler_info() text = logger.getvalue() assert "BrokenPipeError" not in text -@pytest.mark.skipif('USER' not in os.environ, reason='no USER env variable') +@pytest.mark.skipif("USER" not in os.environ, reason="no USER env variable") def test_diagnostics_link_env_variable(loop): - pytest.importorskip('bokeh') + pytest.importorskip("bokeh") from distributed.bokeh.scheduler import BokehScheduler - with cluster(scheduler_kwargs={'services': {('bokeh', 12355): BokehScheduler}}) as (s, [a, b]): - with Client(s['address'], loop=loop) as c: - with dask.config.set({'distributed.dashboard.link': 'http://foo-{USER}:{port}/status'}): + + with cluster(scheduler_kwargs={"services": {("bokeh", 12355): BokehScheduler}}) as ( + s, + [a, b], + ): + with Client(s["address"], loop=loop) as c: + with dask.config.set( + {"distributed.dashboard.link": "http://foo-{USER}:{port}/status"} + ): text = c._repr_html_() - link = 'http://foo-' + os.environ['USER'] + ':12355/status' + link = "http://foo-" + os.environ["USER"] + ":12355/status" assert link in text @gen_test() def test_client_timeout_2(): - with dask.config.set({'distributed.comm.timeouts.connect': '10ms'}): + with dask.config.set({"distributed.comm.timeouts.connect": "10ms"}): start = time() - c = Client('127.0.0.1:3755', asynchronous=True) + c = Client("127.0.0.1:3755", asynchronous=True) with pytest.raises((TimeoutError, IOError)): yield c stop = time() @@ -5262,26 +5380,25 @@ def test_client_timeout_2(): def test_client_active_bad_port(): import tornado.web import tornado.httpserver - application = tornado.web.Application([ - (r"/", tornado.web.RequestHandler), - ]) + + application = tornado.web.Application([(r"/", tornado.web.RequestHandler)]) http_server = tornado.httpserver.HTTPServer(application) http_server.listen(8080) - with dask.config.set({'distributed.comm.timeouts.connect': '10ms'}): - c = Client('127.0.0.1:8080', asynchronous=True) + with dask.config.set({"distributed.comm.timeouts.connect": "10ms"}): + c = Client("127.0.0.1:8080", asynchronous=True) with pytest.raises((TimeoutError, IOError)): yield c yield c._close(fast=True) http_server.stop() -@pytest.mark.parametrize('direct', [True, False]) +@pytest.mark.parametrize("direct", [True, False]) def test_turn_off_pickle(direct): @gen_cluster() def test(s, a, b): import numpy as np - c = yield Client(s.address, asynchronous=True, - serializers=['dask', 'msgpack']) + + c = yield Client(s.address, asynchronous=True, serializers=["dask", "msgpack"]) try: assert (yield c.submit(inc, 1)) == 2 yield c.submit(np.ones, 5) @@ -5319,9 +5436,13 @@ def test(s, a, b): @gen_cluster() def test_de_serialization(s, a, b): import numpy as np - c = yield Client(s.address, asynchronous=True, - serializers=['msgpack', 'pickle'], - deserializers=['msgpack']) + + c = yield Client( + s.address, + asynchronous=True, + serializers=["msgpack", "pickle"], + deserializers=["msgpack"], + ) try: # Can send complex data future = yield c.scatter(np.ones(5)) @@ -5336,8 +5457,8 @@ def test_de_serialization(s, a, b): @gen_cluster() def test_de_serialization_none(s, a, b): import numpy as np - c = yield Client(s.address, asynchronous=True, - deserializers=['msgpack']) + + c = yield Client(s.address, asynchronous=True, deserializers=["msgpack"]) try: # Can send complex data future = yield c.scatter(np.ones(5)) @@ -5362,10 +5483,10 @@ def test_client_repr_closed_sync(loop): c._repr_html_() -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)]) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)]) def test_nested_prioritization(c, s, w): - x = delayed(inc)(1, dask_key_name=('a', 2)) - y = delayed(inc)(2, dask_key_name=('a', 10)) + x = delayed(inc)(1, dask_key_name=("a", 2)) + y = delayed(inc)(2, dask_key_name=("a", 10)) o = dask.order.order(merge(x.__dask_graph__(), y.__dask_graph__())) @@ -5373,24 +5494,25 @@ def test_nested_prioritization(c, s, w): yield wait([fx, fy]) - assert ((o[x.key] < o[y.key]) == - (s.tasks[tokey(fx.key)].priority < s.tasks[tokey(fy.key)].priority)) + assert (o[x.key] < o[y.key]) == ( + s.tasks[tokey(fx.key)].priority < s.tasks[tokey(fy.key)].priority + ) @gen_cluster(client=True) def test_scatter_error_cancel(c, s, a, b): # https://github.com/dask/distributed/issues/2038 def bad_fn(x): - raise Exception('lol') + raise Exception("lol") x = yield c.scatter(1) y = c.submit(bad_fn, x) del x yield wait(y) - assert y.status == 'error' + assert y.status == "error" yield gen.sleep(0.1) - assert y.status == 'error' # not cancelled + assert y.status == "error" # not cancelled def test_no_threads_lingering(): @@ -5432,28 +5554,28 @@ def test_mixing_clients(s, a, b): @gen_cluster(client=True) def test_tuple_keys(c, s, a, b): - x = dask.delayed(inc)(1, dask_key_name=('x', 1)) - y = dask.delayed(inc)(x, dask_key_name=('y', 1)) + x = dask.delayed(inc)(1, dask_key_name=("x", 1)) + y = dask.delayed(inc)(x, dask_key_name=("y", 1)) future = c.compute(y) assert (yield future) == 3 @gen_cluster(client=True) def test_map_large_kwargs_in_graph(c, s, a, b): - np = pytest.importorskip('numpy') + np = pytest.importorskip("numpy") x = np.random.random(100000) futures = c.map(lambda a, b: a + b, range(100), b=x) while not s.tasks: yield gen.sleep(0.01) assert len(s.tasks) == 101 - assert any(k.startswith('ndarray') for k in s.tasks) + assert any(k.startswith("ndarray") for k in s.tasks) @gen_cluster(client=True) def test_retry(c, s, a, b): def f(): - assert dask.config.get('foo') + assert dask.config.get("foo") with dask.config.set(foo=False): future = c.submit(f) @@ -5468,7 +5590,7 @@ def f(): @gen_cluster(client=True) def test_retry_dependencies(c, s, a, b): def f(): - return dask.config.get('foo') + return dask.config.get("foo") x = c.submit(f) y = c.submit(inc, x) @@ -5490,10 +5612,10 @@ def f(): @gen_cluster(client=True) def test_released_dependencies(c, s, a, b): def f(x): - return dask.config.get('foo') + 1 + return dask.config.get("foo") + 1 - x = c.submit(inc, 1, key='x') - y = c.submit(f, x, key='y') + x = c.submit(inc, 1, key="x") + y = c.submit(f, x, key="y") del x with pytest.raises(KeyError): @@ -5507,13 +5629,14 @@ def f(x): @gen_cluster(client=True, check_new_threads=False) def test_profile_bokeh(c, s, a, b): - pytest.importorskip('bokeh.plotting') + pytest.importorskip("bokeh.plotting") from bokeh.model import Model + yield c.map(slowinc, range(10), delay=0.2) state, figure = yield c.profile(plot=True) assert isinstance(figure, Model) - with tmpfile('html') as fn: + with tmpfile("html") as fn: yield c.profile(filename=fn) assert os.path.exists(fn) @@ -5522,42 +5645,45 @@ def test_profile_bokeh(c, s, a, b): def test_get_mix_futures_and_SubgraphCallable(c, s, a, b): future = c.submit(add, 1, 2) - subgraph = SubgraphCallable({'_2': (add, '_0', '_1'), - '_3': (add, future, '_2')}, - '_3', ('_0', '_1')) - dsk = {'a': 1, - 'b': 2, - 'c': (subgraph, 'a', 'b'), - 'd': (subgraph, 'c', 'b')} + subgraph = SubgraphCallable( + {"_2": (add, "_0", "_1"), "_3": (add, future, "_2")}, "_3", ("_0", "_1") + ) + dsk = {"a": 1, "b": 2, "c": (subgraph, "a", "b"), "d": (subgraph, "c", "b")} - future2 = c.get(dsk, 'd', sync=False) + future2 = c.get(dsk, "d", sync=False) result = yield future2 assert result == 11 # Nested subgraphs - subgraph2 = SubgraphCallable({'_2': (subgraph, '_0', '_1'), - '_3': (subgraph, '_2', '_1'), - '_4': (add, '_3', future2)}, - '_4', ('_0', '_1')) - - dsk2 = {'e': 1, 'f': 2, 'g': (subgraph2, 'e', 'f')} - - result = yield c.get(dsk2, 'g', sync=False) + subgraph2 = SubgraphCallable( + { + "_2": (subgraph, "_0", "_1"), + "_3": (subgraph, "_2", "_1"), + "_4": (add, "_3", future2), + }, + "_4", + ("_0", "_1"), + ) + + dsk2 = {"e": 1, "f": 2, "g": (subgraph2, "e", "f")} + + result = yield c.get(dsk2, "g", sync=False) assert result == 22 @gen_cluster(client=True) def test_get_mix_futures_and_SubgraphCallable_dask_dataframe(c, s, a, b): - dd = pytest.importorskip('dask.dataframe') + dd = pytest.importorskip("dask.dataframe") import pandas as pd - df = pd.DataFrame({'x': range(1, 11)}) + + df = pd.DataFrame({"x": range(1, 11)}) ddf = dd.from_pandas(df, npartitions=2).persist() ddf = ddf.map_partitions(lambda x: x) - ddf['x'] = ddf['x'].astype('f8') + ddf["x"] = ddf["x"].astype("f8") ddf = ddf.map_partitions(lambda x: x) - ddf['x'] = ddf['x'].astype('f8') + ddf["x"] = ddf["x"].astype("f8") result = yield c.compute(ddf) - assert result.equals(df.astype('f8')) + assert result.equals(df.astype("f8")) if sys.version_info >= (3, 5): diff --git a/distributed/tests/test_client_executor.py b/distributed/tests/test_client_executor.py index 117c9c31dc3..a7f10491efb 100644 --- a/distributed/tests/test_client_executor.py +++ b/distributed/tests/test_client_executor.py @@ -4,14 +4,19 @@ import time from concurrent.futures import ( - CancelledError, TimeoutError, Future, wait, as_completed, - FIRST_COMPLETED, FIRST_EXCEPTION) + CancelledError, + TimeoutError, + Future, + wait, + as_completed, + FIRST_COMPLETED, + FIRST_EXCEPTION, +) import pytest from toolz import take -from distributed.utils_test import (slowinc, slowadd, slowdec, - inc, throws, varying) +from distributed.utils_test import slowinc, slowadd, slowdec, inc, throws, varying from distributed.utils_test import client, cluster_fixture, loop, s, a, b # noqa: F401 @@ -174,19 +179,20 @@ def test_pure(client): def test_workers(client, s, a, b): N = 10 - with client.get_executor(workers=[b['address']]) as e: + with client.get_executor(workers=[b["address"]]) as e: fs = [e.submit(slowinc, i) for i in range(N)] wait(fs) has_what = client.has_what() - assert not has_what.get(a['address']) - assert len(has_what[b['address']]) == N + assert not has_what.get(a["address"]) + assert len(has_what[b["address"]]) == N def test_unsupported_arguments(client, s, a, b): with pytest.raises(TypeError) as excinfo: - client.get_executor(workers=[b['address']], foo=1, bar=2) - assert ("unsupported arguments to ClientExecutor: ['bar', 'foo']" - in str(excinfo.value)) + client.get_executor(workers=[b["address"]], foo=1, bar=2) + assert "unsupported arguments to ClientExecutor: ['bar', 'foo']" in str( + excinfo.value + ) def test_retries(client): diff --git a/distributed/tests/test_collections.py b/distributed/tests/test_collections.py index cb51d62c5fd..f640d2d21e0 100644 --- a/distributed/tests/test_collections.py +++ b/distributed/tests/test_collections.py @@ -2,22 +2,25 @@ import pytest -pytest.importorskip('numpy') -pytest.importorskip('pandas') + +pytest.importorskip("numpy") +pytest.importorskip("pandas") import dask.dataframe as dd import dask.bag as db from distributed.client import wait from distributed.utils_test import gen_cluster -from distributed.utils_test import client, cluster_fixture, loop # noqa F401 +from distributed.utils_test import client, cluster_fixture, loop # noqa F401 import numpy as np import pandas as pd import pandas.util.testing as tm -dfs = [pd.DataFrame({'x': [1, 2, 3]}, index=[0, 10, 20]), - pd.DataFrame({'x': [4, 5, 6]}, index=[30, 40, 50]), - pd.DataFrame({'x': [7, 8, 9]}, index=[60, 70, 80])] +dfs = [ + pd.DataFrame({"x": [1, 2, 3]}, index=[0, 10, 20]), + pd.DataFrame({"x": [4, 5, 6]}, index=[30, 40, 50]), + pd.DataFrame({"x": [7, 8, 9]}, index=[60, 70, 80]), +] def assert_equal(a, b): @@ -34,9 +37,10 @@ def assert_equal(a, b): @gen_cluster(timeout=240, client=True) def test_dataframes(c, s, a, b): - df = pd.DataFrame({'x': np.random.random(1000), - 'y': np.random.random(1000)}, - index=np.arange(1000)) + df = pd.DataFrame( + {"x": np.random.random(1000), "y": np.random.random(1000)}, + index=np.arange(1000), + ) ldf = dd.from_pandas(df, npartitions=10) rdf = c.persist(ldf) @@ -46,18 +50,20 @@ def test_dataframes(c, s, a, b): remote = c.compute(rdf) result = yield remote - tm.assert_frame_equal(result, ldf.compute(scheduler='sync')) - - exprs = [lambda df: df.x.mean(), - lambda df: df.y.std(), - lambda df: df.assign(z=df.x + df.y).drop_duplicates(), - lambda df: df.index, - lambda df: df.x, - lambda df: df.x.cumsum(), - lambda df: df.groupby(['x', 'y']).count(), - lambda df: df.loc[50:75]] + tm.assert_frame_equal(result, ldf.compute(scheduler="sync")) + + exprs = [ + lambda df: df.x.mean(), + lambda df: df.y.std(), + lambda df: df.assign(z=df.x + df.y).drop_duplicates(), + lambda df: df.index, + lambda df: df.x, + lambda df: df.x.cumsum(), + lambda df: df.groupby(["x", "y"]).count(), + lambda df: df.loc[50:75], + ] for f in exprs: - local = f(ldf).compute(scheduler='sync') + local = f(ldf).compute(scheduler="sync") remote = c.compute(f(rdf)) remote = yield remote assert_equal(local, remote) @@ -67,27 +73,27 @@ def test_dataframes(c, s, a, b): def test__dask_array_collections(c, s, a, b): import dask.array as da - x_dsk = {('x', i, j): np.random.random((3, 3)) for i in range(3) - for j in range(2)} - y_dsk = {('y', i, j): np.random.random((3, 3)) for i in range(2) - for j in range(3)} + x_dsk = {("x", i, j): np.random.random((3, 3)) for i in range(3) for j in range(2)} + y_dsk = {("y", i, j): np.random.random((3, 3)) for i in range(2) for j in range(3)} x_futures = yield c._scatter(x_dsk) y_futures = yield c._scatter(y_dsk) dt = np.random.random(0).dtype - x_local = da.Array(x_dsk, 'x', ((3, 3, 3), (3, 3)), dt) - y_local = da.Array(y_dsk, 'y', ((3, 3), (3, 3, 3)), dt) + x_local = da.Array(x_dsk, "x", ((3, 3, 3), (3, 3)), dt) + y_local = da.Array(y_dsk, "y", ((3, 3), (3, 3, 3)), dt) - x_remote = da.Array(x_futures, 'x', ((3, 3, 3), (3, 3)), dt) - y_remote = da.Array(y_futures, 'y', ((3, 3), (3, 3, 3)), dt) + x_remote = da.Array(x_futures, "x", ((3, 3, 3), (3, 3)), dt) + y_remote = da.Array(y_futures, "y", ((3, 3), (3, 3, 3)), dt) - exprs = [lambda x, y: x.T + y, - lambda x, y: x.mean() + y.mean(), - lambda x, y: x.dot(y).std(axis=0), - lambda x, y: x - x.mean(axis=1)[:, None]] + exprs = [ + lambda x, y: x.T + y, + lambda x, y: x.mean() + y.mean(), + lambda x, y: x.dot(y).std(axis=0), + lambda x, y: x - x.mean(axis=1)[:, None], + ] for expr in exprs: - local = expr(x_local, y_local).compute(scheduler='sync') + local = expr(x_local, y_local).compute(scheduler="sync") remote = c.compute(expr(x_remote, y_remote)) remote = yield remote @@ -99,18 +105,23 @@ def test__dask_array_collections(c, s, a, b): def test_bag_groupby_tasks_default(c, s, a, b): b = db.range(100, npartitions=10) b2 = b.groupby(lambda x: x % 13) - assert not any('partd' in k[0] for k in b2.dask) + assert not any("partd" in k[0] for k in b2.dask) -@pytest.mark.parametrize('wait', [wait, lambda x: None]) +@pytest.mark.parametrize("wait", [wait, lambda x: None]) def test_dataframe_set_index_sync(wait, client): - df = dd.demo.make_timeseries('2000', '2001', - {'value': float, 'name': str, 'id': int}, - freq='2H', partition_freq='1M', seed=1) + df = dd.demo.make_timeseries( + "2000", + "2001", + {"value": float, "name": str, "id": int}, + freq="2H", + partition_freq="1M", + seed=1, + ) df = client.persist(df) wait(df) - df2 = df.set_index('name', shuffle='tasks') + df2 = df.set_index("name", shuffle="tasks") df2 = client.persist(df2) assert len(df2) @@ -119,7 +130,7 @@ def test_dataframe_set_index_sync(wait, client): def test_loc_sync(client): df = pd.util.testing.makeTimeDataFrame() ddf = dd.from_pandas(df, npartitions=10) - ddf.loc['2000-01-17':'2000-01-24'].compute() + ddf.loc["2000-01-17":"2000-01-24"].compute() def test_rolling_sync(client): @@ -132,40 +143,40 @@ def test_rolling_sync(client): def test_loc(c, s, a, b): df = pd.util.testing.makeTimeDataFrame() ddf = dd.from_pandas(df, npartitions=10) - future = c.compute(ddf.loc['2000-01-17':'2000-01-24']) + future = c.compute(ddf.loc["2000-01-17":"2000-01-24"]) yield future def test_dataframe_groupby_tasks(client): df = pd.util.testing.makeTimeDataFrame() - df['A'] = df.A // 0.1 - df['B'] = df.B // 0.1 + df["A"] = df.A // 0.1 + df["B"] = df.B // 0.1 ddf = dd.from_pandas(df, npartitions=10) - for ind in [lambda x: 'A', lambda x: x.A]: + for ind in [lambda x: "A", lambda x: x.A]: a = df.groupby(ind(df)).apply(len) b = ddf.groupby(ind(ddf)).apply(len, meta=int) - assert_equal(a, b.compute(scheduler='sync').sort_index()) - assert not any('partd' in k[0] for k in b.dask) + assert_equal(a, b.compute(scheduler="sync").sort_index()) + assert not any("partd" in k[0] for k in b.dask) a = df.groupby(ind(df)).B.apply(len) - b = ddf.groupby(ind(ddf)).B.apply(len, meta=('B', int)) - assert_equal(a, b.compute(scheduler='sync').sort_index()) - assert not any('partd' in k[0] for k in b.dask) + b = ddf.groupby(ind(ddf)).B.apply(len, meta=("B", int)) + assert_equal(a, b.compute(scheduler="sync").sort_index()) + assert not any("partd" in k[0] for k in b.dask) with pytest.raises((NotImplementedError, ValueError)): - ddf.groupby(ddf[['A', 'B']]).apply(len, meta=int) + ddf.groupby(ddf[["A", "B"]]).apply(len, meta=int) - a = df.groupby(['A', 'B']).apply(len) - b = ddf.groupby(['A', 'B']).apply(len, meta=int) + a = df.groupby(["A", "B"]).apply(len) + b = ddf.groupby(["A", "B"]).apply(len, meta=int) - assert_equal(a, b.compute(scheduler='sync').sort_index()) + assert_equal(a, b.compute(scheduler="sync").sort_index()) @gen_cluster(client=True) def test_sparse_arrays(c, s, a, b): - sparse = pytest.importorskip('sparse') - da = pytest.importorskip('dask.array') + sparse = pytest.importorskip("sparse") + da = pytest.importorskip("dask.array") x = da.random.random((100, 10), chunks=(10, 10)) x[x < 0.95] = 0 diff --git a/distributed/tests/test_compatibility.py b/distributed/tests/test_compatibility.py index 54a39ad4a34..42eae448aa1 100644 --- a/distributed/tests/test_compatibility.py +++ b/distributed/tests/test_compatibility.py @@ -1,11 +1,10 @@ from __future__ import print_function, division, absolute_import -from distributed.compatibility import ( - gzip_compress, gzip_decompress, finalize) +from distributed.compatibility import gzip_compress, gzip_decompress, finalize def test_gzip(): - b = b'Hello, world!' + b = b"Hello, world!" c = gzip_compress(b) d = gzip_decompress(c) assert b == d diff --git a/distributed/tests/test_config.py b/distributed/tests/test_config.py index f14f07308c2..cdd4070f7bb 100644 --- a/distributed/tests/test_config.py +++ b/distributed/tests/test_config.py @@ -8,8 +8,12 @@ import pytest -from distributed.utils_test import (captured_handler, captured_logger, - new_config, new_config_file) +from distributed.utils_test import ( + captured_handler, + captured_logger, + new_config, + new_config_file, +) from distributed.config import initialize_logging @@ -20,12 +24,15 @@ def dump_logger_list(): print("== Loggers (name, level, effective level, propagate) ==") def logger_info(name, logger): - return (name, logging.getLevelName(logger.level), - logging.getLevelName(logger.getEffectiveLevel()), - logger.propagate) + return ( + name, + logging.getLevelName(logger.level), + logging.getLevelName(logger.getEffectiveLevel()), + logger.propagate, + ) infos = [] - infos.append(logger_info('', root)) + infos.append(logger_info("", root)) for name, logger in sorted(loggers.items()): if not isinstance(logger, logging.Logger): @@ -44,28 +51,28 @@ def test_logging_default(): """ Test default logging configuration. """ - d = logging.getLogger('distributed') + d = logging.getLogger("distributed") assert len(d.handlers) == 1 assert isinstance(d.handlers[0], logging.StreamHandler) # Work around Bokeh messing with the root logger level # https://github.com/bokeh/bokeh/issues/5793 - root = logging.getLogger('') + root = logging.getLogger("") old_root_level = root.level - root.setLevel('WARN') + root.setLevel("WARN") for handler in d.handlers: - handler.setLevel('INFO') + handler.setLevel("INFO") try: - dfb = logging.getLogger('distributed.foo.bar') - f = logging.getLogger('foo') - fb = logging.getLogger('foo.bar') + dfb = logging.getLogger("distributed.foo.bar") + f = logging.getLogger("foo") + fb = logging.getLogger("foo.bar") with captured_handler(d.handlers[0]) as distributed_log: with captured_logger(root, level=logging.ERROR) as foreign_log: h = logging.StreamHandler(foreign_log) - fmt = '[%(levelname)s in %(name)s] - %(message)s' + fmt = "[%(levelname)s in %(name)s] - %(message)s" h.setFormatter(logging.Formatter(fmt)) fb.addHandler(h) fb.propagate = False @@ -92,10 +99,7 @@ def test_logging_default(): # foreign logs should be unaffected by distributed's logging # configuration. They get the default ERROR level from logging. - assert foreign_log == [ - "[ERROR in foo.bar] - 5: error", - "7: error", - ] + assert foreign_log == ["[ERROR in foo.bar] - 5: error", "7: error"] finally: root.setLevel(old_root_level) @@ -110,12 +114,7 @@ def test_logging_simple(): """ Test simple ("old-style") logging configuration. """ - c = { - 'logging': { - 'distributed.foo': 'info', - 'distributed.foo.bar': 'error', - } - } + c = {"logging": {"distributed.foo": "info", "distributed.foo.bar": "error"}} # Must test using a subprocess to avoid wrecking pre-existing configuration with new_config_file(c): code = """if 1: @@ -151,35 +150,30 @@ def test_logging_extended(): Test extended ("new-style") logging configuration. """ c = { - 'logging': { - 'version': '1', - 'formatters': { - 'simple': { - 'format': '%(levelname)s: %(name)s: %(message)s', - }, + "logging": { + "version": "1", + "formatters": { + "simple": {"format": "%(levelname)s: %(name)s: %(message)s"} }, - 'handlers': { - 'console': { - 'class': 'logging.StreamHandler', - 'stream': 'ext://sys.stderr', - 'formatter': 'simple', - }, + "handlers": { + "console": { + "class": "logging.StreamHandler", + "stream": "ext://sys.stderr", + "formatter": "simple", + } }, - 'loggers': { - 'distributed.foo': { - 'level': 'INFO', + "loggers": { + "distributed.foo": { + "level": "INFO", #'handlers': ['console'], }, - 'distributed.foo.bar': { - 'level': 'ERROR', + "distributed.foo.bar": { + "level": "ERROR", #'handlers': ['console'], }, }, - 'root': { - 'level': 'WARNING', - 'handlers': ['console'], - }, - }, + "root": {"level": "WARNING", "handlers": ["console"]}, + } } # Must test using a subprocess to avoid wrecking pre-existing configuration with new_config_file(c): @@ -217,7 +211,7 @@ def test_logging_mutual_exclusive(): """ Ensure that 'logging-file-config' and 'logging' have to be mutual exclusive. """ - config = {'logging': {'dask': 'warning'}, 'logging-file-config': '/path/to/config'} + config = {"logging": {"dask": "warning"}, "logging-file-config": "/path/to/config"} with pytest.raises(RuntimeError): initialize_logging(config) @@ -259,9 +253,9 @@ def test_logging_file_config(): handlers=console qualname=foo.bar """ - with tempfile.NamedTemporaryFile(mode='w', delete=False) as logging_config: + with tempfile.NamedTemporaryFile(mode="w", delete=False) as logging_config: logging_config.write(logging_config_contents) - dask_config = {'logging-file-config': logging_config.name} + dask_config = {"logging-file-config": logging_config.name} with new_config_file(dask_config): code = """if 1: import logging diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index 07cb4e2214a..2c8f63de6ff 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -10,20 +10,37 @@ import dask from distributed.compatibility import finalize, get_thread_identity -from distributed.core import (pingpong, Server, rpc, connect, send_recv, - coerce_to_address, ConnectionPool) +from distributed.core import ( + pingpong, + Server, + rpc, + connect, + send_recv, + coerce_to_address, + ConnectionPool, +) from distributed.protocol.compression import compressions from distributed.metrics import time from distributed.protocol import to_serialize from distributed.utils import get_ip, get_ipv6 from distributed.utils_test import ( - slow, gen_test, gen_cluster, has_ipv6, - assert_can_connect, assert_cannot_connect, + slow, + gen_test, + gen_cluster, + has_ipv6, + assert_can_connect, + assert_cannot_connect, assert_can_connect_from_everywhere_4, - assert_can_connect_from_everywhere_4_6, assert_can_connect_from_everywhere_6, - assert_can_connect_locally_4, assert_can_connect_locally_6, - tls_security, captured_logger, inc, throws) + assert_can_connect_from_everywhere_4_6, + assert_can_connect_from_everywhere_6, + assert_can_connect_locally_4, + assert_can_connect_locally_6, + tls_security, + captured_logger, + inc, + throws, +) from distributed.utils_test import loop # noqa F401 @@ -40,6 +57,7 @@ class CountedObject(object): """ A class which counts the number of live instances. """ + n_instances = 0 # Use __new__, as __init__ can be bypassed by pickle. @@ -55,39 +73,40 @@ def _finalize(cls, *args): def echo_serialize(comm, x): - return {'result': to_serialize(x)} + return {"result": to_serialize(x)} def echo_no_serialize(comm, x): - return {'result': x} + return {"result": x} def test_server(loop): """ Simple Server test. """ + @gen.coroutine def f(): - server = Server({'ping': pingpong}) + server = Server({"ping": pingpong}) with pytest.raises(ValueError): server.port server.listen(8881) assert server.port == 8881 - assert server.address == ('tcp://%s:8881' % get_ip()) + assert server.address == ("tcp://%s:8881" % get_ip()) - for addr in ('127.0.0.1:8881', 'tcp://127.0.0.1:8881', server.address): + for addr in ("127.0.0.1:8881", "tcp://127.0.0.1:8881", server.address): comm = yield connect(addr) - n = yield comm.write({'op': 'ping'}) + n = yield comm.write({"op": "ping"}) assert isinstance(n, int) assert 4 <= n <= 1000 response = yield comm.read() - assert response == b'pong' + assert response == b"pong" - yield comm.write({'op': 'ping', 'close': True}) + yield comm.write({"op": "ping", "close": True}) response = yield comm.read() - assert response == b'pong' + assert response == b"pong" yield comm.close() @@ -99,16 +118,16 @@ def f(): def test_server_raises_on_blocked_handlers(loop): @gen.coroutine def f(): - server = Server({'ping': pingpong}, blocked_handlers=['ping']) + server = Server({"ping": pingpong}, blocked_handlers=["ping"]) server.listen(8881) comm = yield connect(server.address) - yield comm.write({'op': 'ping'}) + yield comm.write({"op": "ping"}) msg = yield comm.read() - assert 'exception' in msg - assert isinstance(msg['exception'], ValueError) - assert "'ping' handler has been explicitly disallowed" in repr(msg['exception']) + assert "exception" in msg + assert isinstance(msg["exception"], ValueError) + assert "'ping' handler has been explicitly disallowed" in repr(msg["exception"]) comm.close() server.stop() @@ -139,94 +158,96 @@ def listen_on(cls, *args, **kwargs): with listen_on(Server, 7800) as server: assert server.port == 7800 - assert server.address == 'tcp://%s:%d' % (EXTERNAL_IP4, server.port) + assert server.address == "tcp://%s:%d" % (EXTERNAL_IP4, server.port) yield assert_can_connect(server.address) yield assert_can_connect_from_everywhere_4_6(server.port) with listen_on(Server) as server: assert server.port > 0 - assert server.address == 'tcp://%s:%d' % (EXTERNAL_IP4, server.port) + assert server.address == "tcp://%s:%d" % (EXTERNAL_IP4, server.port) yield assert_can_connect(server.address) yield assert_can_connect_from_everywhere_4_6(server.port) with listen_on(MyServer) as server: assert server.port == MyServer.default_port - assert server.address == 'tcp://%s:%d' % (EXTERNAL_IP4, server.port) + assert server.address == "tcp://%s:%d" % (EXTERNAL_IP4, server.port) yield assert_can_connect(server.address) yield assert_can_connect_from_everywhere_4_6(server.port) - with listen_on(Server, ('', 7801)) as server: + with listen_on(Server, ("", 7801)) as server: assert server.port == 7801 - assert server.address == 'tcp://%s:%d' % (EXTERNAL_IP4, server.port) + assert server.address == "tcp://%s:%d" % (EXTERNAL_IP4, server.port) yield assert_can_connect(server.address) yield assert_can_connect_from_everywhere_4_6(server.port) - with listen_on(Server, 'tcp://:7802') as server: + with listen_on(Server, "tcp://:7802") as server: assert server.port == 7802 - assert server.address == 'tcp://%s:%d' % (EXTERNAL_IP4, server.port) + assert server.address == "tcp://%s:%d" % (EXTERNAL_IP4, server.port) yield assert_can_connect(server.address) yield assert_can_connect_from_everywhere_4_6(server.port) # Only IPv4 - with listen_on(Server, ('0.0.0.0', 7810)) as server: + with listen_on(Server, ("0.0.0.0", 7810)) as server: assert server.port == 7810 - assert server.address == 'tcp://%s:%d' % (EXTERNAL_IP4, server.port) + assert server.address == "tcp://%s:%d" % (EXTERNAL_IP4, server.port) yield assert_can_connect(server.address) yield assert_can_connect_from_everywhere_4(server.port) - with listen_on(Server, ('127.0.0.1', 7811)) as server: + with listen_on(Server, ("127.0.0.1", 7811)) as server: assert server.port == 7811 - assert server.address == 'tcp://127.0.0.1:%d' % server.port + assert server.address == "tcp://127.0.0.1:%d" % server.port yield assert_can_connect(server.address) yield assert_can_connect_locally_4(server.port) - with listen_on(Server, 'tcp://127.0.0.1:7812') as server: + with listen_on(Server, "tcp://127.0.0.1:7812") as server: assert server.port == 7812 - assert server.address == 'tcp://127.0.0.1:%d' % server.port + assert server.address == "tcp://127.0.0.1:%d" % server.port yield assert_can_connect(server.address) yield assert_can_connect_locally_4(server.port) # Only IPv6 if has_ipv6(): - with listen_on(Server, ('::', 7813)) as server: + with listen_on(Server, ("::", 7813)) as server: assert server.port == 7813 - assert server.address == 'tcp://[%s]:%d' % (EXTERNAL_IP6, server.port) + assert server.address == "tcp://[%s]:%d" % (EXTERNAL_IP6, server.port) yield assert_can_connect(server.address) yield assert_can_connect_from_everywhere_6(server.port) - with listen_on(Server, ('::1', 7814)) as server: + with listen_on(Server, ("::1", 7814)) as server: assert server.port == 7814 - assert server.address == 'tcp://[::1]:%d' % server.port + assert server.address == "tcp://[::1]:%d" % server.port yield assert_can_connect(server.address) yield assert_can_connect_locally_6(server.port) - with listen_on(Server, 'tcp://[::1]:7815') as server: + with listen_on(Server, "tcp://[::1]:7815") as server: assert server.port == 7815 - assert server.address == 'tcp://[::1]:%d' % server.port + assert server.address == "tcp://[::1]:%d" % server.port yield assert_can_connect(server.address) yield assert_can_connect_locally_6(server.port) # TLS sec = tls_security() - with listen_on(Server, 'tls://', - listen_args=sec.get_listen_args('scheduler')) as server: - assert server.address.startswith('tls://') - yield assert_can_connect(server.address, - connection_args=sec.get_connection_args('client')) + with listen_on( + Server, "tls://", listen_args=sec.get_listen_args("scheduler") + ) as server: + assert server.address.startswith("tls://") + yield assert_can_connect( + server.address, connection_args=sec.get_connection_args("client") + ) # InProc - with listen_on(Server, 'inproc://') as server: + with listen_on(Server, "inproc://") as server: inproc_addr1 = server.address - assert inproc_addr1.startswith('inproc://%s/%d/' % (get_ip(), os.getpid())) + assert inproc_addr1.startswith("inproc://%s/%d/" % (get_ip(), os.getpid())) yield assert_can_connect(inproc_addr1) - with listen_on(Server, 'inproc://') as server2: + with listen_on(Server, "inproc://") as server2: inproc_addr2 = server2.address - assert inproc_addr2.startswith('inproc://%s/%d/' % (get_ip(), os.getpid())) + assert inproc_addr2.startswith("inproc://%s/%d/" % (get_ip(), os.getpid())) yield assert_can_connect(inproc_addr2) yield assert_can_connect(inproc_addr1) @@ -235,58 +256,59 @@ def listen_on(cls, *args, **kwargs): @gen.coroutine def check_rpc(listen_addr, rpc_addr=None, listen_args=None, connection_args=None): - server = Server({'ping': pingpong}) + server = Server({"ping": pingpong}) server.listen(listen_addr, listen_args=listen_args) if rpc_addr is None: rpc_addr = server.address with rpc(rpc_addr, connection_args=connection_args) as remote: response = yield remote.ping() - assert response == b'pong' + assert response == b"pong" assert remote.comms response = yield remote.ping(close=True) - assert response == b'pong' + assert response == b"pong" response = yield remote.ping() - assert response == b'pong' + assert response == b"pong" assert not remote.comms - assert remote.status == 'closed' + assert remote.status == "closed" server.stop() @gen_test() def test_rpc_default(): - yield check_rpc(8883, '127.0.0.1:8883') + yield check_rpc(8883, "127.0.0.1:8883") yield check_rpc(8883) @gen_test() def test_rpc_tcp(): - yield check_rpc('tcp://:8883', 'tcp://127.0.0.1:8883') - yield check_rpc('tcp://') + yield check_rpc("tcp://:8883", "tcp://127.0.0.1:8883") + yield check_rpc("tcp://") @gen_test() def test_rpc_tls(): sec = tls_security() - yield check_rpc('tcp://', None, sec.get_listen_args('scheduler'), - sec.get_connection_args('worker')) + yield check_rpc( + "tcp://", + None, + sec.get_listen_args("scheduler"), + sec.get_connection_args("worker"), + ) @gen_test() def test_rpc_inproc(): - yield check_rpc('inproc://', None) + yield check_rpc("inproc://", None) def test_rpc_inputs(): - L = [rpc('127.0.0.1:8884'), - rpc(('127.0.0.1', 8884)), - rpc('tcp://127.0.0.1:8884'), - ] + L = [rpc("127.0.0.1:8884"), rpc(("127.0.0.1", 8884)), rpc("tcp://127.0.0.1:8884")] - assert all(r.address == 'tcp://127.0.0.1:8884' for r in L), L + assert all(r.address == "tcp://127.0.0.1:8884" for r in L), L for r in L: r.close_rpc() @@ -296,7 +318,7 @@ def test_rpc_inputs(): def check_rpc_message_lifetime(*listen_args): # Issue #956: rpc arguments and result shouldn't be kept alive longer # than necessary - server = Server({'echo': echo_serialize}) + server = Server({"echo": echo_serialize}) server.listen(*listen_args) # Sanity check @@ -308,12 +330,12 @@ def check_rpc_message_lifetime(*listen_args): with rpc(server.address) as remote: obj = CountedObject() res = yield remote.echo(x=to_serialize(obj)) - assert isinstance(res['result'], CountedObject) + assert isinstance(res["result"], CountedObject) # Make sure resource cleanup code in coroutines runs yield gen.sleep(0.05) w1 = weakref.ref(obj) - w2 = weakref.ref(res['result']) + w2 = weakref.ref(res["result"]) del obj, res assert w1() is None @@ -331,12 +353,12 @@ def test_rpc_message_lifetime_default(): @gen_test() def test_rpc_message_lifetime_tcp(): - yield check_rpc_message_lifetime('tcp://') + yield check_rpc_message_lifetime("tcp://") @gen_test() def test_rpc_message_lifetime_inproc(): - yield check_rpc_message_lifetime('inproc://') + yield check_rpc_message_lifetime("inproc://") @gen.coroutine @@ -346,7 +368,7 @@ def g(): for i in range(10): yield remote.ping() - server = Server({'ping': pingpong}) + server = Server({"ping": pingpong}) server.listen(listen_arg) remote = rpc(server.address) @@ -360,26 +382,26 @@ def g(): @gen_test() def test_rpc_with_many_connections_tcp(): - yield check_rpc_with_many_connections('tcp://') + yield check_rpc_with_many_connections("tcp://") @gen_test() def test_rpc_with_many_connections_inproc(): - yield check_rpc_with_many_connections('inproc://') + yield check_rpc_with_many_connections("inproc://") @gen.coroutine def check_large_packets(listen_arg): """ tornado has a 100MB cap by default """ - server = Server({'echo': echo}) + server = Server({"echo": echo}) server.listen(listen_arg) - data = b'0' * int(200e6) # slightly more than 100MB + data = b"0" * int(200e6) # slightly more than 100MB conn = rpc(server.address) result = yield conn.echo(x=data) assert result == data - d = {'x': data} + d = {"x": data} result = yield conn.echo(x=d) assert result == d @@ -390,12 +412,12 @@ def check_large_packets(listen_arg): @slow @gen_test() def test_large_packets_tcp(): - yield check_large_packets('tcp://') + yield check_large_packets("tcp://") @gen_test() def test_large_packets_inproc(): - yield check_large_packets('inproc://') + yield check_large_packets("inproc://") @gen.coroutine @@ -406,20 +428,20 @@ def check_identity(listen_arg): with rpc(server.address) as remote: a = yield remote.identity() b = yield remote.identity() - assert a['type'] == 'Server' - assert a['id'] == b['id'] + assert a["type"] == "Server" + assert a["id"] == b["id"] server.stop() @gen_test() def test_identity_tcp(): - yield check_identity('tcp://') + yield check_identity("tcp://") @gen_test() def test_identity_inproc(): - yield check_identity('inproc://') + yield check_identity("inproc://") def test_ports(loop): @@ -450,10 +472,10 @@ def stream_div(stream=None, x=None, y=None): @gen_test() def test_errors(): - server = Server({'div': stream_div}) + server = Server({"div": stream_div}) server.listen(0) - with rpc(('127.0.0.1', server.port)) as r: + with rpc(("127.0.0.1", server.port)) as r: with pytest.raises(ZeroDivisionError): yield r.div(x=1, y=0) @@ -461,67 +483,64 @@ def test_errors(): @gen_test() def test_connect_raises(): with pytest.raises((gen.TimeoutError, IOError)): - yield connect('127.0.0.1:58259', timeout=0.01) + yield connect("127.0.0.1:58259", timeout=0.01) @gen_test() def test_send_recv_args(): - server = Server({'echo': echo}) + server = Server({"echo": echo}) server.listen(0) comm = yield connect(server.address) - result = yield send_recv(comm, op='echo', x=b'1') - assert result == b'1' + result = yield send_recv(comm, op="echo", x=b"1") + assert result == b"1" assert not comm.closed() - result = yield send_recv(comm, op='echo', x=b'2', reply=False) + result = yield send_recv(comm, op="echo", x=b"2", reply=False) assert result is None assert not comm.closed() - result = yield send_recv(comm, op='echo', x=b'3', close=True) - assert result == b'3' + result = yield send_recv(comm, op="echo", x=b"3", close=True) + assert result == b"3" assert comm.closed() server.stop() def test_coerce_to_address(): - for arg in ['127.0.0.1:8786', - ('127.0.0.1', 8786), - ('127.0.0.1', '8786')]: - assert coerce_to_address(arg) == 'tcp://127.0.0.1:8786' + for arg in ["127.0.0.1:8786", ("127.0.0.1", 8786), ("127.0.0.1", "8786")]: + assert coerce_to_address(arg) == "tcp://127.0.0.1:8786" @gen_test() def test_connection_pool(): - @gen.coroutine def ping(comm, delay=0.1): yield gen.sleep(delay) - raise gen.Return('pong') + raise gen.Return("pong") - servers = [Server({'ping': ping}) for i in range(10)] + servers = [Server({"ping": ping}) for i in range(10)] for server in servers: server.listen(0) rpc = ConnectionPool(limit=5) # Reuse connections - yield [rpc(ip='127.0.0.1', port=s.port).ping() for s in servers[:5]] + yield [rpc(ip="127.0.0.1", port=s.port).ping() for s in servers[:5]] yield [rpc(s.address).ping() for s in servers[:5]] - yield [rpc('127.0.0.1:%d' % s.port).ping() for s in servers[:5]] - yield [rpc(ip='127.0.0.1', port=s.port).ping() for s in servers[:5]] + yield [rpc("127.0.0.1:%d" % s.port).ping() for s in servers[:5]] + yield [rpc(ip="127.0.0.1", port=s.port).ping() for s in servers[:5]] assert sum(map(len, rpc.available.values())) == 5 assert sum(map(len, rpc.occupied.values())) == 0 assert rpc.active == 0 assert rpc.open == 5 # Clear out connections to make room for more - yield [rpc(ip='127.0.0.1', port=s.port).ping() for s in servers[5:]] + yield [rpc(ip="127.0.0.1", port=s.port).ping() for s in servers[5:]] assert rpc.active == 0 assert rpc.open == 5 s = servers[0] - yield [rpc(ip='127.0.0.1', port=s.port).ping(delay=0.1) for i in range(3)] - assert len(rpc.available['tcp://127.0.0.1:%d' % s.port]) == 3 + yield [rpc(ip="127.0.0.1", port=s.port).ping(delay=0.1) for i in range(3)] + assert len(rpc.available["tcp://127.0.0.1:%d" % s.port]) == 3 # Explicitly clear out connections rpc.collect() @@ -539,17 +558,17 @@ def test_connection_pool_tls(): Make sure connection args are supported. """ sec = tls_security() - connection_args = sec.get_connection_args('client') - listen_args = sec.get_listen_args('scheduler') + connection_args = sec.get_connection_args("client") + listen_args = sec.get_listen_args("scheduler") @gen.coroutine def ping(comm, delay=0.01): yield gen.sleep(delay) - raise gen.Return('pong') + raise gen.Return("pong") - servers = [Server({'ping': ping}) for i in range(10)] + servers = [Server({"ping": ping}) for i in range(10)] for server in servers: - server.listen('tls://', listen_args=listen_args) + server.listen("tls://", listen_args=listen_args) rpc = ConnectionPool(limit=5, connection_args=connection_args) @@ -563,13 +582,12 @@ def ping(comm, delay=0.01): @gen_test() def test_connection_pool_remove(): - @gen.coroutine def ping(comm, delay=0.01): yield gen.sleep(delay) - raise gen.Return('pong') + raise gen.Return("pong") - servers = [Server({'ping': ping}) for i in range(5)] + servers = [Server({"ping": ping}) for i in range(5)] for server in servers: server.listen(0) @@ -601,8 +619,8 @@ def ping(comm, delay=0.01): @gen_test() def test_counters(): - server = Server({'div': stream_div}) - server.listen('tcp://') + server = Server({"div": stream_div}) + server.listen("tcp://") with rpc(server.address) as r: for i in range(2): @@ -611,49 +629,50 @@ def test_counters(): yield r.div(x=1, y=0) c = server.counters - assert c['op'].components[0] == {'identity': 2, 'div': 1} + assert c["op"].components[0] == {"identity": 2, "div": 1} @gen_cluster() def test_ticks(s, a, b): - pytest.importorskip('crick') + pytest.importorskip("crick") yield gen.sleep(0.1) - c = s.digests['tick-duration'] + c = s.digests["tick-duration"] assert c.size() assert 0.01 < c.components[0].quantile(0.5) < 0.5 @gen_cluster() def test_tick_logging(s, a, b): - pytest.importorskip('crick') + pytest.importorskip("crick") from distributed import core + old = core.tick_maximum_delay core.tick_maximum_delay = 0.001 try: - with captured_logger('distributed.core') as sio: + with captured_logger("distributed.core") as sio: yield gen.sleep(0.1) text = sio.getvalue() assert "unresponsive" in text - assert 'Scheduler' in text or 'Worker' in text + assert "Scheduler" in text or "Worker" in text finally: core.tick_maximum_delay = old -@pytest.mark.parametrize('compression', list(compressions)) -@pytest.mark.parametrize('serialize', [echo_serialize, echo_no_serialize]) +@pytest.mark.parametrize("compression", list(compressions)) +@pytest.mark.parametrize("serialize", [echo_serialize, echo_no_serialize]) def test_compression(compression, serialize, loop): with dask.config.set(compression=compression): @gen.coroutine def f(): - server = Server({'echo': serialize}) - server.listen('tcp://') + server = Server({"echo": serialize}) + server.listen("tcp://") with rpc(server.address) as r: - data = b'1' * 1000000 + data = b"1" * 1000000 result = yield r.echo(x=to_serialize(data)) - assert result == {'result': data} + assert result == {"result": data} server.stop() @@ -663,16 +682,16 @@ def f(): def test_rpc_serialization(loop): @gen.coroutine def f(): - server = Server({'echo': echo_serialize}) - server.listen('tcp://') + server = Server({"echo": echo_serialize}) + server.listen("tcp://") - with rpc(server.address, serializers=['msgpack']) as r: + with rpc(server.address, serializers=["msgpack"]) as r: with pytest.raises(TypeError): yield r.echo(x=to_serialize(inc)) - with rpc(server.address, serializers=['msgpack', 'pickle']) as r: + with rpc(server.address, serializers=["msgpack", "pickle"]) as r: result = yield r.echo(x=to_serialize(inc)) - assert result == {'result': inc} + assert result == {"result": inc} server.stop() @@ -686,12 +705,12 @@ def test_thread_id(s, a, b): @gen_test() def test_deserialize_error(): - server = Server({'throws': throws}) + server = Server({"throws": throws}) server.listen(0) comm = yield connect(server.address, deserialize=False) with pytest.raises(Exception) as info: - yield send_recv(comm, op='throws') + yield send_recv(comm, op="throws") assert type(info.value) == Exception for c in str(info.value): diff --git a/distributed/tests/test_counter.py b/distributed/tests/test_counter.py index 43b5e4d022c..956a682920c 100644 --- a/distributed/tests/test_counter.py +++ b/distributed/tests/test_counter.py @@ -11,11 +11,17 @@ Digest = None -@pytest.mark.parametrize('CD,size', [ - (Counter, lambda d: sum(d.values())), - pytest.param(Digest, lambda x: x.size(), - marks=pytest.mark.skipif(not Digest, reason="no crick library")) -]) +@pytest.mark.parametrize( + "CD,size", + [ + (Counter, lambda d: sum(d.values())), + pytest.param( + Digest, + lambda x: x.size(), + marks=pytest.mark.skipif(not Digest, reason="no crick library"), + ), + ], +) def test_digest(loop, CD, size): c = CD(loop=loop) c.add(1) diff --git a/distributed/tests/test_diskutils.py b/distributed/tests/test_diskutils.py index d7079ca2039..8bf4000178e 100644 --- a/distributed/tests/test_diskutils.py +++ b/distributed/tests/test_diskutils.py @@ -21,9 +21,11 @@ def assert_directory_contents(dir_path, expected): expected = [os.path.join(dir_path, p) for p in expected] - actual = [os.path.join(dir_path, p) - for p in os.listdir(dir_path) - if p not in ('global.lock', 'purge.lock')] + actual = [ + os.path.join(dir_path, p) + for p in os.listdir(dir_path) + if p not in ("global.lock", "purge.lock") + ] assert sorted(actual) == sorted(expected) @@ -34,29 +36,29 @@ def test_workdir_simple(tmpdir): ws = WorkSpace(base_dir) assert_contents([]) - a = ws.new_work_dir(name='aa') - assert_contents(['aa', 'aa.dirlock']) - b = ws.new_work_dir(name='bb') - assert_contents(['aa', 'aa.dirlock', 'bb', 'bb.dirlock']) + a = ws.new_work_dir(name="aa") + assert_contents(["aa", "aa.dirlock"]) + b = ws.new_work_dir(name="bb") + assert_contents(["aa", "aa.dirlock", "bb", "bb.dirlock"]) ws._purge_leftovers() - assert_contents(['aa', 'aa.dirlock', 'bb', 'bb.dirlock']) + assert_contents(["aa", "aa.dirlock", "bb", "bb.dirlock"]) a.release() - assert_contents(['bb', 'bb.dirlock']) + assert_contents(["bb", "bb.dirlock"]) del b gc.collect() assert_contents([]) # Generated temporary name with a prefix - a = ws.new_work_dir(prefix='foo-') - b = ws.new_work_dir(prefix='bar-') - c = ws.new_work_dir(prefix='bar-') - assert_contents({a.dir_path, a._lock_path, - b.dir_path, b._lock_path, - c.dir_path, c._lock_path}) - assert os.path.basename(a.dir_path).startswith('foo-') - assert os.path.basename(b.dir_path).startswith('bar-') - assert os.path.basename(c.dir_path).startswith('bar-') + a = ws.new_work_dir(prefix="foo-") + b = ws.new_work_dir(prefix="bar-") + c = ws.new_work_dir(prefix="bar-") + assert_contents( + {a.dir_path, a._lock_path, b.dir_path, b._lock_path, c.dir_path, c._lock_path} + ) + assert os.path.basename(a.dir_path).startswith("foo-") + assert os.path.basename(b.dir_path).startswith("bar-") + assert os.path.basename(c.dir_path).startswith("bar-") assert b.dir_path != c.dir_path @@ -68,19 +70,19 @@ def test_two_workspaces_in_same_directory(tmpdir): ws = WorkSpace(base_dir) assert_contents([]) - a = ws.new_work_dir(name='aa') - assert_contents(['aa', 'aa.dirlock']) + a = ws.new_work_dir(name="aa") + assert_contents(["aa", "aa.dirlock"]) ws2 = WorkSpace(base_dir) ws2._purge_leftovers() - assert_contents(['aa', 'aa.dirlock']) - b = ws.new_work_dir(name='bb') - assert_contents(['aa', 'aa.dirlock', 'bb', 'bb.dirlock']) + assert_contents(["aa", "aa.dirlock"]) + b = ws.new_work_dir(name="bb") + assert_contents(["aa", "aa.dirlock", "bb", "bb.dirlock"]) del ws del b gc.collect() - assert_contents(['aa', 'aa.dirlock']) + assert_contents(["aa", "aa.dirlock"]) del a gc.collect() assert_contents([]) @@ -108,26 +110,31 @@ def test_workspace_process_crash(tmpdir): sys.stdout.flush() time.sleep(100) - """ % dict(base_dir=base_dir) - - p = subprocess.Popen([sys.executable, '-c', code], - stdin=subprocess.PIPE, stdout=subprocess.PIPE, - universal_newlines=True) + """ % dict( + base_dir=base_dir + ) + + p = subprocess.Popen( + [sys.executable, "-c", code], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + universal_newlines=True, + ) line = p.stdout.readline() assert p.poll() is None a_path, b_path = eval(line) - assert_contents([a_path, a_path + '.dirlock', b_path, b_path + '.dirlock']) + assert_contents([a_path, a_path + ".dirlock", b_path, b_path + ".dirlock"]) # The child process holds a lock so the work dirs shouldn't be removed ws._purge_leftovers() - assert_contents([a_path, a_path + '.dirlock', b_path, b_path + '.dirlock']) + assert_contents([a_path, a_path + ".dirlock", b_path, b_path + ".dirlock"]) # Kill the process so it's unable to clear the work dirs itself p.kill() assert p.wait() # process returned with non-zero code - assert_contents([a_path, a_path + '.dirlock', b_path, b_path + '.dirlock']) + assert_contents([a_path, a_path + ".dirlock", b_path, b_path + ".dirlock"]) - with captured_logger('distributed.diskutils', 'INFO', propagate=False) as sio: + with captured_logger("distributed.diskutils", "INFO", propagate=False) as sio: ws._purge_leftovers() assert_contents([]) # One log line per purged directory @@ -141,9 +148,9 @@ def test_workspace_rmtree_failure(tmpdir): base_dir = str(tmpdir) ws = WorkSpace(base_dir) - a = ws.new_work_dir(name='aa') + a = ws.new_work_dir(name="aa") shutil.rmtree(a.dir_path) - with captured_logger('distributed.diskutils', 'ERROR', propagate=False) as sio: + with captured_logger("distributed.diskutils", "ERROR", propagate=False) as sio: a.release() lines = sio.getvalue().splitlines() # shutil.rmtree() may call its onerror callback several times @@ -155,21 +162,21 @@ def test_workspace_rmtree_failure(tmpdir): def test_locking_disabled(tmpdir): base_dir = str(tmpdir) - with dask.config.set({'distributed.worker.use-file-locking': False}): - with mock.patch('distributed.diskutils.locket.lock_file') as lock_file: + with dask.config.set({"distributed.worker.use-file-locking": False}): + with mock.patch("distributed.diskutils.locket.lock_file") as lock_file: assert_contents = functools.partial(assert_directory_contents, base_dir) ws = WorkSpace(base_dir) assert_contents([]) - a = ws.new_work_dir(name='aa') - assert_contents(['aa']) - b = ws.new_work_dir(name='bb') - assert_contents(['aa', 'bb']) + a = ws.new_work_dir(name="aa") + assert_contents(["aa"]) + b = ws.new_work_dir(name="bb") + assert_contents(["aa", "bb"]) ws._purge_leftovers() - assert_contents(['aa', 'bb']) + assert_contents(["aa", "bb"]) a.release() - assert_contents(['bb']) + assert_contents(["bb"]) del b gc.collect() assert_contents([]) @@ -180,7 +187,7 @@ def test_locking_disabled(tmpdir): def _workspace_concurrency(base_dir, purged_q, err_q, stop_evt): ws = WorkSpace(base_dir) n_purged = 0 - with captured_logger('distributed.diskutils', 'ERROR') as sio: + with captured_logger("distributed.diskutils", "ERROR") as sio: while not stop_evt.is_set(): # Add a bunch of locks, and simulate forgetting them try: @@ -193,7 +200,7 @@ def _workspace_concurrency(base_dir, purged_q, err_q, stop_evt): lines = sio.getvalue().splitlines() if lines: try: - raise AssertionError("got %d logs, see stderr" % (len(lines,))) + raise AssertionError("got %d logs, see stderr" % (len(lines))) except Exception as e: err_q.put(e) @@ -215,10 +222,13 @@ def _test_workspace_concurrency(tmpdir, timeout, max_procs): ws._purge_leftovers = lambda: None # Run a bunch of child processes that will try to purge concurrently - NPROCS = 2 if sys.platform == 'win32' else max_procs - processes = [mp_context.Process(target=_workspace_concurrency, - args=(base_dir, purged_q, err_q, stop_evt)) - for i in range(NPROCS)] + NPROCS = 2 if sys.platform == "win32" else max_procs + processes = [ + mp_context.Process( + target=_workspace_concurrency, args=(base_dir, purged_q, err_q, stop_evt) + ) + for i in range(NPROCS) + ] for p in processes: p.start() @@ -230,7 +240,7 @@ def _test_workspace_concurrency(tmpdir, timeout, max_procs): # Add a bunch of locks, and simulate forgetting them. # The concurrent processes should try to purge them. for i in range(50): - d = ws.new_work_dir(prefix='workspace-concurrency-') + d = ws.new_work_dir(prefix="workspace-concurrency-") d._finalizer.detach() n_created += 1 sleep(1e-2) @@ -259,7 +269,7 @@ def _test_workspace_concurrency(tmpdir, timeout, max_procs): def test_workspace_concurrency(tmpdir): if WINDOWS: - raise pytest.xfail.Exception('TODO: unknown failure on windows') + raise pytest.xfail.Exception("TODO: unknown failure on windows") _test_workspace_concurrency(tmpdir, 2.0, 6) diff --git a/distributed/tests/test_failed_workers.py b/distributed/tests/test_failed_workers.py index 72811c6fc52..5bb1c61fb5b 100644 --- a/distributed/tests/test_failed_workers.py +++ b/distributed/tests/test_failed_workers.py @@ -15,17 +15,25 @@ from distributed.client import wait from distributed.metrics import time from distributed.utils import sync, ignoring -from distributed.utils_test import (gen_cluster, cluster, inc, slow, div, - slowinc, slowadd, captured_logger) -from distributed.utils_test import loop # noqa: F401 +from distributed.utils_test import ( + gen_cluster, + cluster, + inc, + slow, + div, + slowinc, + slowadd, + captured_logger, +) +from distributed.utils_test import loop # noqa: F401 def test_submit_after_failed_worker_sync(loop): with cluster(active_rpc_timeout=10) as (s, [a, b]): - with Client(s['address'], loop=loop) as c: + with Client(s["address"], loop=loop) as c: L = c.map(inc, range(10)) wait(L) - a['proc']().terminate() + a["proc"]().terminate() total = c.submit(sum, L) assert total.result() == sum(map(inc, range(10))) @@ -61,16 +69,20 @@ def test_submit_after_failed_worker(c, s, a, b): def test_gather_after_failed_worker(loop): with cluster(active_rpc_timeout=10) as (s, [a, b]): - with Client(s['address'], loop=loop) as c: + with Client(s["address"], loop=loop) as c: L = c.map(inc, range(10)) wait(L) - a['proc']().terminate() + a["proc"]().terminate() result = c.gather(L) assert result == list(map(inc, range(10))) -@gen_cluster(client=True, Worker=Nanny, ncores=[('127.0.0.1', 1)] * 4, - config={'distributed.comm.timeouts.connect': '1s'}) +@gen_cluster( + client=True, + Worker=Nanny, + ncores=[("127.0.0.1", 1)] * 4, + config={"distributed.comm.timeouts.connect": "1s"}, +) def test_gather_then_submit_after_failed_workers(c, s, w, x, y, z): L = c.map(inc, range(20)) yield wait(L) @@ -168,7 +180,7 @@ def test_restart_cleared(c, s, a, b): def test_restart_sync_no_center(loop): with cluster(nanny=True) as (s, [a, b]): - with Client(s['address'], loop=loop) as c: + with Client(s["address"], loop=loop) as c: x = c.submit(inc, 1) c.restart() assert x.cancelled() @@ -179,7 +191,7 @@ def test_restart_sync_no_center(loop): def test_restart_sync(loop): with cluster(nanny=True) as (s, [a, b]): - with Client(s['address'], loop=loop) as c: + with Client(s["address"], loop=loop) as c: x = c.submit(div, 1, 2) x.result() @@ -205,7 +217,7 @@ def test_restart_fast(c, s, a, b): assert time() - start < 10 assert len(s.ncores) == 2 - assert all(x.status == 'cancelled' for x in L) + assert all(x.status == "cancelled" for x in L) x = c.submit(inc, 1) result = yield x @@ -214,7 +226,7 @@ def test_restart_fast(c, s, a, b): def test_worker_doesnt_await_task_completion(loop): with cluster(nanny=True, nworkers=1) as (s, [w]): - with Client(s['address'], loop=loop) as c: + with Client(s["address"], loop=loop) as c: future = c.submit(sleep, 100) sleep(0.1) start = time() @@ -225,7 +237,7 @@ def test_worker_doesnt_await_task_completion(loop): def test_restart_fast_sync(loop): with cluster(nanny=True) as (s, [a, b]): - with Client(s['address'], loop=loop) as c: + with Client(s["address"], loop=loop) as c: L = c.map(sleep, range(10)) start = time() @@ -233,7 +245,7 @@ def test_restart_fast_sync(loop): assert time() - start < 10 assert len(c.ncores()) == 2 - assert all(x.status == 'cancelled' for x in L) + assert all(x.status == "cancelled" for x in L) x = c.submit(inc, 1) assert x.result() == 2 @@ -247,7 +259,7 @@ def test_fast_kill(c, s, a, b): yield c._restart() assert time() - start < 10 - assert all(x.status == 'cancelled' for x in L) + assert all(x.status == "cancelled" for x in L) x = c.submit(inc, 1) result = yield x @@ -278,6 +290,7 @@ def test_multiple_clients_restart(s, a, b): @gen_cluster(Worker=Nanny, timeout=60) def test_restart_scheduler(s, a, b): import gc + gc.collect() addrs = (a.worker_address, b.worker_address) yield s.restart() @@ -294,6 +307,7 @@ def test_forgotten_futures_dont_clean_up_new_futures(c, s, a, b): y = c.submit(inc, 1) del x import gc + gc.collect() yield gen.sleep(0.1) yield y @@ -313,12 +327,14 @@ def test_broken_worker_during_computation(c, s, a, b): N = 256 expected_result = N * (N + 1) // 2 i = 0 - L = c.map(inc, range(N), - key=['inc-%d-%d' % (i, j) for j in range(N)]) + L = c.map(inc, range(N), key=["inc-%d-%d" % (i, j) for j in range(N)]) while len(L) > 1: i += 1 - L = c.map(slowadd, *zip(*partition_all(2, L)), - key=['add-%d-%d' % (i, j) for j in range(len(L) // 2)]) + L = c.map( + slowadd, + *zip(*partition_all(2, L)), + key=["add-%d-%d" % (i, j) for j in range(len(L) // 2)] + ) yield gen.sleep(random.random() / 20) with ignoring(CommClosedError): # comm will be closed abrupty @@ -328,7 +344,9 @@ def test_broken_worker_during_computation(c, s, a, b): while len(s.workers) < 3: yield gen.sleep(0.01) - with ignoring(CommClosedError, EnvironmentError): # perhaps new worker can't be contacted yet + with ignoring( + CommClosedError, EnvironmentError + ): # perhaps new worker can't be contacted yet yield c._run(os._exit, 1, workers=[n.worker_address]) [result] = yield c.gather(L) @@ -365,8 +383,7 @@ def test_worker_who_has_clears_after_failed_connection(c, s, a, b): yield gen.sleep(0.01) assert time() < start + 5 - futures = c.map(slowinc, range(20), delay=0.01, - key=['f%d' % i for i in range(20)]) + futures = c.map(slowinc, range(20), delay=0.01, key=["f%d" % i for i in range(20)]) yield wait(futures) result = yield c.submit(sum, futures, workers=a.address) @@ -390,20 +407,20 @@ def test_worker_who_has_clears_after_failed_connection(c, s, a, b): @slow -@gen_cluster(client=True, timeout=60, Worker=Nanny, ncores=[('127.0.0.1', 1)]) +@gen_cluster(client=True, timeout=60, Worker=Nanny, ncores=[("127.0.0.1", 1)]) def test_restart_timeout_on_long_running_task(c, s, a): - with captured_logger('distributed.scheduler') as sio: + with captured_logger("distributed.scheduler") as sio: future = c.submit(sleep, 3600) yield gen.sleep(0.1) yield c.restart(timeout=20) text = sio.getvalue() - assert 'timeout' not in text.lower() + assert "timeout" not in text.lower() -@gen_cluster(client=True, scheduler_kwargs={'worker_ttl': '100ms'}) +@gen_cluster(client=True, scheduler_kwargs={"worker_ttl": "100ms"}) def test_worker_time_to_live(c, s, a, b): - a.periodic_callbacks['heartbeat'].stop() + a.periodic_callbacks["heartbeat"].stop() yield gen.sleep(0.010) assert set(s.workers) == {a.address, b.address} @@ -415,8 +432,8 @@ def test_worker_time_to_live(c, s, a, b): set(s.workers) == {b.address} start = time() - while b.status == 'running': + while b.status == "running": yield gen.sleep(0.050) assert time() < start + 1 - assert b.status in ('closed', 'closing') + assert b.status in ("closed", "closing") diff --git a/distributed/tests/test_ipython.py b/distributed/tests/test_ipython.py index 4730e0c9515..8bb64bb4e0b 100644 --- a/distributed/tests/test_ipython.py +++ b/distributed/tests/test_ipython.py @@ -30,18 +30,18 @@ def test_start_ipython_workers(loop, zmq_ctx): from jupyter_client import BlockingKernelClient with cluster(1) as (s, [a]): - with Client(s['address'], loop=loop) as e: + with Client(s["address"], loop=loop) as e: info_dict = e.start_ipython_workers() info = first(info_dict.values()) - key = info.pop('key') + key = info.pop("key") kc = BlockingKernelClient(**info) kc.session.key = key kc.start_channels() kc.wait_for_ready(timeout=10) msg_id = kc.execute("worker") reply = kc.get_shell_msg(timeout=10) - assert reply['parent_header']['msg_id'] == msg_id - assert reply['content']['status'] == 'ok' + assert reply["parent_header"]["msg_id"] == msg_id + assert reply["content"]["status"] == "ok" kc.stop_channels() @@ -51,9 +51,9 @@ def test_start_ipython_scheduler(loop, zmq_ctx): from jupyter_client import BlockingKernelClient with cluster(1) as (s, [a]): - with Client(s['address'], loop=loop) as e: + with Client(s["address"], loop=loop) as e: info = e.start_ipython_scheduler() - key = info.pop('key') + key = info.pop("key") kc = BlockingKernelClient(**info) kc.session.key = key kc.start_channels() @@ -66,15 +66,17 @@ def test_start_ipython_scheduler(loop, zmq_ctx): @need_functional_ipython def test_start_ipython_scheduler_magic(loop, zmq_ctx): with cluster(1) as (s, [a]): - with Client(s['address'], loop=loop) as e, mock_ipython() as ip: + with Client(s["address"], loop=loop) as e, mock_ipython() as ip: info = e.start_ipython_scheduler() expected = [ - {'magic_kind': 'line', 'magic_name': 'scheduler'}, - {'magic_kind': 'cell', 'magic_name': 'scheduler'}, + {"magic_kind": "line", "magic_name": "scheduler"}, + {"magic_kind": "cell", "magic_name": "scheduler"}, ] - call_kwargs_list = [kwargs for (args, kwargs) in ip.register_magic_function.call_args_list] + call_kwargs_list = [ + kwargs for (args, kwargs) in ip.register_magic_function.call_args_list + ] assert call_kwargs_list == expected magic = ip.register_magic_function.call_args_list[0][0][0] magic(line="", cell="scheduler") @@ -85,20 +87,22 @@ def test_start_ipython_scheduler_magic(loop, zmq_ctx): def test_start_ipython_workers_magic(loop, zmq_ctx): with cluster(2) as (s, [a, b]): - with Client(s['address'], loop=loop) as e, mock_ipython() as ip: + with Client(s["address"], loop=loop) as e, mock_ipython() as ip: workers = list(e.ncores())[:2] - names = ['magic%i' % i for i in range(len(workers))] + names = ["magic%i" % i for i in range(len(workers))] info_dict = e.start_ipython_workers(workers, magic_names=names) expected = [ - {'magic_kind': 'line', 'magic_name': 'remote'}, - {'magic_kind': 'cell', 'magic_name': 'remote'}, - {'magic_kind': 'line', 'magic_name': 'magic0'}, - {'magic_kind': 'cell', 'magic_name': 'magic0'}, - {'magic_kind': 'line', 'magic_name': 'magic1'}, - {'magic_kind': 'cell', 'magic_name': 'magic1'}, + {"magic_kind": "line", "magic_name": "remote"}, + {"magic_kind": "cell", "magic_name": "remote"}, + {"magic_kind": "line", "magic_name": "magic0"}, + {"magic_kind": "cell", "magic_name": "magic0"}, + {"magic_kind": "line", "magic_name": "magic1"}, + {"magic_kind": "cell", "magic_name": "magic1"}, + ] + call_kwargs_list = [ + kwargs for (args, kwargs) in ip.register_magic_function.call_args_list ] - call_kwargs_list = [kwargs for (args, kwargs) in ip.register_magic_function.call_args_list] assert call_kwargs_list == expected assert ip.register_magic_function.call_count == 6 magics = [args[0][0] for args in ip.register_magic_function.call_args_list[2:]] @@ -111,19 +115,21 @@ def test_start_ipython_workers_magic(loop, zmq_ctx): def test_start_ipython_workers_magic_asterix(loop, zmq_ctx): with cluster(2) as (s, [a, b]): - with Client(s['address'], loop=loop) as e, mock_ipython() as ip: + with Client(s["address"], loop=loop) as e, mock_ipython() as ip: workers = list(e.ncores())[:2] - info_dict = e.start_ipython_workers(workers, magic_names='magic_*') + info_dict = e.start_ipython_workers(workers, magic_names="magic_*") expected = [ - {'magic_kind': 'line', 'magic_name': 'remote'}, - {'magic_kind': 'cell', 'magic_name': 'remote'}, - {'magic_kind': 'line', 'magic_name': 'magic_0'}, - {'magic_kind': 'cell', 'magic_name': 'magic_0'}, - {'magic_kind': 'line', 'magic_name': 'magic_1'}, - {'magic_kind': 'cell', 'magic_name': 'magic_1'}, + {"magic_kind": "line", "magic_name": "remote"}, + {"magic_kind": "cell", "magic_name": "remote"}, + {"magic_kind": "line", "magic_name": "magic_0"}, + {"magic_kind": "cell", "magic_name": "magic_0"}, + {"magic_kind": "line", "magic_name": "magic_1"}, + {"magic_kind": "cell", "magic_name": "magic_1"}, + ] + call_kwargs_list = [ + kwargs for (args, kwargs) in ip.register_magic_function.call_args_list ] - call_kwargs_list = [kwargs for (args, kwargs) in ip.register_magic_function.call_args_list] assert call_kwargs_list == expected assert ip.register_magic_function.call_count == 6 magics = [args[0][0] for args in ip.register_magic_function.call_args_list[2:]] @@ -135,16 +141,17 @@ def test_start_ipython_workers_magic_asterix(loop, zmq_ctx): @need_functional_ipython def test_start_ipython_remote(loop, zmq_ctx): from distributed._ipython_utils import remote_magic + with cluster(1) as (s, [a]): - with Client(s['address'], loop=loop) as e, mock_ipython() as ip: + with Client(s["address"], loop=loop) as e, mock_ipython() as ip: worker = first(e.ncores()) - ip.user_ns['info'] = e.start_ipython_workers(worker)[worker] - remote_magic('info 1') # line magic - remote_magic('info', 'worker') # cell magic + ip.user_ns["info"] = e.start_ipython_workers(worker)[worker] + remote_magic("info 1") # line magic + remote_magic("info", "worker") # cell magic expected = [ - ((remote_magic,), {'magic_kind': 'line', 'magic_name': 'remote'}), - ((remote_magic,), {'magic_kind': 'cell', 'magic_name': 'remote'}), + ((remote_magic,), {"magic_kind": "line", "magic_name": "remote"}), + ((remote_magic,), {"magic_kind": "cell", "magic_name": "remote"}), ] assert ip.register_magic_function.call_args_list == expected assert ip.register_magic_function.call_count == 2 @@ -155,12 +162,14 @@ def test_start_ipython_remote(loop, zmq_ctx): def test_start_ipython_qtconsole(loop): Popen = mock.Mock() with cluster() as (s, [a, b]): - with mock.patch('distributed._ipython_utils.Popen', Popen), Client(s['address'], loop=loop) as e: + with mock.patch("distributed._ipython_utils.Popen", Popen), Client( + s["address"], loop=loop + ) as e: worker = first(e.ncores()) e.start_ipython_workers(worker, qtconsole=True) - e.start_ipython_workers(worker, qtconsole=True, qtconsole_args=['--debug']) + e.start_ipython_workers(worker, qtconsole=True, qtconsole_args=["--debug"]) assert Popen.call_count == 2 (cmd,), kwargs = Popen.call_args_list[0] - assert cmd[:3] == ['jupyter', 'qtconsole', '--existing'] + assert cmd[:3] == ["jupyter", "qtconsole", "--existing"] (cmd,), kwargs = Popen.call_args_list[1] - assert cmd[-1:] == ['--debug'] + assert cmd[-1:] == ["--debug"] diff --git a/distributed/tests/test_locks.py b/distributed/tests/test_locks.py index b35d9d6268f..952d43ceb9b 100644 --- a/distributed/tests/test_locks.py +++ b/distributed/tests/test_locks.py @@ -11,34 +11,34 @@ from distributed.utils_test import client, cluster_fixture, loop # noqa F401 -@gen_cluster(client=True, ncores=[('127.0.0.1', 8)] * 2) +@gen_cluster(client=True, ncores=[("127.0.0.1", 8)] * 2) def test_lock(c, s, a, b): - c.set_metadata('locked', False) + c.set_metadata("locked", False) def f(x): client = get_client() - with Lock('x') as lock: - assert client.get_metadata('locked') is False - client.set_metadata('locked', True) + with Lock("x") as lock: + assert client.get_metadata("locked") is False + client.set_metadata("locked", True) sleep(0.05) - assert client.get_metadata('locked') is True - client.set_metadata('locked', False) + assert client.get_metadata("locked") is True + client.set_metadata("locked", False) futures = c.map(f, range(20)) results = yield futures - assert not s.extensions['locks'].events - assert not s.extensions['locks'].ids + assert not s.extensions["locks"].events + assert not s.extensions["locks"].ids @gen_cluster(client=True) def test_timeout(c, s, a, b): - locks = s.extensions['locks'] - lock = Lock('x') + locks = s.extensions["locks"] + lock = Lock("x") result = yield lock.acquire() assert result is True - assert locks.ids['x'] == lock.id + assert locks.ids["x"] == lock.id - lock2 = Lock('x') + lock2 = Lock("x") assert lock.id != lock2.id start = time() @@ -46,15 +46,15 @@ def test_timeout(c, s, a, b): stop = time() assert stop - start < 0.3 assert result is False - assert locks.ids['x'] == lock.id - assert not locks.events['x'] + assert locks.ids["x"] == lock.id + assert not locks.events["x"] yield lock.release() @gen_cluster(client=True) def test_acquires_with_zero_timeout(c, s, a, b): - lock = Lock('x') + lock = Lock("x") yield lock.acquire(timeout=0) assert lock.locked() yield lock.release() @@ -67,7 +67,7 @@ def test_acquires_with_zero_timeout(c, s, a, b): @gen_cluster(client=True) def test_acquires_blocking(c, s, a, b): - lock = Lock('x') + lock = Lock("x") yield lock.acquire(blocking=False) assert lock.locked() yield lock.release() @@ -78,52 +78,52 @@ def test_acquires_blocking(c, s, a, b): def test_timeout_sync(client): - with Lock('x') as lock: - assert Lock('x').acquire(timeout=0.1) is False + with Lock("x") as lock: + assert Lock("x").acquire(timeout=0.1) is False @gen_cluster(client=True) def test_errors(c, s, a, b): - lock = Lock('x') + lock = Lock("x") with pytest.raises(ValueError): yield lock.release() def test_lock_sync(client): def f(x): - with Lock('x') as lock: + with Lock("x") as lock: client = get_client() - assert client.get_metadata('locked') is False - client.set_metadata('locked', True) + assert client.get_metadata("locked") is False + client.set_metadata("locked", True) sleep(0.05) - assert client.get_metadata('locked') is True - client.set_metadata('locked', False) + assert client.get_metadata("locked") is True + client.set_metadata("locked", False) - client.set_metadata('locked', False) + client.set_metadata("locked", False) futures = client.map(f, range(10)) client.gather(futures) @gen_cluster(client=True) def test_lock_types(c, s, a, b): - for name in [1, ('a', 1), ['a', 1], b'123', '123']: + for name in [1, ("a", 1), ["a", 1], b"123", "123"]: lock = Lock(name) assert lock.name == name yield lock.acquire() yield lock.release() - assert not s.extensions['locks'].events + assert not s.extensions["locks"].events @gen_cluster(client=True) def test_serializable(c, s, a, b): def f(x, lock=None): with lock: - assert lock.name == 'x' + assert lock.name == "x" return x + 1 - lock = Lock('x') + lock = Lock("x") futures = c.map(f, range(10), lock=lock) yield c.gather(futures) diff --git a/distributed/tests/test_metrics.py b/distributed/tests/test_metrics.py index 84b7c180993..d1eb4a1dad0 100644 --- a/distributed/tests/test_metrics.py +++ b/distributed/tests/test_metrics.py @@ -58,7 +58,7 @@ def test_thread_time(): dt = metrics.thread_time() - start assert dt <= 0.05 - if sys.platform == 'linux': + if sys.platform == "linux": # Always per-thread on Linux t = threading.Thread(target=run_for, args=(0.1,)) start = metrics.thread_time() diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index 23e40c4c8ef..932419015f3 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -18,8 +18,7 @@ from distributed.metrics import time from distributed.protocol.pickle import dumps from distributed.utils import ignoring, tmpfile -from distributed.utils_test import (gen_cluster, gen_test, slow, inc, - captured_logger) +from distributed.utils_test import gen_cluster, gen_test, slow, inc, captured_logger @gen_cluster(ncores=[]) @@ -29,7 +28,7 @@ def test_nanny(s): with rpc(n.address) as nn: assert n.is_alive() assert s.ncores[n.worker_address] == 2 - assert s.workers[n.worker_address].services['nanny'] > 1024 + assert s.workers[n.worker_address].services["nanny"] > 1024 yield nn.kill() assert not n.is_alive() @@ -44,7 +43,7 @@ def test_nanny(s): yield nn.instantiate() assert n.is_alive() assert s.ncores[n.worker_address] == 2 - assert s.workers[n.worker_address].services['nanny'] > 1024 + assert s.workers[n.worker_address].services["nanny"] > 1024 yield nn.terminate() assert not n.is_alive() @@ -78,7 +77,7 @@ def test_nanny_process_failure(c, s): original_address = n.worker_address ww = rpc(n.worker_address) - yield ww.update_data(data=valmap(dumps, {'x': 1, 'y': 2})) + yield ww.update_data(data=valmap(dumps, {"x": 1, "y": 2})) pid = n.pid assert pid is not None with ignoring(CommClosedError): @@ -112,31 +111,31 @@ def test_nanny_process_failure(c, s): def test_nanny_no_port(): - _ = str(Nanny('127.0.0.1', 8786)) + _ = str(Nanny("127.0.0.1", 8786)) @gen_cluster(ncores=[]) def test_run(s): - pytest.importorskip('psutil') + pytest.importorskip("psutil") n = yield Nanny(s.ip, s.port, ncores=2, loop=s.loop) with rpc(n.address) as nn: response = yield nn.run(function=dumps(lambda: 1)) - assert response['status'] == 'OK' - assert response['result'] == 1 + assert response["status"] == "OK" + assert response["result"] == 1 yield n._close() @slow -@gen_cluster(Worker=Nanny, - ncores=[('127.0.0.1', 1)], - worker_kwargs={'reconnect': False}) +@gen_cluster( + Worker=Nanny, ncores=[("127.0.0.1", 1)], worker_kwargs={"reconnect": False} +) def test_close_on_disconnect(s, w): yield s.close() start = time() - while w.status != 'closed': + while w.status != "closed": yield gen.sleep(0.05) assert time() < start + 9 @@ -149,15 +148,14 @@ class Something(Worker): @gen_cluster(client=True, Worker=Nanny) def test_nanny_worker_class(c, s, w1, w2): out = yield c._run(lambda dask_worker=None: str(dask_worker.__class__)) - assert 'Worker' in list(out.values())[0] + assert "Worker" in list(out.values())[0] assert w1.Worker is Worker -@gen_cluster(client=True, Worker=Nanny, - worker_kwargs={'worker_class': Something}) +@gen_cluster(client=True, Worker=Nanny, worker_kwargs={"worker_class": Something}) def test_nanny_alt_worker_class(c, s, w1, w2): out = yield c._run(lambda dask_worker=None: str(dask_worker.__class__)) - assert 'Something' in list(out.values())[0] + assert "Something" in list(out.values())[0] assert w1.Worker is Something @@ -168,15 +166,15 @@ def test_nanny_death_timeout(s): w = yield Nanny(s.address, death_timeout=1) yield gen.sleep(3) - assert w.status == 'closed' + assert w.status == "closed" @gen_cluster(client=True, Worker=Nanny) def test_random_seed(c, s, a, b): @gen.coroutine def check_func(func): - x = c.submit(func, 0, 2**31, pure=False, workers=a.worker_address) - y = c.submit(func, 0, 2**31, pure=False, workers=b.worker_address) + x = c.submit(func, 0, 2 ** 31, pure=False, workers=a.worker_address) + y = c.submit(func, 0, 2 ** 31, pure=False, workers=b.worker_address) assert x.key != y.key x = yield x y = yield y @@ -186,11 +184,12 @@ def check_func(func): yield check_func(lambda a, b: np.random.randint(a, b)) -@pytest.mark.skipif(sys.platform.startswith('win'), - reason="num_fds not supported on windows") +@pytest.mark.skipif( + sys.platform.startswith("win"), reason="num_fds not supported on windows" +) @gen_cluster(client=False, ncores=[]) def test_num_fds(s): - psutil = pytest.importorskip('psutil') + psutil = pytest.importorskip("psutil") proc = psutil.Process() # Warm up @@ -213,11 +212,12 @@ def test_num_fds(s): assert time() < start + 10 -@pytest.mark.skipif(not sys.platform.startswith('linux'), - reason="Need 127.0.0.2 to mean localhost") +@pytest.mark.skipif( + not sys.platform.startswith("linux"), reason="Need 127.0.0.2 to mean localhost" +) @gen_cluster(client=True, ncores=[]) def test_worker_uses_same_host_as_nanny(c, s): - for host in ['tcp://0.0.0.0', 'tcp://127.0.0.2']: + for host in ["tcp://0.0.0.0", "tcp://127.0.0.2"]: n = Nanny(s.address) yield n._start(host) @@ -240,44 +240,50 @@ def test_scheduler_file(): s.stop() -@gen_cluster(client=True, Worker=Nanny, ncores=[('127.0.0.1', 2)]) +@gen_cluster(client=True, Worker=Nanny, ncores=[("127.0.0.1", 2)]) def test_nanny_timeout(c, s, a): x = yield c.scatter(123) - with captured_logger(logging.getLogger('distributed.nanny'), - level=logging.ERROR) as logger: + with captured_logger( + logging.getLogger("distributed.nanny"), level=logging.ERROR + ) as logger: response = yield a.restart(timeout=0.1) out = logger.getvalue() - assert 'timed out' in out.lower() + assert "timed out" in out.lower() start = time() - while x.status != 'cancelled': + while x.status != "cancelled": yield gen.sleep(0.1) assert time() < start + 7 -@gen_cluster(ncores=[('127.0.0.1', 1)], client=True, Worker=Nanny, - worker_kwargs={'memory_limit': 1e8}, timeout=20, - check_new_threads=False) +@gen_cluster( + ncores=[("127.0.0.1", 1)], + client=True, + Worker=Nanny, + worker_kwargs={"memory_limit": 1e8}, + timeout=20, + check_new_threads=False, +) def test_nanny_terminate(c, s, a): from time import sleep def leak(): L = [] while True: - L.append(b'0' * 5000000) + L.append(b"0" * 5000000) sleep(0.01) proc = a.process.pid - with captured_logger(logging.getLogger('distributed.nanny')) as logger: + with captured_logger(logging.getLogger("distributed.nanny")) as logger: future = c.submit(leak) start = time() while a.process.pid == proc: yield gen.sleep(0.1) assert time() < start + 10 out = logger.getvalue() - assert 'restart' in out.lower() - assert 'memory' in out.lower() + assert "restart" in out.lower() + assert "memory" in out.lower() @gen_cluster(ncores=[], client=True) @@ -286,8 +292,8 @@ def test_avoid_memory_monitor_if_zero_limit(c, s): typ = yield c.run(lambda dask_worker: type(dask_worker.data)) assert typ == {nanny.worker_address: dict} pcs = yield c.run(lambda dask_worker: list(dask_worker.periodic_callbacks)) - assert 'memory' not in pcs - assert 'memory' not in nanny.periodic_callbacks + assert "memory" not in pcs + assert "memory" not in nanny.periodic_callbacks future = c.submit(inc, 1) assert (yield future) == 2 @@ -300,7 +306,7 @@ def test_avoid_memory_monitor_if_zero_limit(c, s): @gen_cluster(ncores=[], client=True) def test_scheduler_address_config(c, s): - with dask.config.set({'scheduler-address': s.address}): + with dask.config.set({"scheduler-address": s.address}): nanny = yield Nanny(loop=s.loop) assert nanny.scheduler.address == s.address @@ -315,14 +321,14 @@ def test_scheduler_address_config(c, s): @slow @gen_test() def test_wait_for_scheduler(): - with captured_logger('distributed') as log: - w = Nanny('127.0.0.1:44737') + with captured_logger("distributed") as log: + w = Nanny("127.0.0.1:44737") w._start() yield gen.sleep(6) log = log.getvalue() - assert 'error' not in log.lower(), log - assert 'restart' not in log.lower(), log + assert "error" not in log.lower(), log + assert "restart" not in log.lower(), log @gen_cluster(ncores=[], client=True) @@ -330,7 +336,7 @@ def test_environment_variable(c, s): a = Nanny(s.address, loop=s.loop, memory_limit=0, env={"FOO": "123"}) b = Nanny(s.address, loop=s.loop, memory_limit=0, env={"FOO": "456"}) yield [a, b] - results = yield c.run(lambda: os.environ['FOO']) + results = yield c.run(lambda: os.environ["FOO"]) assert results == {a.worker_address: "123", b.worker_address: "456"} yield [a._close(), b._close()] diff --git a/distributed/tests/test_preload.py b/distributed/tests/test_preload.py index 6239fdd8c0b..07ee56d85a6 100644 --- a/distributed/tests/test_preload.py +++ b/distributed/tests/test_preload.py @@ -20,47 +20,48 @@ def get_worker_address(): def test_worker_preload_file(loop): - def check_worker(): import worker_info + return worker_info.get_worker_address() tmpdir = tempfile.mkdtemp() try: - path = os.path.join(tmpdir, 'worker_info.py') - with open(path, 'w') as f: + path = os.path.join(tmpdir, "worker_info.py") + with open(path, "w") as f: f.write(PRELOAD_TEXT) - with cluster(worker_kwargs={'preload': [path]}) as (s, workers), \ - Client(s['address'], loop=loop) as c: + with cluster(worker_kwargs={"preload": [path]}) as (s, workers), Client( + s["address"], loop=loop + ) as c: assert c.run(check_worker) == { - worker['address']: worker['address'] - for worker in workers + worker["address"]: worker["address"] for worker in workers } finally: shutil.rmtree(tmpdir) def test_worker_preload_module(loop): - def check_worker(): import worker_info + return worker_info.get_worker_address() tmpdir = tempfile.mkdtemp() sys.path.insert(0, tmpdir) try: - path = os.path.join(tmpdir, 'worker_info.py') - with open(path, 'w') as f: + path = os.path.join(tmpdir, "worker_info.py") + with open(path, "w") as f: f.write(PRELOAD_TEXT) - with cluster(worker_kwargs={'preload': ['worker_info']}) \ - as (s, workers), Client(s['address'], loop=loop) as c: + with cluster(worker_kwargs={"preload": ["worker_info"]}) as ( + s, + workers, + ), Client(s["address"], loop=loop) as c: assert c.run(check_worker) == { - worker['address']: worker['address'] - for worker in workers + worker["address"]: worker["address"] for worker in workers } finally: sys.path.remove(tmpdir) diff --git a/distributed/tests/test_priorities.py b/distributed/tests/test_priorities.py index 0b18b1ba729..421bf7e3028 100644 --- a/distributed/tests/test_priorities.py +++ b/distributed/tests/test_priorities.py @@ -17,7 +17,7 @@ def test_submit(c, s, a, b): high = c.submit(inc, 2, priority=1) yield wait(high) assert all(s.processing.values()) - assert s.tasks[low.key].state == 'processing' + assert s.tasks[low.key].state == "processing" @gen_cluster(client=True) @@ -27,12 +27,12 @@ def test_map(c, s, a, b): high = c.map(inc, [4, 5, 6], priority=1) yield wait(high) assert all(s.processing.values()) - assert s.tasks[low[0].key].state == 'processing' + assert s.tasks[low[0].key].state == "processing" @gen_cluster(client=True) def test_compute(c, s, a, b): - da = pytest.importorskip('dask.array') + da = pytest.importorskip("dask.array") x = da.random.random((10, 10), chunks=(5, 5)) y = da.random.random((10, 10), chunks=(5, 5)) @@ -41,12 +41,12 @@ def test_compute(c, s, a, b): high = c.compute(y, priority=1) yield wait(high) assert all(s.processing.values()) - assert s.tasks[tokey(low.key)].state in ('processing', 'waiting') + assert s.tasks[tokey(low.key)].state in ("processing", "waiting") @gen_cluster(client=True) def test_persist(c, s, a, b): - da = pytest.importorskip('dask.array') + da = pytest.importorskip("dask.array") x = da.random.random((10, 10), chunks=(5, 5)) y = da.random.random((10, 10), chunks=(5, 5)) @@ -55,8 +55,10 @@ def test_persist(c, s, a, b): high = y.persist(priority=1) yield wait(high) assert all(s.processing.values()) - assert all(s.tasks[tokey(k)].state in ('processing', 'waiting') - for k in flatten(low.__dask_keys__())) + assert all( + s.tasks[tokey(k)].state in ("processing", "waiting") + for k in flatten(low.__dask_keys__()) + ) @gen_cluster(client=True) @@ -67,43 +69,51 @@ def test_expand_compute(c, s, a, b): low, many, high = c.compute([low, many, high], priority={low: -1, high: 1}) yield wait(high) - assert s.tasks[low.key].state == 'processing' + assert s.tasks[low.key].state == "processing" @gen_cluster(client=True) def test_expand_persist(c, s, a, b): - low = delayed(inc)(1, dask_key_name='low') + low = delayed(inc)(1, dask_key_name="low") many = [delayed(slowinc)(i, delay=0.1) for i in range(4)] - high = delayed(inc)(2, dask_key_name='high') + high = delayed(inc)(2, dask_key_name="high") low, high, x, y, z, w = persist(low, high, *many, priority={low: -1, high: 1}) yield wait(high) - assert s.tasks[low.key].state == 'processing' + assert s.tasks[low.key].state == "processing" -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)]) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)]) def test_repeated_persists_same_priority(c, s, w): - xs = [delayed(slowinc)(i, delay=0.05, dask_key_name='x-%d' % i) for i in range(10)] - ys = [delayed(slowinc)(x, delay=0.05, dask_key_name='y-%d' % i) for i, x in enumerate(xs)] - zs = [delayed(slowdec)(x, delay=0.05, dask_key_name='z-%d' % i) for i, x in enumerate(xs)] + xs = [delayed(slowinc)(i, delay=0.05, dask_key_name="x-%d" % i) for i in range(10)] + ys = [ + delayed(slowinc)(x, delay=0.05, dask_key_name="y-%d" % i) + for i, x in enumerate(xs) + ] + zs = [ + delayed(slowdec)(x, delay=0.05, dask_key_name="z-%d" % i) + for i, x in enumerate(xs) + ] ys = dask.persist(*ys) zs = dask.persist(*zs) - while sum(t.state == 'memory' for t in s.tasks.values()) < 5: # TODO: reduce this number + while ( + sum(t.state == "memory" for t in s.tasks.values()) < 5 + ): # TODO: reduce this number yield gen.sleep(0.01) - assert any(s.tasks[y.key].state == 'memory' for y in ys) - assert any(s.tasks[z.key].state == 'memory' for z in zs) + assert any(s.tasks[y.key].state == "memory" for y in ys) + assert any(s.tasks[z.key].state == "memory" for z in zs) -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)]) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)]) def test_last_in_first_out(c, s, w): xs = [c.submit(slowinc, i, delay=0.05) for i in range(5)] ys = [c.submit(slowinc, x, delay=0.05) for x in xs] zs = [c.submit(slowinc, y, delay=0.05) for y in ys] - while len(s.tasks) < 15 or not any(s.tasks[z.key].state == 'memory' for z in zs): + while len(s.tasks) < 15 or not any(s.tasks[z.key].state == "memory" for z in zs): yield gen.sleep(0.01) - assert not all(s.tasks[x.key].state == 'memory' for x in xs) + assert not all(s.tasks[x.key].state == "memory" for x in xs) diff --git a/distributed/tests/test_profile.py b/distributed/tests/test_profile.py index b7c717e1b61..57a7ca657e4 100644 --- a/distributed/tests/test_profile.py +++ b/distributed/tests/test_profile.py @@ -5,8 +5,7 @@ from distributed.compatibility import get_thread_identity from distributed import metrics -from distributed.profile import (process, merge, create, call_stack, - identifier, watch) +from distributed.profile import process, merge, create, call_stack, identifier, watch def test_basic(): @@ -32,66 +31,86 @@ def test_f(): frame = sys._current_frames()[thread.ident] process(frame, None, state) - assert state['count'] == 100 + assert state["count"] == 100 d = state - while len(d['children']) == 1: - d = first(d['children'].values()) + while len(d["children"]) == 1: + d = first(d["children"].values()) - assert d['count'] == 100 - assert 'test_f' in str(d['description']) - g = [c for c in d['children'].values() if 'test_g' in str(c['description'])][0] - h = [c for c in d['children'].values() if 'test_h' in str(c['description'])][0] + assert d["count"] == 100 + assert "test_f" in str(d["description"]) + g = [c for c in d["children"].values() if "test_g" in str(c["description"])][0] + h = [c for c in d["children"].values() if "test_h" in str(c["description"])][0] - assert g['count'] < h['count'] - assert 95 < g['count'] + h['count'] <= 100 + assert g["count"] < h["count"] + assert 95 < g["count"] + h["count"] <= 100 def test_merge(): a1 = { - 'count': 5, - 'identifier': 'root', - 'description': 'a', - 'children': { - 'b': {'count': 3, - 'description': 'b-func', - 'identifier': 'b', - 'children': {}}, - 'c': {'count': 2, - 'description': 'c-func', - 'identifier': 'c', - 'children': {}}}} + "count": 5, + "identifier": "root", + "description": "a", + "children": { + "b": { + "count": 3, + "description": "b-func", + "identifier": "b", + "children": {}, + }, + "c": { + "count": 2, + "description": "c-func", + "identifier": "c", + "children": {}, + }, + }, + } a2 = { - 'count': 4, - 'description': 'a', - 'identifier': 'root', - 'children': { - 'd': {'count': 2, - 'description': 'd-func', - 'children': {}, - 'identifier': 'd'}, - 'c': {'count': 2, - 'description': 'c-func', - 'children': {}, - 'identifier': 'c'}}} + "count": 4, + "description": "a", + "identifier": "root", + "children": { + "d": { + "count": 2, + "description": "d-func", + "children": {}, + "identifier": "d", + }, + "c": { + "count": 2, + "description": "c-func", + "children": {}, + "identifier": "c", + }, + }, + } expected = { - 'count': 9, - 'identifier': 'root', - 'description': 'a', - 'children': { - 'b': {'count': 3, - 'description': 'b-func', - 'identifier': 'b', - 'children': {}}, - 'd': {'count': 2, - 'description': 'd-func', - 'identifier': 'd', - 'children': {}}, - 'c': {'count': 4, - 'description': 'c-func', - 'identifier': 'c', - 'children': {}}}} + "count": 9, + "identifier": "root", + "description": "a", + "children": { + "b": { + "count": 3, + "description": "b-func", + "identifier": "b", + "children": {}, + }, + "d": { + "count": 2, + "description": "d-func", + "identifier": "d", + "children": {}, + }, + "c": { + "count": 4, + "description": "c-func", + "identifier": "c", + "children": {}, + }, + }, + } assert merge(a1, a2) == expected @@ -107,7 +126,7 @@ def test_call_stack(): L = call_stack(frame) assert isinstance(L, list) assert all(isinstance(s, str) for s in L) - assert 'test_call_stack' in str(L[-1]) + assert "test_call_stack" in str(L[-1]) def test_identifier(): @@ -124,7 +143,7 @@ def stop(): start_threads = threading.active_count() - log = watch(interval='10ms', cycle='50ms', stop=stop) + log = watch(interval="10ms", cycle="50ms", stop=stop) start = metrics.time() # wait until thread starts up while threading.active_count() <= start_threads: diff --git a/distributed/tests/test_publish.py b/distributed/tests/test_publish.py index c6a899374c8..e4789589c48 100644 --- a/distributed/tests/test_publish.py +++ b/distributed/tests/test_publish.py @@ -1,4 +1,3 @@ - import pytest from dask import delayed @@ -18,8 +17,8 @@ def test_publish_simple(s, a, b): data = yield c.scatter(range(3)) out = yield c.publish_dataset(data=data) - assert 'data' in s.extensions['publish'].datasets - assert isinstance(s.extensions['publish'].datasets['data']['data'], Serialized) + assert "data" in s.extensions["publish"].datasets + assert isinstance(s.extensions["publish"].datasets["data"]["data"], Serialized) with pytest.raises(KeyError) as exc_info: out = yield c.publish_dataset(data=data) @@ -28,10 +27,10 @@ def test_publish_simple(s, a, b): assert "data" in str(exc_info.value) result = yield c.scheduler.publish_list() - assert result == ('data',) + assert result == ("data",) result = yield f.scheduler.publish_list() - assert result == ('data',) + assert result == ("data",) yield c.close() yield f.close() @@ -43,11 +42,13 @@ def test_publish_non_string_key(s, a, b): f = yield Client((s.ip, s.port), asynchronous=True) try: - for name in [('a', 'b'), 9.0, 8]: + for name in [("a", "b"), 9.0, 8]: data = yield c.scatter(range(3)) out = yield c.publish_dataset(data, name=name) - assert name in s.extensions['publish'].datasets - assert isinstance(s.extensions['publish'].datasets[name]['data'], Serialized) + assert name in s.extensions["publish"].datasets + assert isinstance( + s.extensions["publish"].datasets[name]["data"], Serialized + ) datasets = yield c.scheduler.publish_list() assert name in datasets @@ -65,15 +66,15 @@ def test_publish_roundtrip(s, a, b): data = yield c.scatter([0, 1, 2]) yield c.publish_dataset(data=data) - assert 'published-data' in s.who_wants[data[0].key] - result = yield f.get_dataset(name='data') + assert "published-data" in s.who_wants[data[0].key] + result = yield f.get_dataset(name="data") assert len(result) == len(data) out = yield f.gather(result) assert out == [0, 1, 2] with pytest.raises(KeyError) as exc_info: - result = yield f.get_dataset(name='nonexistent') + result = yield f.get_dataset(name="nonexistent") assert "not found" in str(exc_info.value) assert "nonexistent" in str(exc_info.value) @@ -90,9 +91,9 @@ def test_unpublish(c, s, a, b): key = data[0].key del data - yield c.scheduler.publish_delete(name='data') + yield c.scheduler.publish_delete(name="data") - assert 'data' not in s.extensions['publish'].datasets + assert "data" not in s.extensions["publish"].datasets start = time() while key in s.who_wants: @@ -100,7 +101,7 @@ def test_unpublish(c, s, a, b): assert time() < start + 5 with pytest.raises(KeyError) as exc_info: - result = yield c.get_dataset(name='data') + result = yield c.get_dataset(name="data") assert "not found" in str(exc_info.value) assert "data" in str(exc_info.value) @@ -109,10 +110,10 @@ def test_unpublish(c, s, a, b): def test_unpublish_sync(client): data = client.scatter([0, 1, 2]) client.publish_dataset(data=data) - client.unpublish_dataset(name='data') + client.unpublish_dataset(name="data") with pytest.raises(KeyError) as exc_info: - result = client.get_dataset(name='data') + result = client.get_dataset(name="data") assert "not found" in str(exc_info.value) assert "data" in str(exc_info.value) @@ -125,28 +126,28 @@ def test_publish_multiple_datasets(c, s, a, b): yield c.publish_dataset(x=x, y=y) datasets = yield c.scheduler.publish_list() - assert set(datasets) == {'x', 'y'} + assert set(datasets) == {"x", "y"} def test_unpublish_multiple_datasets_sync(client): x = delayed(inc)(1) y = delayed(inc)(2) client.publish_dataset(x=x, y=y) - client.unpublish_dataset(name='x') + client.unpublish_dataset(name="x") with pytest.raises(KeyError) as exc_info: - result = client.get_dataset(name='x') + result = client.get_dataset(name="x") datasets = client.list_datasets() - assert set(datasets) == {'y'} + assert set(datasets) == {"y"} assert "not found" in str(exc_info.value) assert "x" in str(exc_info.value) - client.unpublish_dataset(name='y') + client.unpublish_dataset(name="y") with pytest.raises(KeyError) as exc_info: - result = client.get_dataset(name='y') + result = client.get_dataset(name="y") assert "not found" in str(exc_info.value) assert "y" in str(exc_info.value) @@ -154,7 +155,7 @@ def test_unpublish_multiple_datasets_sync(client): @gen_cluster(client=False) def test_publish_bag(s, a, b): - db = pytest.importorskip('dask.bag') + db = pytest.importorskip("dask.bag") c = yield Client((s.ip, s.port), asynchronous=True) f = yield Client((s.ip, s.port), asynchronous=True) @@ -170,7 +171,7 @@ def test_publish_bag(s, a, b): # check that serialization didn't affect original bag's dask assert len(futures_of(bagp)) == 3 - result = yield f.get_dataset('data') + result = yield f.get_dataset("data") assert set(result.dask.keys()) == set(bagp.dask.keys()) assert {f.key for f in result.dask.values()} == {f.key for f in bagp.dask.values()} @@ -181,22 +182,22 @@ def test_publish_bag(s, a, b): def test_datasets_setitem(client): - for key in ['key', ('key', 'key'), 1]: - value = 'value' + for key in ["key", ("key", "key"), 1]: + value = "value" client.datasets[key] = value assert client.get_dataset(key) == value def test_datasets_getitem(client): - for key in ['key', ('key', 'key'), 1]: - value = 'value' + for key in ["key", ("key", "key"), 1]: + value = "value" client.publish_dataset(value, name=key) assert client.datasets[key] == value def test_datasets_delitem(client): - for key in ['key', ('key', 'key'), 1]: - value = 'value' + for key in ["key", ("key", "key"), 1]: + value = "value" client.publish_dataset(value, name=key) del client.datasets[key] assert key not in client.list_datasets() @@ -209,7 +210,7 @@ def test_datasets_keys(client): def test_datasets_contains(client): - key, value = 'key', 'value' + key, value = "key", "value" client.publish_dataset(key=value) assert key in client.datasets @@ -223,11 +224,10 @@ def test_datasets_iter(client): @gen_cluster(client=True) def test_pickle_safe(c, s, a, b): - c2 = yield Client(s.address, asynchronous=True, - serializers=['msgpack']) + c2 = yield Client(s.address, asynchronous=True, serializers=["msgpack"]) try: yield c2.publish_dataset(x=[1, 2, 3]) - result = yield c2.get_dataset('x') + result = yield c2.get_dataset("x") assert result == (1, 2, 3) with pytest.raises(TypeError): @@ -236,6 +236,6 @@ def test_pickle_safe(c, s, a, b): yield c.publish_dataset(z=lambda x: x) # this can use pickle with pytest.raises(TypeError): - yield c2.get_dataset('z') + yield c2.get_dataset("z") finally: yield c2.close() diff --git a/distributed/tests/test_pubsub.py b/distributed/tests/test_pubsub.py index ff299400e6e..c44637cf9fd 100644 --- a/distributed/tests/test_pubsub.py +++ b/distributed/tests/test_pubsub.py @@ -18,6 +18,7 @@ def test_speed(c, s, a, b): Interestingly this runs 10x slower on Python 2 """ + def pingpong(a, b, start=False, n=1000, msg=1): sub = Sub(a) pub = Pub(b) @@ -36,10 +37,11 @@ def pingpong(a, b, start=False, n=1000, msg=1): return n import numpy as np + x = np.random.random(1000) - x = c.submit(pingpong, 'a', 'b', start=True, msg=x, n=100) - y = c.submit(pingpong, 'b', 'a', n=100) + x = c.submit(pingpong, "a", "b", start=True, msg=x, n=100) + y = c.submit(pingpong, "b", "a", n=100) start = time() yield c.gather([x, y]) @@ -51,14 +53,14 @@ def pingpong(a, b, start=False, n=1000, msg=1): def test_client(c, s): with pytest.raises(Exception): get_worker() - sub = Sub('a') - pub = Pub('a') + sub = Sub("a") + pub = Pub("a") - sps = s.extensions['pubsub'] - cps = c.extensions['pubsub'] + sps = s.extensions["pubsub"] + cps = c.extensions["pubsub"] start = time() - while not set(sps.client_subscribers['a']) == {c.id}: + while not set(sps.client_subscribers["a"]) == {c.id}: yield gen.sleep(0.01) assert time() < start + 3 @@ -70,10 +72,10 @@ def test_client(c, s): @gen_cluster(client=True) def test_client_worker(c, s, a, b): - sub = Sub('a', client=c, worker=None) + sub = Sub("a", client=c, worker=None) def f(x): - pub = Pub('a') + pub = Pub("a") pub.put(x) futures = c.map(f, range(10)) @@ -86,32 +88,36 @@ def f(x): assert set(L) == set(range(10)) - sps = s.extensions['pubsub'] - aps = a.extensions['pubsub'] - bps = b.extensions['pubsub'] + sps = s.extensions["pubsub"] + aps = a.extensions["pubsub"] + bps = b.extensions["pubsub"] start = time() - while (sps.publishers['a'] or - sps.subscribers['a'] or - aps.publishers['a'] or - bps.publishers['a'] or - len(sps.client_subscribers['a']) != 1): + while ( + sps.publishers["a"] + or sps.subscribers["a"] + or aps.publishers["a"] + or bps.publishers["a"] + or len(sps.client_subscribers["a"]) != 1 + ): yield gen.sleep(0.01) assert time() < start + 3 del sub start = time() - while (sps.client_subscribers or - any(aps.publish_to_scheduler.values()) or - any(bps.publish_to_scheduler.values())): + while ( + sps.client_subscribers + or any(aps.publish_to_scheduler.values()) + or any(bps.publish_to_scheduler.values()) + ): yield gen.sleep(0.01) assert time() < start + 3 @gen_cluster(client=True) def test_timeouts(c, s, a, b): - sub = Sub('a', client=c, worker=None) + sub = Sub("a", client=c, worker=None) start = time() with pytest.raises(TimeoutError): yield sub.get(timeout=0.1) diff --git a/distributed/tests/test_queues.py b/distributed/tests/test_queues.py index 6d6306ca6f9..e82b893989b 100644 --- a/distributed/tests/test_queues.py +++ b/distributed/tests/test_queues.py @@ -9,15 +9,15 @@ from distributed import Client, Queue, Nanny, worker_client, wait from distributed.metrics import time -from distributed.utils_test import (gen_cluster, inc, slow, div) -from distributed.utils_test import client, cluster_fixture, loop # noqa: F401 +from distributed.utils_test import gen_cluster, inc, slow, div +from distributed.utils_test import client, cluster_fixture, loop # noqa: F401 @gen_cluster(client=True) def test_queue(c, s, a, b): - x = yield Queue('x') - y = yield Queue('y') - xx = yield Queue('x') + x = yield Queue("x") + y = yield Queue("y") + xx = yield Queue("x") assert x.client is c future = c.submit(inc, 1) @@ -44,14 +44,14 @@ def test_queue(c, s, a, b): @gen_cluster(client=True) def test_queue_with_data(c, s, a, b): - x = yield Queue('x') - xx = yield Queue('x') + x = yield Queue("x") + xx = yield Queue("x") assert x.client is c - yield x.put((1, 'hello')) + yield x.put((1, "hello")) data = yield xx.get() - assert data == (1, 'hello') + assert data == (1, "hello") with pytest.raises(gen.TimeoutError): yield x.get(timeout=0.1) @@ -59,8 +59,8 @@ def test_queue_with_data(c, s, a, b): def test_sync(client): future = client.submit(lambda x: x + 1, 10) - x = Queue('x') - xx = Queue('x') + x = Queue("x") + xx = Queue("x") x.put(future) assert x.qsize() == 1 assert xx.qsize() == 1 @@ -73,7 +73,7 @@ def test_sync(client): def test_hold_futures(s, a, b): c1 = yield Client(s.address, asynchronous=True) future = c1.submit(lambda x: x + 1, 10) - q1 = yield Queue('q') + q1 = yield Queue("q") yield q1.put(future) del q1 yield c1.close() @@ -81,7 +81,7 @@ def test_hold_futures(s, a, b): yield gen.sleep(0.1) c2 = yield Client(s.address, asynchronous=True) - q2 = yield Queue('q') + q2 = yield Queue("q") future2 = yield q2.get() result = yield future2 @@ -89,7 +89,7 @@ def test_hold_futures(s, a, b): yield c2.close() -@pytest.mark.skip(reason='getting same client from main thread') +@pytest.mark.skip(reason="getting same client from main thread") @gen_cluster(client=True) def test_picklability(c, s, a, b): q = Queue() @@ -113,14 +113,13 @@ def f(x): assert q.get() == 11 -@pytest.mark.skipif(sys.version_info[0] == 2, reason='Multi-client issues') +@pytest.mark.skipif(sys.version_info[0] == 2, reason="Multi-client issues") @slow -@gen_cluster(client=True, ncores=[('127.0.0.1', 2)] * 5, Worker=Nanny, - timeout=None) +@gen_cluster(client=True, ncores=[("127.0.0.1", 2)] * 5, Worker=Nanny, timeout=None) def test_race(c, s, *workers): def f(i): with worker_client() as c: - q = Queue('x', client=c) + q = Queue("x", client=c) for _ in range(100): future = q.get() x = future.result() @@ -130,7 +129,7 @@ def f(i): result = q.get().result() return result - q = Queue('x', client=c) + q = Queue("x", client=c) L = yield c.scatter(range(5)) for future in L: yield q.put(future) @@ -145,32 +144,32 @@ def f(i): @gen_cluster(client=True) def test_same_futures(c, s, a, b): - q = Queue('x') + q = Queue("x") future = yield c.scatter(123) for i in range(5): yield q.put(future) - assert s.wants_what['queue-x'] == {future.key} + assert s.wants_what["queue-x"] == {future.key} for i in range(4): future2 = yield q.get() - assert s.wants_what['queue-x'] == {future.key} + assert s.wants_what["queue-x"] == {future.key} yield gen.sleep(0.05) - assert s.wants_what['queue-x'] == {future.key} + assert s.wants_what["queue-x"] == {future.key} yield q.get() start = time() - while s.wants_what['queue-x']: + while s.wants_what["queue-x"]: yield gen.sleep(0.01) assert time() - start < 2 @gen_cluster(client=True) def test_get_many(c, s, a, b): - x = yield Queue('x') - xx = yield Queue('x') + x = yield Queue("x") + xx = yield Queue("x") yield x.put(1) yield x.put(2) @@ -187,27 +186,26 @@ def test_get_many(c, s, a, b): assert data == [1, 2] with pytest.raises(gen.TimeoutError): - data = yield gen.with_timeout(timedelta(seconds=0.100), - xx.get(batch=2)) + data = yield gen.with_timeout(timedelta(seconds=0.100), xx.get(batch=2)) @gen_cluster(client=True) def test_Future_knows_status_immediately(c, s, a, b): x = yield c.scatter(123) - q = yield Queue('q') + q = yield Queue("q") yield q.put(x) c2 = yield Client(s.address, asynchronous=True) - q2 = yield Queue('q', client=c2) + q2 = yield Queue("q", client=c2) future = yield q2.get() - assert future.status == 'finished' + assert future.status == "finished" x = c.submit(div, 1, 0) yield wait(x) yield q.put(x) future2 = yield q2.get() - assert future2.status == 'error' + assert future2.status == "error" with pytest.raises(Exception): yield future2 @@ -242,19 +240,19 @@ def test_erred_future(c, s, a, b): def test_close(c, s, a, b): q = Queue() - while q.name not in s.extensions['queues'].queues: + while q.name not in s.extensions["queues"].queues: yield gen.sleep(0.01) q.close() q.close() - while q.name in s.extensions['queues'].queues: + while q.name in s.extensions["queues"].queues: yield gen.sleep(0.01) @gen_cluster(client=True) def test_timeout(c, s, a, b): - q = Queue('v', maxsize=1) + q = Queue("v", maxsize=1) start = time() with pytest.raises(gen.TimeoutError): diff --git a/distributed/tests/test_resources.py b/distributed/tests/test_resources.py index 4b1e9e2a80c..35f5e160969 100644 --- a/distributed/tests/test_resources.py +++ b/distributed/tests/test_resources.py @@ -10,7 +10,7 @@ from distributed.client import wait from distributed.compatibility import WINDOWS from distributed.utils import tokey -from distributed.utils_test import (inc, gen_cluster, slowinc, slowadd) +from distributed.utils_test import inc, gen_cluster, slowinc, slowadd from distributed.utils_test import client, cluster_fixture, loop, s, a, b # noqa: F401 @@ -19,30 +19,33 @@ def test_resources(c, s): assert not s.worker_resources assert not s.resources - a = Worker(s.ip, s.port, loop=s.loop, resources={'GPU': 2}) - b = Worker(s.ip, s.port, loop=s.loop, resources={'GPU': 1, 'DB': 1}) + a = Worker(s.ip, s.port, loop=s.loop, resources={"GPU": 2}) + b = Worker(s.ip, s.port, loop=s.loop, resources={"GPU": 1, "DB": 1}) yield [a, b] - assert s.resources == {'GPU': {a.address: 2, b.address: 1}, - 'DB': {b.address: 1}} - assert s.worker_resources == {a.address: {'GPU': 2}, - b.address: {'GPU': 1, 'DB': 1}} + assert s.resources == {"GPU": {a.address: 2, b.address: 1}, "DB": {b.address: 1}} + assert s.worker_resources == {a.address: {"GPU": 2}, b.address: {"GPU": 1, "DB": 1}} yield b._close() - assert s.resources == {'GPU': {a.address: 2}, 'DB': {}} - assert s.worker_resources == {a.address: {'GPU': 2}} + assert s.resources == {"GPU": {a.address: 2}, "DB": {}} + assert s.worker_resources == {a.address: {"GPU": 2}} yield a._close() -@gen_cluster(client=True, ncores=[('127.0.0.1', 1, {'resources': {'A': 5}}), - ('127.0.0.1', 1, {'resources': {'A': 1, 'B': 1}})]) +@gen_cluster( + client=True, + ncores=[ + ("127.0.0.1", 1, {"resources": {"A": 5}}), + ("127.0.0.1", 1, {"resources": {"A": 1, "B": 1}}), + ], +) def test_resource_submit(c, s, a, b): - x = c.submit(inc, 1, resources={'A': 3}) - y = c.submit(inc, 2, resources={'B': 1}) - z = c.submit(inc, 3, resources={'C': 2}) + x = c.submit(inc, 1, resources={"A": 3}) + y = c.submit(inc, 2, resources={"B": 1}) + z = c.submit(inc, 3, resources={"C": 2}) yield wait(x) assert x.key in a.data @@ -50,9 +53,9 @@ def test_resource_submit(c, s, a, b): yield wait(y) assert y.key in b.data - assert s.get_task_status(keys=[z.key]) == {z.key: 'no-worker'} + assert s.get_task_status(keys=[z.key]) == {z.key: "no-worker"} - d = yield Worker(s.ip, s.port, loop=s.loop, resources={'C': 10}) + d = yield Worker(s.ip, s.port, loop=s.loop, resources={"C": 10}) yield wait(z) assert z.key in d.data @@ -60,55 +63,81 @@ def test_resource_submit(c, s, a, b): yield d._close() -@gen_cluster(client=True, ncores=[('127.0.0.1', 1, {'resources': {'A': 1}}), - ('127.0.0.1', 1, {'resources': {'B': 1}})]) +@gen_cluster( + client=True, + ncores=[ + ("127.0.0.1", 1, {"resources": {"A": 1}}), + ("127.0.0.1", 1, {"resources": {"B": 1}}), + ], +) def test_submit_many_non_overlapping(c, s, a, b): - futures = [c.submit(inc, i, resources={'A': 1}) for i in range(5)] + futures = [c.submit(inc, i, resources={"A": 1}) for i in range(5)] yield wait(futures) assert len(a.data) == 5 assert len(b.data) == 0 -@gen_cluster(client=True, ncores=[('127.0.0.1', 1, {'resources': {'A': 1}}), - ('127.0.0.1', 1, {'resources': {'B': 1}})]) +@gen_cluster( + client=True, + ncores=[ + ("127.0.0.1", 1, {"resources": {"A": 1}}), + ("127.0.0.1", 1, {"resources": {"B": 1}}), + ], +) def test_move(c, s, a, b): [x] = yield c._scatter([1], workers=b.address) - future = c.submit(inc, x, resources={'A': 1}) + future = c.submit(inc, x, resources={"A": 1}) yield wait(future) assert a.data[future.key] == 2 -@gen_cluster(client=True, ncores=[('127.0.0.1', 1, {'resources': {'A': 1}}), - ('127.0.0.1', 1, {'resources': {'B': 1}})]) +@gen_cluster( + client=True, + ncores=[ + ("127.0.0.1", 1, {"resources": {"A": 1}}), + ("127.0.0.1", 1, {"resources": {"B": 1}}), + ], +) def test_dont_work_steal(c, s, a, b): [x] = yield c._scatter([1], workers=a.address) - futures = [c.submit(slowadd, x, i, resources={'A': 1}, delay=0.05) - for i in range(10)] + futures = [ + c.submit(slowadd, x, i, resources={"A": 1}, delay=0.05) for i in range(10) + ] yield wait(futures) assert all(f.key in a.data for f in futures) -@gen_cluster(client=True, ncores=[('127.0.0.1', 1, {'resources': {'A': 1}}), - ('127.0.0.1', 1, {'resources': {'B': 1}})]) +@gen_cluster( + client=True, + ncores=[ + ("127.0.0.1", 1, {"resources": {"A": 1}}), + ("127.0.0.1", 1, {"resources": {"B": 1}}), + ], +) def test_map(c, s, a, b): - futures = c.map(inc, range(10), resources={'B': 1}) + futures = c.map(inc, range(10), resources={"B": 1}) yield wait(futures) assert set(b.data) == {f.key for f in futures} assert not a.data -@gen_cluster(client=True, ncores=[('127.0.0.1', 1, {'resources': {'A': 1}}), - ('127.0.0.1', 1, {'resources': {'B': 1}})]) +@gen_cluster( + client=True, + ncores=[ + ("127.0.0.1", 1, {"resources": {"A": 1}}), + ("127.0.0.1", 1, {"resources": {"B": 1}}), + ], +) def test_persist(c, s, a, b): x = delayed(inc)(1) y = delayed(inc)(x) - xx, yy = c.persist([x, y], resources={x: {'A': 1}, y: {'B': 1}}) + xx, yy = c.persist([x, y], resources={x: {"A": 1}, y: {"B": 1}}) yield wait([xx, yy]) @@ -116,40 +145,55 @@ def test_persist(c, s, a, b): assert y.key in b.data -@gen_cluster(client=True, ncores=[('127.0.0.1', 1, {'resources': {'A': 1}}), - ('127.0.0.1', 1, {'resources': {'B': 11}})]) +@gen_cluster( + client=True, + ncores=[ + ("127.0.0.1", 1, {"resources": {"A": 1}}), + ("127.0.0.1", 1, {"resources": {"B": 11}}), + ], +) def test_compute(c, s, a, b): x = delayed(inc)(1) y = delayed(inc)(x) - yy = c.compute(y, resources={x: {'A': 1}, y: {'B': 1}}) + yy = c.compute(y, resources={x: {"A": 1}, y: {"B": 1}}) yield wait(yy) assert b.data xs = [delayed(inc)(i) for i in range(10, 20)] - xxs = c.compute(xs, resources={'B': 1}) + xxs = c.compute(xs, resources={"B": 1}) yield wait(xxs) assert len(b.data) > 10 -@gen_cluster(client=True, ncores=[('127.0.0.1', 1, {'resources': {'A': 1}}), - ('127.0.0.1', 1, {'resources': {'B': 1}})]) +@gen_cluster( + client=True, + ncores=[ + ("127.0.0.1", 1, {"resources": {"A": 1}}), + ("127.0.0.1", 1, {"resources": {"B": 1}}), + ], +) def test_get(c, s, a, b): - dsk = {'x': (inc, 1), 'y': (inc, 'x')} + dsk = {"x": (inc, 1), "y": (inc, "x")} - result = yield c.get(dsk, 'y', resources={'y': {'A': 1}}, sync=False) + result = yield c.get(dsk, "y", resources={"y": {"A": 1}}, sync=False) assert result == 3 -@gen_cluster(client=True, ncores=[('127.0.0.1', 1, {'resources': {'A': 1}}), - ('127.0.0.1', 1, {'resources': {'B': 1}})]) +@gen_cluster( + client=True, + ncores=[ + ("127.0.0.1", 1, {"resources": {"A": 1}}), + ("127.0.0.1", 1, {"resources": {"B": 1}}), + ], +) def test_persist_tuple(c, s, a, b): x = delayed(inc)(1) y = delayed(inc)(x) - xx, yy = c.persist([x, y], resources={(x, y): {'A': 1}}) + xx, yy = c.persist([x, y], resources={(x, y): {"A": 1}}) yield wait([xx, yy]) @@ -158,10 +202,15 @@ def test_persist_tuple(c, s, a, b): assert not b.data -@gen_cluster(client=True, ncores=[('127.0.0.1', 4, {'resources': {'A': 2}}), - ('127.0.0.1', 4, {'resources': {'A': 1}})]) +@gen_cluster( + client=True, + ncores=[ + ("127.0.0.1", 4, {"resources": {"A": 2}}), + ("127.0.0.1", 4, {"resources": {"A": 1}}), + ], +) def test_submit_many_non_overlapping(c, s, a, b): - futures = c.map(slowinc, range(100), resources={'A': 1}, delay=0.02) + futures = c.map(slowinc, range(100), resources={"A": 1}, delay=0.02) while len(a.data) + len(b.data) < 100: yield gen.sleep(0.01) @@ -173,9 +222,9 @@ def test_submit_many_non_overlapping(c, s, a, b): assert b.total_resources == b.available_resources -@gen_cluster(client=True, ncores=[('127.0.0.1', 4, {'resources': {'A': 2, 'B': 1}})]) +@gen_cluster(client=True, ncores=[("127.0.0.1", 4, {"resources": {"A": 2, "B": 1}})]) def test_minimum_resource(c, s, a): - futures = c.map(slowinc, range(30), resources={'A': 1, 'B': 1}, delay=0.02) + futures = c.map(slowinc, range(30), resources={"A": 1, "B": 1}, delay=0.02) while len(a.data) < 30: yield gen.sleep(0.01) @@ -185,10 +234,10 @@ def test_minimum_resource(c, s, a): assert a.total_resources == a.available_resources -@gen_cluster(client=True, ncores=[('127.0.0.1', 2, {'resources': {'A': 1}})]) +@gen_cluster(client=True, ncores=[("127.0.0.1", 2, {"resources": {"A": 1}})]) def test_prefer_constrained(c, s, a): futures = c.map(slowinc, range(1000), delay=0.1) - constrained = c.map(inc, range(10), resources={'A': 1}) + constrained = c.map(inc, range(10), resources={"A": 1}) start = time() yield wait(constrained) @@ -201,44 +250,54 @@ def test_prefer_constrained(c, s, a): @pytest.mark.skip(reason="") -@gen_cluster(client=True, ncores=[('127.0.0.1', 2, {'resources': {'A': 1}}), - ('127.0.0.1', 2, {'resources': {'A': 1}})]) +@gen_cluster( + client=True, + ncores=[ + ("127.0.0.1", 2, {"resources": {"A": 1}}), + ("127.0.0.1", 2, {"resources": {"A": 1}}), + ], +) def test_balance_resources(c, s, a, b): futures = c.map(slowinc, range(100), delay=0.1, workers=a.address) - constrained = c.map(inc, range(2), resources={'A': 1}) + constrained = c.map(inc, range(2), resources={"A": 1}) yield wait(constrained) assert any(f.key in a.data for f in constrained) # share assert any(f.key in b.data for f in constrained) -@gen_cluster(client=True, ncores=[('127.0.0.1', 2)]) +@gen_cluster(client=True, ncores=[("127.0.0.1", 2)]) def test_set_resources(c, s, a): yield a.set_resources(A=2) - assert a.total_resources['A'] == 2 - assert a.available_resources['A'] == 2 - assert s.worker_resources[a.address] == {'A': 2} + assert a.total_resources["A"] == 2 + assert a.available_resources["A"] == 2 + assert s.worker_resources[a.address] == {"A": 2} - future = c.submit(slowinc, 1, delay=1, resources={'A': 1}) - while a.available_resources['A'] == 2: + future = c.submit(slowinc, 1, delay=1, resources={"A": 1}) + while a.available_resources["A"] == 2: yield gen.sleep(0.01) yield a.set_resources(A=3) - assert a.total_resources['A'] == 3 - assert a.available_resources['A'] == 2 - assert s.worker_resources[a.address] == {'A': 3} - - -@gen_cluster(client=True, ncores=[('127.0.0.1', 1, {'resources': {'A': 1}}), - ('127.0.0.1', 1, {'resources': {'B': 1}})]) + assert a.total_resources["A"] == 3 + assert a.available_resources["A"] == 2 + assert s.worker_resources[a.address] == {"A": 3} + + +@gen_cluster( + client=True, + ncores=[ + ("127.0.0.1", 1, {"resources": {"A": 1}}), + ("127.0.0.1", 1, {"resources": {"B": 1}}), + ], +) def test_persist_collections(c, s, a, b): - da = pytest.importorskip('dask.array') + da = pytest.importorskip("dask.array") x = da.arange(10, chunks=(5,)) y = x.map_blocks(lambda x: x + 1) z = y.map_blocks(lambda x: 2 * x) w = z.sum() - ww, yy = c.persist([w, y], resources={tuple(y.__dask_keys__()): {'A': 1}}) + ww, yy = c.persist([w, y], resources={tuple(y.__dask_keys__()): {"A": 1}}) yield wait([ww, yy]) @@ -246,57 +305,75 @@ def test_persist_collections(c, s, a, b): @pytest.mark.skip(reason="Should protect resource keys from optimization") -@gen_cluster(client=True, ncores=[('127.0.0.1', 1, {'resources': {'A': 1}}), - ('127.0.0.1', 1, {'resources': {'B': 1}})]) +@gen_cluster( + client=True, + ncores=[ + ("127.0.0.1", 1, {"resources": {"A": 1}}), + ("127.0.0.1", 1, {"resources": {"B": 1}}), + ], +) def test_dont_optimize_out(c, s, a, b): - da = pytest.importorskip('dask.array') + da = pytest.importorskip("dask.array") x = da.arange(10, chunks=(5,)) y = x.map_blocks(lambda x: x + 1) z = y.map_blocks(lambda x: 2 * x) w = z.sum() - yield c.compute(w, resources={tuple(y.__dask_keys__()): {'A': 1}},) + yield c.compute(w, resources={tuple(y.__dask_keys__()): {"A": 1}}) for key in map(tokey, y.__dask_keys__()): - assert 'executing' in str(a.story(key)) + assert "executing" in str(a.story(key)) @pytest.mark.xfail(reason="atop fusion seemed to break this") -@gen_cluster(client=True, ncores=[('127.0.0.1', 1, {'resources': {'A': 1}}), - ('127.0.0.1', 1, {'resources': {'B': 1}})]) +@gen_cluster( + client=True, + ncores=[ + ("127.0.0.1", 1, {"resources": {"A": 1}}), + ("127.0.0.1", 1, {"resources": {"B": 1}}), + ], +) def test_full_collections(c, s, a, b): - dd = pytest.importorskip('dask.dataframe') - df = dd.demo.make_timeseries(freq='60s', partition_freq='1d', - start='2000-01-01', end='2000-01-31') + dd = pytest.importorskip("dask.dataframe") + df = dd.demo.make_timeseries( + freq="60s", partition_freq="1d", start="2000-01-01", end="2000-01-31" + ) z = df.x + df.y # some extra nodes in the graph - yield c.compute(z, resources={tuple(z.dask): {'A': 1}}) + yield c.compute(z, resources={tuple(z.dask): {"A": 1}}) assert a.log assert not b.log -@pytest.mark.parametrize('optimize_graph', [ - pytest.param(True, - marks=pytest.mark.xfail(reason="don't track resources through optimization")), - pytest.param(False, - marks=pytest.mark.skipif(WINDOWS, reason="intermittent failure")) -]) +@pytest.mark.parametrize( + "optimize_graph", + [ + pytest.param( + True, + marks=pytest.mark.xfail( + reason="don't track resources through optimization" + ), + ), + pytest.param( + False, marks=pytest.mark.skipif(WINDOWS, reason="intermittent failure") + ), + ], +) def test_collections_get(client, optimize_graph, s, a, b): - da = pytest.importorskip('dask.array') + da = pytest.importorskip("dask.array") def f(dask_worker): - dask_worker.set_resources(**{'A': 1}) + dask_worker.set_resources(**{"A": 1}) - client.run(f, workers=[a['address']]) + client.run(f, workers=[a["address"]]) x = da.random.random(100, chunks=(10,)) + 1 - x.compute(resources={tuple(x.dask): {'A': 1}}, - optimize_graph=optimize_graph) + x.compute(resources={tuple(x.dask): {"A": 1}}, optimize_graph=optimize_graph) def g(dask_worker): return len(dask_worker.log) logs = client.run(g) - assert logs[a['address']] - assert not logs[b['address']] + assert logs[a["address"]] + assert not logs[b["address"]] diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index edf4a4eaece..02f15e1e1a2 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -23,15 +23,25 @@ from distributed.protocol.pickle import dumps from distributed.worker import dumps_function, dumps_task from distributed.utils import tmpfile -from distributed.utils_test import (inc, dec, gen_cluster, gen_test, - slowinc, slowadd, slowdec, cluster, div, - varying, slow) +from distributed.utils_test import ( + inc, + dec, + gen_cluster, + gen_test, + slowinc, + slowadd, + slowdec, + cluster, + div, + varying, + slow, +) from distributed.utils_test import loop, nodebug # noqa: F401 from dask.compatibility import apply -alice = 'alice:1234' -bob = 'bob:1234' +alice = "alice:1234" +bob = "bob:1234" occupancy = defaultdict(lambda: 0) @@ -44,7 +54,7 @@ def test_administration(s, a, b): assert str(len(s.ncores)) in repr(s) -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)]) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)]) def test_respect_data_in_memory(c, s, a): x = delayed(inc)(1) y = delayed(inc)(x) @@ -79,28 +89,31 @@ def test_recompute_released_results(c, s, a, b): @gen_cluster(client=True) def test_decide_worker_with_many_independent_leaves(c, s, a, b): - xs = yield [c.scatter(list(range(0, 100, 2)), workers=a.address), - c.scatter(list(range(1, 100, 2)), workers=b.address)] + xs = yield [ + c.scatter(list(range(0, 100, 2)), workers=a.address), + c.scatter(list(range(1, 100, 2)), workers=b.address), + ] xs = list(concat(zip(*xs))) ys = [delayed(inc)(x) for x in xs] y2s = c.persist(ys) yield wait(y2s) - nhits = (sum(y.key in a.data for y in y2s[::2]) + - sum(y.key in b.data for y in y2s[1::2])) + nhits = sum(y.key in a.data for y in y2s[::2]) + sum( + y.key in b.data for y in y2s[1::2] + ) assert nhits > 80 -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 3) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 3) def test_decide_worker_with_restrictions(client, s, a, b, c): x = client.submit(inc, 1, workers=[a.address, b.address]) yield wait(x) assert x.key in a.data or x.key in b.data -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 3) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 3) def test_move_data_over_break_restrictions(client, s, a, b, c): [x] = yield client.scatter([1], workers=b.address) y = client.submit(inc, x, workers=[a.address, b.address]) @@ -108,19 +121,21 @@ def test_move_data_over_break_restrictions(client, s, a, b, c): assert y.key in a.data or y.key in b.data -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 3) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 3) def test_balance_with_restrictions(client, s, a, b, c): - [x], [y] = yield [client.scatter([[1, 2, 3]], workers=a.address), - client.scatter([1], workers=c.address)] + [x], [y] = yield [ + client.scatter([[1, 2, 3]], workers=a.address), + client.scatter([1], workers=c.address), + ] z = client.submit(inc, 1, workers=[a.address, c.address]) yield wait(z) assert s.tasks[z.key].who_has == {s.workers[c.address]} -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 3) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 3) def test_no_valid_workers(client, s, a, b, c): - x = client.submit(inc, 1, workers='127.0.0.5:9999') + x = client.submit(inc, 1, workers="127.0.0.5:9999") while not s.tasks: yield gen.sleep(0.01) @@ -130,10 +145,9 @@ def test_no_valid_workers(client, s, a, b, c): yield gen.with_timeout(timedelta(milliseconds=50), x) -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 3) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 3) def test_no_valid_workers_loose_restrictions(client, s, a, b, c): - x = client.submit(inc, 1, workers='127.0.0.5:9999', - allow_other_workers=True) + x = client.submit(inc, 1, workers="127.0.0.5:9999", allow_other_workers=True) result = yield x assert result == 2 @@ -158,16 +172,17 @@ def test_retire_workers_empty(s): @gen_cluster() def test_remove_client(s, a, b): - s.update_graph(tasks={'x': dumps_task((inc, 1)), - 'y': dumps_task((inc, 'x'))}, - dependencies={'x': [], 'y': ['x']}, - keys=['y'], - client='ident') + s.update_graph( + tasks={"x": dumps_task((inc, 1)), "y": dumps_task((inc, "x"))}, + dependencies={"x": [], "y": ["x"]}, + keys=["y"], + client="ident", + ) assert s.tasks assert s.dependencies - s.remove_client(client='ident') + s.remove_client(client="ident") assert not s.tasks assert not s.dependencies @@ -177,15 +192,18 @@ def test_remove_client(s, a, b): def test_server_listens_to_other_ops(s, a, b): with rpc(s.address) as r: ident = yield r.identity() - assert ident['type'] == 'Scheduler' - assert ident['id'].lower().startswith('scheduler') + assert ident["type"] == "Scheduler" + assert ident["id"].lower().startswith("scheduler") @gen_cluster() def test_remove_worker_from_scheduler(s, a, b): - dsk = {('x-%d' % i): (inc, i) for i in range(20)} - s.update_graph(tasks=valmap(dumps_task, dsk), keys=list(dsk), - dependencies={k: set() for k in dsk}) + dsk = {("x-%d" % i): (inc, i) for i in range(20)} + s.update_graph( + tasks=valmap(dumps_task, dsk), + keys=list(dsk), + dependencies={k: set() for k in dsk}, + ) assert a.address in s.stream_comms s.remove_worker(address=a.address) @@ -214,7 +232,9 @@ def test_clear_events_worker_removal(s, a, b): assert b.address in s.events -@gen_cluster(config={"distributed.scheduler.events-cleanup-delay": "10 ms"}, client=True) +@gen_cluster( + config={"distributed.scheduler.events-cleanup-delay": "10 ms"}, client=True +) def test_clear_events_client_removal(c, s, a, b): assert c.id in s.events s.remove_client(c.id) @@ -234,47 +254,52 @@ def test_clear_events_client_removal(c, s, a, b): @gen_cluster() def test_add_worker(s, a, b): w = Worker(s.ip, s.port, ncores=3) - w.data['x-5'] = 6 - w.data['y'] = 1 + w.data["x-5"] = 6 + w.data["y"] = 1 yield w - dsk = {('x-%d' % i): (inc, i) for i in range(10)} - s.update_graph(tasks=valmap(dumps_task, dsk), keys=list(dsk), client='client', - dependencies={k: set() for k in dsk}) + dsk = {("x-%d" % i): (inc, i) for i in range(10)} + s.update_graph( + tasks=valmap(dumps_task, dsk), + keys=list(dsk), + client="client", + dependencies={k: set() for k in dsk}, + ) - s.add_worker(address=w.address, keys=list(w.data), - ncores=w.ncores, services=s.services) + s.add_worker( + address=w.address, keys=list(w.data), ncores=w.ncores, services=s.services + ) s.validate_state() assert w.ip in s.host_info - assert s.host_info[w.ip]['addresses'] == {a.address, b.address, w.address} + assert s.host_info[w.ip]["addresses"] == {a.address, b.address, w.address} yield w._close() -@gen_cluster(scheduler_kwargs={'blocked_handlers': ['feed']}) +@gen_cluster(scheduler_kwargs={"blocked_handlers": ["feed"]}) def test_blocked_handlers_are_respected(s, a, b): def func(scheduler): return dumps(dict(scheduler.worker_info)) comm = yield connect(s.address) - yield comm.write({'op': 'feed', - 'function': dumps(func), - 'interval': 0.01}) + yield comm.write({"op": "feed", "function": dumps(func), "interval": 0.01}) response = yield comm.read() - assert 'exception' in response - assert isinstance(response['exception'], ValueError) - assert "'feed' handler has been explicitly disallowed" in repr(response['exception']) + assert "exception" in response + assert isinstance(response["exception"], ValueError) + assert "'feed' handler has been explicitly disallowed" in repr( + response["exception"] + ) yield comm.close() def test_scheduler_init_pulls_blocked_handlers_from_config(): - with dask.config.set({'distributed.scheduler.blocked-handlers': ['test-handler']}): + with dask.config.set({"distributed.scheduler.blocked-handlers": ["test-handler"]}): s = Scheduler() - assert s.blocked_handlers == ['test-handler'] + assert s.blocked_handlers == ["test-handler"] @gen_cluster() @@ -283,9 +308,7 @@ def func(scheduler): return dumps(dict(scheduler.worker_info)) comm = yield connect(s.address) - yield comm.write({'op': 'feed', - 'function': dumps(func), - 'interval': 0.01}) + yield comm.write({"op": "feed", "function": dumps(func), "interval": 0.01}) for i in range(5): response = yield comm.read() @@ -302,32 +325,36 @@ def setup(scheduler): def func(scheduler, state): assert state == 1 - return 'OK' + return "OK" def teardown(scheduler, state): - scheduler.flag = 'done' + scheduler.flag = "done" comm = yield connect(s.address) - yield comm.write({'op': 'feed', - 'function': dumps(func), - 'setup': dumps(setup), - 'teardown': dumps(teardown), - 'interval': 0.01}) + yield comm.write( + { + "op": "feed", + "function": dumps(func), + "setup": dumps(setup), + "teardown": dumps(teardown), + "interval": 0.01, + } + ) for i in range(5): response = yield comm.read() - assert response == 'OK' + assert response == "OK" yield comm.close() start = time() - while not hasattr(s, 'flag'): + while not hasattr(s, "flag"): yield gen.sleep(0.01) assert time() - start < 5 @gen_cluster() def test_feed_large_bytestring(s, a, b): - np = pytest.importorskip('numpy') + np = pytest.importorskip("numpy") x = np.ones(10000000) @@ -336,9 +363,7 @@ def func(scheduler): return True comm = yield connect(s.address) - yield comm.write({'op': 'feed', - 'function': dumps(func), - 'interval': 0.05}) + yield comm.write({"op": "feed", "function": dumps(func), "interval": 0.05}) for i in range(5): response = yield comm.read() @@ -349,22 +374,22 @@ def func(scheduler): @gen_cluster(client=True) def test_delete_data(c, s, a, b): - d = yield c.scatter({'x': 1, 'y': 2, 'z': 3}) + d = yield c.scatter({"x": 1, "y": 2, "z": 3}) - assert {ts.key for ts in s.tasks.values() if ts.who_has} == {'x', 'y', 'z'} - assert set(a.data) | set(b.data) == {'x', 'y', 'z'} - assert merge(a.data, b.data) == {'x': 1, 'y': 2, 'z': 3} + assert {ts.key for ts in s.tasks.values() if ts.who_has} == {"x", "y", "z"} + assert set(a.data) | set(b.data) == {"x", "y", "z"} + assert merge(a.data, b.data) == {"x": 1, "y": 2, "z": 3} - del d['x'] - del d['y'] + del d["x"] + del d["y"] start = time() - while set(a.data) | set(b.data) != {'z'}: + while set(a.data) | set(b.data) != {"z"}: yield gen.sleep(0.01) assert time() < start + 5 -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)]) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)]) def test_delete(c, s, a): x = c.submit(inc, 1) yield x @@ -382,33 +407,39 @@ def test_delete(c, s, a): def test_filtered_communication(s, a, b): c = yield connect(s.address) f = yield connect(s.address) - yield c.write({'op': 'register-client', 'client': 'c'}) - yield f.write({'op': 'register-client', 'client': 'f'}) + yield c.write({"op": "register-client", "client": "c"}) + yield f.write({"op": "register-client", "client": "f"}) yield c.read() yield f.read() - assert set(s.client_comms) == {'c', 'f'} - - yield c.write({'op': 'update-graph', - 'tasks': {'x': dumps_task((inc, 1)), - 'y': dumps_task((inc, 'x'))}, - 'dependencies': {'x': [], 'y': ['x']}, - 'client': 'c', - 'keys': ['y']}) - - yield f.write({'op': 'update-graph', - 'tasks': {'x': dumps_task((inc, 1)), - 'z': dumps_task((add, 'x', 10))}, - 'dependencies': {'x': [], 'z': ['x']}, - 'client': 'f', - 'keys': ['z']}) + assert set(s.client_comms) == {"c", "f"} + + yield c.write( + { + "op": "update-graph", + "tasks": {"x": dumps_task((inc, 1)), "y": dumps_task((inc, "x"))}, + "dependencies": {"x": [], "y": ["x"]}, + "client": "c", + "keys": ["y"], + } + ) + + yield f.write( + { + "op": "update-graph", + "tasks": {"x": dumps_task((inc, 1)), "z": dumps_task((add, "x", 10))}, + "dependencies": {"x": [], "z": ["x"]}, + "client": "f", + "keys": ["z"], + } + ) msg, = yield c.read() - assert msg['op'] == 'key-in-memory' - assert msg['key'] == 'y' + assert msg["op"] == "key-in-memory" + assert msg["key"] == "y" msg, = yield f.read() - assert msg['op'] == 'key-in-memory' - assert msg['key'] == 'z' + assert msg["op"] == "key-in-memory" + assert msg["key"] == "z" def test_dumps_function(): @@ -424,26 +455,28 @@ def test_dumps_function(): def test_dumps_task(): d = dumps_task((inc, 1)) - assert set(d) == {'function', 'args'} + assert set(d) == {"function", "args"} f = lambda x, y=2: x + y - d = dumps_task((apply, f, (1,), {'y': 10})) - assert cloudpickle.loads(d['function'])(1, 2) == 3 - assert cloudpickle.loads(d['args']) == (1,) - assert cloudpickle.loads(d['kwargs']) == {'y': 10} + d = dumps_task((apply, f, (1,), {"y": 10})) + assert cloudpickle.loads(d["function"])(1, 2) == 3 + assert cloudpickle.loads(d["args"]) == (1,) + assert cloudpickle.loads(d["kwargs"]) == {"y": 10} d = dumps_task((apply, f, (1,))) - assert cloudpickle.loads(d['function'])(1, 2) == 3 - assert cloudpickle.loads(d['args']) == (1,) - assert set(d) == {'function', 'args'} + assert cloudpickle.loads(d["function"])(1, 2) == 3 + assert cloudpickle.loads(d["args"]) == (1,) + assert set(d) == {"function", "args"} @gen_cluster() def test_ready_remove_worker(s, a, b): - s.update_graph(tasks={'x-%d' % i: dumps_task((inc, i)) for i in range(20)}, - keys=['x-%d' % i for i in range(20)], - client='client', - dependencies={'x-%d' % i: [] for i in range(20)}) + s.update_graph( + tasks={"x-%d" % i: dumps_task((inc, i)) for i in range(20)}, + keys=["x-%d" % i for i in range(20)], + client="client", + dependencies={"x-%d" % i: [] for i in range(20)}, + ) assert all(len(w.processing) > w.ncores for w in s.workers.values()) @@ -472,29 +505,28 @@ def test_restart(c, s, a, b): @gen_cluster() def test_broadcast(s, a, b): - result = yield s.broadcast(msg={'op': 'ping'}) - assert result == {a.address: b'pong', b.address: b'pong'} + result = yield s.broadcast(msg={"op": "ping"}) + assert result == {a.address: b"pong", b.address: b"pong"} - result = yield s.broadcast(msg={'op': 'ping'}, workers=[a.address]) - assert result == {a.address: b'pong'} + result = yield s.broadcast(msg={"op": "ping"}, workers=[a.address]) + assert result == {a.address: b"pong"} - result = yield s.broadcast(msg={'op': 'ping'}, hosts=[a.ip]) - assert result == {a.address: b'pong', b.address: b'pong'} + result = yield s.broadcast(msg={"op": "ping"}, hosts=[a.ip]) + assert result == {a.address: b"pong", b.address: b"pong"} @gen_cluster(Worker=Nanny) def test_broadcast_nanny(s, a, b): - result1 = yield s.broadcast(msg={'op': 'identity'}, nanny=True) - assert all(d['type'] == 'Nanny' for d in result1.values()) + result1 = yield s.broadcast(msg={"op": "identity"}, nanny=True) + assert all(d["type"] == "Nanny" for d in result1.values()) - result2 = yield s.broadcast(msg={'op': 'identity'}, - workers=[a.worker_address], - nanny=True) + result2 = yield s.broadcast( + msg={"op": "identity"}, workers=[a.worker_address], nanny=True + ) assert len(result2) == 1 - assert first(result2.values())['id'] == a.id + assert first(result2.values())["id"] == a.id - result3 = yield s.broadcast(msg={'op': 'identity'}, hosts=[a.ip], - nanny=True) + result3 = yield s.broadcast(msg={"op": "identity"}, hosts=[a.ip], nanny=True) assert result1 == result3 @@ -502,12 +534,12 @@ def test_broadcast_nanny(s, a, b): def test_worker_name(): s = Scheduler(validate=True) s.start(0) - w = yield Worker(s.ip, s.port, name='alice') - assert s.workers[w.address].name == 'alice' - assert s.aliases['alice'] == w.address + w = yield Worker(s.ip, s.port, name="alice") + assert s.workers[w.address].name == "alice" + assert s.aliases["alice"] == w.address with pytest.raises(ValueError): - w2 = yield Worker(s.ip, s.port, name='alice') + w2 = yield Worker(s.ip, s.port, name="alice") yield w2._close() yield s.close() @@ -516,44 +548,51 @@ def test_worker_name(): @gen_test() def test_coerce_address(): - with dask.config.set({'distributed.comm.timeouts.connect': '100ms'}): + with dask.config.set({"distributed.comm.timeouts.connect": "100ms"}): s = Scheduler(validate=True) s.start(0) print("scheduler:", s.address, s.listen_address) - a = Worker(s.ip, s.port, name='alice') + a = Worker(s.ip, s.port, name="alice") b = Worker(s.ip, s.port, name=123) - c = Worker('127.0.0.1', s.port, name='charlie') + c = Worker("127.0.0.1", s.port, name="charlie") yield [a, b, c] - assert s.coerce_address('127.0.0.1:8000') == 'tcp://127.0.0.1:8000' - assert s.coerce_address('[::1]:8000') == 'tcp://[::1]:8000' - assert s.coerce_address('tcp://127.0.0.1:8000') == 'tcp://127.0.0.1:8000' - assert s.coerce_address('tcp://[::1]:8000') == 'tcp://[::1]:8000' - assert s.coerce_address('localhost:8000') in ('tcp://127.0.0.1:8000', 'tcp://[::1]:8000') - assert s.coerce_address(u'localhost:8000') in ('tcp://127.0.0.1:8000', 'tcp://[::1]:8000') + assert s.coerce_address("127.0.0.1:8000") == "tcp://127.0.0.1:8000" + assert s.coerce_address("[::1]:8000") == "tcp://[::1]:8000" + assert s.coerce_address("tcp://127.0.0.1:8000") == "tcp://127.0.0.1:8000" + assert s.coerce_address("tcp://[::1]:8000") == "tcp://[::1]:8000" + assert s.coerce_address("localhost:8000") in ( + "tcp://127.0.0.1:8000", + "tcp://[::1]:8000", + ) + assert s.coerce_address(u"localhost:8000") in ( + "tcp://127.0.0.1:8000", + "tcp://[::1]:8000", + ) assert s.coerce_address(a.address) == a.address # Aliases - assert s.coerce_address('alice') == a.address + assert s.coerce_address("alice") == a.address assert s.coerce_address(123) == b.address - assert s.coerce_address('charlie') == c.address + assert s.coerce_address("charlie") == c.address - assert s.coerce_hostname('127.0.0.1') == '127.0.0.1' - assert s.coerce_hostname('alice') == a.ip + assert s.coerce_hostname("127.0.0.1") == "127.0.0.1" + assert s.coerce_hostname("alice") == a.ip assert s.coerce_hostname(123) == b.ip - assert s.coerce_hostname('charlie') == c.ip - assert s.coerce_hostname('jimmy') == 'jimmy' + assert s.coerce_hostname("charlie") == c.ip + assert s.coerce_hostname("jimmy") == "jimmy" - assert s.coerce_address('zzzt:8000', resolve=False) == 'tcp://zzzt:8000' + assert s.coerce_address("zzzt:8000", resolve=False) == "tcp://zzzt:8000" yield s.close() yield [w._close() for w in [a, b, c]] -@pytest.mark.skipif(sys.platform.startswith('win'), - reason="file descriptors not really a thing") +@pytest.mark.skipif( + sys.platform.startswith("win"), reason="file descriptors not really a thing" +) @gen_cluster(ncores=[]) def test_file_descriptors_dont_leak(s): - psutil = pytest.importorskip('psutil') + psutil = pytest.importorskip("psutil") proc = psutil.Process() before = proc.num_fds() @@ -570,14 +609,18 @@ def test_file_descriptors_dont_leak(s): @gen_cluster() def test_update_graph_culls(s, a, b): - s.update_graph(tasks={'x': dumps_task((inc, 1)), - 'y': dumps_task((inc, 'x')), - 'z': dumps_task((inc, 2))}, - keys=['y'], - dependencies={'y': 'x', 'x': [], 'z': []}, - client='client') - assert 'z' not in s.tasks - assert 'z' not in s.dependencies + s.update_graph( + tasks={ + "x": dumps_task((inc, 1)), + "y": dumps_task((inc, "x")), + "z": dumps_task((inc, 2)), + }, + keys=["y"], + dependencies={"y": "x", "x": [], "z": []}, + client="client", + ) + assert "z" not in s.tasks + assert "z" not in s.dependencies @gen_cluster(ncores=[]) @@ -613,7 +656,7 @@ def test_story(c, s, a, b): @gen_cluster(ncores=[], client=True) def test_scatter_no_workers(c, s): with pytest.raises(gen.TimeoutError): - yield s.scatter(data={'x': 1}, client='alice', timeout=0.1) + yield s.scatter(data={"x": 1}, client="alice", timeout=0.1) start = time() with pytest.raises(gen.TimeoutError): @@ -621,10 +664,9 @@ def test_scatter_no_workers(c, s): assert time() < start + 1.5 w = Worker(s.ip, s.port, ncores=3) - yield [c.scatter(data={'y': 2}, timeout=5), - w._start()] + yield [c.scatter(data={"y": 2}, timeout=5), w._start()] - assert w.data['y'] == 2 + assert w.data["y"] == 2 yield w._close() @@ -645,13 +687,12 @@ def test_retire_workers(c, s, a, b): workers = yield s.retire_workers() assert list(workers) == [a.address] - assert workers[a.address]['ncores'] == a.ncores + assert workers[a.address]["ncores"] == a.ncores assert list(s.ncores) == [b.address] assert s.workers_to_close() == [] - assert s.workers[b.address].has_what == {s.tasks[x.key], - s.tasks[y.key]} + assert s.workers[b.address].has_what == {s.tasks[x.key], s.tasks[y.key]} workers = yield s.retire_workers() assert not workers @@ -671,17 +712,17 @@ def test_retire_workers_n(c, s, a, b): yield s.retire_workers(n=0, close_workers=True) assert len(s.workers) == 0 - while not (a.status.startswith('clos') and b.status.startswith('clos')): + while not (a.status.startswith("clos") and b.status.startswith("clos")): yield gen.sleep(0.01) -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 4) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 4) def test_workers_to_close(cl, s, *workers): - s.task_duration['a'] = 4 - s.task_duration['b'] = 4 - s.task_duration['c'] = 1 + s.task_duration["a"] = 4 + s.task_duration["b"] = 4 + s.task_duration["c"] = 1 - futures = cl.map(slowinc, [1, 1, 1], key=['a-4','b-4','c-1']) + futures = cl.map(slowinc, [1, 1, 1], key=["a-4", "b-4", "c-1"]) while sum(len(w.processing) for w in s.workers.values()) < 3: yield gen.sleep(0.001) @@ -690,24 +731,26 @@ def test_workers_to_close(cl, s, *workers): assert len(wtc) == 1 -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 4) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 4) def test_workers_to_close_grouped(c, s, *workers): - groups = {workers[0].address: 'a', workers[1].address: 'a', - workers[2].address: 'b', workers[3].address: 'b'} + groups = { + workers[0].address: "a", + workers[1].address: "a", + workers[2].address: "b", + workers[3].address: "b", + } def key(ws): return groups[ws.address] - assert (set(s.workers_to_close(key=key)) - == set(w.address for w in workers)) + assert set(s.workers_to_close(key=key)) == set(w.address for w in workers) # Assert that job in one worker blocks closure of group future = c.submit(slowinc, 1, delay=0.2, workers=workers[0].address) while len(s.rprocessing) < 1: yield gen.sleep(0.001) - assert (set(s.workers_to_close(key=key)) - == {workers[2].address, workers[3].address}) + assert set(s.workers_to_close(key=key)) == {workers[2].address, workers[3].address} del future @@ -719,13 +762,14 @@ def key(ws): bv = yield c.scatter("b" * 75, workers=workers[2].address) bv2 = yield c.scatter("b" * 75, workers=workers[3].address) - assert (set(s.workers_to_close(key=key)) - == {workers[0].address, workers[1].address}) + assert set(s.workers_to_close(key=key)) == {workers[0].address, workers[1].address} @gen_cluster(client=True) def test_retire_workers_no_suspicious_tasks(c, s, a, b): - future = c.submit(slowinc, 100, delay=0.5, workers=a.address, allow_other_workers=True) + future = c.submit( + slowinc, 100, delay=0.5, workers=a.address, allow_other_workers=True + ) yield gen.sleep(0.2) yield s.retire_workers(workers=[a.address]) @@ -733,15 +777,15 @@ def test_retire_workers_no_suspicious_tasks(c, s, a, b): @slow -@pytest.mark.skipif(sys.platform.startswith('win'), - reason="file descriptors not really a thing") -@pytest.mark.skipif(sys.version_info < (3, 6), - reason="intermittent failure") +@pytest.mark.skipif( + sys.platform.startswith("win"), reason="file descriptors not really a thing" +) +@pytest.mark.skipif(sys.version_info < (3, 6), reason="intermittent failure") @gen_cluster(client=True, ncores=[], timeout=240) def test_file_descriptors(c, s): yield gen.sleep(0.1) - psutil = pytest.importorskip('psutil') - da = pytest.importorskip('dask.array') + psutil = pytest.importorskip("psutil") + da = pytest.importorskip("dask.array") proc = psutil.Process() num_fds_1 = proc.num_fds() @@ -828,7 +872,7 @@ def test_occupancy_cleardown(c, s, a, b): @nodebug -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 30) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 30) def test_balance_many_workers(c, s, *workers): futures = c.map(slowinc, range(20), delay=0.2) yield wait(futures) @@ -836,9 +880,9 @@ def test_balance_many_workers(c, s, *workers): @nodebug -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 30) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 30) def test_balance_many_workers_2(c, s, *workers): - s.extensions['stealing']._pc.callback_time = 100000000 + s.extensions["stealing"]._pc.callback_time = 100000000 futures = c.map(slowinc, range(90), delay=0.2) yield wait(futures) assert {len(w.has_what) for w in s.workers.values()} == {3} @@ -852,16 +896,15 @@ def test_learn_occupancy_multiple_workers(c, s, a, b): yield wait(x) - assert not any(v == 0.5 for w in s.workers.values() - for v in w.processing.values()) + assert not any(v == 0.5 for w in s.workers.values() for v in w.processing.values()) s.validate_state() @gen_cluster(client=True) def test_include_communication_in_occupancy(c, s, a, b): - s.task_duration['slowadd'] = 0.001 - x = c.submit(mul, b'0', int(BANDWIDTH), workers=a.address) - y = c.submit(mul, b'1', int(BANDWIDTH * 1.5), workers=b.address) + s.task_duration["slowadd"] = 0.001 + x = c.submit(mul, b"0", int(BANDWIDTH), workers=a.address) + y = c.submit(mul, b"1", int(BANDWIDTH * 1.5), workers=b.address) z = c.submit(slowadd, x, y, delay=1) while z.key not in s.tasks or not s.tasks[z.key].processing_on: @@ -896,15 +939,15 @@ def test_worker_arrives_with_processing_data(c, s, a, b): yield gen.sleep(0.01) assert s.get_task_status(keys={x.key, y.key, z.key}) == { - x.key: 'released', - y.key: 'memory', - z.key: 'processing', + x.key: "released", + y.key: "memory", + z.key: "processing", } yield w._close() -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)]) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)]) def test_worker_breaks_and_returns(c, s, a): future = c.submit(slowinc, 1, delay=0.1) for i in range(10): @@ -922,7 +965,7 @@ def test_worker_breaks_and_returns(c, s, a): assert end - start < 1 states = frequencies(ts.state for ts in s.tasks.values()) - assert states == {'memory': 1, 'released': 10} + assert states == {"memory": 1, "released": 10} @gen_cluster(client=True, ncores=[]) @@ -947,9 +990,9 @@ def test_no_workers_to_memory(c, s): yield gen.sleep(0.01) assert s.get_task_status(keys={x.key, y.key, z.key}) == { - x.key: 'released', - y.key: 'memory', - z.key: 'processing', + x.key: "released", + y.key: "memory", + z.key: "processing", } yield w._close() @@ -961,12 +1004,12 @@ def test_no_worker_to_memory_restrictions(c, s, a, b): y = delayed(slowinc)(x, delay=0.4) z = delayed(slowinc)(y, delay=0.4) - yy, zz = c.persist([y, z], workers={(x, y, z): 'alice'}) + yy, zz = c.persist([y, z], workers={(x, y, z): "alice"}) while not s.tasks: yield gen.sleep(0.01) - w = Worker(s.ip, s.port, ncores=1, name='alice') + w = Worker(s.ip, s.port, ncores=1, name="alice") w.put_key_in_memory(y.key, 3) yield w @@ -976,9 +1019,9 @@ def test_no_worker_to_memory_restrictions(c, s, a, b): yield gen.sleep(0.3) assert s.get_task_status(keys={x.key, y.key, z.key}) == { - x.key: 'released', - y.key: 'memory', - z.key: 'processing', + x.key: "released", + y.key: "memory", + z.key: "processing", } yield w._close() @@ -989,9 +1032,9 @@ def f(dask_scheduler=None): return dask_scheduler.address with cluster() as (s, [a, b]): - with Client(s['address'], loop=loop) as c: + with Client(s["address"], loop=loop) as c: address = c.run_on_scheduler(f) - assert address == s['address'] + assert address == s["address"] with pytest.raises(ZeroDivisionError): c.run_on_scheduler(div, 1, 0) @@ -1046,7 +1089,7 @@ def test_close_nanny(c, s, a, b): assert not a.is_alive() assert a.pid is None - while a.status != 'closed': + while a.status != "closed": yield gen.sleep(0.05) assert time() < start + 10 @@ -1055,7 +1098,7 @@ def test_close_nanny(c, s, a, b): def test_retire_workers_close(c, s, a, b): yield s.retire_workers(close_workers=True) assert not s.workers - while a.status != 'closed' and b.status != 'closed': + while a.status != "closed" and b.status != "closed": yield gen.sleep(0.01) @@ -1067,7 +1110,7 @@ def test_retire_nannies_close(c, s, a, b): start = time() - while any(n.status != 'closed' for n in nannies): + while any(n.status != "closed" for n in nannies): yield gen.sleep(0.05) assert time() < start + 10 @@ -1075,16 +1118,15 @@ def test_retire_nannies_close(c, s, a, b): assert not s.workers -@gen_cluster(client=True, ncores=[('127.0.0.1', 2)]) +@gen_cluster(client=True, ncores=[("127.0.0.1", 2)]) def test_fifo_submission(c, s, w): futures = [] for i in range(20): - future = c.submit(slowinc, i, delay=0.1, key='inc-%02d' % i, - fifo_timeout=0.01) + future = c.submit(slowinc, i, delay=0.1, key="inc-%02d" % i, fifo_timeout=0.01) futures.append(future) yield gen.sleep(0.02) yield wait(futures[-1]) - assert futures[10].status == 'finished' + assert futures[10].status == "finished" @gen_test() @@ -1094,30 +1136,29 @@ def test_scheduler_file(): s.start(0) with open(fn) as f: data = json.load(f) - assert data['address'] == s.address + assert data["address"] == s.address c = yield Client(scheduler_file=fn, loop=s.loop, asynchronous=True) yield s.close() -@pytest.mark.xfail(reason='') +@pytest.mark.xfail(reason="") @gen_cluster(client=True, ncores=[]) def test_non_existent_worker(c, s): - with dask.config.set({'distributed.comm.timeouts.connect': '100ms'}): - s.add_worker(address='127.0.0.1:5738', ncores=2, nbytes={}, host_info={}) + with dask.config.set({"distributed.comm.timeouts.connect": "100ms"}): + s.add_worker(address="127.0.0.1:5738", ncores=2, nbytes={}, host_info={}) futures = c.map(inc, range(10)) yield gen.sleep(0.300) assert not s.workers - assert all(ts.state == 'no-worker' for ts in s.tasks.values()) + assert all(ts.state == "no-worker" for ts in s.tasks.values()) -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 3) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 3) def test_correct_bad_time_estimate(c, s, *workers): future = c.submit(slowinc, 1, delay=0) yield wait(future) - futures = [c.submit(slowinc, future, delay=0.1, pure=False) - for i in range(20)] + futures = [c.submit(slowinc, future, delay=0.1, pure=False) for i in range(20)] yield gen.sleep(0.5) @@ -1128,39 +1169,40 @@ def test_correct_bad_time_estimate(c, s, *workers): @gen_test() def test_service_hosts(): - pytest.importorskip('bokeh') + pytest.importorskip("bokeh") from distributed.bokeh.scheduler import BokehScheduler port = 0 for url, expected in [ - ('tcp://0.0.0.0', ('::', '0.0.0.0')), - ('tcp://127.0.0.1', '127.0.0.1'), - ('tcp://127.0.0.1:38275', '127.0.0.1')]: - services = {('bokeh', port): BokehScheduler} + ("tcp://0.0.0.0", ("::", "0.0.0.0")), + ("tcp://127.0.0.1", "127.0.0.1"), + ("tcp://127.0.0.1:38275", "127.0.0.1"), + ]: + services = {("bokeh", port): BokehScheduler} s = Scheduler(services=services) yield s.start(url) - sock = first(s.services['bokeh'].server._http._sockets.values()) + sock = first(s.services["bokeh"].server._http._sockets.values()) if isinstance(expected, tuple): assert sock.getsockname()[0] in expected else: assert sock.getsockname()[0] == expected yield s.close() - port = ('127.0.0.1', 0) - for url in ['tcp://0.0.0.0', 'tcp://127.0.0.1', 'tcp://127.0.0.1:38275']: - services = {('bokeh', port): BokehScheduler} + port = ("127.0.0.1", 0) + for url in ["tcp://0.0.0.0", "tcp://127.0.0.1", "tcp://127.0.0.1:38275"]: + services = {("bokeh", port): BokehScheduler} s = Scheduler(services=services) yield s.start(url) - sock = first(s.services['bokeh'].server._http._sockets.values()) - assert sock.getsockname()[0] == '127.0.0.1' + sock = first(s.services["bokeh"].server._http._sockets.values()) + assert sock.getsockname()[0] == "127.0.0.1" yield s.close() -@gen_cluster(client=True, worker_kwargs={'profile_cycle_interval': 100}) +@gen_cluster(client=True, worker_kwargs={"profile_cycle_interval": 100}) def test_profile_metadata(c, s, a, b): start = time() - 1 futures = c.map(slowinc, range(10), delay=0.05, workers=a.address) @@ -1170,12 +1212,12 @@ def test_profile_metadata(c, s, a, b): meta = yield s.get_profile_metadata(profile_cycle_interval=0.100) now = time() + 1 assert meta - assert all(start < t < now for t, count in meta['counts']) - assert all(0 <= count < 30 for t, count in meta['counts'][:4]) - assert not meta['counts'][-1][1] + assert all(start < t < now for t, count in meta["counts"]) + assert all(0 <= count < 30 for t, count in meta["counts"][:4]) + assert not meta["counts"][-1][1] -@gen_cluster(client=True, worker_kwargs={'profile_cycle_interval': 100}) +@gen_cluster(client=True, worker_kwargs={"profile_cycle_interval": 100}) def test_profile_metadata_keys(c, s, a, b): start = time() - 1 x = c.map(slowinc, range(10), delay=0.05) @@ -1183,8 +1225,8 @@ def test_profile_metadata_keys(c, s, a, b): yield wait(x + y) meta = yield s.get_profile_metadata(profile_cycle_interval=0.100) - assert set(meta['keys']) == {'slowinc', 'slowdec'} - assert len(meta['counts']) == len(meta['keys']['slowinc']) + assert set(meta["keys"]) == {"slowinc", "slowdec"} + assert len(meta["counts"]) == len(meta["keys"]["slowinc"]) @gen_cluster(client=True) @@ -1198,7 +1240,7 @@ def test_cancel_fire_and_forget(c, s, a, b): yield gen.sleep(0.05) yield future.cancel(force=True) - assert future.status == 'cancelled' + assert future.status == "cancelled" assert not s.tasks @@ -1206,10 +1248,10 @@ def test_cancel_fire_and_forget(c, s, a, b): def test_log_tasks_during_restart(c, s, a, b): future = c.submit(sys.exit, 0) yield wait(future) - assert 'exit' in str(s.events) + assert "exit" in str(s.events) -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 2) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 2) def test_reschedule(c, s, a, b): yield c.submit(slowinc, -1, delay=0.1) # learn cost x = c.map(slowinc, range(4), delay=0.1) @@ -1235,18 +1277,19 @@ def test_get_task_status(c, s, a, b): yield wait(future) result = yield a.scheduler.get_task_status(keys=[future.key]) - assert result == {future.key: 'memory'} + assert result == {future.key: "memory"} def test_deque_handler(): from distributed.scheduler import logger + s = Scheduler() deque_handler = s._deque_handler - logger.info('foo123') + logger.info("foo123") assert len(deque_handler.deque) >= 1 msg = deque_handler.deque[-1] - assert 'distributed.scheduler' in deque_handler.format(msg) - assert any(msg.msg == 'foo123' for msg in deque_handler.deque) + assert "distributed.scheduler" in deque_handler.format(msg) + assert any(msg.msg == "foo123" for msg in deque_handler.deque) @gen_cluster(client=True) @@ -1277,10 +1320,10 @@ def test_retries(c, s, a, b): @pytest.mark.xfail(reason="second worker also errant for some reason") -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 3, timeout=5) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 3, timeout=5) def test_mising_data_errant_worker(c, s, w1, w2, w3): - with dask.config.set({'distributed.comm.timeouts.connect': '1s'}): - np = pytest.importorskip('numpy') + with dask.config.set({"distributed.comm.timeouts.connect": "1s"}): + np = pytest.importorskip("numpy") x = c.submit(np.random.random, 10000000, workers=w1.address) yield wait(x) @@ -1295,8 +1338,8 @@ def test_mising_data_errant_worker(c, s, w1, w2, w3): @gen_cluster(client=True) def test_dont_recompute_if_persisted(c, s, a, b): - x = delayed(inc)(1, dask_key_name='x') - y = delayed(inc)(x, dask_key_name='y') + x = delayed(inc)(1, dask_key_name="x") + y = delayed(inc)(x, dask_key_name="y") yy = y.persist() yield wait(yy) @@ -1312,28 +1355,28 @@ def test_dont_recompute_if_persisted(c, s, a, b): @gen_cluster(client=True) def test_dont_recompute_if_persisted_2(c, s, a, b): - x = delayed(inc)(1, dask_key_name='x') - y = delayed(inc)(x, dask_key_name='y') - z = delayed(inc)(y, dask_key_name='z') + x = delayed(inc)(1, dask_key_name="x") + y = delayed(inc)(x, dask_key_name="y") + z = delayed(inc)(y, dask_key_name="z") yy = y.persist() yield wait(yy) - old = s.story('x', 'y') + old = s.story("x", "y") zz = z.persist() yield wait(zz) yield gen.sleep(0.100) - assert s.story('x', 'y') == old + assert s.story("x", "y") == old @gen_cluster(client=True) def test_dont_recompute_if_persisted_3(c, s, a, b): - x = delayed(inc)(1, dask_key_name='x') - y = delayed(inc)(2, dask_key_name='y') - z = delayed(inc)(y, dask_key_name='z') - w = delayed(add)(x, z, dask_key_name='w') + x = delayed(inc)(1, dask_key_name="x") + y = delayed(inc)(2, dask_key_name="y") + z = delayed(inc)(y, dask_key_name="z") + w = delayed(add)(x, z, dask_key_name="w") ww = w.persist() yield wait(ww) @@ -1348,44 +1391,44 @@ def test_dont_recompute_if_persisted_3(c, s, a, b): @gen_cluster(client=True) def test_dont_recompute_if_persisted_4(c, s, a, b): - x = delayed(inc)(1, dask_key_name='x') - y = delayed(inc)(x, dask_key_name='y') - z = delayed(inc)(x, dask_key_name='z') + x = delayed(inc)(1, dask_key_name="x") + y = delayed(inc)(x, dask_key_name="y") + z = delayed(inc)(x, dask_key_name="z") yy = y.persist() yield wait(yy) - old = s.story('x') + old = s.story("x") - while s.tasks['x'].state == 'memory': + while s.tasks["x"].state == "memory": yield gen.sleep(0.01) yyy, zzz = dask.persist(y, z) yield wait([yyy, zzz]) - new = s.story('x') + new = s.story("x") assert len(new) > len(old) @gen_cluster(client=True) def test_dont_forget_released_keys(c, s, a, b): - x = c.submit(inc, 1, key='x') - y = c.submit(inc, x, key='y') - z = c.submit(dec, x, key='z') + x = c.submit(inc, 1, key="x") + y = c.submit(inc, x, key="y") + z = c.submit(dec, x, key="z") del x yield wait([y, z]) del z - while 'z' in s.tasks: + while "z" in s.tasks: yield gen.sleep(0.01) - assert 'x' in s.tasks + assert "x" in s.tasks @gen_cluster(client=True) def test_dont_recompute_if_erred(c, s, a, b): - x = delayed(inc)(1, dask_key_name='x') - y = delayed(div)(x, 0, dask_key_name='y') + x = delayed(inc)(1, dask_key_name="x") + y = delayed(div)(x, 0, dask_key_name="y") yy = y.persist() yield wait(yy) @@ -1404,15 +1447,16 @@ def test_closing_scheduler_closes_workers(s, a, b): yield s.close() start = time() - while a.status != 'closed' or b.status != 'closed': + while a.status != "closed" or b.status != "closed": yield gen.sleep(0.01) assert time() < start + 2 -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)], - worker_kwargs={'resources': {'A': 1}}) +@gen_cluster( + client=True, ncores=[("127.0.0.1", 1)], worker_kwargs={"resources": {"A": 1}} +) def test_resources_reset_after_cancelled_task(c, s, w): - future = c.submit(sleep, 0.2, resources={'A': 1}) + future = c.submit(sleep, 0.2, resources={"A": 1}) while not w.executing: yield gen.sleep(0.01) @@ -1422,33 +1466,33 @@ def test_resources_reset_after_cancelled_task(c, s, w): while w.executing: yield gen.sleep(0.01) - assert not s.workers[w.address].used_resources['A'] - assert w.available_resources == {'A': 1} + assert not s.workers[w.address].used_resources["A"] + assert w.available_resources == {"A": 1} - yield c.submit(inc, 1, resources={'A': 1}) + yield c.submit(inc, 1, resources={"A": 1}) @gen_cluster(client=True) def test_gh2187(c, s, a, b): def foo(): - return 'foo' + return "foo" def bar(x): - return x + 'bar' + return x + "bar" def baz(x): - return x + 'baz' + return x + "baz" def qux(x): sleep(0.1) - return x + 'qux' + return x + "qux" - w = c.submit(foo, key='w') - x = c.submit(bar, w, key='x') - y = c.submit(baz, x, key='y') + w = c.submit(foo, key="w") + x = c.submit(bar, w, key="x") + y = c.submit(baz, x, key="y") yield y - z = c.submit(qux, y, key='z') + z = c.submit(qux, y, key="z") del y yield gen.sleep(0.1) - f = c.submit(bar, x, key='y') + f = c.submit(bar, x, key="y") yield f diff --git a/distributed/tests/test_security.py b/distributed/tests/test_security.py index 7ebd414ca24..8e82db1308e 100644 --- a/distributed/tests/test_security.py +++ b/distributed/tests/test_security.py @@ -16,21 +16,21 @@ from distributed.utils_test import new_config, get_cert, gen_test -ca_file = get_cert('tls-ca-cert.pem') +ca_file = get_cert("tls-ca-cert.pem") -cert1 = get_cert('tls-cert.pem') -key1 = get_cert('tls-key.pem') -keycert1 = get_cert('tls-key-cert.pem') +cert1 = get_cert("tls-cert.pem") +key1 = get_cert("tls-key.pem") +keycert1 = get_cert("tls-key-cert.pem") # Note this cipher uses RSA auth as this matches our test certs -FORCED_CIPHER = 'ECDHE-RSA-AES128-GCM-SHA256' +FORCED_CIPHER = "ECDHE-RSA-AES128-GCM-SHA256" TLS_13_CIPHERS = [ - 'TLS_AES_128_GCM_SHA256', - 'TLS_AES_256_GCM_SHA384', - 'TLS_CHACHA20_POLY1305_SHA256', - 'TLS_AES_128_CCM_SHA256', - 'TLS_AES_128_CCM_8_SHA256', + "TLS_AES_128_GCM_SHA256", + "TLS_AES_256_GCM_SHA384", + "TLS_CHACHA20_POLY1305_SHA256", + "TLS_AES_128_CCM_SHA256", + "TLS_AES_128_CCM_8_SHA256", ] @@ -50,7 +50,7 @@ def test_defaults(): def test_attribute_error(): sec = Security() - assert hasattr(sec, 'tls_ca_file') + assert hasattr(sec, "tls_ca_file") with pytest.raises(AttributeError): sec.tls_foobar with pytest.raises(AttributeError): @@ -59,103 +59,93 @@ def test_attribute_error(): def test_from_config(): c = { - 'tls': { - 'ca-file': 'ca.pem', - 'scheduler': { - 'key': 'skey.pem', - 'cert': 'scert.pem', - }, - 'worker': { - 'cert': 'wcert.pem', - }, - 'ciphers': FORCED_CIPHER, + "tls": { + "ca-file": "ca.pem", + "scheduler": {"key": "skey.pem", "cert": "scert.pem"}, + "worker": {"cert": "wcert.pem"}, + "ciphers": FORCED_CIPHER, }, - 'require-encryption': True, + "require-encryption": True, } with new_config(c): sec = Security() assert sec.require_encryption is True - assert sec.tls_ca_file == 'ca.pem' + assert sec.tls_ca_file == "ca.pem" assert sec.tls_ciphers == FORCED_CIPHER assert sec.tls_client_key is None assert sec.tls_client_cert is None - assert sec.tls_scheduler_key == 'skey.pem' - assert sec.tls_scheduler_cert == 'scert.pem' + assert sec.tls_scheduler_key == "skey.pem" + assert sec.tls_scheduler_cert == "scert.pem" assert sec.tls_worker_key is None - assert sec.tls_worker_cert == 'wcert.pem' + assert sec.tls_worker_cert == "wcert.pem" def test_kwargs(): c = { - 'tls': { - 'ca-file': 'ca.pem', - 'scheduler': { - 'key': 'skey.pem', - 'cert': 'scert.pem', - }, - }, + "tls": { + "ca-file": "ca.pem", + "scheduler": {"key": "skey.pem", "cert": "scert.pem"}, + } } with new_config(c): - sec = Security(tls_scheduler_cert='newcert.pem', - require_encryption=True, - tls_ca_file=None) + sec = Security( + tls_scheduler_cert="newcert.pem", require_encryption=True, tls_ca_file=None + ) assert sec.require_encryption is True # None value didn't override default - assert sec.tls_ca_file == 'ca.pem' + assert sec.tls_ca_file == "ca.pem" assert sec.tls_ciphers is None assert sec.tls_client_key is None assert sec.tls_client_cert is None - assert sec.tls_scheduler_key == 'skey.pem' - assert sec.tls_scheduler_cert == 'newcert.pem' + assert sec.tls_scheduler_key == "skey.pem" + assert sec.tls_scheduler_cert == "newcert.pem" assert sec.tls_worker_key is None assert sec.tls_worker_cert is None def test_repr(): with new_config({}): - sec = Security(tls_ca_file='ca.pem', tls_scheduler_cert='scert.pem') - assert repr(sec) == "Security(tls_ca_file='ca.pem', tls_scheduler_cert='scert.pem')" + sec = Security(tls_ca_file="ca.pem", tls_scheduler_cert="scert.pem") + assert ( + repr(sec) + == "Security(tls_ca_file='ca.pem', tls_scheduler_cert='scert.pem')" + ) def test_tls_config_for_role(): c = { - 'tls': { - 'ca-file': 'ca.pem', - 'scheduler': { - 'key': 'skey.pem', - 'cert': 'scert.pem', - }, - 'worker': { - 'cert': 'wcert.pem', - }, - 'ciphers': FORCED_CIPHER, - }, + "tls": { + "ca-file": "ca.pem", + "scheduler": {"key": "skey.pem", "cert": "scert.pem"}, + "worker": {"cert": "wcert.pem"}, + "ciphers": FORCED_CIPHER, + } } with new_config(c): sec = Security() - t = sec.get_tls_config_for_role('scheduler') + t = sec.get_tls_config_for_role("scheduler") assert t == { - 'ca_file': 'ca.pem', - 'key': 'skey.pem', - 'cert': 'scert.pem', - 'ciphers': FORCED_CIPHER, + "ca_file": "ca.pem", + "key": "skey.pem", + "cert": "scert.pem", + "ciphers": FORCED_CIPHER, } - t = sec.get_tls_config_for_role('worker') + t = sec.get_tls_config_for_role("worker") assert t == { - 'ca_file': 'ca.pem', - 'key': None, - 'cert': 'wcert.pem', - 'ciphers': FORCED_CIPHER, + "ca_file": "ca.pem", + "key": None, + "cert": "wcert.pem", + "ciphers": FORCED_CIPHER, } - t = sec.get_tls_config_for_role('client') + t = sec.get_tls_config_for_role("client") assert t == { - 'ca_file': 'ca.pem', - 'key': None, - 'cert': None, - 'ciphers': FORCED_CIPHER, + "ca_file": "ca.pem", + "key": None, + "cert": None, + "ciphers": FORCED_CIPHER, } with pytest.raises(ValueError): - sec.get_tls_config_for_role('supervisor') + sec.get_tls_config_for_role("supervisor") def test_connection_args(): @@ -168,51 +158,46 @@ def many_ciphers(ctx): assert len(ctx.get_ciphers()) > 2 # Most likely c = { - 'tls': { - 'ca-file': ca_file, - 'scheduler': { - 'key': key1, - 'cert': cert1, - }, - 'worker': { - 'cert': keycert1, - }, - }, + "tls": { + "ca-file": ca_file, + "scheduler": {"key": key1, "cert": cert1}, + "worker": {"cert": keycert1}, + } } with new_config(c): sec = Security() - d = sec.get_connection_args('scheduler') - assert not d['require_encryption'] - ctx = d['ssl_context'] + d = sec.get_connection_args("scheduler") + assert not d["require_encryption"] + ctx = d["ssl_context"] basic_checks(ctx) many_ciphers(ctx) - d = sec.get_connection_args('worker') - ctx = d['ssl_context'] + d = sec.get_connection_args("worker") + ctx = d["ssl_context"] basic_checks(ctx) many_ciphers(ctx) # No cert defined => no TLS - d = sec.get_connection_args('client') - assert d.get('ssl_context') is None + d = sec.get_connection_args("client") + assert d.get("ssl_context") is None # With more settings - c['tls']['ciphers'] = FORCED_CIPHER - c['require-encryption'] = True + c["tls"]["ciphers"] = FORCED_CIPHER + c["require-encryption"] = True with new_config(c): sec = Security() - d = sec.get_listen_args('scheduler') - assert d['require_encryption'] - ctx = d['ssl_context'] + d = sec.get_listen_args("scheduler") + assert d["require_encryption"] + ctx = d["ssl_context"] basic_checks(ctx) if sys.version_info >= (3, 6): supported_ciphers = ctx.get_ciphers() - tls_12_ciphers = [c for c in supported_ciphers if c['protocol'] == 'TLSv1.2'] + tls_12_ciphers = [c for c in supported_ciphers if c["protocol"] == "TLSv1.2"] assert len(tls_12_ciphers) == 1 - tls_13_ciphers = [c for c in supported_ciphers if c['protocol'] == 'TLSv1.3'] + tls_13_ciphers = [c for c in supported_ciphers if c["protocol"] == "TLSv1.3"] if len(tls_13_ciphers): assert len(tls_13_ciphers) == 3 @@ -227,51 +212,46 @@ def many_ciphers(ctx): assert len(ctx.get_ciphers()) > 2 # Most likely c = { - 'tls': { - 'ca-file': ca_file, - 'scheduler': { - 'key': key1, - 'cert': cert1, - }, - 'worker': { - 'cert': keycert1, - }, - }, + "tls": { + "ca-file": ca_file, + "scheduler": {"key": key1, "cert": cert1}, + "worker": {"cert": keycert1}, + } } with new_config(c): sec = Security() - d = sec.get_listen_args('scheduler') - assert not d['require_encryption'] - ctx = d['ssl_context'] + d = sec.get_listen_args("scheduler") + assert not d["require_encryption"] + ctx = d["ssl_context"] basic_checks(ctx) many_ciphers(ctx) - d = sec.get_listen_args('worker') - ctx = d['ssl_context'] + d = sec.get_listen_args("worker") + ctx = d["ssl_context"] basic_checks(ctx) many_ciphers(ctx) # No cert defined => no TLS - d = sec.get_listen_args('client') - assert d.get('ssl_context') is None + d = sec.get_listen_args("client") + assert d.get("ssl_context") is None # With more settings - c['tls']['ciphers'] = FORCED_CIPHER - c['require-encryption'] = True + c["tls"]["ciphers"] = FORCED_CIPHER + c["require-encryption"] = True with new_config(c): sec = Security() - d = sec.get_listen_args('scheduler') - assert d['require_encryption'] - ctx = d['ssl_context'] + d = sec.get_listen_args("scheduler") + assert d["require_encryption"] + ctx = d["ssl_context"] basic_checks(ctx) if sys.version_info >= (3, 6): supported_ciphers = ctx.get_ciphers() - tls_12_ciphers = [c for c in supported_ciphers if c['protocol'] == 'TLSv1.2'] + tls_12_ciphers = [c for c in supported_ciphers if c["protocol"] == "TLSv1.2"] assert len(tls_12_ciphers) == 1 - tls_13_ciphers = [c for c in supported_ciphers if c['protocol'] == 'TLSv1.3'] + tls_13_ciphers = [c for c in supported_ciphers if c["protocol"] == "TLSv1.3"] if len(tls_13_ciphers): assert len(tls_13_ciphers) == 3 @@ -281,49 +261,51 @@ def test_tls_listen_connect(): """ Functional test for TLS connection args. """ + @gen.coroutine def handle_comm(comm): peer_addr = comm.peer_address - assert peer_addr.startswith('tls://') - yield comm.write('hello') + assert peer_addr.startswith("tls://") + yield comm.write("hello") yield comm.close() c = { - 'tls': { - 'ca-file': ca_file, - 'scheduler': { - 'key': key1, - 'cert': cert1, - }, - 'worker': { - 'cert': keycert1, - }, - }, + "tls": { + "ca-file": ca_file, + "scheduler": {"key": key1, "cert": cert1}, + "worker": {"cert": keycert1}, + } } with new_config(c): sec = Security() - c['tls']['ciphers'] = FORCED_CIPHER + c["tls"]["ciphers"] = FORCED_CIPHER with new_config(c): forced_cipher_sec = Security() - with listen('tls://', handle_comm, - connection_args=sec.get_listen_args('scheduler')) as listener: - comm = yield connect(listener.contact_address, - connection_args=sec.get_connection_args('worker')) + with listen( + "tls://", handle_comm, connection_args=sec.get_listen_args("scheduler") + ) as listener: + comm = yield connect( + listener.contact_address, connection_args=sec.get_connection_args("worker") + ) msg = yield comm.read() - assert msg == 'hello' + assert msg == "hello" comm.abort() # No SSL context for client with pytest.raises(TypeError): - yield connect(listener.contact_address, - connection_args=sec.get_connection_args('client')) + yield connect( + listener.contact_address, + connection_args=sec.get_connection_args("client"), + ) # Check forced cipher - comm = yield connect(listener.contact_address, - connection_args=forced_cipher_sec.get_connection_args('worker')) - cipher, _, _, = comm.extra_info['cipher'] + comm = yield connect( + listener.contact_address, + connection_args=forced_cipher_sec.get_connection_args("worker"), + ) + cipher, _, _, = comm.extra_info["cipher"] assert cipher in [FORCED_CIPHER] + TLS_13_CIPHERS comm.abort() @@ -333,39 +315,41 @@ def test_require_encryption(): """ Functional test for "require_encryption" setting. """ + @gen.coroutine def handle_comm(comm): comm.abort() c = { - 'tls': { - 'ca-file': ca_file, - 'scheduler': { - 'key': key1, - 'cert': cert1, - }, - 'worker': { - 'cert': keycert1, - }, - }, + "tls": { + "ca-file": ca_file, + "scheduler": {"key": key1, "cert": cert1}, + "worker": {"cert": keycert1}, + } } with new_config(c): sec = Security() - c['require-encryption'] = True + c["require-encryption"] = True with new_config(c): sec2 = Security() - for listen_addr in ['inproc://', 'tls://']: - with listen(listen_addr, handle_comm, - connection_args=sec.get_listen_args('scheduler')) as listener: - comm = yield connect(listener.contact_address, - connection_args=sec2.get_connection_args('worker')) + for listen_addr in ["inproc://", "tls://"]: + with listen( + listen_addr, handle_comm, connection_args=sec.get_listen_args("scheduler") + ) as listener: + comm = yield connect( + listener.contact_address, + connection_args=sec2.get_connection_args("worker"), + ) comm.abort() - with listen(listen_addr, handle_comm, - connection_args=sec2.get_listen_args('scheduler')) as listener: - comm = yield connect(listener.contact_address, - connection_args=sec2.get_connection_args('worker')) + with listen( + listen_addr, handle_comm, connection_args=sec2.get_listen_args("scheduler") + ) as listener: + comm = yield connect( + listener.contact_address, + connection_args=sec2.get_connection_args("worker"), + ) comm.abort() @contextmanager @@ -374,17 +358,25 @@ def check_encryption_error(): yield assert "encryption required" in str(excinfo.value) - for listen_addr in ['tcp://']: - with listen(listen_addr, handle_comm, - connection_args=sec.get_listen_args('scheduler')) as listener: - comm = yield connect(listener.contact_address, - connection_args=sec.get_connection_args('worker')) + for listen_addr in ["tcp://"]: + with listen( + listen_addr, handle_comm, connection_args=sec.get_listen_args("scheduler") + ) as listener: + comm = yield connect( + listener.contact_address, + connection_args=sec.get_connection_args("worker"), + ) comm.abort() with pytest.raises(RuntimeError): - yield connect(listener.contact_address, - connection_args=sec2.get_connection_args('worker')) + yield connect( + listener.contact_address, + connection_args=sec2.get_connection_args("worker"), + ) with pytest.raises(RuntimeError): - listen(listen_addr, handle_comm, - connection_args=sec2.get_listen_args('scheduler')) + listen( + listen_addr, + handle_comm, + connection_args=sec2.get_listen_args("scheduler"), + ) diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index caaa939b665..f93022e6d81 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -15,10 +15,15 @@ from distributed.config import config from distributed.metrics import time from distributed.scheduler import BANDWIDTH, key_split -from distributed.utils_test import (slowinc, slowadd, inc, gen_cluster, - slowidentity, captured_logger) -from distributed.utils_test import (nodebug_setup_module, - nodebug_teardown_module) +from distributed.utils_test import ( + slowinc, + slowadd, + inc, + gen_cluster, + slowidentity, + captured_logger, +) +from distributed.utils_test import nodebug_setup_module, nodebug_teardown_module from distributed.worker import TOTAL_MEMORY import pytest @@ -29,10 +34,10 @@ teardown_module = nodebug_teardown_module -@pytest.mark.skipif(not sys.platform.startswith('linux'), - reason="Need 127.0.0.2 to mean localhost") -@gen_cluster(client=True, ncores=[('127.0.0.1', 2), ('127.0.0.2', 2)], - timeout=20) +@pytest.mark.skipif( + not sys.platform.startswith("linux"), reason="Need 127.0.0.2 to mean localhost" +) +@gen_cluster(client=True, ncores=[("127.0.0.1", 2), ("127.0.0.2", 2)], timeout=20) def test_work_stealing(c, s, a, b): [x] = yield c._scatter([1], workers=a.address) futures = c.map(slowadd, range(50), [x] * 50) @@ -41,37 +46,40 @@ def test_work_stealing(c, s, a, b): assert len(b.data) > 10 -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 2) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 2) def test_dont_steal_expensive_data_fast_computation(c, s, a, b): - np = pytest.importorskip('numpy') + np = pytest.importorskip("numpy") x = c.submit(np.arange, 1000000, workers=a.address) yield wait([x]) future = c.submit(np.sum, [1], workers=a.address) # learn that sum is fast yield wait([future]) - cheap = [c.submit(np.sum, x, pure=False, workers=a.address, - allow_other_workers=True) for i in range(10)] + cheap = [ + c.submit(np.sum, x, pure=False, workers=a.address, allow_other_workers=True) + for i in range(10) + ] yield wait(cheap) assert len(s.who_has[x.key]) == 1 assert len(b.data) == 0 assert len(a.data) == 12 -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 2) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 2) def test_steal_cheap_data_slow_computation(c, s, a, b): x = c.submit(slowinc, 100, delay=0.1) # learn that slowinc is slow yield wait(x) - futures = c.map(slowinc, range(10), delay=0.1, workers=a.address, - allow_other_workers=True) + futures = c.map( + slowinc, range(10), delay=0.1, workers=a.address, allow_other_workers=True + ) yield wait(futures) assert abs(len(a.data) - len(b.data)) <= 5 @pytest.mark.avoid_travis -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 2) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 2) def test_steal_expensive_data_slow_computation(c, s, a, b): - np = pytest.importorskip('numpy') + np = pytest.importorskip("numpy") x = c.submit(slowinc, 100, delay=0.2, workers=a.address) yield wait(x) # learn that slowinc is slow @@ -86,7 +94,7 @@ def test_steal_expensive_data_slow_computation(c, s, a, b): assert b.data # not empty -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 10) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 10) def test_worksteal_many_thieves(c, s, *workers): x = c.submit(slowinc, -1, delay=0.1) yield x @@ -102,7 +110,7 @@ def test_worksteal_many_thieves(c, s, *workers): assert sum(map(len, s.has_what.values())) < 150 -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 2) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 2) def test_dont_steal_unknown_functions(c, s, a, b): futures = c.map(inc, [1, 2], workers=a.address, allow_other_workers=True) yield wait(futures) @@ -110,20 +118,22 @@ def test_dont_steal_unknown_functions(c, s, a, b): assert len(b.data) == 0 -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 2) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 2) def test_eventually_steal_unknown_functions(c, s, a, b): - futures = c.map(slowinc, range(10), delay=0.1, workers=a.address, - allow_other_workers=True) + futures = c.map( + slowinc, range(10), delay=0.1, workers=a.address, allow_other_workers=True + ) yield wait(futures) assert len(a.data) >= 3 assert len(b.data) >= 3 -@pytest.mark.skip(reason='') -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 3) +@pytest.mark.skip(reason="") +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 3) def test_steal_related_tasks(e, s, a, b, c): - futures = e.map(slowinc, range(20), delay=0.05, workers=a.address, - allow_other_workers=True) + futures = e.map( + slowinc, range(20), delay=0.05, workers=a.address, allow_other_workers=True + ) yield wait(futures) @@ -135,9 +145,9 @@ def test_steal_related_tasks(e, s, a, b, c): assert nearby > 10 -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 10, timeout=1000) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 10, timeout=1000) def test_dont_steal_fast_tasks(c, s, *workers): - np = pytest.importorskip('numpy') + np = pytest.importorskip("numpy") x = c.submit(np.random.random, 10000000, workers=workers[0].address) def do_nothing(x, y=None): @@ -153,7 +163,7 @@ def do_nothing(x, y=None): assert len(s.has_what[workers[0].address]) == 1001 -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)], timeout=20) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)], timeout=20) def test_new_worker_steals(c, s, a): yield wait(c.submit(slowinc, 1, delay=0.01)) @@ -179,8 +189,9 @@ def test_new_worker_steals(c, s, a): def test_work_steal_no_kwargs(c, s, a, b): yield wait(c.submit(slowinc, 1, delay=0.05)) - futures = c.map(slowinc, range(100), workers=a.address, - allow_other_workers=True, delay=0.05) + futures = c.map( + slowinc, range(100), workers=a.address, allow_other_workers=True, delay=0.05 + ) yield wait(futures) @@ -193,7 +204,7 @@ def test_work_steal_no_kwargs(c, s, a, b): assert result == sum(map(inc, range(100))) -@gen_cluster(client=True, ncores=[('127.0.0.1', 1), ('127.0.0.1', 2)]) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1), ("127.0.0.1", 2)]) def test_dont_steal_worker_restrictions(c, s, a, b): future = c.submit(slowinc, 1, delay=0.10, workers=a.address) yield future @@ -206,7 +217,7 @@ def test_dont_steal_worker_restrictions(c, s, a, b): assert len(a.task_state) == 100 assert len(b.task_state) == 0 - result = s.extensions['stealing'].balance() + result = s.extensions["stealing"].balance() yield gen.sleep(0.1) @@ -214,58 +225,59 @@ def test_dont_steal_worker_restrictions(c, s, a, b): assert len(b.task_state) == 0 -@pytest.mark.skipif(not sys.platform.startswith('linux'), - reason="Need 127.0.0.2 to mean localhost") -@gen_cluster(client=True, ncores=[('127.0.0.1', 1), ('127.0.0.2', 1)]) +@pytest.mark.skipif( + not sys.platform.startswith("linux"), reason="Need 127.0.0.2 to mean localhost" +) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1), ("127.0.0.2", 1)]) def test_dont_steal_host_restrictions(c, s, a, b): future = c.submit(slowinc, 1, delay=0.10, workers=a.address) yield future - futures = c.map(slowinc, range(100), delay=0.1, workers='127.0.0.1') + futures = c.map(slowinc, range(100), delay=0.1, workers="127.0.0.1") while len(a.task_state) < 10: yield gen.sleep(0.01) assert len(a.task_state) == 100 assert len(b.task_state) == 0 - result = s.extensions['stealing'].balance() + result = s.extensions["stealing"].balance() yield gen.sleep(0.1) assert len(a.task_state) == 100 assert len(b.task_state) == 0 -@gen_cluster(client=True, ncores=[('127.0.0.1', 1, {'resources': {'A': 2}}), - ('127.0.0.1', 1)]) +@gen_cluster( + client=True, ncores=[("127.0.0.1", 1, {"resources": {"A": 2}}), ("127.0.0.1", 1)] +) def test_dont_steal_resource_restrictions(c, s, a, b): future = c.submit(slowinc, 1, delay=0.10, workers=a.address) yield future - futures = c.map(slowinc, range(100), delay=0.1, resources={'A': 1}) + futures = c.map(slowinc, range(100), delay=0.1, resources={"A": 1}) while len(a.task_state) < 10: yield gen.sleep(0.01) assert len(a.task_state) == 100 assert len(b.task_state) == 0 - result = s.extensions['stealing'].balance() + result = s.extensions["stealing"].balance() yield gen.sleep(0.1) assert len(a.task_state) == 100 assert len(b.task_state) == 0 -@pytest.mark.skip(reason='no stealing of resources') -@gen_cluster(client=True, ncores=[('127.0.0.1', 1, {'resources': {'A': 2}})], - timeout=3) +@pytest.mark.skip(reason="no stealing of resources") +@gen_cluster(client=True, ncores=[("127.0.0.1", 1, {"resources": {"A": 2}})], timeout=3) def test_steal_resource_restrictions(c, s, a): future = c.submit(slowinc, 1, delay=0.10, workers=a.address) yield future - futures = c.map(slowinc, range(100), delay=0.2, resources={'A': 1}) + futures = c.map(slowinc, range(100), delay=0.2, resources={"A": 1}) while len(a.task_state) < 101: yield gen.sleep(0.01) assert len(a.task_state) == 101 - b = yield Worker(s.ip, s.port, loop=s.loop, ncores=1, resources={'A': 4}) + b = yield Worker(s.ip, s.port, loop=s.loop, ncores=1, resources={"A": 4}) start = time() while not b.task_state or len(a.task_state) == 101: @@ -278,14 +290,15 @@ def test_steal_resource_restrictions(c, s, a): yield b._close() -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 5, timeout=20) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 5, timeout=20) def test_balance_without_dependencies(c, s, *workers): - s.extensions['stealing']._pc.callback_time = 20 + s.extensions["stealing"]._pc.callback_time = 20 def slow(x): y = random.random() * 0.1 sleep(y) return y + futures = c.map(slow, range(100)) yield wait(futures) @@ -293,22 +306,23 @@ def slow(x): assert max(durations) / min(durations) < 3 -@gen_cluster(client=True, ncores=[('127.0.0.1', 4)] * 2) +@gen_cluster(client=True, ncores=[("127.0.0.1", 4)] * 2) def test_dont_steal_executing_tasks(c, s, a, b): - futures = c.map(slowinc, range(4), delay=0.1, workers=a.address, - allow_other_workers=True) + futures = c.map( + slowinc, range(4), delay=0.1, workers=a.address, allow_other_workers=True + ) yield wait(futures) assert len(a.data) == 4 assert len(b.data) == 0 -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 10) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 10) def test_dont_steal_few_saturated_tasks_many_workers(c, s, a, *rest): - s.extensions['stealing']._pc.callback_time = 20 - x = c.submit(mul, b'0', 100000000, workers=a.address) # 100 MB + s.extensions["stealing"]._pc.callback_time = 20 + x = c.submit(mul, b"0", 100000000, workers=a.address) # 100 MB yield wait(x) - s.task_duration['slowidentity'] = 0.2 + s.task_duration["slowidentity"] = 0.2 futures = [c.submit(slowidentity, x, pure=False, delay=0.2) for i in range(2)] @@ -318,16 +332,18 @@ def test_dont_steal_few_saturated_tasks_many_workers(c, s, a, *rest): assert not any(w.task_state for w in rest) -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 10, - worker_kwargs={'memory_limit': TOTAL_MEMORY}) +@gen_cluster( + client=True, + ncores=[("127.0.0.1", 1)] * 10, + worker_kwargs={"memory_limit": TOTAL_MEMORY}, +) def test_steal_when_more_tasks(c, s, a, *rest): - s.extensions['stealing']._pc.callback_time = 20 - x = c.submit(mul, b'0', 50000000, workers=a.address) # 50 MB + s.extensions["stealing"]._pc.callback_time = 20 + x = c.submit(mul, b"0", 50000000, workers=a.address) # 50 MB yield wait(x) - s.task_duration['slowidentity'] = 0.2 + s.task_duration["slowidentity"] = 0.2 - futures = [c.submit(slowidentity, x, pure=False, delay=0.2) - for i in range(20)] + futures = [c.submit(slowidentity, x, pure=False, delay=0.2) for i in range(20)] start = time() while not any(w.task_state for w in rest): @@ -335,22 +351,20 @@ def test_steal_when_more_tasks(c, s, a, *rest): assert time() < start + 1 -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 10) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 10) def test_steal_more_attractive_tasks(c, s, a, *rest): - def slow2(x): sleep(1) return x - s.extensions['stealing']._pc.callback_time = 20 - x = c.submit(mul, b'0', 100000000, workers=a.address) # 100 MB + s.extensions["stealing"]._pc.callback_time = 20 + x = c.submit(mul, b"0", 100000000, workers=a.address) # 100 MB yield wait(x) - s.task_duration['slowidentity'] = 0.2 - s.task_duration['slow2'] = 1 + s.task_duration["slowidentity"] = 0.2 + s.task_duration["slow2"] = 1 - futures = [c.submit(slowidentity, x, pure=False, delay=0.2) - for i in range(10)] + futures = [c.submit(slowidentity, x, pure=False, delay=0.2) for i in range(10)] future = c.submit(slow2, x, priority=-1) while not any(w.task_state for w in rest): @@ -365,7 +379,7 @@ def func(x): def assert_balanced(inp, expected, c, s, *workers): - steal = s.extensions['stealing'] + steal = s.extensions["stealing"] steal._pc.stop() counter = itertools.count() @@ -387,9 +401,15 @@ def assert_balanced(inp, expected, c, s, *workers): dat = 123 s.task_duration[str(int(t))] = 1 i = next(counter) - f = c.submit(func, dat, key='%d-%d' % (int(t), i), - workers=w.address, allow_other_workers=True, - pure=False, priority=-i) + f = c.submit( + func, + dat, + key="%d-%d" % (int(t), i), + workers=w.address, + allow_other_workers=True, + pure=False, + priority=-i, + ) futures.append(f) while len(s.rprocessing) < len(futures): @@ -401,90 +421,71 @@ def assert_balanced(inp, expected, c, s, *workers): while steal.in_flight: yield gen.sleep(0.001) - result = [sorted([int(key_split(k)) for k in s.processing[w.address]], - reverse=True) - for w in workers] + result = [ + sorted([int(key_split(k)) for k in s.processing[w.address]], reverse=True) + for w in workers + ] result2 = sorted(result, reverse=True) expected2 = sorted(expected, reverse=True) - if config.get('pdb-on-err'): + if config.get("pdb-on-err"): if result2 != expected2: import pdb + pdb.set_trace() if result2 == expected2: return - raise Exception('Expected: {}; got: {}'.format(str(expected2), str(result2))) - - -@pytest.mark.parametrize('inp,expected', [ - ([[1], []], # don't move unnecessarily - [[1], []]), - - ([[0, 0], []], # balance - [[0], [0]]), - - ([[0.1, 0.1], []], # balance even if results in even - [[0], [0]]), - - ([[0, 0, 0], []], # don't over balance - [[0, 0], [0]]), - - ([[0, 0], [0, 0, 0], []], # move from larger - [[0, 0], [0, 0], [0]]), - - ([[0, 0, 0], [0], []], # move to smaller - [[0, 0], [0], [0]]), - - ([[0, 1], []], # choose easier first - [[1], [0]]), - - ([[0, 0, 0, 0], [], []], # spread evenly - [[0, 0], [0], [0]]), - - ([[1, 0, 2, 0], [], []], # move easier - [[2, 1], [0], [0]]), - - ([[1, 1, 1], []], # be willing to move costly items - [[1, 1], [1]]), - - ([[1, 1, 1, 1], []], # but don't move too many - [[1, 1, 1], [1]]), - - ([[0, 0], [0, 0], [0, 0], []], # no one clearly saturated - [[0, 0], [0, 0], [0], [0]]), - - ([[4, 2, 2, 2, 2, 1, 1], - [4, 2, 1, 1], - [], - [], - []], - [[4, 2, 2, 2, 2], - [4, 2, 1], - [1], - [1], - [1]]), - - pytest.param([[1, 1, 1, 1, 1, 1, 1], [1, 1], [1, 1], [1, 1], []], - [[1, 1, 1, 1, 1], [1, 1], [1, 1], [1, 1], [1, 1]], - marks=pytest.mark.xfail(reason="Some uncertainty based on executing stolen task")) -]) + raise Exception("Expected: {}; got: {}".format(str(expected2), str(result2))) + + +@pytest.mark.parametrize( + "inp,expected", + [ + ([[1], []], [[1], []]), # don't move unnecessarily + ([[0, 0], []], [[0], [0]]), # balance + ([[0.1, 0.1], []], [[0], [0]]), # balance even if results in even + ([[0, 0, 0], []], [[0, 0], [0]]), # don't over balance + ([[0, 0], [0, 0, 0], []], [[0, 0], [0, 0], [0]]), # move from larger + ([[0, 0, 0], [0], []], [[0, 0], [0], [0]]), # move to smaller + ([[0, 1], []], [[1], [0]]), # choose easier first + ([[0, 0, 0, 0], [], []], [[0, 0], [0], [0]]), # spread evenly + ([[1, 0, 2, 0], [], []], [[2, 1], [0], [0]]), # move easier + ([[1, 1, 1], []], [[1, 1], [1]]), # be willing to move costly items + ([[1, 1, 1, 1], []], [[1, 1, 1], [1]]), # but don't move too many + ( + [[0, 0], [0, 0], [0, 0], []], # no one clearly saturated + [[0, 0], [0, 0], [0], [0]], + ), + ( + [[4, 2, 2, 2, 2, 1, 1], [4, 2, 1, 1], [], [], []], + [[4, 2, 2, 2, 2], [4, 2, 1], [1], [1], [1]], + ), + pytest.param( + [[1, 1, 1, 1, 1, 1, 1], [1, 1], [1, 1], [1, 1], []], + [[1, 1, 1, 1, 1], [1, 1], [1, 1], [1, 1], [1, 1]], + marks=pytest.mark.xfail( + reason="Some uncertainty based on executing stolen task" + ), + ), + ], +) def test_balance(inp, expected): test = lambda *args, **kwargs: assert_balanced(inp, expected, *args, **kwargs) - test = gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * len(inp))(test) + test = gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * len(inp))(test) test() -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 2, Worker=Nanny, - timeout=20) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 2, Worker=Nanny, timeout=20) def test_restart(c, s, a, b): - futures = c.map(slowinc, range(100), delay=0.1, workers=a.address, - allow_other_workers=True) + futures = c.map( + slowinc, range(100), delay=0.1, workers=a.address, allow_other_workers=True + ) while not s.processing[b.worker_address]: yield gen.sleep(0.01) - steal = s.extensions['stealing'] + steal = s.extensions["stealing"] assert any(st for st in steal.stealable_all) assert any(x for L in steal.stealable.values() for x in L) @@ -496,14 +497,23 @@ def test_restart(c, s, a, b): @gen_cluster(client=True) def test_steal_communication_heavy_tasks(c, s, a, b): - steal = s.extensions['stealing'] - s.task_duration['slowadd'] = 0.001 - x = c.submit(mul, b'0', int(BANDWIDTH), workers=a.address) - y = c.submit(mul, b'1', int(BANDWIDTH), workers=b.address) - - futures = [c.submit(slowadd, x, y, delay=1, pure=False, workers=a.address, - allow_other_workers=True) - for i in range(10)] + steal = s.extensions["stealing"] + s.task_duration["slowadd"] = 0.001 + x = c.submit(mul, b"0", int(BANDWIDTH), workers=a.address) + y = c.submit(mul, b"1", int(BANDWIDTH), workers=b.address) + + futures = [ + c.submit( + slowadd, + x, + y, + delay=1, + pure=False, + workers=a.address, + allow_other_workers=True, + ) + for i in range(10) + ] while not any(f.key in s.rprocessing for f in futures): yield gen.sleep(0.01) @@ -533,8 +543,10 @@ def test_steal_twice(c, s, a, b): has_what = dict(s.has_what) # take snapshot empty_workers = [w for w, keys in has_what.items() if not len(keys)] if len(empty_workers) > 2: - pytest.fail("Too many workers without keys (%d out of %d)" - % (len(empty_workers), len(has_what))) + pytest.fail( + "Too many workers without keys (%d out of %d)" + % (len(empty_workers), len(has_what)) + ) assert max(map(len, has_what.values())) < 30 yield c._close() @@ -543,20 +555,21 @@ def test_steal_twice(c, s, a, b): @gen_cluster(client=True) def test_dont_steal_executing_tasks(c, s, a, b): - steal = s.extensions['stealing'] + steal = s.extensions["stealing"] future = c.submit(slowinc, 1, delay=0.5, workers=a.address) while not a.executing: yield gen.sleep(0.01) - steal.move_task_request(s.tasks[future.key], - s.workers[a.address], s.workers[b.address]) + steal.move_task_request( + s.tasks[future.key], s.workers[a.address], s.workers[b.address] + ) yield gen.sleep(0.1) assert future.key in a.executing assert not b.executing -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 2) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 2) def test_dont_steal_long_running_tasks(c, s, a, b): def long(delay): with worker_client() as c: @@ -565,13 +578,12 @@ def long(delay): yield c.submit(long, 0.1) # learn duration yield c.submit(inc, 1) # learn duration - long_tasks = c.map(long, [0.5, 0.6], workers=a.address, - allow_other_workers=True) + long_tasks = c.map(long, [0.5, 0.6], workers=a.address, allow_other_workers=True) while sum(map(len, s.processing.values())) < 2: # let them start yield gen.sleep(0.01) start = time() - while any(t.key in s.extensions['stealing'].key_stealable for t in long_tasks): + while any(t.key in s.extensions["stealing"].key_stealable for t in long_tasks): yield gen.sleep(0.01) assert time() < start + 1 @@ -585,21 +597,24 @@ def long(delay): yield wait(long_tasks) for t in long_tasks: - assert (sum(log[1] == 'executing' for log in a.story(t)) + - sum(log[1] == 'executing' for log in b.story(t))) <= 1 + assert ( + sum(log[1] == "executing" for log in a.story(t)) + + sum(log[1] == "executing" for log in b.story(t)) + ) <= 1 -@gen_cluster(client=True, ncores=[('127.0.0.1', 5)] * 2) +@gen_cluster(client=True, ncores=[("127.0.0.1", 5)] * 2) def test_cleanup_repeated_tasks(c, s, a, b): class Foo(object): pass - s.extensions['stealing']._pc.callback_time = 20 + s.extensions["stealing"]._pc.callback_time = 20 yield c.submit(slowidentity, -1, delay=0.1) objects = [c.submit(Foo, pure=False, workers=a.address) for _ in range(50)] - x = c.map(slowidentity, objects, workers=a.address, allow_other_workers=True, - delay=0.05) + x = c.map( + slowidentity, objects, workers=a.address, allow_other_workers=True, delay=0.05 + ) del objects yield wait(x) assert a.data and b.data @@ -620,15 +635,21 @@ class Foo(object): assert not list(ws) -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 2) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 2) def test_lose_task(c, s, a, b): - with captured_logger('distributed.stealing') as log: - s.periodic_callbacks['stealing'].interval = 1 + with captured_logger("distributed.stealing") as log: + s.periodic_callbacks["stealing"].interval = 1 for i in range(100): - futures = c.map(slowinc, range(10), delay=0.01, pure=False, - workers=a.address, allow_other_workers=True) + futures = c.map( + slowinc, + range(10), + delay=0.01, + pure=False, + workers=a.address, + allow_other_workers=True, + ) yield gen.sleep(0.01) del futures out = log.getvalue() - assert 'Error' not in out + assert "Error" not in out diff --git a/distributed/tests/test_stress.py b/distributed/tests/test_stress.py index 50edba30d81..8a36b8b3b94 100644 --- a/distributed/tests/test_stress.py +++ b/distributed/tests/test_stress.py @@ -15,9 +15,21 @@ from distributed.config import config from distributed.metrics import time from distributed.utils import All -from distributed.utils_test import (gen_cluster, cluster, inc, slowinc, - slowadd, slow, slowsum, bump_rlimit) -from distributed.utils_test import (loop, nodebug_setup_module, nodebug_teardown_module) # noqa: F401 +from distributed.utils_test import ( + gen_cluster, + cluster, + inc, + slowinc, + slowadd, + slow, + slowsum, + bump_rlimit, +) +from distributed.utils_test import ( # noqa: F401 + loop, + nodebug_setup_module, + nodebug_teardown_module, +) from distributed.client import wait from tornado import gen @@ -29,21 +41,20 @@ @gen_cluster(client=True) def test_stress_1(c, s, a, b): - n = 2**6 + n = 2 ** 6 seq = c.map(inc, range(n)) while len(seq) > 1: yield gen.sleep(0.1) - seq = [c.submit(add, seq[i], seq[i + 1]) - for i in range(0, len(seq), 2)] + seq = [c.submit(add, seq[i], seq[i + 1]) for i in range(0, len(seq), 2)] result = yield seq[0] assert result == sum(map(inc, range(n))) -@pytest.mark.parametrize(('func', 'n'), [(slowinc, 100), (inc, 1000)]) +@pytest.mark.parametrize(("func", "n"), [(slowinc, 100), (inc, 1000)]) def test_stress_gc(loop, func, n): with cluster() as (s, [a, b]): - with Client(s['address'], loop=loop) as c: + with Client(s["address"], loop=loop) as c: x = c.submit(func, 1) for i in range(n): x = c.submit(func, x) @@ -51,11 +62,12 @@ def test_stress_gc(loop, func, n): assert x.result() == n + 2 -@pytest.mark.skipif(sys.platform.startswith('win'), - reason="test can leave dangling RPC objects") -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 8, timeout=None) +@pytest.mark.skipif( + sys.platform.startswith("win"), reason="test can leave dangling RPC objects" +) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 8, timeout=None) def test_cancel_stress(c, s, *workers): - da = pytest.importorskip('dask.array') + da = pytest.importorskip("dask.array") x = da.random.random((50, 50), chunks=(2, 2)) x = c.persist(x) yield wait([x]) @@ -69,10 +81,10 @@ def test_cancel_stress(c, s, *workers): def test_cancel_stress_sync(loop): - da = pytest.importorskip('dask.array') + da = pytest.importorskip("dask.array") x = da.random.random((50, 50), chunks=(2, 2)) with cluster(active_rpc_timeout=10) as (s, [a, b]): - with Client(s['address'], loop=loop) as c: + with Client(s["address"], loop=loop) as c: x = c.persist(x) y = (x.sum(axis=0) + x.sum(axis=1) + 1).std() wait(x) @@ -86,7 +98,7 @@ def test_cancel_stress_sync(loop): def test_stress_creation_and_deletion(c, s): # Assertions are handled by the validate mechanism in the scheduler s.allowed_failures = 100000 - da = pytest.importorskip('dask.array') + da = pytest.importorskip("dask.array") x = da.random.random(size=(2000, 2000), chunks=(100, 100)) y = (x + 1).T + (x * 2) - x.mean(axis=1) @@ -105,28 +117,35 @@ def create_and_destroy_worker(delay): yield n._close() print("Killed nanny") - yield gen.with_timeout(timedelta(minutes=1), - All([create_and_destroy_worker(0.1 * i) for i in - range(20)])) + yield gen.with_timeout( + timedelta(minutes=1), + All([create_and_destroy_worker(0.1 * i) for i in range(20)]), + ) -@gen_cluster(ncores=[('127.0.0.1', 1)] * 10, client=True, timeout=60) +@gen_cluster(ncores=[("127.0.0.1", 1)] * 10, client=True, timeout=60) def test_stress_scatter_death(c, s, *workers): import random + s.allowed_failures = 1000 - np = pytest.importorskip('numpy') + np = pytest.importorskip("numpy") L = yield c.scatter([np.random.random(10000) for i in range(len(workers))]) yield c._replicate(L, n=2) - adds = [delayed(slowadd, pure=True)(random.choice(L), - random.choice(L), - delay=0.05, - dask_key_name='slowadd-1-%d' % i) - for i in range(50)] - - adds = [delayed(slowadd, pure=True)(a, b, delay=0.02, - dask_key_name='slowadd-2-%d' % i) - for i, (a, b) in enumerate(sliding_window(2, adds))] + adds = [ + delayed(slowadd, pure=True)( + random.choice(L), + random.choice(L), + delay=0.05, + dask_key_name="slowadd-1-%d" % i, + ) + for i in range(50) + ] + + adds = [ + delayed(slowadd, pure=True)(a, b, delay=0.02, dask_key_name="slowadd-2-%d" % i) + for i, (a, b) in enumerate(sliding_window(2, adds)) + ] futures = c.compute(adds) L = adds = None @@ -141,8 +160,9 @@ def test_stress_scatter_death(c, s, *workers): s.validate_state() except Exception as c: logger.exception(c) - if config.get('log-on-err'): + if config.get("log-on-err"): import pdb + pdb.set_trace() else: raise @@ -153,7 +173,7 @@ def test_stress_scatter_death(c, s, *workers): try: yield gen.with_timeout(timedelta(seconds=25), c._gather(futures)) except gen.TimeoutError: - ws = {w.address: w for w in workers if w.status != 'closed'} + ws = {w.address: w for w in workers if w.status != "closed"} print(s.processing) print(ws) print(futures) @@ -161,8 +181,9 @@ def test_stress_scatter_death(c, s, *workers): worker = [w for w in ws.values() if w.waiting_for_data][0] except Exception: pass - if config.get('log-on-err'): + if config.get("log-on-err"): import pdb + pdb.set_trace() else: raise @@ -178,18 +199,18 @@ def vsum(*args): @pytest.mark.avoid_travis @slow -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 80, timeout=1000) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 80, timeout=1000) def test_stress_communication(c, s, *workers): s.validate = False # very slow otherwise - da = pytest.importorskip('dask.array') + da = pytest.importorskip("dask.array") # Test consumes many file descriptors and can hang if the limit is too low - resource = pytest.importorskip('resource') + resource = pytest.importorskip("resource") bump_rlimit(resource.RLIMIT_NOFILE, 8192) n = 20 xs = [da.random.random((100, 100), chunks=(5, 5)) for i in range(n)] ys = [x + x.T for x in xs] - z = da.atop(vsum, 'ij', *concat(zip(ys, ['ij'] * n)), dtype='float64') + z = da.atop(vsum, "ij", *concat(zip(ys, ["ij"] * n)), dtype="float64") future = c.compute(z.sum()) @@ -198,7 +219,7 @@ def test_stress_communication(c, s, *workers): @pytest.mark.skip -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 10, timeout=60) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 10, timeout=60) def test_stress_steal(c, s, *workers): s.validate = False for w in workers: @@ -207,13 +228,12 @@ def test_stress_steal(c, s, *workers): dinc = delayed(slowinc) L = [delayed(slowinc)(i, delay=0.005) for i in range(100)] for i in range(5): - L = [delayed(slowsum)(part, delay=0.005) - for part in sliding_window(5, L)] + L = [delayed(slowsum)(part, delay=0.005) for part in sliding_window(5, L)] total = delayed(sum)(L) future = c.compute(total) - while future.status != 'finished': + while future.status != "finished": yield gen.sleep(0.1) for i in range(3): a = random.choice(workers) @@ -225,9 +245,9 @@ def test_stress_steal(c, s, *workers): @slow -@gen_cluster(ncores=[('127.0.0.1', 1)] * 10, client=True, timeout=120) +@gen_cluster(ncores=[("127.0.0.1", 1)] * 10, client=True, timeout=120) def test_close_connections(c, s, *workers): - da = pytest.importorskip('dask.array') + da = pytest.importorskip("dask.array") x = da.random.random(size=(1000, 1000), chunks=(1000, 1)) for i in range(3): x = x.rechunk((1, 1000)) @@ -246,12 +266,14 @@ def test_close_connections(c, s, *workers): yield wait(future) -@pytest.mark.xfail(reason="IOStream._handle_write blocks on large write_buffer" - " https://github.com/tornadoweb/tornado/issues/2110") -@gen_cluster(client=True, timeout=20, ncores=[('127.0.0.1', 1)]) +@pytest.mark.xfail( + reason="IOStream._handle_write blocks on large write_buffer" + " https://github.com/tornadoweb/tornado/issues/2110" +) +@gen_cluster(client=True, timeout=20, ncores=[("127.0.0.1", 1)]) def test_no_delay_during_large_transfer(c, s, w): - pytest.importorskip('crick') - np = pytest.importorskip('numpy') + pytest.importorskip("crick") + np = pytest.importorskip("numpy") x = np.random.random(100000000) x_nbytes = x.nbytes @@ -273,7 +295,7 @@ def test_no_delay_during_large_transfer(c, s, w): x = None # lose ref for server in [s, w]: - assert server.digests['tick-duration'].components[0].max() < 0.5 + assert server.digests["tick-duration"].components[0].max() < 0.5 nbytes = np.array([t.mem for t in rprof.results]) nbytes -= nbytes[0] diff --git a/distributed/tests/test_submit_cli.py b/distributed/tests/test_submit_cli.py index 7d84ce8cc12..04267a28e2b 100644 --- a/distributed/tests/test_submit_cli.py +++ b/distributed/tests/test_submit_cli.py @@ -4,34 +4,40 @@ from tornado import gen from tornado.ioloop import IOLoop from distributed.submit import RemoteClient, _submit, _remote -from distributed.utils_test import (valid_python_script, invalid_python_script, loop) # noqa: F401 +from distributed.utils_test import ( # noqa: F401 + valid_python_script, + invalid_python_script, + loop, +) -def test_dask_submit_cli_writes_result_to_stdout(loop, tmpdir, - valid_python_script): +def test_dask_submit_cli_writes_result_to_stdout(loop, tmpdir, valid_python_script): @gen.coroutine def test(): - remote_client = RemoteClient(ip='127.0.0.1', local_dir=str(tmpdir)) + remote_client = RemoteClient(ip="127.0.0.1", local_dir=str(tmpdir)) yield remote_client._start() - out, err = yield _submit('127.0.0.1:{0}'.format(remote_client.port), - str(valid_python_script)) - assert b'hello world!' in out + out, err = yield _submit( + "127.0.0.1:{0}".format(remote_client.port), str(valid_python_script) + ) + assert b"hello world!" in out yield remote_client._close() loop.run_sync(test, timeout=5) -def test_dask_submit_cli_writes_traceback_to_stdout(loop, tmpdir, - invalid_python_script): +def test_dask_submit_cli_writes_traceback_to_stdout( + loop, tmpdir, invalid_python_script +): @gen.coroutine def test(): - remote_client = RemoteClient(ip='127.0.0.1', local_dir=str(tmpdir)) + remote_client = RemoteClient(ip="127.0.0.1", local_dir=str(tmpdir)) yield remote_client._start() - out, err = yield _submit('127.0.0.1:{0}'.format(remote_client.port), - str(invalid_python_script)) - assert b'Traceback' in err + out, err = yield _submit( + "127.0.0.1:{0}".format(remote_client.port), str(invalid_python_script) + ) + assert b"Traceback" in err yield remote_client._close() loop.run_sync(test, timeout=5) @@ -41,9 +47,9 @@ def test_cli_runs_remote_client(): mock_remote_client = Mock(spec=RemoteClient) mock_ioloop = Mock(spec=IOLoop.current()) - _remote('127.0.0.1:8799', 8788, loop=mock_ioloop, client=mock_remote_client) + _remote("127.0.0.1:8799", 8788, loop=mock_ioloop, client=mock_remote_client) - mock_remote_client.assert_called_once_with(ip='127.0.0.1', loop=mock_ioloop) + mock_remote_client.assert_called_once_with(ip="127.0.0.1", loop=mock_ioloop) mock_remote_client().start.assert_called_once_with(port=8799) assert mock_ioloop.start.called diff --git a/distributed/tests/test_submit_remote_client.py b/distributed/tests/test_submit_remote_client.py index d74c3952497..e6527d8319b 100644 --- a/distributed/tests/test_submit_remote_client.py +++ b/distributed/tests/test_submit_remote_client.py @@ -4,19 +4,25 @@ from distributed import rpc from distributed.submit import RemoteClient -from distributed.utils_test import (loop, valid_python_script, invalid_python_script) # noqa: F401 +from distributed.utils_test import ( # noqa: F401 + loop, + valid_python_script, + invalid_python_script, +) def test_remote_client_uploads_a_file(loop, tmpdir): @gen.coroutine def test(): - remote_client = RemoteClient(ip='127.0.0.1', local_dir=str(tmpdir)) + remote_client = RemoteClient(ip="127.0.0.1", local_dir=str(tmpdir)) yield remote_client._start(0) remote_process = rpc(remote_client.address) - upload = yield remote_process.upload_file(filename='script.py', file_payload='x=1') + upload = yield remote_process.upload_file( + filename="script.py", file_payload="x=1" + ) - assert upload == {'status': 'OK', 'nbytes': 3} - assert tmpdir.join('script.py').read() == "x=1" + assert upload == {"status": "OK", "nbytes": 3} + assert tmpdir.join("script.py").read() == "x=1" yield remote_client._close() @@ -26,14 +32,14 @@ def test(): def test_remote_client_execution_outputs_to_stdout(loop, tmpdir): @gen.coroutine def test(): - remote_client = RemoteClient(ip='127.0.0.1', local_dir=str(tmpdir)) + remote_client = RemoteClient(ip="127.0.0.1", local_dir=str(tmpdir)) yield remote_client._start(0) rr = rpc(remote_client.address) - yield rr.upload_file(filename='script.py', file_payload='print("hello world!")') + yield rr.upload_file(filename="script.py", file_payload='print("hello world!")') - message = yield rr.execute(filename='script.py') - assert message['stdout'] == b'hello world!' + os.linesep.encode() - assert message['returncode'] == 0 + message = yield rr.execute(filename="script.py") + assert message["stdout"] == b"hello world!" + os.linesep.encode() + assert message["returncode"] == 0 yield remote_client._close() @@ -43,14 +49,14 @@ def test(): def test_remote_client_execution_outputs_stderr(loop, tmpdir, invalid_python_script): @gen.coroutine def test(): - remote_client = RemoteClient(ip='127.0.0.1', local_dir=str(tmpdir)) + remote_client = RemoteClient(ip="127.0.0.1", local_dir=str(tmpdir)) yield remote_client._start(0) rr = rpc(remote_client.address) - yield rr.upload_file(filename='script.py', file_payload='a+1') + yield rr.upload_file(filename="script.py", file_payload="a+1") - message = yield rr.execute(filename='script.py') - assert b'\'a\' is not defined' in message['stderr'] - assert message['returncode'] == 1 + message = yield rr.execute(filename="script.py") + assert b"'a' is not defined" in message["stderr"] + assert message["returncode"] == 1 yield remote_client._close() diff --git a/distributed/tests/test_system_monitor.py b/distributed/tests/test_system_monitor.py index 9c3e284dd36..f42fb8e3e08 100644 --- a/distributed/tests/test_system_monitor.py +++ b/distributed/tests/test_system_monitor.py @@ -18,7 +18,7 @@ def test_SystemMonitor(): assert all(wb >= 0 for wb in sm.write_bytes) assert all(len(q) == 3 for q in sm.quantities.values()) - assert 'cpu' in repr(sm) + assert "cpu" in repr(sm) def test_count(): diff --git a/distributed/tests/test_threadpoolexecutor.py b/distributed/tests/test_threadpoolexecutor.py index 8777e574282..8b807512168 100644 --- a/distributed/tests/test_threadpoolexecutor.py +++ b/distributed/tests/test_threadpoolexecutor.py @@ -111,6 +111,7 @@ def f(): def test_rejoin_idempotent(): with ThreadPoolExecutor(2) as e: + def f(): secede() for i in range(5): diff --git a/distributed/tests/test_tls_functional.py b/distributed/tests/test_tls_functional.py index 14f349545ba..6b71941257c 100644 --- a/distributed/tests/test_tls_functional.py +++ b/distributed/tests/test_tls_functional.py @@ -10,16 +10,15 @@ from distributed import Nanny, worker_client, Queue from distributed.client import wait -from distributed.utils_test import (gen_tls_cluster, inc, double, slowinc, - slowadd) +from distributed.utils_test import gen_tls_cluster, inc, double, slowinc, slowadd @gen_tls_cluster(client=True) def test_Queue(c, s, a, b): - assert s.address.startswith('tls://') + assert s.address.startswith("tls://") - x = Queue('x') - y = Queue('y') + x = Queue("x") + y = Queue("y") size = yield x.qsize() assert size == 0 @@ -34,7 +33,7 @@ def test_Queue(c, s, a, b): @gen_tls_cluster(client=True, timeout=None) def test_client_submit(c, s, a, b): - assert s.address.startswith('tls://') + assert s.address.startswith("tls://") x = c.submit(inc, 10) result = yield x @@ -49,7 +48,7 @@ def test_client_submit(c, s, a, b): @gen_tls_cluster(client=True) def test_gather(c, s, a, b): - assert s.address.startswith('tls://') + assert s.address.startswith("tls://") x = c.submit(inc, 10) y = c.submit(inc, x) @@ -58,29 +57,29 @@ def test_gather(c, s, a, b): assert result == 11 result = yield c._gather([x]) assert result == [11] - result = yield c._gather({'x': x, 'y': [y]}) - assert result == {'x': 11, 'y': [12]} + result = yield c._gather({"x": x, "y": [y]}) + assert result == {"x": 11, "y": [12]} @gen_tls_cluster(client=True) def test_scatter(c, s, a, b): - assert s.address.startswith('tls://') + assert s.address.startswith("tls://") - d = yield c._scatter({'y': 20}) - ts = s.tasks['y'] + d = yield c._scatter({"y": 20}) + ts = s.tasks["y"] assert ts.who_has assert ts.nbytes > 0 - yy = yield c._gather([d['y']]) + yy = yield c._gather([d["y"]]) assert yy == [20] @gen_tls_cluster(client=True, Worker=Nanny) def test_nanny(c, s, a, b): - assert s.address.startswith('tls://') + assert s.address.startswith("tls://") for n in [a, b]: assert isinstance(n, Nanny) - assert n.address.startswith('tls://') - assert n.worker_address.startswith('tls://') + assert n.address.startswith("tls://") + assert n.worker_address.startswith("tls://") assert s.ncores == {n.worker_address: n.ncores for n in [a, b]} x = c.submit(inc, 10) @@ -100,7 +99,7 @@ def test_rebalance(c, s, a, b): assert len(b.data) == 1 -@gen_tls_cluster(client=True, ncores=[('tls://127.0.0.1', 2)] * 2) +@gen_tls_cluster(client=True, ncores=[("tls://127.0.0.1", 2)] * 2) def test_work_stealing(c, s, a, b): [x] = yield c._scatter([1], workers=a.address) futures = c.map(slowadd, range(50), [x] * 50, delay=0.1) @@ -126,12 +125,12 @@ def func(x): assert yy == 20 + 1 + (20 + 1) * 2 -@gen_tls_cluster(client=True, ncores=[('tls://127.0.0.1', 1)] * 2) +@gen_tls_cluster(client=True, ncores=[("tls://127.0.0.1", 1)] * 2) def test_worker_client_gather(c, s, a, b): a_address = a.address b_address = b.address - assert a_address.startswith('tls://') - assert b_address.startswith('tls://') + assert a_address.startswith("tls://") + assert b_address.startswith("tls://") assert a_address != b_address def func(): diff --git a/distributed/tests/test_utils.py b/distributed/tests/test_utils.py index 45de450fc44..f4423d26e4a 100644 --- a/distributed/tests/test_utils.py +++ b/distributed/tests/test_utils.py @@ -16,15 +16,33 @@ import dask from distributed.compatibility import Queue, Empty, isqueue, PY2, Iterator from distributed.metrics import time -from distributed.utils import (All, sync, is_kernel, ensure_ip, str_graph, - truncate_exception, get_traceback, queue_to_iterator, - iterator_to_queue, _maybe_complex, read_block, seek_delimiter, - funcname, ensure_bytes, open_port, get_ip_interface, nbytes, - set_thread_state, thread_state, LoopRunner, - parse_bytes, parse_timedelta, warn_on_duration) +from distributed.utils import ( + All, + sync, + is_kernel, + ensure_ip, + str_graph, + truncate_exception, + get_traceback, + queue_to_iterator, + iterator_to_queue, + _maybe_complex, + read_block, + seek_delimiter, + funcname, + ensure_bytes, + open_port, + get_ip_interface, + nbytes, + set_thread_state, + thread_state, + LoopRunner, + parse_bytes, + parse_timedelta, + warn_on_duration, +) from distributed.utils_test import loop, loop_in_thread # noqa: F401 -from distributed.utils_test import (div, has_ipv6, inc, throws, gen_test, - captured_logger) +from distributed.utils_test import div, has_ipv6, inc, throws, gen_test, captured_logger def test_All(loop): @@ -65,10 +83,10 @@ def test_sync_error(loop_in_thread): result = sync(loop, throws, 1) except Exception as exc: f = exc - assert 'hello' in str(exc) + assert "hello" in str(exc) tb = get_traceback() L = traceback.format_tb(tb) - assert any('throws' in line for line in L) + assert any("throws" in line for line in L) def function1(x): return function2(x) @@ -79,11 +97,11 @@ def function2(x): try: result = sync(loop, function1, 1) except Exception as exc: - assert 'hello' in str(exc) + assert "hello" in str(exc) tb = get_traceback() L = traceback.format_tb(tb) - assert any('function1' in line for line in L) - assert any('function2' in line for line in L) + assert any("function1" in line for line in L) + assert any("function2" in line for line in L) def test_sync_timeout(loop_in_thread): @@ -104,47 +122,47 @@ def test_sync_closed_loop(): def test_is_kernel(): - pytest.importorskip('IPython') + pytest.importorskip("IPython") assert is_kernel() is False -#@pytest.mark.leaking('fds') -#def test_zzz_leaks(l=[]): - #import os, subprocess - #l.append(b"x" * (17 * 1024**2)) - #os.open(__file__, os.O_RDONLY) - #subprocess.Popen('sleep 100', shell=True, stdin=subprocess.DEVNULL) +# @pytest.mark.leaking('fds') +# def test_zzz_leaks(l=[]): +# import os, subprocess +# l.append(b"x" * (17 * 1024**2)) +# os.open(__file__, os.O_RDONLY) +# subprocess.Popen('sleep 100', shell=True, stdin=subprocess.DEVNULL) def test_ensure_ip(): - assert ensure_ip('localhost') in ('127.0.0.1', '::1') - assert ensure_ip('123.123.123.123') == '123.123.123.123' - assert ensure_ip('8.8.8.8') == '8.8.8.8' + assert ensure_ip("localhost") in ("127.0.0.1", "::1") + assert ensure_ip("123.123.123.123") == "123.123.123.123" + assert ensure_ip("8.8.8.8") == "8.8.8.8" if has_ipv6(): - assert ensure_ip('2001:4860:4860::8888') == '2001:4860:4860::8888' - assert ensure_ip('::1') == '::1' + assert ensure_ip("2001:4860:4860::8888") == "2001:4860:4860::8888" + assert ensure_ip("::1") == "::1" def test_get_ip_interface(): - if sys.platform == 'darwin': - assert get_ip_interface('lo0') == '127.0.0.1' - elif sys.platform.startswith('linux'): - assert get_ip_interface('lo') == '127.0.0.1' + if sys.platform == "darwin": + assert get_ip_interface("lo0") == "127.0.0.1" + elif sys.platform.startswith("linux"): + assert get_ip_interface("lo") == "127.0.0.1" else: pytest.skip("test needs to be enhanced for platform %r" % (sys.platform,)) with pytest.raises(KeyError): - get_ip_interface('__non-existent-interface') + get_ip_interface("__non-existent-interface") def test_truncate_exception(): - e = ValueError('a' * 1000) + e = ValueError("a" * 1000) assert len(str(e)) >= 1000 f = truncate_exception(e, 100) assert type(f) == type(e) assert len(str(f)) < 200 - assert 'aaaa' in str(f) + assert "aaaa" in str(f) - e = ValueError('a') + e = ValueError("a") assert truncate_exception(e) is e @@ -162,7 +180,7 @@ def c(x): c(1) except Exception as e: tb = get_traceback() - assert type(tb).__name__ == 'traceback' + assert type(tb).__name__ == "traceback" def test_queue_to_iterator(): @@ -185,20 +203,23 @@ def test_iterator_to_queue(): def test_str_graph(): - dsk = {'x': 1} + dsk = {"x": 1} assert str_graph(dsk) == dsk - dsk = {('x', 1): (inc, 1)} - assert str_graph(dsk) == {str(('x', 1)): (inc, 1)} + dsk = {("x", 1): (inc, 1)} + assert str_graph(dsk) == {str(("x", 1)): (inc, 1)} - dsk = {('x', 1): (inc, 1), ('x', 2): (inc, ('x', 1))} - assert str_graph(dsk) == {str(('x', 1)): (inc, 1), - str(('x', 2)): (inc, str(('x', 1)))} + dsk = {("x", 1): (inc, 1), ("x", 2): (inc, ("x", 1))} + assert str_graph(dsk) == { + str(("x", 1)): (inc, 1), + str(("x", 2)): (inc, str(("x", 1))), + } - dsks = [{'x': 1}, - {('x', 1): (inc, 1), ('x', 2): (inc, ('x', 1))}, - {('x', 1): (sum, [1, 2, 3]), - ('x', 2): (sum, [('x', 1), ('x', 1)])}] + dsks = [ + {"x": 1}, + {("x", 1): (inc, 1), ("x", 2): (inc, ("x", 1))}, + {("x", 1): (sum, [1, 2, 3]), ("x", 2): (sum, [("x", 1), ("x", 1)])}, + ] for dsk in dsks: sdsk = str_graph(dsk) keys = list(dsk) @@ -209,59 +230,58 @@ def test_str_graph(): def test_maybe_complex(): assert not _maybe_complex(1) - assert not _maybe_complex('x') + assert not _maybe_complex("x") assert _maybe_complex((inc, 1)) assert _maybe_complex([(inc, 1)]) assert _maybe_complex([(inc, 1)]) - assert _maybe_complex({'x': (inc, 1)}) + assert _maybe_complex({"x": (inc, 1)}) def test_read_block(): - delimiter = b'\n' - data = delimiter.join([b'123', b'456', b'789']) + delimiter = b"\n" + data = delimiter.join([b"123", b"456", b"789"]) f = io.BytesIO(data) - assert read_block(f, 1, 2) == b'23' - assert read_block(f, 0, 1, delimiter=b'\n') == b'123\n' - assert read_block(f, 0, 2, delimiter=b'\n') == b'123\n' - assert read_block(f, 0, 3, delimiter=b'\n') == b'123\n' - assert read_block(f, 0, 5, delimiter=b'\n') == b'123\n456\n' - assert read_block(f, 0, 8, delimiter=b'\n') == b'123\n456\n789' - assert read_block(f, 0, 100, delimiter=b'\n') == b'123\n456\n789' - assert read_block(f, 1, 1, delimiter=b'\n') == b'' - assert read_block(f, 1, 5, delimiter=b'\n') == b'456\n' - assert read_block(f, 1, 8, delimiter=b'\n') == b'456\n789' - - for ols in [[(0, 3), (3, 3), (6, 3), (9, 2)], - [(0, 4), (4, 4), (8, 4)]]: - out = [read_block(f, o, l, b'\n') for o, l in ols] + assert read_block(f, 1, 2) == b"23" + assert read_block(f, 0, 1, delimiter=b"\n") == b"123\n" + assert read_block(f, 0, 2, delimiter=b"\n") == b"123\n" + assert read_block(f, 0, 3, delimiter=b"\n") == b"123\n" + assert read_block(f, 0, 5, delimiter=b"\n") == b"123\n456\n" + assert read_block(f, 0, 8, delimiter=b"\n") == b"123\n456\n789" + assert read_block(f, 0, 100, delimiter=b"\n") == b"123\n456\n789" + assert read_block(f, 1, 1, delimiter=b"\n") == b"" + assert read_block(f, 1, 5, delimiter=b"\n") == b"456\n" + assert read_block(f, 1, 8, delimiter=b"\n") == b"456\n789" + + for ols in [[(0, 3), (3, 3), (6, 3), (9, 2)], [(0, 4), (4, 4), (8, 4)]]: + out = [read_block(f, o, l, b"\n") for o, l in ols] assert b"".join(filter(None, out)) == data def test_seek_delimiter_endline(): - f = io.BytesIO(b'123\n456\n789') + f = io.BytesIO(b"123\n456\n789") # if at zero, stay at zero - seek_delimiter(f, b'\n', 5) + seek_delimiter(f, b"\n", 5) assert f.tell() == 0 # choose the first block for bs in [1, 5, 100]: f.seek(1) - seek_delimiter(f, b'\n', blocksize=bs) + seek_delimiter(f, b"\n", blocksize=bs) assert f.tell() == 4 # handle long delimiters well, even with short blocksizes - f = io.BytesIO(b'123abc456abc789') + f = io.BytesIO(b"123abc456abc789") for bs in [1, 2, 3, 4, 5, 6, 10]: f.seek(1) - seek_delimiter(f, b'abc', blocksize=bs) + seek_delimiter(f, b"abc", blocksize=bs) assert f.tell() == 6 # End at the end - f = io.BytesIO(b'123\n456') + f = io.BytesIO(b"123\n456") f.seek(5) - seek_delimiter(f, b'\n', 5) + seek_delimiter(f, b"\n", 5) assert f.tell() == 7 @@ -269,19 +289,19 @@ def test_funcname(): def f(): pass - assert funcname(f) == 'f' - assert funcname(partial(f)) == 'f' - assert funcname(partial(partial(f))) == 'f' + assert funcname(f) == "f" + assert funcname(partial(f)) == "f" + assert funcname(partial(partial(f))) == "f" def test_ensure_bytes(): - data = [b'1', '1', memoryview(b'1'), bytearray(b'1')] + data = [b"1", "1", memoryview(b"1"), bytearray(b"1")] if PY2: - data.append(buffer(b'1')) # noqa: F821 + data.append(buffer(b"1")) # noqa: F821 for d in data: result = ensure_bytes(d) assert isinstance(result, bytes) - assert result == b'1' + assert result == b"1" def test_nbytes(): @@ -289,8 +309,8 @@ def check(obj, expected): assert nbytes(obj) == expected assert nbytes(memoryview(obj)) == expected - check(b'123', 3) - check(bytearray(b'4567'), 4) + check(b"123", 3) + check(bytearray(b"4567"), 4) multi_dim = np.ones(shape=(10, 10)) scalar = np.array(1) @@ -302,7 +322,7 @@ def check(obj, expected): def test_open_port(): port = open_port() s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - s.bind(('', port)) + s.bind(("", port)) s.close() @@ -310,7 +330,7 @@ def test_set_thread_state(): with set_thread_state(x=1): assert thread_state.x == 1 - assert not hasattr(thread_state, 'x') + assert not hasattr(thread_state, "x") def assert_running(loop): @@ -474,70 +494,73 @@ def test_loop_runner_gen(): def test_parse_bytes(): - assert parse_bytes('100') == 100 - assert parse_bytes('100 MB') == 100000000 - assert parse_bytes('100M') == 100000000 - assert parse_bytes('5kB') == 5000 - assert parse_bytes('5.4 kB') == 5400 - assert parse_bytes('1kiB') == 1024 - assert parse_bytes('1Mi') == 2**20 - assert parse_bytes('1e6') == 1000000 - assert parse_bytes('1e6 kB') == 1000000000 - assert parse_bytes('MB') == 1000000 + assert parse_bytes("100") == 100 + assert parse_bytes("100 MB") == 100000000 + assert parse_bytes("100M") == 100000000 + assert parse_bytes("5kB") == 5000 + assert parse_bytes("5.4 kB") == 5400 + assert parse_bytes("1kiB") == 1024 + assert parse_bytes("1Mi") == 2 ** 20 + assert parse_bytes("1e6") == 1000000 + assert parse_bytes("1e6 kB") == 1000000000 + assert parse_bytes("MB") == 1000000 def test_parse_timedelta(): - for text, value in [('1s', 1), - ('100ms', 0.1), - ('5S', 5), - ('5.5s', 5.5), - ('5.5 s', 5.5), - ('1 second', 1), - ('3.3 seconds', 3.3), - ('3.3 milliseconds', 0.0033), - ('3500 us', 0.0035), - ('1 ns', 1e-9), - ('2m', 120), - ('2 minutes', 120), - (datetime.timedelta(seconds=2), 2), - (datetime.timedelta(milliseconds=100), 0.1)]: + for text, value in [ + ("1s", 1), + ("100ms", 0.1), + ("5S", 5), + ("5.5s", 5.5), + ("5.5 s", 5.5), + ("1 second", 1), + ("3.3 seconds", 3.3), + ("3.3 milliseconds", 0.0033), + ("3500 us", 0.0035), + ("1 ns", 1e-9), + ("2m", 120), + ("2 minutes", 120), + (datetime.timedelta(seconds=2), 2), + (datetime.timedelta(milliseconds=100), 0.1), + ]: result = parse_timedelta(text) assert abs(result - value) < 1e-14 - assert parse_timedelta('1ms', default='seconds') == 0.001 - assert parse_timedelta('1', default='seconds') == 1 - assert parse_timedelta('1', default='ms') == 0.001 - assert parse_timedelta(1, default='ms') == 0.001 + assert parse_timedelta("1ms", default="seconds") == 0.001 + assert parse_timedelta("1", default="seconds") == 1 + assert parse_timedelta("1", default="ms") == 0.001 + assert parse_timedelta(1, default="ms") == 0.001 @gen_test() def test_all_exceptions_logging(): @gen.coroutine def throws(): - raise Exception('foo1234') + raise Exception("foo1234") - with captured_logger('') as sio: + with captured_logger("") as sio: try: - yield All([throws() for _ in range(5)], - quiet_exceptions=Exception) + yield All([throws() for _ in range(5)], quiet_exceptions=Exception) except Exception: pass - import gc; gc.collect() + import gc + + gc.collect() yield gen.sleep(0.1) - assert 'foo1234' not in sio.getvalue() + assert "foo1234" not in sio.getvalue() def test_warn_on_duration(): with pytest.warns(None) as record: - with warn_on_duration('10s', 'foo'): + with warn_on_duration("10s", "foo"): pass assert not record with pytest.warns(None) as record: - with warn_on_duration('1ms', 'foo'): + with warn_on_duration("1ms", "foo"): sleep(0.100) assert record - assert any('foo' in str(rec.message) for rec in record) + assert any("foo" in str(rec.message) for rec in record) diff --git a/distributed/tests/test_utils_comm.py b/distributed/tests/test_utils_comm.py index 1e69eef6a03..c9750891dd7 100644 --- a/distributed/tests/test_utils_comm.py +++ b/distributed/tests/test_utils_comm.py @@ -4,23 +4,24 @@ from distributed.core import rpc from distributed.utils_test import gen_cluster -from distributed.utils_comm import (pack_data, gather_from_workers) +from distributed.utils_comm import pack_data, gather_from_workers def test_pack_data(): - data = {'x': 1} - assert pack_data(('x', 'y'), data) == (1, 'y') - assert pack_data({'a': 'x', 'b': 'y'}, data) == {'a': 1, 'b': 'y'} - assert pack_data({'a': ['x'], 'b': 'y'}, data) == {'a': [1], 'b': 'y'} + data = {"x": 1} + assert pack_data(("x", "y"), data) == (1, "y") + assert pack_data({"a": "x", "b": "y"}, data) == {"a": 1, "b": "y"} + assert pack_data({"a": ["x"], "b": "y"}, data) == {"a": [1], "b": "y"} -@pytest.mark.xfail(reason='rpc now needs to be a connection pool') +@pytest.mark.xfail(reason="rpc now needs to be a connection pool") @gen_cluster(client=True) def test_gather_from_workers_permissive(c, s, a, b): - x = yield c.scatter({'x': 1}, workers=a.address) + x = yield c.scatter({"x": 1}, workers=a.address) data, missing, bad_workers = yield gather_from_workers( - {'x': [a.address], 'y': [b.address]}, rpc=rpc) + {"x": [a.address], "y": [b.address]}, rpc=rpc + ) - assert data == {'x': 1} - assert list(missing) == ['y'] + assert data == {"x": 1} + assert list(missing) == ["y"] diff --git a/distributed/tests/test_utils_perf.py b/distributed/tests/test_utils_perf.py index e8e4cae9e37..55b250273c0 100644 --- a/distributed/tests/test_utils_perf.py +++ b/distributed/tests/test_utils_perf.py @@ -10,8 +10,7 @@ from distributed.compatibility import PY2 from distributed.metrics import thread_time -from distributed.utils_perf import (FractionalTimer, GCDiagnosis, - disable_gc_diagnosis) +from distributed.utils_perf import FractionalTimer, GCDiagnosis, disable_gc_diagnosis from distributed.utils_test import captured_logger, run_for @@ -43,8 +42,9 @@ def check_fraction(timer, ft): # sum of last N "measurement" intervals over the sum of last # 2N intervals (not 2N - 1 or 2N + 1) actual = ft.running_fraction - expected = (sum(timer.durations[1][-N:]) / - (sum(timer.durations[0][-N:] + timer.durations[1][-N:]))) + expected = sum(timer.durations[1][-N:]) / ( + sum(timer.durations[0][-N:] + timer.durations[1][-N:]) + ) assert actual == pytest.approx(expected) timer = RandomTimer() @@ -68,13 +68,12 @@ def check_fraction(timer, ft): @contextlib.contextmanager -def enable_gc_diagnosis_and_log(diag, level='INFO'): +def enable_gc_diagnosis_and_log(diag, level="INFO"): disable_gc_diagnosis(force=True) # just in case if gc.callbacks: print("Unexpected gc.callbacks", gc.callbacks) - with captured_logger('distributed.utils_perf', level=level, - propagate=False) as sio: + with captured_logger("distributed.utils_perf", level=level, propagate=False) as sio: gc.disable() gc.collect() # drain any leftover from previous tests diag.enable() @@ -90,7 +89,7 @@ def test_gc_diagnosis_cpu_time(): diag = GCDiagnosis(warn_over_frac=0.75) diag.N_SAMPLES = 3 # shorten tests - with enable_gc_diagnosis_and_log(diag, level='WARN') as sio: + with enable_gc_diagnosis_and_log(diag, level="WARN") as sio: # Spend some CPU time doing only full GCs for i in range(diag.N_SAMPLES): gc.collect() @@ -99,10 +98,12 @@ def test_gc_diagnosis_cpu_time(): lines = sio.getvalue().splitlines() assert len(lines) == 1 # Between 80% and 100% - assert re.match(r"full garbage collections took (100|[89][0-9])% " - r"CPU time recently", lines[0]) + assert re.match( + r"full garbage collections took (100|[89][0-9])% " r"CPU time recently", + lines[0], + ) - with enable_gc_diagnosis_and_log(diag, level='WARN') as sio: + with enable_gc_diagnosis_and_log(diag, level="WARN") as sio: # Spend half the CPU time doing full GCs for i in range(diag.N_SAMPLES + 1): t1 = thread_time() @@ -113,7 +114,7 @@ def test_gc_diagnosis_cpu_time(): assert not sio.getvalue() -@pytest.mark.xfail(reason='unknown') +@pytest.mark.xfail(reason="unknown") @pytest.mark.skipif(PY2, reason="requires Python 3") def test_gc_diagnosis_rss_win(): diag = GCDiagnosis(info_over_rss_win=10e6) @@ -137,5 +138,8 @@ def make_refcycle(nbytes): lines = sio.getvalue().splitlines() assert len(lines) == 1 # Several MB released, and at least 1 reference cycles - assert re.match(r"full garbage collection released [\d\.]+ MB " - r"from [1-9]\d* reference cycles", lines[0]) + assert re.match( + r"full garbage collection released [\d\.]+ MB " + r"from [1-9]\d* reference cycles", + lines[0], + ) diff --git a/distributed/tests/test_utils_test.py b/distributed/tests/test_utils_test.py index 84118750595..6f704c23f5b 100755 --- a/distributed/tests/test_utils_test.py +++ b/distributed/tests/test_utils_test.py @@ -12,11 +12,22 @@ from distributed import Scheduler, Worker, Client, config, default_client from distributed.core import rpc from distributed.metrics import time -from distributed.utils_test import (cluster, gen_cluster, inc, - gen_test, wait_for_port, new_config) - -from distributed.utils_test import (loop, tls_only_security, # noqa: F401 - security, tls_client, tls_cluster) +from distributed.utils_test import ( # noqa: F401 + cluster, + gen_cluster, + inc, + gen_test, + wait_for_port, + new_config, +) + +from distributed.utils_test import ( # noqa: F401 + loop, + tls_only_security, + security, + tls_client, + tls_cluster, +) from distributed.utils import get_ip @@ -27,10 +38,10 @@ def test_bare_cluster(loop): def test_cluster(loop): with cluster() as (s, [a, b]): - with rpc(s['address']) as s: + with rpc(s["address"]) as s: ident = loop.run_sync(s.identity) - assert ident['type'] == 'Scheduler' - assert len(ident['workers']) == 2 + assert ident["type"] == "Scheduler" + assert len(ident["workers"]) == 2 @gen_cluster(client=True) @@ -45,16 +56,17 @@ def test_gen_cluster(c, s, a, b): @pytest.mark.skip(reason="This hangs on travis") def test_gen_cluster_cleans_up_client(loop): import dask.context - assert not dask.config.get('get', None) + + assert not dask.config.get("get", None) @gen_cluster(client=True) def f(c, s, a, b): - assert dask.config.get('get', None) + assert dask.config.get("get", None) yield c.submit(inc, 1) f() - assert not dask.config.get('get', None) + assert not dask.config.get("get", None) @gen_cluster(client=False) @@ -65,16 +77,19 @@ def test_gen_cluster_without_client(s, a, b): assert s.ncores == {w.address: w.ncores for w in [a, b]} -@gen_cluster(client=True, scheduler='tls://127.0.0.1', - ncores=[('tls://127.0.0.1', 1), ('tls://127.0.0.1', 2)], - security=tls_only_security()) +@gen_cluster( + client=True, + scheduler="tls://127.0.0.1", + ncores=[("tls://127.0.0.1", 1), ("tls://127.0.0.1", 2)], + security=tls_only_security(), +) def test_gen_cluster_tls(e, s, a, b): assert isinstance(e, Client) assert isinstance(s, Scheduler) - assert s.address.startswith('tls://') + assert s.address.startswith("tls://") for w in [a, b]: assert isinstance(w, Worker) - assert w.address.startswith('tls://') + assert w.address.startswith("tls://") assert s.ncores == {w.address: w.ncores for w in [a, b]} @@ -132,11 +147,11 @@ def test_wait_for_port(): def test_new_config(): c = config.copy() - with new_config({'xyzzy': 5}): - config['xyzzy'] == 5 + with new_config({"xyzzy": 5}): + config["xyzzy"] == 5 assert config == c - assert 'xyzzy' not in config + assert "xyzzy" not in config def test_lingering_client(): @@ -152,7 +167,7 @@ def f(s, a, b): def test_lingering_client(loop): with cluster() as (s, [a, b]): - client = Client(s['address'], loop=loop) + client = Client(s["address"], loop=loop) def test_tls_cluster(tls_client): @@ -162,8 +177,8 @@ def test_tls_cluster(tls_client): def test_tls_scheduler(security, loop): s = Scheduler(security=security, loop=loop) - s.start('localhost') - assert s.address.startswith('tls') + s.start("localhost") + assert s.address.startswith("tls") s.close() diff --git a/distributed/tests/test_variable.py b/distributed/tests/test_variable.py index 0d0898923c4..5ae94d037c5 100644 --- a/distributed/tests/test_variable.py +++ b/distributed/tests/test_variable.py @@ -9,14 +9,14 @@ from distributed import Client, Variable, worker_client, Nanny, wait from distributed.metrics import time -from distributed.utils_test import (gen_cluster, inc, slow, div) -from distributed.utils_test import client, cluster_fixture, loop # noqa: F401 +from distributed.utils_test import gen_cluster, inc, slow, div +from distributed.utils_test import client, cluster_fixture, loop # noqa: F401 @gen_cluster(client=True) def test_variable(c, s, a, b): - x = Variable('x') - xx = Variable('x') + x = Variable("x") + xx = Variable("x") assert x.client is c future = c.submit(inc, 1) @@ -40,20 +40,20 @@ def test_variable(c, s, a, b): @gen_cluster(client=True) def test_queue_with_data(c, s, a, b): - x = Variable('x') - xx = Variable('x') + x = Variable("x") + xx = Variable("x") assert x.client is c - yield x.set((1, 'hello')) + yield x.set((1, "hello")) data = yield xx.get() - assert data == (1, 'hello') + assert data == (1, "hello") def test_sync(client): future = client.submit(lambda x: x + 1, 10) - x = Variable('x') - xx = Variable('x') + x = Variable("x") + xx = Variable("x") x.set(future) future2 = xx.get() @@ -64,7 +64,7 @@ def test_sync(client): def test_hold_futures(s, a, b): c1 = yield Client(s.address, asynchronous=True) future = c1.submit(lambda x: x + 1, 10) - x1 = Variable('x') + x1 = Variable("x") yield x1.set(future) del x1 yield c1.close() @@ -72,7 +72,7 @@ def test_hold_futures(s, a, b): yield gen.sleep(0.1) c2 = yield Client(s.address, asynchronous=True) - x2 = Variable('x') + x2 = Variable("x") future2 = yield x2.get() result = yield future2 @@ -82,7 +82,7 @@ def test_hold_futures(s, a, b): @gen_cluster(client=True) def test_timeout(c, s, a, b): - v = Variable('v') + v = Variable("v") start = time() with pytest.raises(gen.TimeoutError): @@ -92,7 +92,7 @@ def test_timeout(c, s, a, b): def test_timeout_sync(client): - v = Variable('v') + v = Variable("v") start = time() with pytest.raises(gen.TimeoutError): v.get(timeout=0.1) @@ -102,8 +102,8 @@ def test_timeout_sync(client): @gen_cluster(client=True) def test_cleanup(c, s, a, b): - v = Variable('v') - vv = Variable('v') + v = Variable("v") + vv = Variable("v") x = c.submit(lambda x: x + 1, 10) y = c.submit(lambda x: x + 1, 20) @@ -124,7 +124,7 @@ def test_cleanup(c, s, a, b): def test_pickleable(client): - v = Variable('v') + v = Variable("v") def f(x): v.set(x + 1) @@ -135,27 +135,26 @@ def f(x): @gen_cluster(client=True) def test_timeout_get(c, s, a, b): - v = Variable('v') + v = Variable("v") tornado_future = v.get() - vv = Variable('v') + vv = Variable("v") yield vv.set(1) result = yield tornado_future assert result == 1 -@pytest.mark.skipif(sys.version_info[0] == 2, reason='Multi-client issues') +@pytest.mark.skipif(sys.version_info[0] == 2, reason="Multi-client issues") @slow -@gen_cluster(client=True, ncores=[('127.0.0.1', 2)] * 5, Worker=Nanny, - timeout=None) +@gen_cluster(client=True, ncores=[("127.0.0.1", 2)] * 5, Worker=Nanny, timeout=None) def test_race(c, s, *workers): NITERS = 50 def f(i): with worker_client() as c: - v = Variable('x', client=c) + v = Variable("x", client=c) for _ in range(NITERS): future = v.get() x = future.result() @@ -166,7 +165,7 @@ def f(i): sleep(0.1) # allow fire-and-forget messages to clear return result - v = Variable('x', client=c) + v = Variable("x", client=c) x = yield c.scatter(1) yield v.set(x) @@ -175,7 +174,7 @@ def f(i): assert all(r > NITERS * 0.8 for r in results) start = time() - while len(s.wants_what['variable-x']) != 1: + while len(s.wants_what["variable-x"]) != 1: yield gen.sleep(0.01) assert time() - start < 2 @@ -183,20 +182,20 @@ def f(i): @gen_cluster(client=True) def test_Future_knows_status_immediately(c, s, a, b): x = yield c.scatter(123) - v = Variable('x') + v = Variable("x") yield v.set(x) c2 = yield Client(s.address, asynchronous=True) - v2 = Variable('x', client=c2) + v2 = Variable("x", client=c2) future = yield v2.get() - assert future.status == 'finished' + assert future.status == "finished" x = c.submit(div, 1, 0) yield wait(x) yield v.set(x) future2 = yield v2.get() - assert future2.status == 'error' + assert future2.status == "error" with pytest.raises(Exception): yield future2 diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 3c9b51f2baa..05b61a997f4 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -19,8 +19,7 @@ from tornado import gen from tornado.ioloop import TimeoutError -from distributed import (Nanny, get_client, wait, default_client, - get_worker, Reschedule) +from distributed import Nanny, get_client, wait, default_client, get_worker, Reschedule from distributed.compatibility import WINDOWS, cache_from_source from distributed.core import rpc from distributed.client import wait @@ -28,14 +27,32 @@ from distributed.metrics import time from distributed.worker import Worker, error_message, logger, TOTAL_MEMORY from distributed.utils import tmpfile, format_bytes -from distributed.utils_test import (inc, mul, gen_cluster, div, dec, - slow, slowinc, gen_test, captured_logger) -from distributed.utils_test import client, loop, nodebug, cluster_fixture, s, a, b # noqa: F401 +from distributed.utils_test import ( + inc, + mul, + gen_cluster, + div, + dec, + slow, + slowinc, + gen_test, + captured_logger, +) +from distributed.utils_test import ( # noqa: F401 + client, + loop, + nodebug, + cluster_fixture, + s, + a, + b, +) def test_worker_ncores(): from distributed.worker import _ncores - w = Worker('127.0.0.1', 8019) + + w = Worker("127.0.0.1", 8019) try: assert w.executor._max_workers == _ncores finally: @@ -52,12 +69,12 @@ def test_str(s, a, b): def test_identity(): - w = Worker('127.0.0.1', 8019) + w = Worker("127.0.0.1", 8019) ident = w.identity(None) - assert 'Worker' in ident['type'] - assert ident['scheduler'] == 'tcp://127.0.0.1:8019' - assert isinstance(ident['ncores'], int) - assert isinstance(ident['memory_limit'], Number) + assert "Worker" in ident["type"] + assert ident["scheduler"] == "tcp://127.0.0.1:8019" + assert isinstance(ident["ncores"], int) + assert isinstance(ident["memory_limit"], Number) @gen_cluster(client=True) @@ -91,11 +108,11 @@ def emit(self, record): def reset(self): self.messages = { - 'debug': [], - 'info': [], - 'warning': [], - 'error': [], - 'critical': [], + "debug": [], + "info": [], + "warning": [], + "error": [], + "critical": [], } hdlr = MockLoggingHandler() @@ -106,7 +123,7 @@ def reset(self): yield wait(y) assert not b.executing - assert y.status == 'error' + assert y.status == "error" # Make sure job died because of bad func and not because of bad # argument. with pytest.raises(ZeroDivisionError): @@ -114,10 +131,10 @@ def reset(self): if sys.version_info[0] >= 3: tb = yield y._traceback() - assert any('1 / 0' in line - for line in pluck(3, traceback.extract_tb(tb)) - if line) - assert "Compute Failed" in hdlr.messages['warning'][0] + assert any( + "1 / 0" in line for line in pluck(3, traceback.extract_tb(tb)) if line + ) + assert "Compute Failed" in hdlr.messages["warning"][0] logger.setLevel(old_level) # Now we check that both workers are still alive. @@ -133,18 +150,18 @@ def reset(self): @slow @gen_cluster() def dont_test_delete_data_with_missing_worker(c, a, b): - bad = '127.0.0.1:9001' # this worker doesn't exist - c.who_has['z'].add(bad) - c.who_has['z'].add(a.address) - c.has_what[bad].add('z') - c.has_what[a.address].add('z') - a.data['z'] = 5 + bad = "127.0.0.1:9001" # this worker doesn't exist + c.who_has["z"].add(bad) + c.who_has["z"].add(a.address) + c.has_what[bad].add("z") + c.has_what[a.address].add("z") + a.data["z"] = 5 cc = rpc(ip=c.ip, port=c.port) - yield cc.delete_data(keys=['z']) # TODO: this hangs for a while - assert 'z' not in a.data - assert not c.who_has['z'] + yield cc.delete_data(keys=["z"]) # TODO: this hangs for a while + assert "z" not in a.data + assert not c.who_has["z"] assert not c.has_what[bad] assert not c.has_what[a.address] @@ -153,20 +170,23 @@ def dont_test_delete_data_with_missing_worker(c, a, b): @gen_cluster(client=True) def test_upload_file(c, s, a, b): - assert not os.path.exists(os.path.join(a.local_dir, 'foobar.py')) - assert not os.path.exists(os.path.join(b.local_dir, 'foobar.py')) + assert not os.path.exists(os.path.join(a.local_dir, "foobar.py")) + assert not os.path.exists(os.path.join(b.local_dir, "foobar.py")) assert a.local_dir != b.local_dir aa = rpc(a.address) bb = rpc(b.address) - yield [aa.upload_file(filename='foobar.py', data=b'x = 123'), - bb.upload_file(filename='foobar.py', data='x = 123')] + yield [ + aa.upload_file(filename="foobar.py", data=b"x = 123"), + bb.upload_file(filename="foobar.py", data="x = 123"), + ] - assert os.path.exists(os.path.join(a.local_dir, 'foobar.py')) - assert os.path.exists(os.path.join(b.local_dir, 'foobar.py')) + assert os.path.exists(os.path.join(a.local_dir, "foobar.py")) + assert os.path.exists(os.path.join(b.local_dir, "foobar.py")) def g(): import foobar + return foobar.x future = c.submit(g, workers=a.address) @@ -177,27 +197,29 @@ def g(): yield b._close() aa.close_rpc() bb.close_rpc() - assert not os.path.exists(os.path.join(a.local_dir, 'foobar.py')) + assert not os.path.exists(os.path.join(a.local_dir, "foobar.py")) @pytest.mark.skip(reason="don't yet support uploading pyc files") -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)]) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)]) def test_upload_file_pyc(c, s, w): with tmpfile() as dirname: os.mkdir(dirname) - with open(os.path.join(dirname, 'foo.py'), mode='w') as f: - f.write('def f():\n return 123') + with open(os.path.join(dirname, "foo.py"), mode="w") as f: + f.write("def f():\n return 123") sys.path.append(dirname) try: import foo + assert foo.f() == 123 - pyc = cache_from_source(os.path.join(dirname, 'foo.py')) + pyc = cache_from_source(os.path.join(dirname, "foo.py")) assert os.path.exists(pyc) yield c.upload_file(pyc) def g(): import foo + return foo.x future = c.submit(g) @@ -209,8 +231,8 @@ def g(): @gen_cluster(client=True) def test_upload_egg(c, s, a, b): - eggname = 'testegg-1.0.0-py3.4.egg' - local_file = __file__.replace('test_worker.py', eggname) + eggname = "testegg-1.0.0-py3.4.egg" + local_file = __file__.replace("test_worker.py", eggname) assert not os.path.exists(os.path.join(a.local_dir, eggname)) assert not os.path.exists(os.path.join(b.local_dir, eggname)) assert a.local_dir != b.local_dir @@ -222,6 +244,7 @@ def test_upload_egg(c, s, a, b): def g(x): import testegg + return testegg.inc(x) future = c.submit(g, 10, workers=a.address) @@ -235,8 +258,8 @@ def g(x): @gen_cluster(client=True) def test_upload_pyz(c, s, a, b): - pyzname = 'mytest.pyz' - local_file = __file__.replace('test_worker.py', pyzname) + pyzname = "mytest.pyz" + local_file = __file__.replace("test_worker.py", pyzname) assert not os.path.exists(os.path.join(a.local_dir, pyzname)) assert not os.path.exists(os.path.join(b.local_dir, pyzname)) assert a.local_dir != b.local_dir @@ -248,6 +271,7 @@ def test_upload_pyz(c, s, a, b): def g(x): from mytest import mytest + return mytest.inc(x) future = c.submit(g, 10, workers=a.address) @@ -259,22 +283,22 @@ def g(x): assert not os.path.exists(os.path.join(a.local_dir, pyzname)) -@pytest.mark.xfail(reason='Still lose time to network I/O') +@pytest.mark.xfail(reason="Still lose time to network I/O") @gen_cluster(client=True) def test_upload_large_file(c, s, a, b): - pytest.importorskip('crick') + pytest.importorskip("crick") yield gen.sleep(0.05) with rpc(a.address) as aa: - yield aa.upload_file(filename='myfile.dat', data=b'0' * 100000000) + yield aa.upload_file(filename="myfile.dat", data=b"0" * 100000000) yield gen.sleep(0.05) - assert a.digests['tick-duration'].components[0].max() < 0.050 + assert a.digests["tick-duration"].components[0].max() < 0.050 @gen_cluster() def test_broadcast(s, a, b): with rpc(s.address) as cc: - results = yield cc.broadcast(msg={'op': 'ping'}) - assert results == {a.address: b'pong', b.address: b'pong'} + results = yield cc.broadcast(msg={"op": "ping"}) + assert results == {a.address: b"pong", b.address: b"pong"} @gen_test() @@ -292,7 +316,7 @@ def test_worker_with_port_zero(): def test_worker_waits_for_center_to_come_up(loop): @gen.coroutine def f(): - w = yield Worker('127.0.0.1', 8007) + w = yield Worker("127.0.0.1", 8007) try: loop.run_sync(f, timeout=4) @@ -300,7 +324,7 @@ def f(): pass -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)]) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)]) def test_worker_task_data(c, s, w): x = delayed(2) xx = c.persist(x) @@ -316,20 +340,20 @@ def __init__(self, a, b): def __str__(self): return "MyException(%s)" % self.args - msg = error_message(MyException('Hello', 'World!')) - assert 'Hello' in str(msg['exception']) + msg = error_message(MyException("Hello", "World!")) + assert "Hello" in str(msg["exception"]) @gen_cluster() def test_gather(s, a, b): - b.data['x'] = 1 - b.data['y'] = 2 + b.data["x"] = 1 + b.data["y"] = 2 with rpc(a.address) as aa: - resp = yield aa.gather(who_has={'x': [b.address], 'y': [b.address]}) - assert resp['status'] == 'OK' + resp = yield aa.gather(who_has={"x": [b.address], "y": [b.address]}) + assert resp["status"] == "OK" - assert a.data['x'] == b.data['x'] - assert a.data['y'] == b.data['y'] + assert a.data["x"] == b.data["x"] + assert a.data["y"] == b.data["y"] def test_io_loop(loop): @@ -342,19 +366,24 @@ def test_io_loop(loop): @gen_cluster(client=True, ncores=[]) def test_spill_to_disk(c, s): - np = pytest.importorskip('numpy') - w = yield Worker(s.address, loop=s.loop, memory_limit=1200 / 0.6, - memory_pause_fraction=None, memory_spill_fraction=None) - - x = c.submit(np.random.randint, 0, 255, size=500, dtype='u1', key='x') + np = pytest.importorskip("numpy") + w = yield Worker( + s.address, + loop=s.loop, + memory_limit=1200 / 0.6, + memory_pause_fraction=None, + memory_spill_fraction=None, + ) + + x = c.submit(np.random.randint, 0, 255, size=500, dtype="u1", key="x") yield wait(x) - y = c.submit(np.random.randint, 0, 255, size=500, dtype='u1', key='y') + y = c.submit(np.random.randint, 0, 255, size=500, dtype="u1", key="y") yield wait(y) assert set(w.data) == {x.key, y.key} assert set(w.data.fast) == {x.key, y.key} - z = c.submit(np.random.randint, 0, 255, size=500, dtype='u1', key='z') + z = c.submit(np.random.randint, 0, 255, size=500, dtype="u1", key="z") yield wait(z) assert set(w.data) == {x.key, y.key, z.key} assert set(w.data.fast) == {y.key, z.key} @@ -370,11 +399,12 @@ def test_spill_to_disk(c, s): def test_access_key(c, s, a, b): def f(i): from distributed.worker import thread_state + return thread_state.key - futures = [c.submit(f, i, key='x-%d' % i) for i in range(20)] + futures = [c.submit(f, i, key="x-%d" % i) for i in range(20)] results = yield c._gather(futures) - assert list(results) == ['x-%d' % i for i in range(20)] + assert list(results) == ["x-%d" % i for i in range(20)] @gen_cluster(client=True) @@ -417,32 +447,31 @@ def test_Executor(c, s): @pytest.mark.skip(reason="Leaks a large amount of memory") -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)], timeout=30) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)], timeout=30) def test_spill_by_default(c, s, w): - da = pytest.importorskip('dask.array') - x = da.ones(int(TOTAL_MEMORY * 0.7), chunks=10000000, dtype='u1') + da = pytest.importorskip("dask.array") + x = da.ones(int(TOTAL_MEMORY * 0.7), chunks=10000000, dtype="u1") y = c.persist(x) yield wait(y) assert len(w.data.slow) # something is on disk del x, y -@gen_cluster(ncores=[('127.0.0.1', 1)], - worker_kwargs={'reconnect': False}) +@gen_cluster(ncores=[("127.0.0.1", 1)], worker_kwargs={"reconnect": False}) def test_close_on_disconnect(s, w): yield s.close() start = time() - while w.status != 'closed': + while w.status != "closed": yield gen.sleep(0.01) assert time() < start + 5 def test_memory_limit_auto(): - a = Worker('127.0.0.1', 8099, ncores=1) - b = Worker('127.0.0.1', 8099, ncores=2) - c = Worker('127.0.0.1', 8099, ncores=100) - d = Worker('127.0.0.1', 8099, ncores=200) + a = Worker("127.0.0.1", 8099, ncores=1) + b = Worker("127.0.0.1", 8099, ncores=2) + c = Worker("127.0.0.1", 8099, ncores=100) + d = Worker("127.0.0.1", 8099, ncores=200) assert isinstance(a.memory_limit, Number) assert isinstance(b.memory_limit, Number) @@ -468,8 +497,17 @@ def test_clean(c, s, a, b): yield y - collections = [a.tasks, a.task_state, a.startstops, a.data, a.nbytes, - a.durations, a.priorities, a.types, a.threads] + collections = [ + a.tasks, + a.task_state, + a.startstops, + a.data, + a.nbytes, + a.durations, + a.priorities, + a.types, + a.threads, + ] for c in collections: assert c @@ -489,15 +527,15 @@ def test_message_breakup(c, s, a, b): n = 100000 a.target_message_size = 10 * n b.target_message_size = 10 * n - xs = [c.submit(mul, b'%d' % i, n, workers=a.address) for i in range(30)] + xs = [c.submit(mul, b"%d" % i, n, workers=a.address) for i in range(30)] y = c.submit(lambda *args: None, xs, workers=b.address) yield y assert 2 <= len(b.incoming_transfer_log) <= 20 assert 2 <= len(a.outgoing_transfer_log) <= 20 - assert all(msg['who'] == b.address for msg in a.outgoing_transfer_log) - assert all(msg['who'] == a.address for msg in a.incoming_transfer_log) + assert all(msg["who"] == b.address for msg in a.outgoing_transfer_log) + assert all(msg["who"] == a.address for msg in a.incoming_transfer_log) @gen_cluster(client=True) @@ -528,13 +566,14 @@ def test_system_monitor(s, a, b): b.monitor.update() -@gen_cluster(client=True, ncores=[('127.0.0.1', 2, {'resources': {'A': 1}}), - ('127.0.0.1', 1)]) +@gen_cluster( + client=True, ncores=[("127.0.0.1", 2, {"resources": {"A": 1}}), ("127.0.0.1", 1)] +) def test_restrictions(c, s, a, b): # Resource restrictions - x = c.submit(inc, 1, resources={'A': 1}) + x = c.submit(inc, 1, resources={"A": 1}) yield x - assert a.resource_restrictions == {x.key: {'A': 1}} + assert a.resource_restrictions == {x.key: {"A": 1}} yield c._cancel(x) while x.key in a.task_state: @@ -558,7 +597,7 @@ def test_clean_nbytes(c, s, a, b): assert len(a.nbytes) + len(b.nbytes) == 1 -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 20) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 20) def test_gather_many_small(c, s, a, *workers): a.total_out_connections = 2 futures = yield c._scatter(list(range(100))) @@ -572,14 +611,14 @@ def f(*args): yield wait(future) types = list(pluck(0, a.log)) - req = [i for i, t in enumerate(types) if t == 'request-dep'] - recv = [i for i, t in enumerate(types) if t == 'receive-dep'] + req = [i for i, t in enumerate(types) if t == "request-dep"] + recv = [i for i, t in enumerate(types) if t == "receive-dep"] assert min(recv) > max(req) assert a.comm_nbytes == 0 -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 3) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 3) def test_multiple_transfers(c, s, w1, w2, w3): x = c.submit(inc, 1, workers=w1.address) y = c.submit(inc, 2, workers=w2.address) @@ -588,14 +627,14 @@ def test_multiple_transfers(c, s, w1, w2, w3): yield wait(z) r = w3.startstops[z.key] - transfers = [t for t in r if t[0] == 'transfer'] + transfers = [t for t in r if t[0] == "transfer"] assert len(transfers) == 2 -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 3) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 3) def test_share_communication(c, s, w1, w2, w3): - x = c.submit(mul, b'1', int(w3.target_message_size + 1), workers=w1.address) - y = c.submit(mul, b'2', int(w3.target_message_size + 1), workers=w2.address) + x = c.submit(mul, b"1", int(w3.target_message_size + 1), workers=w1.address) + y = c.submit(mul, b"2", int(w3.target_message_size + 1), workers=w2.address) yield wait([x, y]) yield c._replicate([x, y], workers=[w1.address, w2.address]) z = c.submit(add, x, y, workers=w3.address) @@ -607,15 +646,15 @@ def test_share_communication(c, s, w1, w2, w3): @gen_cluster(client=True) def test_dont_overlap_communications_to_same_worker(c, s, a, b): - x = c.submit(mul, b'1', int(b.target_message_size + 1), workers=a.address) - y = c.submit(mul, b'2', int(b.target_message_size + 1), workers=a.address) + x = c.submit(mul, b"1", int(b.target_message_size + 1), workers=a.address) + y = c.submit(mul, b"2", int(b.target_message_size + 1), workers=a.address) yield wait([x, y]) z = c.submit(add, x, y, workers=b.address) yield wait(z) assert len(b.incoming_transfer_log) == 2 l1, l2 = b.incoming_transfer_log - assert l1['stop'] < l2['start'] + assert l1["stop"] < l2["start"] @pytest.mark.avoid_travis @@ -625,6 +664,7 @@ def test_log_exception_on_failed_task(c, s, a, b): fh = logging.FileHandler(fn) try: from distributed.worker import logger + logger.addHandler(fh) future = c.submit(div, 1, 0) @@ -677,12 +717,12 @@ def test_hold_onto_dependents(c, s, a, b): @slow @gen_cluster(client=False, ncores=[]) def test_worker_death_timeout(s): - with dask.config.set({'distributed.comm.timeouts.connect': '1s'}): + with dask.config.set({"distributed.comm.timeouts.connect": "1s"}): yield s.close() w = yield Worker(s.address, death_timeout=1) yield gen.sleep(2) - assert w.status == 'closed' + assert w.status == "closed" @gen_cluster(client=True) @@ -698,14 +738,14 @@ def test_stop_doing_unnecessary_work(c, s, a, b): assert time() - start < 0.5 -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)]) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)]) def test_priorities(c, s, w): values = [] for i in range(10): - a = delayed(slowinc)(i, dask_key_name='a-%d' % i, delay=0.01) - a1 = delayed(inc)(a, dask_key_name='a1-%d' % i) - a2 = delayed(inc)(a1, dask_key_name='a2-%d' % i) - b1 = delayed(dec)(a, dask_key_name='b1-%d' % i) # <<-- least favored + a = delayed(slowinc)(i, dask_key_name="a-%d" % i, delay=0.01) + a1 = delayed(inc)(a, dask_key_name="a1-%d" % i) + a2 = delayed(inc)(a1, dask_key_name="a2-%d" % i) + b1 = delayed(dec)(a, dask_key_name="b1-%d" % i) # <<-- least favored values.append(a2) values.append(b1) @@ -713,29 +753,31 @@ def test_priorities(c, s, w): futures = c.compute(values) yield wait(futures) - log = [t[0] for t in w.log - if t[1] == 'executing' - and t[2] == 'memory' - and not t[0].startswith('finalize')] + log = [ + t[0] + for t in w.log + if t[1] == "executing" and t[2] == "memory" and not t[0].startswith("finalize") + ] - assert any(key.startswith('b1') for key in log[:len(log) // 2]) + assert any(key.startswith("b1") for key in log[: len(log) // 2]) @gen_cluster(client=True) def test_heartbeats(c, s, a, b): x = s.workers[a.address].last_seen start = time() - yield gen.sleep(a.periodic_callbacks['heartbeat'].callback_time / 1000 + 0.1) + yield gen.sleep(a.periodic_callbacks["heartbeat"].callback_time / 1000 + 0.1) while s.workers[a.address].last_seen == x: yield gen.sleep(0.01) assert time() < start + 2 - assert a.periodic_callbacks['heartbeat'].callback_time < 1000 + assert a.periodic_callbacks["heartbeat"].callback_time < 1000 -@pytest.mark.parametrize('worker', [Worker, Nanny]) +@pytest.mark.parametrize("worker", [Worker, Nanny]) def test_worker_dir(worker): with tmpfile() as fn: - @gen_cluster(client=True, worker_kwargs={'local_dir': fn}) + + @gen_cluster(client=True, worker_kwargs={"local_dir": fn}) def test_worker_dir(c, s, a, b): directories = [w.local_directory for w in s.workers.values()] assert all(d.startswith(fn) for d in directories) @@ -751,7 +793,7 @@ def __init__(self, data): self.data = data def __sizeof__(self): - raise TypeError('Hello') + raise TypeError("Hello") future = c.submit(BadSize, 123) result = yield future @@ -770,7 +812,7 @@ def __sizeof__(self): future = c.submit(Bad) yield wait(future) - assert future.status == 'error' + assert future.status == "error" with pytest.raises(TypeError): yield future @@ -781,8 +823,9 @@ def __sizeof__(self): @pytest.mark.skip(reason="Our logic here is faulty") -@gen_cluster(ncores=[('127.0.0.1', 2)], client=True, - worker_kwargs={'memory_limit': 10e9}) +@gen_cluster( + ncores=[("127.0.0.1", 2)], client=True, worker_kwargs={"memory_limit": 10e9} +) def test_fail_write_many_to_disk(c, s, a): a.validate = False yield gen.sleep(0.1) @@ -863,8 +906,7 @@ def f(): raise gen.Return(result) results = yield c.run(f) - assert results == {a.address: 11, - b.address: 11} + assert results == {a.address: 11, b.address: 11} def test_get_client_coroutine_sync(client, s, a, b): @@ -876,13 +918,13 @@ def f(): raise gen.Return(result) results = client.run(f) - assert results == {a['address']: 11, - b['address']: 11} + assert results == {a["address"]: 11, b["address"]: 11} @gen_cluster() def test_global_workers(s, a, b): from distributed.worker import _global_workers + n = len(_global_workers) w = _global_workers[-1]() assert w is a or w is b @@ -894,7 +936,7 @@ def test_global_workers(s, a, b): @pytest.mark.skipif(WINDOWS, reason="file descriptors") @gen_cluster(ncores=[]) def test_worker_fds(s): - psutil = pytest.importorskip('psutil') + psutil = pytest.importorskip("psutil") yield gen.sleep(0.05) start = psutil.Process().num_fds() @@ -916,39 +958,41 @@ def test_worker_fds(s): @gen_cluster(ncores=[]) def test_service_hosts_match_worker(s): - pytest.importorskip('bokeh') + pytest.importorskip("bokeh") from distributed.bokeh.worker import BokehWorker - services = {('bokeh', ':0'): BokehWorker} - w = Worker(s.address, services={('bokeh', ':0'): BokehWorker}) - yield w._start('tcp://0.0.0.0') - sock = first(w.services['bokeh'].server._http._sockets.values()) - assert sock.getsockname()[0] in ('::', '0.0.0.0') + services = {("bokeh", ":0"): BokehWorker} + + w = Worker(s.address, services={("bokeh", ":0"): BokehWorker}) + yield w._start("tcp://0.0.0.0") + sock = first(w.services["bokeh"].server._http._sockets.values()) + assert sock.getsockname()[0] in ("::", "0.0.0.0") yield w._close() - w = Worker(s.address, services={('bokeh', ':0'): BokehWorker}) - yield w._start('tcp://127.0.0.1') - sock = first(w.services['bokeh'].server._http._sockets.values()) - assert sock.getsockname()[0] in ('::', '0.0.0.0') + w = Worker(s.address, services={("bokeh", ":0"): BokehWorker}) + yield w._start("tcp://127.0.0.1") + sock = first(w.services["bokeh"].server._http._sockets.values()) + assert sock.getsockname()[0] in ("::", "0.0.0.0") yield w._close() - w = Worker(s.address, services={('bokeh', 0): BokehWorker}) - yield w._start('tcp://127.0.0.1') - sock = first(w.services['bokeh'].server._http._sockets.values()) - assert sock.getsockname()[0] == '127.0.0.1' + w = Worker(s.address, services={("bokeh", 0): BokehWorker}) + yield w._start("tcp://127.0.0.1") + sock = first(w.services["bokeh"].server._http._sockets.values()) + assert sock.getsockname()[0] == "127.0.0.1" yield w._close() @gen_cluster(ncores=[]) def test_start_services(s): - pytest.importorskip('bokeh') + pytest.importorskip("bokeh") from distributed.bokeh.worker import BokehWorker - services = {('bokeh', ':1234'): BokehWorker} + + services = {("bokeh", ":1234"): BokehWorker} w = Worker(s.address, services=services) yield w._start() - assert w.services['bokeh'].server.port == 1234 + assert w.services["bokeh"].server.port == 1234 yield w._close() @@ -968,7 +1012,7 @@ def test_scheduler_delay(c, s, a, b): old = a.scheduler_delay assert abs(a.scheduler_delay) < 0.3 assert abs(b.scheduler_delay) < 0.3 - yield gen.sleep(a.periodic_callbacks['heartbeat'].callback_time / 1000 + .3) + yield gen.sleep(a.periodic_callbacks["heartbeat"].callback_time / 1000 + 0.3) assert a.scheduler_delay != old @@ -977,26 +1021,30 @@ def test_statistical_profiling(c, s, a, b): futures = c.map(slowinc, range(10), delay=0.1) yield wait(futures) - profile = a.profile_keys['slowinc'] - assert profile['count'] + profile = a.profile_keys["slowinc"] + assert profile["count"] @nodebug @gen_cluster(client=True) def test_statistical_profiling_2(c, s, a, b): - da = pytest.importorskip('dask.array') + da = pytest.importorskip("dask.array") for i in range(5): x = da.random.random(1000000, chunks=(10000,)) y = (x + x * 2) - x.sum().persist() yield wait(y) profile = a.get_profile() - assert profile['count'] - assert 'sum' in str(profile) or 'random' in str(profile) + assert profile["count"] + assert "sum" in str(profile) or "random" in str(profile) -@gen_cluster(ncores=[('127.0.0.1', 1)], client=True, worker_kwargs={'memory_monitor_interval': 10}) +@gen_cluster( + ncores=[("127.0.0.1", 1)], + client=True, + worker_kwargs={"memory_monitor_interval": 10}, +) def test_robust_to_bad_sizeof_estimates(c, s, a): - np = pytest.importorskip('numpy') + np = pytest.importorskip("numpy") memory = psutil.Process().memory_info().rss a.memory_limit = memory / 0.7 + 400e6 @@ -1008,7 +1056,7 @@ def __sizeof__(self): return 10 def f(n): - x = np.ones(int(n), dtype='u1') + x = np.ones(int(n), dtype="u1") result = BadAccounting(x) return result @@ -1021,42 +1069,48 @@ def f(n): @pytest.mark.slow -@gen_cluster(ncores=[('127.0.0.1', 2)], - client=True, - worker_kwargs={'memory_monitor_interval': 10, - 'memory_spill_fraction': False, # don't spill - 'memory_target_fraction': False, - 'memory_pause_fraction': 0.5}, - timeout=20) +@gen_cluster( + ncores=[("127.0.0.1", 2)], + client=True, + worker_kwargs={ + "memory_monitor_interval": 10, + "memory_spill_fraction": False, # don't spill + "memory_target_fraction": False, + "memory_pause_fraction": 0.5, + }, + timeout=20, +) def test_pause_executor(c, s, a): memory = psutil.Process().memory_info().rss a.memory_limit = memory / 0.5 + 200e6 - np = pytest.importorskip('numpy') + np = pytest.importorskip("numpy") def f(): - x = np.ones(int(400e6), dtype='u1') + x = np.ones(int(400e6), dtype="u1") sleep(1) - with captured_logger(logging.getLogger('distributed.worker')) as logger: + with captured_logger(logging.getLogger("distributed.worker")) as logger: future = c.submit(f) futures = c.map(slowinc, range(30), delay=0.1) start = time() while not a.paused: yield gen.sleep(0.01) - assert time() < start + 4, (format_bytes(psutil.Process().memory_info().rss), - format_bytes(a.memory_limit), - len(a.data)) + assert time() < start + 4, ( + format_bytes(psutil.Process().memory_info().rss), + format_bytes(a.memory_limit), + len(a.data), + ) out = logger.getvalue() - assert 'memory' in out.lower() - assert 'pausing' in out.lower() + assert "memory" in out.lower() + assert "pausing" in out.lower() - assert sum(f.status == 'finished' for f in futures) < 4 + assert sum(f.status == "finished" for f in futures) < 4 yield wait(futures) -@gen_cluster(client=True, worker_kwargs={'profile_cycle_interval': '50 ms'}) +@gen_cluster(client=True, worker_kwargs={"profile_cycle_interval": "50 ms"}) def test_statistical_profiling_cycle(c, s, a, b): futures = c.map(slowinc, range(20), delay=0.05) yield wait(futures) @@ -1065,30 +1119,29 @@ def test_statistical_profiling_cycle(c, s, a, b): assert len(a.profile_history) > 3 x = a.get_profile(start=time() + 10, stop=time() + 20) - assert not x['count'] + assert not x["count"] x = a.get_profile(start=0, stop=time()) - actual = sum(p['count'] for _, p in a.profile_history) + a.profile_recent['count'] + actual = sum(p["count"] for _, p in a.profile_history) + a.profile_recent["count"] x2 = a.get_profile(start=0, stop=time()) - assert x['count'] <= actual <= x2['count'] + assert x["count"] <= actual <= x2["count"] y = a.get_profile(start=end - 0.300, stop=time()) - assert 0 < y['count'] <= x['count'] + assert 0 < y["count"] <= x["count"] @gen_cluster(client=True) def test_get_current_task(c, s, a, b): - def some_name(): return get_worker().get_current_task() result = yield c.submit(some_name) - assert result.startswith('some_name') + assert result.startswith("some_name") -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 2) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 2) def test_reschedule(c, s, a, b): - s.extensions['stealing']._pc.stop() + s.extensions["stealing"]._pc.stop() a_address = a.address def f(x): @@ -1105,21 +1158,23 @@ def f(x): def test_deque_handler(): from distributed.worker import logger - w = Worker('127.0.0.1', 8019) + + w = Worker("127.0.0.1", 8019) deque_handler = w._deque_handler - logger.info('foo456') + logger.info("foo456") assert deque_handler.deque msg = deque_handler.deque[-1] - assert 'distributed.worker' in deque_handler.format(msg) - assert any(msg.msg == 'foo456' for msg in deque_handler.deque) + assert "distributed.worker" in deque_handler.format(msg) + assert any(msg.msg == "foo456" for msg in deque_handler.deque) @gen_cluster(ncores=[], client=True) def test_avoid_memory_monitor_if_zero_limit(c, s): - worker = yield Worker(s.address, loop=s.loop, memory_limit=0, - memory_monitor_interval=10) + worker = yield Worker( + s.address, loop=s.loop, memory_limit=0, memory_monitor_interval=10 + ) assert type(worker.data) is dict - assert 'memory' not in worker.periodic_callbacks + assert "memory" not in worker.periodic_callbacks future = c.submit(inc, 1) assert (yield future) == 2 @@ -1130,9 +1185,13 @@ def test_avoid_memory_monitor_if_zero_limit(c, s): yield worker._close() -@gen_cluster(ncores=[('127.0.0.1', 1)], - config={'distributed.worker.memory.spill': False, - 'distributed.worker.memory.target': False}) +@gen_cluster( + ncores=[("127.0.0.1", 1)], + config={ + "distributed.worker.memory.spill": False, + "distributed.worker.memory.target": False, + }, +) def test_dict_data_if_no_spill_to_disk(s, w): assert type(w.data) is dict @@ -1147,20 +1206,19 @@ def func(dask_scheduler): return list(dask_scheduler.clients) start = time() - while not any('worker' in n for n in client.run_on_scheduler(func)): + while not any("worker" in n for n in client.run_on_scheduler(func)): sleep(0.1) assert time() < start + 10 -@gen_cluster(ncores=[('127.0.0.1', 1)], - worker_kwargs={'memory_limit': '2e3 MB'}) +@gen_cluster(ncores=[("127.0.0.1", 1)], worker_kwargs={"memory_limit": "2e3 MB"}) def test_parse_memory_limit(s, w): assert w.memory_limit == 2e9 @gen_cluster(ncores=[], client=True) def test_scheduler_address_config(c, s): - with dask.config.set({'scheduler-address': s.address}): + with dask.config.set({"scheduler-address": s.address}): worker = yield Worker(loop=s.loop) assert worker.scheduler.address == s.address yield worker._close() @@ -1169,7 +1227,7 @@ def test_scheduler_address_config(c, s): @slow @gen_cluster(client=True) def test_wait_for_outgoing(c, s, a, b): - np = pytest.importorskip('numpy') + np = pytest.importorskip("numpy") x = np.random.random(10000000) future = yield c.scatter(x, workers=a.address) @@ -1177,36 +1235,39 @@ def test_wait_for_outgoing(c, s, a, b): yield wait(y) assert len(b.incoming_transfer_log) == len(a.outgoing_transfer_log) == 1 - bb = b.incoming_transfer_log[0]['duration'] - aa = a.outgoing_transfer_log[0]['duration'] + bb = b.incoming_transfer_log[0]["duration"] + aa = a.outgoing_transfer_log[0]["duration"] ratio = aa / bb assert 1 / 3 < ratio < 3 -@pytest.mark.skipif(not sys.platform.startswith('linux'), - reason="Need 127.0.0.2 to mean localhost") -@gen_cluster(ncores=[('127.0.0.1', 1), ('127.0.0.1', 1), ('127.0.0.2', 1)], - client=True) +@pytest.mark.skipif( + not sys.platform.startswith("linux"), reason="Need 127.0.0.2 to mean localhost" +) +@gen_cluster(ncores=[("127.0.0.1", 1), ("127.0.0.1", 1), ("127.0.0.2", 1)], client=True) def test_prefer_gather_from_local_address(c, s, w1, w2, w3): x = yield c.scatter(123, workers=[w1.address, w3.address], broadcast=True) y = c.submit(inc, x, workers=[w2.address]) yield wait(y) - assert any(d['who'] == w2.address for d in w1.outgoing_transfer_log) - assert not any(d['who'] == w2.address for d in w3.outgoing_transfer_log) + assert any(d["who"] == w2.address for d in w1.outgoing_transfer_log) + assert not any(d["who"] == w2.address for d in w3.outgoing_transfer_log) -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 20, timeout=30, - config={'distributed.worker.connections.incoming': 1}) +@gen_cluster( + client=True, + ncores=[("127.0.0.1", 1)] * 20, + timeout=30, + config={"distributed.worker.connections.incoming": 1}, +) def test_avoid_oversubscription(c, s, *workers): - np = pytest.importorskip('numpy') + np = pytest.importorskip("numpy") x = c.submit(np.random.random, 1000000, workers=[workers[0].address]) yield wait(x) - futures = [c.submit(len, x, pure=False, workers=[w.address]) - for w in workers[1:]] + futures = [c.submit(len, x, pure=False, workers=[w.address]) for w in workers[1:]] yield wait(futures) @@ -1217,31 +1278,33 @@ def test_avoid_oversubscription(c, s, *workers): assert len([w for w in workers if len(w.outgoing_transfer_log) > 0]) >= 3 -@gen_cluster(client=True, worker_kwargs={'metrics': {'my_port': lambda w: w.port}}) +@gen_cluster(client=True, worker_kwargs={"metrics": {"my_port": lambda w: w.port}}) def test_custom_metrics(c, s, a, b): - assert s.workers[a.address].metrics['my_port'] == a.port - assert s.workers[b.address].metrics['my_port'] == b.port + assert s.workers[a.address].metrics["my_port"] == a.port + assert s.workers[b.address].metrics["my_port"] == b.port @gen_cluster(client=True) def test_register_worker_callbacks(c, s, a, b): - #preload function to run + # preload function to run def mystartup(dask_worker): dask_worker.init_variable = 1 def mystartup2(): import os - os.environ['MY_ENV_VALUE'] = 'WORKER_ENV_VALUE' + + os.environ["MY_ENV_VALUE"] = "WORKER_ENV_VALUE" return "Env set." - #Check that preload function has been run + # Check that preload function has been run def test_import(dask_worker): - return hasattr(dask_worker, 'init_variable') + return hasattr(dask_worker, "init_variable") # and dask_worker.init_variable == 1 def test_startup2(): import os - return os.getenv('MY_ENV_VALUE', None) == 'WORKER_ENV_VALUE' + + return os.getenv("MY_ENV_VALUE", None) == "WORKER_ENV_VALUE" # Nothing has been run yet assert len(s.worker_setups) == 0 @@ -1309,7 +1372,7 @@ def __init__(self, x, y): self.x = x self.y = y - w = yield Worker(s.address, data=(Data, {'x': 123, 'y': 456})) + w = yield Worker(s.address, data=(Data, {"x": 123, "y": 456})) assert w.data.x == 123 assert w.data.y == 456 yield w._close() diff --git a/distributed/tests/test_worker_client.py b/distributed/tests/test_worker_client.py index b5d96edb238..b0dd338153f 100644 --- a/distributed/tests/test_worker_client.py +++ b/distributed/tests/test_worker_client.py @@ -10,8 +10,14 @@ import pytest from tornado import gen -from distributed import (worker_client, Client, as_completed, get_worker, wait, - get_client) +from distributed import ( + worker_client, + Client, + as_completed, + get_worker, + wait, + get_client, +) from distributed.metrics import time from distributed.utils_test import double, gen_cluster, inc from distributed.utils_test import client, cluster_fixture, loop # noqa: F401 @@ -33,11 +39,10 @@ def func(x): assert yy == 20 + 1 + (20 + 1) * 2 assert len(s.transition_log) > 10 - assert len([id for id in s.wants_what - if id.lower().startswith('client')]) == 1 + assert len([id for id in s.wants_what if id.lower().startswith("client")]) == 1 -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 2) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 2) def test_scatter_from_worker(c, s, a, b): def func(): with worker_client() as c: @@ -64,8 +69,8 @@ def func(): correct &= type(futures) == type(data) o = object() - futures = c.scatter({'x': o}) - correct &= get_worker().data['x'] is o + futures = c.scatter({"x": o}) + correct &= get_worker().data["x"] is o return correct future = c.submit(func) @@ -78,9 +83,9 @@ def func(): assert time() < start + 5 -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 2) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 2) def test_scatter_singleton(c, s, a, b): - np = pytest.importorskip('numpy') + np = pytest.importorskip("numpy") def func(): with worker_client() as c: @@ -91,7 +96,7 @@ def func(): yield c.submit(func) -@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 2) +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 2) def test_gather_multi_machine(c, s, a, b): a_address = a.address b_address = b.address @@ -157,7 +162,7 @@ def mysum(): assert time() < start + 3 -@gen_cluster(client=True, ncores=[('127.0.0.1', 3)]) +@gen_cluster(client=True, ncores=[("127.0.0.1", 3)]) def test_separate_thread_false(c, s, a): a.count = 0 @@ -230,13 +235,13 @@ def func(x): yield wait(c.map(func, range(10))) yield a._close() - assert c.status == 'running' + assert c.status == "running" def test_timeout(client): def func(): with worker_client(timeout=0) as wc: - print('hello') + print("hello") future = client.submit(func) with pytest.raises(EnvironmentError): @@ -254,12 +259,13 @@ def test_secede_without_stealing_issue_1262(): # run the loop as an inner function so all workers are closed # and exceptions can be examined - @gen_cluster(client=True, scheduler_kwargs={'extensions': extensions}) + @gen_cluster(client=True, scheduler_kwargs={"extensions": extensions}) def secede_test(c, s, a, b): def func(x): with worker_client() as wc: y = wc.submit(lambda: 1 + x) return wc.gather(y) + f = yield c.gather(c.submit(func, 1)) raise gen.Return((c, s, a, b, f)) @@ -273,7 +279,6 @@ def func(x): @gen_cluster(client=True) def test_compute_within_worker_client(c, s, a, b): - @dask.delayed def f(): with worker_client(): @@ -298,9 +303,10 @@ def f(): @gen_cluster() def test_submit_different_names(s, a, b): # https://github.com/dask/distributed/issues/2058 - da = pytest.importorskip('dask.array') - c = yield Client('localhost:' + s.address.split(":")[-1], loop=s.loop, - asynchronous=True) + da = pytest.importorskip("dask.array") + c = yield Client( + "localhost:" + s.address.split(":")[-1], loop=s.loop, asynchronous=True + ) try: X = c.persist(da.random.uniform(size=(100, 10), chunks=50)) yield wait(X) diff --git a/distributed/threadpoolexecutor.py b/distributed/threadpoolexecutor.py index 84c08f447da..d2d4e3b7921 100644 --- a/distributed/threadpoolexecutor.py +++ b/distributed/threadpoolexecutor.py @@ -61,7 +61,7 @@ def _worker(executor, work_queue): return del executor except BaseException: - logger.critical('Exception in worker', exc_info=True) + logger.critical("Exception in worker", exc_info=True) finally: del thread_state.proceed del thread_state.executor @@ -75,13 +75,18 @@ def __init__(self, *args, **kwargs): super(ThreadPoolExecutor, self).__init__(*args, **kwargs) self._rejoin_list = [] self._rejoin_lock = threading.Lock() - self._thread_name_prefix = kwargs.get('thread_name_prefix', 'DaskThreadPoolExecutor') + self._thread_name_prefix = kwargs.get( + "thread_name_prefix", "DaskThreadPoolExecutor" + ) def _adjust_thread_count(self): if len(self._threads) < self._max_workers: - t = threading.Thread(target=_worker, - name=self._thread_name_prefix + "-%d-%d" % (os.getpid(), next(self._counter)), - args=(self, self._work_queue)) + t = threading.Thread( + target=_worker, + name=self._thread_name_prefix + + "-%d-%d" % (os.getpid(), next(self._counter)), + args=(self, self._work_queue), + ) t.daemon = True self._threads.add(t) t.start() diff --git a/distributed/utils.py b/distributed/utils.py index dbc27251758..5259e567358 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -40,6 +40,7 @@ import tornado from tornado import gen from tornado.ioloop import IOLoop + try: from tornado.ioloop import PollIOLoop except ImportError: @@ -57,17 +58,17 @@ logger = _logger = logging.getLogger(__name__) -no_default = '__no_default__' +no_default = "__no_default__" def _initialize_mp_context(): - if PY3 and not sys.platform.startswith('win') and 'PyPy' not in sys.version: - method = dask.config.get('distributed.worker.multiprocessing-method') + if PY3 and not sys.platform.startswith("win") and "PyPy" not in sys.version: + method = dask.config.get("distributed.worker.multiprocessing-method") ctx = multiprocessing.get_context(method) # Makes the test suite much faster - preload = ['distributed'] - if 'pkg_resources' in sys.modules: - preload.append('pkg_resources') + preload = ["distributed"] + if "pkg_resources" in sys.modules: + preload.append("pkg_resources") ctx.set_forkserver_preload(preload) else: ctx = multiprocessing @@ -80,7 +81,7 @@ def _initialize_mp_context(): def funcname(func): """Get the name of a function.""" - while hasattr(func, 'func'): + while hasattr(func, "func"): func = func.func try: return func.__name__ @@ -129,29 +130,31 @@ def _get_ip(host, port, family, default): return ip except EnvironmentError as e: # XXX Should first try getaddrinfo() on socket.gethostname() and getfqdn() - warnings.warn("Couldn't detect a suitable IP address for " - "reaching %r, defaulting to %r: %s" - % (host, default, e), RuntimeWarning) + warnings.warn( + "Couldn't detect a suitable IP address for " + "reaching %r, defaulting to %r: %s" % (host, default, e), + RuntimeWarning, + ) return default finally: sock.close() -def get_ip(host='8.8.8.8', port=80): +def get_ip(host="8.8.8.8", port=80): """ Get the local IP address through which the *host* is reachable. *host* defaults to a well-known Internet host (one of Google's public DNS servers). """ - return _get_ip(host, port, family=socket.AF_INET, default='127.0.0.1') + return _get_ip(host, port, family=socket.AF_INET, default="127.0.0.1") -def get_ipv6(host='2001:4860:4860::8888', port=80): +def get_ipv6(host="2001:4860:4860::8888", port=80): """ The same as get_ip(), but for IPv6. """ - return _get_ip(host, port, family=socket.AF_INET6, default='::1') + return _get_ip(host, port, family=socket.AF_INET6, default="::1") def get_ip_interface(ifname): @@ -163,6 +166,7 @@ def get_ip_interface(ifname): associated with it. """ import psutil + for info in psutil.net_if_addrs()[ifname]: if info.family == socket.AF_INET: return info.address @@ -213,6 +217,7 @@ def All(args, quiet_exceptions=()): try: result = yield tasks.next() except Exception: + @gen.coroutine def quiet(): """ Watch unfinished tasks @@ -226,6 +231,7 @@ def quiet(): yield task except quiet_exceptions: pass + quiet() raise @@ -251,6 +257,7 @@ def Any(args, quiet_exceptions=()): try: result = yield tasks.next() except Exception: + @gen.coroutine def quiet(): """ Watch unfinished tasks @@ -264,6 +271,7 @@ def quiet(): yield task except quiet_exceptions: pass + quiet() raise @@ -277,8 +285,10 @@ def sync(loop, func, *args, **kwargs): Run coroutine in loop running in separate thread. """ # Tornado's PollIOLoop doesn't raise when using closed, do it ourselves - if PollIOLoop and ((isinstance(loop, PollIOLoop) and getattr(loop, '_closing', False)) or - (hasattr(loop, 'asyncio_loop') and loop.asyncio_loop._closed)): + if PollIOLoop and ( + (isinstance(loop, PollIOLoop) and getattr(loop, "_closing", False)) + or (hasattr(loop, "asyncio_loop") and loop.asyncio_loop._closed) + ): raise RuntimeError("IOLoop is closed") try: if loop.asyncio_loop.is_closed(): # tornado 6 @@ -286,7 +296,7 @@ def sync(loop, func, *args, **kwargs): except AttributeError: pass - timeout = kwargs.pop('callback_timeout', None) + timeout = kwargs.pop("callback_timeout", None) e = threading.Event() main_tid = get_thread_identity() @@ -339,6 +349,7 @@ class LoopRunner(object): If true, the loop is meant to run in the thread this object is instantiated from, and will not be started automatically. """ + # All loops currently associated to loop runners _all_loops = weakref.WeakKeyDictionary() _lock = threading.Lock() @@ -376,7 +387,7 @@ def _start_unlocked(self): assert not self._started count, real_runner = self._all_loops[self._loop] - if (self._asynchronous or real_runner is not None or count > 0): + if self._asynchronous or real_runner is not None or count > 0: self._all_loops[self._loop] = count + 1, real_runner self._started = True return @@ -414,7 +425,9 @@ def run_loop(loop=self._loop): # Loop already running in other thread (user-launched) done_evt.wait(5) if not isinstance(start_exc[0], RuntimeError): - if not isinstance(start_exc[0], Exception): # track down infrequent error + if not isinstance( + start_exc[0], Exception + ): # track down infrequent error raise TypeError("not an exception", start_exc[0]) raise start_exc[0] self._all_loops[self._loop] = count + 1, None @@ -507,7 +520,7 @@ def set_thread_state(**kwargs): @contextmanager def tmp_text(filename, text): fn = os.path.join(tempfile.gettempdir(), filename) - with open(fn, 'w') as f: + with open(fn, "w") as f: f.write(text) try: @@ -529,14 +542,15 @@ def is_kernel(): False """ # http://stackoverflow.com/questions/34091701/determine-if-were-in-an-ipython-notebook-session - if 'IPython' not in sys.modules: # IPython hasn't been imported + if "IPython" not in sys.modules: # IPython hasn't been imported return False from IPython import get_ipython + # check for `kernel` attribute on the IPython instance - return getattr(get_ipython(), 'kernel', None) is not None + return getattr(get_ipython(), "kernel", None) is not None -hex_pattern = re.compile('[a-f]+') +hex_pattern = re.compile("[a-f]+") def key_split(s): @@ -571,25 +585,26 @@ def key_split(s): if type(s) is tuple: s = s[0] try: - words = s.split('-') + words = s.split("-") if not words[0][0].isalpha(): result = words[0].split(",")[0].strip("'(\"") else: result = words[0] for word in words[1:]: - if word.isalpha() and not (len(word) == 8 and - hex_pattern.match(word) is not None): - result += '-' + word + if word.isalpha() and not ( + len(word) == 8 and hex_pattern.match(word) is not None + ): + result += "-" + word else: break - if len(result) == 32 and re.match(r'[a-f0-9]{32}', result): - return 'data' + if len(result) == 32 and re.match(r"[a-f0-9]{32}", result): + return "data" else: - if result[0] == '<': - result = result.strip('<>').split()[0].split('.')[-1] + if result[0] == "<": + result = result.strip("<>").split()[0].split(".")[-1] return result except Exception: - return 'Other' + return "Other" try: @@ -601,6 +616,7 @@ def key_split(s): key_split = lru_cache(100000)(key_split) if PY3: + def key_split_group(x): """A more fine-grained version of key_split @@ -631,19 +647,22 @@ def key_split_group(x): if typ is tuple: return x[0] elif typ is str: - if x[0] == '(': - return x.split(',', 1)[0].strip('()"\'') - elif len(x) == 32 and re.match(r'[a-f0-9]{32}', x): - return 'data' - elif x[0] == '<': - return x.strip('<>').split()[0].split('.')[-1] + if x[0] == "(": + return x.split(",", 1)[0].strip("()\"'") + elif len(x) == 32 and re.match(r"[a-f0-9]{32}", x): + return "data" + elif x[0] == "<": + return x.strip("<>").split()[0].split(".")[-1] else: return x elif typ is bytes: return key_split_group(x.decode()) else: - return 'Other' + return "Other" + + else: + def key_split_group(x): """A more fine-grained version of key_split @@ -674,21 +693,22 @@ def key_split_group(x): if typ is tuple: return x[0] elif typ is str or typ is unicode: - if x[0] == '(': - return x.split(',', 1)[0].strip('()"\'') - elif len(x) == 32 and re.match(r'[a-f0-9]{32}', x): - return 'data' - elif x[0] == '<': - return x.strip('<>').split()[0].split('.')[-1] + if x[0] == "(": + return x.split(",", 1)[0].strip("()\"'") + elif len(x) == 32 and re.match(r"[a-f0-9]{32}", x): + return "data" + elif x[0] == "<": + return x.strip("<>").split()[0].split(".")[-1] else: return x else: - return 'Other' + return "Other" @contextmanager def log_errors(pdb=False): from .comm import CommClosedError + try: yield except (CommClosedError, gen.Return): @@ -700,11 +720,12 @@ def log_errors(pdb=False): pass if pdb: import pdb + pdb.set_trace() raise -def silence_logging(level, root='distributed'): +def silence_logging(level, root="distributed"): """ Force all existing loggers below *root* to the given level at least (or keep the existing level if less verbose). @@ -737,9 +758,9 @@ def ensure_ip(hostname): families = [socket.AF_INET, socket.AF_INET6] for fam in families: try: - results = socket.getaddrinfo(hostname, - 1234, # dummy port number - fam, socket.SOCK_STREAM) + results = socket.getaddrinfo( + hostname, 1234, fam, socket.SOCK_STREAM # dummy port number + ) except socket.gaierror as e: exc = e else: @@ -753,12 +774,15 @@ def ensure_ip(hostname): def get_traceback(): exc_type, exc_value, exc_traceback = sys.exc_info() - bad = [os.path.join('distributed', 'worker'), - os.path.join('distributed', 'scheduler'), - os.path.join('tornado', 'gen.py'), - os.path.join('concurrent', 'futures')] - while exc_traceback and any(b in exc_traceback.tb_frame.f_code.co_filename - for b in bad): + bad = [ + os.path.join("distributed", "worker"), + os.path.join("distributed", "scheduler"), + os.path.join("tornado", "gen.py"), + os.path.join("concurrent", "futures"), + ] + while exc_traceback and any( + b in exc_traceback.tb_frame.f_code.co_filename for b in bad + ): exc_traceback = exc_traceback.tb_next return exc_traceback @@ -767,25 +791,24 @@ def truncate_exception(e, n=10000): """ Truncate exception to be about a certain length """ if len(str(e)) > n: try: - return type(e)("Long error message", - str(e)[:n]) + return type(e)("Long error message", str(e)[:n]) except Exception: - return Exception("Long error message", - type(e), - str(e)[:n]) + return Exception("Long error message", type(e), str(e)[:n]) else: return e if sys.version_info >= (3,): # (re-)raising StopIteration is deprecated in 3.6+ - exec("""def queue_to_iterator(q): + exec( + """def queue_to_iterator(q): while True: result = q.get() if isinstance(result, StopIteration): return result.value yield result - """) + """ + ) else: # Returning non-None from generator is a syntax error in 2.x def queue_to_iterator(q): @@ -836,15 +859,18 @@ def validate_key(k): """ typ = type(k) if typ is not unicode and typ is not bytes: - raise TypeError("Unexpected key type %s (value: %r)" - % (typ, k)) + raise TypeError("Unexpected key type %s (value: %r)" % (typ, k)) def _maybe_complex(task): """ Possibly contains a nested task """ - return (istask(task) or - type(task) is list and any(map(_maybe_complex, task)) or - type(task) is dict and any(map(_maybe_complex, task.values()))) + return ( + istask(task) + or type(task) is list + and any(map(_maybe_complex, task)) + or type(task) is dict + and any(map(_maybe_complex, task.values())) + ) def convert(task, dsk, extra_values): @@ -884,7 +910,7 @@ def seek_delimiter(file, delimiter, blocksize): if file.tell() == 0: return - last = b'' + last = b"" while True: current = file.read(blocksize) if not current: @@ -896,7 +922,7 @@ def seek_delimiter(file, delimiter, blocksize): return except ValueError: pass - last = full[-len(delimiter):] + last = full[-len(delimiter) :] def read_block(f, offset, length, delimiter=None): @@ -935,12 +961,12 @@ def read_block(f, offset, length, delimiter=None): """ if delimiter: f.seek(offset) - seek_delimiter(f, delimiter, 2**16) + seek_delimiter(f, delimiter, 2 ** 16) start = f.tell() length -= start - offset f.seek(start + length) - seek_delimiter(f, delimiter, 2**16) + seek_delimiter(f, delimiter, 2 ** 16) end = f.tell() offset = start @@ -952,8 +978,8 @@ def read_block(f, offset, length, delimiter=None): @contextmanager -def tmpfile(extension=''): - extension = '.' + extension.lstrip('.') +def tmpfile(extension=""): + extension = "." + extension.lstrip(".") handle, filename = tempfile.mkstemp(extension) os.close(handle) os.remove(filename) @@ -984,10 +1010,9 @@ def ensure_bytes(s): return s.tobytes() if isinstance(s, bytearray) or PY2 and isinstance(s, buffer): # noqa: F821 return bytes(s) - if hasattr(s, 'encode'): + if hasattr(s, "encode"): return s.encode() - raise TypeError( - "Object %s is neither a bytes object nor has an encode method" % s) + raise TypeError("Object %s is neither a bytes object nor has an encode method" % s) def divide_n_among_bins(n, bins): @@ -1020,9 +1045,11 @@ def mean(seq): if hasattr(sys, "is_finalizing"): + def shutting_down(is_finalizing=sys.is_finalizing): return is_finalizing() + else: _shutting_down = [False] @@ -1043,7 +1070,7 @@ def shutting_down(l=_shutting_down): """ -def open_port(host=''): +def open_port(host=""): """ Return a probably-open port There is a chance that this port will be taken by the operating system soon @@ -1065,23 +1092,24 @@ def import_file(path): names_to_import = [] tmp_python_path = None - if ext in ('.py',): # , '.pyc'): + if ext in (".py",): # , '.pyc'): if directory not in sys.path: tmp_python_path = directory names_to_import.append(name) - if ext == '.py': # Ensure that no pyc file will be reused + if ext == ".py": # Ensure that no pyc file will be reused cache_file = cache_from_source(path) with ignoring(OSError): os.remove(cache_file) - if ext in ('.egg', '.zip', '.pyz'): + if ext in (".egg", ".zip", ".pyz"): if path not in sys.path: sys.path.insert(0, path) - if ext == '.egg': + if ext == ".egg": import pkg_resources + pkgs = pkg_resources.find_distributions(path) for pkg in pkgs: names_to_import.append(pkg.project_name) - elif ext in ('.zip', '.pyz'): + elif ext in (".zip", ".pyz"): names_to_import.append(name) loaded = [] @@ -1111,7 +1139,8 @@ class itemgetter(object): >>> get_1(data) 1 """ - __slots__ = ('index',) + + __slots__ = ("index",) def __init__(self, index): self.index = index @@ -1140,35 +1169,35 @@ def format_bytes(n): '1.23 PB' """ if n > 1e15: - return '%0.2f PB' % (n / 1e15) + return "%0.2f PB" % (n / 1e15) if n > 1e12: - return '%0.2f TB' % (n / 1e12) + return "%0.2f TB" % (n / 1e12) if n > 1e9: - return '%0.2f GB' % (n / 1e9) + return "%0.2f GB" % (n / 1e9) if n > 1e6: - return '%0.2f MB' % (n / 1e6) + return "%0.2f MB" % (n / 1e6) if n > 1e3: - return '%0.2f kB' % (n / 1000) - return '%d B' % n + return "%0.2f kB" % (n / 1000) + return "%d B" % n byte_sizes = { - 'kB': 10**3, - 'MB': 10**6, - 'GB': 10**9, - 'TB': 10**12, - 'PB': 10**15, - 'KiB': 2**10, - 'MiB': 2**20, - 'GiB': 2**30, - 'TiB': 2**40, - 'PiB': 2**50, - 'B': 1, - '': 1, + "kB": 10 ** 3, + "MB": 10 ** 6, + "GB": 10 ** 9, + "TB": 10 ** 12, + "PB": 10 ** 15, + "KiB": 2 ** 10, + "MiB": 2 ** 20, + "GiB": 2 ** 30, + "TiB": 2 ** 40, + "PiB": 2 ** 50, + "B": 1, + "": 1, } byte_sizes = {k.lower(): v for k, v in byte_sizes.items()} -byte_sizes.update({k[0]: v for k, v in byte_sizes.items() if k and 'i' not in k}) -byte_sizes.update({k[:-1]: v for k, v in byte_sizes.items() if k and 'i' in k}) +byte_sizes.update({k[0]: v for k, v in byte_sizes.items() if k and "i" not in k}) +byte_sizes.update({k[:-1]: v for k, v in byte_sizes.items() if k and "i" in k}) def parse_bytes(s): @@ -1193,9 +1222,9 @@ def parse_bytes(s): >>> parse_bytes('MB') 1000000 """ - s = s.replace(' ', '') + s = s.replace(" ", "") if not s[0].isdigit(): - s = '1' + s + s = "1" + s for i in range(len(s) - 1, -1, -1): if not s[i].isalpha(): @@ -1214,30 +1243,30 @@ def parse_bytes(s): timedelta_sizes = { - 's': 1, - 'ms': 1e-3, - 'us': 1e-6, - 'ns': 1e-9, - 'm': 60, - 'h': 3600, - 'd': 3600 * 24, + "s": 1, + "ms": 1e-3, + "us": 1e-6, + "ns": 1e-9, + "m": 60, + "h": 3600, + "d": 3600 * 24, } tds2 = { - 'second': 1, - 'minute': 60, - 'hour': 60 * 60, - 'day': 60 * 60 * 24, - 'millisecond': 1e-3, - 'microsecond': 1e-6, - 'nanosecond': 1e-9, + "second": 1, + "minute": 60, + "hour": 60 * 60, + "day": 60 * 60 * 24, + "millisecond": 1e-3, + "microsecond": 1e-6, + "nanosecond": 1e-9, } -tds2.update({k + 's': v for k, v in tds2.items()}) +tds2.update({k + "s": v for k, v in tds2.items()}) timedelta_sizes.update(tds2) timedelta_sizes.update({k.upper(): v for k, v in timedelta_sizes.items()}) -def parse_timedelta(s, default='seconds'): +def parse_timedelta(s, default="seconds"): """ Parse timedelta string to number of seconds Examples @@ -1255,9 +1284,9 @@ def parse_timedelta(s, default='seconds'): return s.total_seconds() if isinstance(s, Number): s = str(s) - s = s.replace(' ', '') + s = s.replace(" ", "") if not s[0].isdigit(): - s = '1' + s + s = "1" + s for i in range(len(s) - 1, -1, -1): if not s[i].isalpha(): @@ -1290,16 +1319,16 @@ def asciitable(columns, rows): """ rows = [tuple(str(i) for i in r) for r in rows] columns = tuple(str(i) for i in columns) - widths = tuple(max(max(map(len, x)), len(c)) - for x, c in zip(zip(*rows), columns)) - row_template = ('|' + (' %%-%ds |' * len(columns))) % widths + widths = tuple(max(max(map(len, x)), len(c)) for x, c in zip(zip(*rows), columns)) + row_template = ("|" + (" %%-%ds |" * len(columns))) % widths header = row_template % tuple(columns) - bar = '+%s+' % '+'.join('-' * (w + 2) for w in widths) - data = '\n'.join(row_template % r for r in rows) - return '\n'.join([bar, header, bar, data, bar]) + bar = "+%s+" % "+".join("-" * (w + 2) for w in widths) + data = "\n".join(row_template % r for r in rows) + return "\n".join([bar, header, bar, data, bar]) if PY2: + def nbytes(frame, _bytes_like=(bytes, bytearray, buffer)): # noqa: F821 """ Number of bytes of a frame or memoryview """ if isinstance(frame, _bytes_like): @@ -1308,11 +1337,13 @@ def nbytes(frame, _bytes_like=(bytes, bytearray, buffer)): # noqa: F821 if frame.shape is None: return frame.itemsize else: - return functools.reduce(operator.mul, frame.shape, - frame.itemsize) + return functools.reduce(operator.mul, frame.shape, frame.itemsize) else: return frame.nbytes + + else: + def nbytes(frame, _bytes_like=(bytes, bytearray)): """ Number of bytes of a frame or memoryview """ if isinstance(frame, _bytes_like): @@ -1341,7 +1372,7 @@ def time_warn(duration, text): yield end = time() if end - start > duration: - print('TIME WARNING', text, end - start) + print("TIME WARNING", text, end - start) def json_load_robust(fn, load=json.load): @@ -1372,18 +1403,19 @@ def format_time(n): '123.46 s' """ if n >= 1: - return '%.2f s' % n + return "%.2f s" % n if n >= 1e-3: - return '%.2f ms' % (n * 1e3) - return '%.2f us' % (n * 1e6) + return "%.2f ms" % (n * 1e3) + return "%.2f us" % (n * 1e6) class DequeHandler(logging.Handler): """ A logging.Handler that records records into a deque """ + _instances = weakref.WeakSet() def __init__(self, *args, **kwargs): - n = kwargs.pop('n', 10000) + n = kwargs.pop("n", 10000) self.deque = deque(maxlen=n) super(DequeHandler, self).__init__(*args, **kwargs) self._instances.add(self) @@ -1417,22 +1449,25 @@ def reset_logger_locks(): # Only bother if asyncio has been loaded by Tornado -if 'asyncio' in sys.modules and tornado.version_info[0] >= 5: +if "asyncio" in sys.modules and tornado.version_info[0] >= 5: jupyter_event_loop_initialized = False - if 'notebook' in sys.modules: + if "notebook" in sys.modules: import traitlets from notebook.notebookapp import NotebookApp - jupyter_event_loop_initialized = ( - traitlets.config.Application.initialized() and - isinstance(traitlets.config.Application.instance(), NotebookApp) + + jupyter_event_loop_initialized = traitlets.config.Application.initialized() and isinstance( + traitlets.config.Application.instance(), NotebookApp ) if not jupyter_event_loop_initialized: import asyncio import tornado.platform.asyncio - asyncio.set_event_loop_policy(tornado.platform.asyncio.AnyThreadEventLoopPolicy()) + + asyncio.set_event_loop_policy( + tornado.platform.asyncio.AnyThreadEventLoopPolicy() + ) def has_keyword(func, keyword): @@ -1451,9 +1486,26 @@ def has_keyword(func, keyword): # from bokeh.palettes import viridis # palette = viridis(18) -palette = ['#440154', '#471669', '#472A79', '#433C84', '#3C4D8A', '#355D8C', - '#2E6C8E', '#287A8E', '#23898D', '#1E978A', '#20A585', '#2EB27C', - '#45BF6F', '#64CB5D', '#88D547', '#AFDC2E', '#D7E219', '#FDE724'] +palette = [ + "#440154", + "#471669", + "#472A79", + "#433C84", + "#3C4D8A", + "#355D8C", + "#2E6C8E", + "#287A8E", + "#23898D", + "#1E978A", + "#20A585", + "#2EB27C", + "#45BF6F", + "#64CB5D", + "#88D547", + "#AFDC2E", + "#D7E219", + "#FDE724", +] @toolz.memoize diff --git a/distributed/utils_comm.py b/distributed/utils_comm.py index 63622044291..d2bd19908af 100644 --- a/distributed/utils_comm.py +++ b/distributed/utils_comm.py @@ -32,6 +32,7 @@ def gather_from_workers(who_has, rpc, close=True, serializers=None, who=None): _gather """ from .worker import get_data_from_worker + bad_addresses = set() missing_workers = set() original_who_has = who_has @@ -57,11 +58,17 @@ def gather_from_workers(who_has, rpc, close=True, serializers=None, who=None): rpcs = {addr: rpc(addr) for addr in d} try: - coroutines = {address: get_data_from_worker(rpc, keys, address, - who=who, - serializers=serializers, - max_connections=False) - for address, keys in d.items()} + coroutines = { + address: get_data_from_worker( + rpc, + keys, + address, + who=who, + serializers=serializers, + max_connections=False, + ) + for address, keys in d.items() + } response = {} for worker, c in coroutines.items(): try: @@ -69,7 +76,7 @@ def gather_from_workers(who_has, rpc, close=True, serializers=None, who=None): except EnvironmentError: missing_workers.add(worker) else: - response.update(r['data']) + response.update(r["data"]) finally: for r in rpcs.values(): r.close_rpc() @@ -91,6 +98,7 @@ class WrappedKey(object): only be accessed in a certain way. Schedulers may have particular needs that can only be addressed by additional metadata. """ + def __init__(self, key): self.key = key @@ -122,19 +130,23 @@ def scatter_to_workers(ncores, data, rpc=rpc, report=True, serializers=None): L = list(zip(worker_iter, names, data)) d = groupby(0, L) - d = {worker: {key: value for _, key, value in v} - for worker, v in d.items()} + d = {worker: {key: value for _, key, value in v} for worker, v in d.items()} rpcs = {addr: rpc(addr) for addr in d} try: - out = yield All([rpcs[address].update_data(data=v, report=report, - serializers=serializers) - for address, v in d.items()]) + out = yield All( + [ + rpcs[address].update_data( + data=v, report=report, serializers=serializers + ) + for address, v in d.items() + ] + ) finally: for r in rpcs.values(): r.close_rpc() - nbytes = merge(o['nbytes'] for o in out) + nbytes = merge(o["nbytes"] for o in out) who_has = {k: [w for w, _, _ in v] for k, v in groupby(1, L).items()} @@ -184,15 +196,23 @@ def unpack_remotedata(o, byte_keys=False, myset=None): if type(o[0]) is SubgraphCallable: sc = o[0] futures = set() - dsk = {k: unpack_remotedata(v, byte_keys, futures) - for k, v in sc.dsk.items()} + dsk = { + k: unpack_remotedata(v, byte_keys, futures) for k, v in sc.dsk.items() + } args = tuple(unpack_remotedata(i, byte_keys, futures) for i in o[1:]) if futures: myset.update(futures) - futures = (tuple(tokey(f.key) for f in futures) - if byte_keys else tuple(f.key for f in futures)) + futures = ( + tuple(tokey(f.key) for f in futures) + if byte_keys + else tuple(f.key for f in futures) + ) inkeys = sc.inkeys + futures - return (SubgraphCallable(dsk, sc.outkey, inkeys, sc.name),) + args + futures + return ( + (SubgraphCallable(dsk, sc.outkey, inkeys, sc.name),) + + args + + futures + ) else: return o else: diff --git a/distributed/utils_perf.py b/distributed/utils_perf.py index c6c1d37b107..9f300c5f567 100644 --- a/distributed/utils_perf.py +++ b/distributed/utils_perf.py @@ -26,6 +26,7 @@ class ThrottledGC(object): to log a warning level message whenever an actual call to gc.collect() lasts too long. """ + def __init__(self, max_in_gc_frac=0.05, warn_if_longer=1, logger=None): self.max_in_gc_frac = max_in_gc_frac self.warn_if_longer = warn_if_longer @@ -41,25 +42,30 @@ def collect(self): collect_start = thread_time() elapsed = max(collect_start - self.last_collect, MIN_RUNTIME) if self.last_gc_duration / elapsed < self.max_in_gc_frac: - self.logger.debug("Calling gc.collect(). %0.3fs elapsed since " - "previous call.", elapsed) + self.logger.debug( + "Calling gc.collect(). %0.3fs elapsed since " "previous call.", elapsed + ) gc.collect() self.last_collect = collect_start self.last_gc_duration = max(thread_time() - collect_start, MIN_RUNTIME) if self.last_gc_duration > self.warn_if_longer: - self.logger.warning("gc.collect() took %0.3fs. This is usually" - " a sign that the some tasks handle too" - " many Python objects at the same time." - " Rechunking the work into smaller tasks" - " might help.", - self.last_gc_duration) + self.logger.warning( + "gc.collect() took %0.3fs. This is usually" + " a sign that the some tasks handle too" + " many Python objects at the same time." + " Rechunking the work into smaller tasks" + " might help.", + self.last_gc_duration, + ) else: - self.logger.debug("gc.collect() took %0.3fs", - self.last_gc_duration) + self.logger.debug("gc.collect() took %0.3fs", self.last_gc_duration) else: - self.logger.debug("gc.collect() lasts %0.3fs but only %0.3fs " - "elapsed since last call: throttling.", - self.last_gc_duration, elapsed) + self.logger.debug( + "gc.collect() lasts %0.3fs but only %0.3fs " + "elapsed since last call: throttling.", + self.last_gc_duration, + elapsed, + ) class FractionalTimer(object): @@ -178,33 +184,42 @@ def __exit__(self, *args): def _gc_callback(self, phase, info): # Young generations are small and collected very often, # don't waste time measuring them - if info['generation'] != 2: + if info["generation"] != 2: return if self._proc is not None: rss = self._proc.memory_info().rss else: rss = 0 - if phase == 'start': + if phase == "start": self._fractional_timer.start_timing() self._gc_rss_before = rss return - assert phase == 'stop' + assert phase == "stop" self._fractional_timer.stop_timing() frac = self._fractional_timer.running_fraction if frac is not None and frac >= self._warn_over_frac: - logger.warning("full garbage collections took %d%% CPU time " - "recently (threshold: %d%%)", - 100 * frac, 100 * self._warn_over_frac) + logger.warning( + "full garbage collections took %d%% CPU time " + "recently (threshold: %d%%)", + 100 * frac, + 100 * self._warn_over_frac, + ) rss_saved = self._gc_rss_before - rss if rss_saved >= self._info_over_rss_win: - logger.info("full garbage collection released %s " - "from %d reference cycles (threshold: %s)", - format_bytes(rss_saved), info['collected'], - format_bytes(self._info_over_rss_win)) - if info['uncollectable'] > 0: + logger.info( + "full garbage collection released %s " + "from %d reference cycles (threshold: %s)", + format_bytes(rss_saved), + info["collected"], + format_bytes(self._info_over_rss_win), + ) + if info["uncollectable"] > 0: # This should ideally never happen on Python 3, but who knows? - logger.warning("garbage collector couldn't collect %d objects, " - "please look in gc.garbage", info['uncollectable']) + logger.warning( + "garbage collector couldn't collect %d objects, " + "please look in gc.garbage", + info["uncollectable"], + ) _gc_diagnosis = GCDiagnosis() diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 11f4df047e9..a3f76e4c477 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -48,9 +48,18 @@ from .process import _cleanup_dangling from .proctitle import enable_proctitle_on_children from .security import Security -from .utils import (ignoring, log_errors, mp_context, get_ip, get_ipv6, - DequeHandler, reset_logger_locks, sync, - iscoroutinefunction, thread_state) +from .utils import ( + ignoring, + log_errors, + mp_context, + get_ip, + get_ipv6, + DequeHandler, + reset_logger_locks, + sync, + iscoroutinefunction, + thread_state, +) from .worker import Worker, TOTAL_MEMORY, _global_workers try: @@ -62,33 +71,38 @@ logger = logging.getLogger(__name__) -logging_levels = {name: logger.level for name, logger in - logging.root.manager.loggerDict.items() - if isinstance(logger, logging.Logger)} +logging_levels = { + name: logger.level + for name, logger in logging.root.manager.loggerDict.items() + if isinstance(logger, logging.Logger) +} offload(lambda: None).result() # create thread during import -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def valid_python_script(tmpdir_factory): - local_file = tmpdir_factory.mktemp('data').join('file.py') + local_file = tmpdir_factory.mktemp("data").join("file.py") local_file.write("print('hello world!')") return local_file -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def client_contract_script(tmpdir_factory): - local_file = tmpdir_factory.mktemp('data').join('distributed_script.py') - lines = ("from distributed import Client", "e = Client('127.0.0.1:8989')", - 'print(e)') - local_file.write('\n'.join(lines)) + local_file = tmpdir_factory.mktemp("data").join("distributed_script.py") + lines = ( + "from distributed import Client", + "e = Client('127.0.0.1:8989')", + "print(e)", + ) + local_file.write("\n".join(lines)) return local_file -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def invalid_python_script(tmpdir_factory): - local_file = tmpdir_factory.mktemp('data').join('file.py') + local_file = tmpdir_factory.mktemp("data").join("file.py") local_file.write("a+1") return local_file @@ -116,6 +130,7 @@ def start(): orig_start() finally: is_stopped.set() + loop.start = start yield loop @@ -151,8 +166,7 @@ def start(): @pytest.fixture def loop_in_thread(): with pristine_loop() as loop: - thread = threading.Thread(target=loop.start, - name="test IOLoop") + thread = threading.Thread(target=loop.start, name="test IOLoop") thread.daemon = True thread.start() loop_started = threading.Event() @@ -166,6 +180,7 @@ def loop_in_thread(): @pytest.fixture def zmq_ctx(): import zmq + ctx = zmq.Context.instance() yield ctx ctx.destroy(linger=0) @@ -193,6 +208,7 @@ def pristine_loop(): def mock_ipython(): import mock from distributed._ipython_utils import remote_magic + ip = mock.Mock() ip.user_ns = {} ip.kernel = None @@ -200,8 +216,9 @@ def mock_ipython(): def get_ip(): return ip - with mock.patch('IPython.get_ipython', get_ip), \ - mock.patch('distributed._ipython_utils.get_ipython', get_ip): + with mock.patch("IPython.get_ipython", get_ip), mock.patch( + "distributed._ipython_utils.get_ipython", get_ip + ): yield ip # cleanup remote_magic client cache for kc in remote_magic._clients.values(): @@ -285,7 +302,7 @@ def deep(n): def throws(x): - raise RuntimeError('hello!') + raise RuntimeError("hello!") def double(x): @@ -309,6 +326,7 @@ def slowdouble(x, delay=0.02): def randominc(x, scale=1): from random import random + sleep(random() * scale) return x + 1 @@ -324,7 +342,7 @@ def slowsum(seq, delay=0.02): def slowidentity(*args, **kwargs): - delay = kwargs.get('delay', 0.02) + delay = kwargs.get("delay", 0.02) sleep(delay) if len(args) == 1: return args[0] @@ -364,7 +382,7 @@ def varying(items): # used by *func* below, so we can't use `global `. # Instead look up the module by name to get the original namespace # and not a copy. - slot = _ModuleSlot(__name__, '_varying_dict') + slot = _ModuleSlot(__name__, "_varying_dict") key = next(_varying_key_gen) def func(): @@ -388,6 +406,7 @@ def map_varying(itemslists): Like *varying*, but return the full specification for a map() call on multiple items lists. """ + def apply(func, *args, **kwargs): return func(*args, **kwargs) @@ -403,17 +422,19 @@ def geninc(x, delay=0.02): def compile_snippet(code, dedent=True): if dedent: code = textwrap.dedent(code) - code = compile(code, '', 'exec') + code = compile(code, "", "exec") ns = globals() exec(code, ns, ns) if sys.version_info >= (3, 5): - compile_snippet(""" + compile_snippet( + """ async def asyncinc(x, delay=0.02): await gen.sleep(delay) return x + 1 - """) + """ + ) assert asyncinc # noqa: F821 else: asyncinc = None @@ -461,7 +482,7 @@ def run_scheduler(q, nputs, **kwargs): # so avoid inheriting the parent's IO loop. with pristine_loop() as loop: scheduler = Scheduler(validate=True, **kwargs) - done = scheduler.start('127.0.0.1') + done = scheduler.start("127.0.0.1") for i in range(nputs): q.put(scheduler.address) @@ -482,6 +503,7 @@ def run_worker(q, scheduler_q, **kwargs): loop.run_sync(lambda: worker._start(0)) q.put(worker.address) try: + @gen.coroutine def wait_until_closed(): yield worker._closed.wait() @@ -523,13 +545,17 @@ def check_active_rpc(loop, active_rpc_timeout=1): # (*) (example: gather_from_workers()) def fail(): - pytest.fail("some RPCs left active by test: %s" - % (set(rpc.active) - active_before)) + pytest.fail( + "some RPCs left active by test: %s" % (set(rpc.active) - active_before) + ) @gen.coroutine def wait(): - yield async_wait_for(lambda: len(set(rpc.active) - active_before) == 0, - timeout=active_rpc_timeout, fail_func=fail) + yield async_wait_for( + lambda: len(set(rpc.active) - active_before) == 0, + timeout=active_rpc_timeout, + fail_func=fail, + ) loop.run_sync(wait) @@ -561,27 +587,28 @@ def b(cluster_fixture): @pytest.fixture def client(loop, cluster_fixture): scheduler, workers = cluster_fixture - with Client(scheduler['address'], loop=loop) as client: + with Client(scheduler["address"], loop=loop) as client: yield client @pytest.fixture def client_secondary(loop, cluster_fixture): scheduler, workers = cluster_fixture - with Client(scheduler['address'], loop=loop) as client: + with Client(scheduler["address"], loop=loop) as client: yield client @contextmanager -def tls_cluster_context(worker_kwargs=None, scheduler_kwargs=None, - security=None, **kwargs): +def tls_cluster_context( + worker_kwargs=None, scheduler_kwargs=None, security=None, **kwargs +): security = security or tls_only_security() - worker_kwargs = assoc(worker_kwargs or {}, 'security', security) - scheduler_kwargs = assoc(scheduler_kwargs or {}, 'security', security) + worker_kwargs = assoc(worker_kwargs or {}, "security", security) + scheduler_kwargs = assoc(scheduler_kwargs or {}, "security", security) - with cluster(worker_kwargs=worker_kwargs, - scheduler_kwargs=scheduler_kwargs, - **kwargs) as (s, workers): + with cluster( + worker_kwargs=worker_kwargs, scheduler_kwargs=scheduler_kwargs, **kwargs + ) as (s, workers): yield s, workers @@ -594,7 +621,7 @@ def tls_cluster(loop, security): @pytest.fixture def tls_client(tls_cluster, loop, security): s, workers = tls_cluster - with Client(s['address'], security=security, loop=loop) as client: + with Client(s["address"], security=security, loop=loop) as client: yield client @@ -604,8 +631,9 @@ def security(): @contextmanager -def cluster(nworkers=2, nanny=False, worker_kwargs={}, active_rpc_timeout=1, - scheduler_kwargs={}): +def cluster( + nworkers=2, nanny=False, worker_kwargs={}, active_rpc_timeout=1, scheduler_kwargs={} +): ws = weakref.WeakSet() reset_config() @@ -626,9 +654,11 @@ def cluster(nworkers=2, nanny=False, worker_kwargs={}, active_rpc_timeout=1, scheduler_q = mp_context.Queue() # Launch scheduler - scheduler = mp_context.Process(target=run_scheduler, - args=(scheduler_q, nworkers + 1), - kwargs=scheduler_kwargs) + scheduler = mp_context.Process( + target=run_scheduler, + args=(scheduler_q, nworkers + 1), + kwargs=scheduler_kwargs, + ) ws.add(scheduler) scheduler.daemon = True scheduler.start() @@ -637,20 +667,22 @@ def cluster(nworkers=2, nanny=False, worker_kwargs={}, active_rpc_timeout=1, workers = [] for i in range(nworkers): q = mp_context.Queue() - fn = '_test_worker-%s' % uuid.uuid4() - kwargs = merge({'ncores': 1, 'local_dir': fn, - 'memory_limit': TOTAL_MEMORY}, worker_kwargs) - proc = mp_context.Process(target=_run_worker, - args=(q, scheduler_q), - kwargs=kwargs) + fn = "_test_worker-%s" % uuid.uuid4() + kwargs = merge( + {"ncores": 1, "local_dir": fn, "memory_limit": TOTAL_MEMORY}, + worker_kwargs, + ) + proc = mp_context.Process( + target=_run_worker, args=(q, scheduler_q), kwargs=kwargs + ) ws.add(proc) - workers.append({'proc': proc, 'queue': q, 'dir': fn}) + workers.append({"proc": proc, "queue": q, "dir": fn}) for worker in workers: - worker['proc'].start() + worker["proc"].start() try: for worker in workers: - worker['address'] = worker['queue'].get(timeout=5) + worker["address"] = worker["queue"].get(timeout=5) except Empty: raise pytest.xfail.Exception("Worker failed to start in test") @@ -659,8 +691,10 @@ def cluster(nworkers=2, nanny=False, worker_kwargs={}, active_rpc_timeout=1, start = time() try: try: - security = scheduler_kwargs['security'] - rpc_kwargs = {'connection_args': security.get_connection_args('client')} + security = scheduler_kwargs["security"] + rpc_kwargs = { + "connection_args": security.get_connection_args("client") + } except KeyError: rpc_kwargs = {} @@ -673,16 +707,23 @@ def cluster(nworkers=2, nanny=False, worker_kwargs={}, active_rpc_timeout=1, raise Exception("Timeout on cluster creation") # avoid sending processes down to function - yield {'address': saddr}, [{'address': w['address'], - 'proc': weakref.ref(w['proc'])} - for w in workers] + yield {"address": saddr}, [ + {"address": w["address"], "proc": weakref.ref(w["proc"])} + for w in workers + ] finally: logger.debug("Closing out test cluster") - loop.run_sync(lambda: disconnect_all([w['address'] for w in workers], - timeout=0.5, - rpc_kwargs=rpc_kwargs)) - loop.run_sync(lambda: disconnect(saddr, timeout=0.5, rpc_kwargs=rpc_kwargs)) + loop.run_sync( + lambda: disconnect_all( + [w["address"] for w in workers], + timeout=0.5, + rpc_kwargs=rpc_kwargs, + ) + ) + loop.run_sync( + lambda: disconnect(saddr, timeout=0.5, rpc_kwargs=rpc_kwargs) + ) scheduler.terminate() scheduler_q.close() @@ -690,21 +731,21 @@ def cluster(nworkers=2, nanny=False, worker_kwargs={}, active_rpc_timeout=1, scheduler_q._writer.close() for w in workers: - w['proc'].terminate() - w['queue'].close() - w['queue']._reader.close() - w['queue']._writer.close() + w["proc"].terminate() + w["queue"].close() + w["queue"]._reader.close() + w["queue"]._writer.close() scheduler.join(2) del scheduler - for proc in [w['proc'] for w in workers]: + for proc in [w["proc"] for w in workers]: proc.join(timeout=2) with ignoring(UnboundLocalError): del worker, w, proc del workers[:] - for fn in glob('_test_worker-*'): + for fn in glob("_test_worker-*"): with ignoring(OSError): shutil.rmtree(fn) @@ -718,7 +759,7 @@ def cluster(nworkers=2, nanny=False, worker_kwargs={}, active_rpc_timeout=1, start = time() while list(ws): sleep(0.01) - assert time() < start + 1, 'Workers still around after one second' + assert time() < start + 1, "Workers still around after one second" @gen.coroutine @@ -758,6 +799,7 @@ def gen_test(timeout=10): def test_foo(): yield ... # use tornado coroutines """ + def _(func): def test_func(): with pristine_loop() as loop: @@ -769,7 +811,9 @@ def test_func(): loop.run_sync(cor, timeout=timeout) finally: loop.stop() + return test_func + return _ @@ -778,25 +822,38 @@ def test_func(): @gen.coroutine -def start_cluster(ncores, scheduler_addr, loop, security=None, - Worker=Worker, scheduler_kwargs={}, worker_kwargs={}): - s = Scheduler(loop=loop, validate=True, security=security, - **scheduler_kwargs) +def start_cluster( + ncores, + scheduler_addr, + loop, + security=None, + Worker=Worker, + scheduler_kwargs={}, + worker_kwargs={}, +): + s = Scheduler(loop=loop, validate=True, security=security, **scheduler_kwargs) done = s.start(scheduler_addr) - workers = [Worker(s.address, ncores=ncore[1], name=i, security=security, - loop=loop, validate=True, - **(merge(worker_kwargs, ncore[2]) - if len(ncore) > 2 - else worker_kwargs)) - for i, ncore in enumerate(ncores)] + workers = [ + Worker( + s.address, + ncores=ncore[1], + name=i, + security=security, + loop=loop, + validate=True, + **(merge(worker_kwargs, ncore[2]) if len(ncore) > 2 else worker_kwargs) + ) + for i, ncore in enumerate(ncores) + ] for w in workers: w.rpc = workers[0].rpc yield [w._start(ncore[0]) for ncore, w in zip(ncores, workers)] start = time() - while (len(s.workers) < len(ncores) or - any(comm.comm is None for comm in s.stream_comms.values())): + while len(s.workers) < len(ncores) or any( + comm.comm is None for comm in s.stream_comms.values() + ): yield gen.sleep(0.01) if time() - start > 5: yield [w._close(timeout=1) for w in workers] @@ -819,12 +876,22 @@ def end_worker(w): s.stop() -def gen_cluster(ncores=[('127.0.0.1', 1), ('127.0.0.1', 2)], - scheduler='127.0.0.1', timeout=10, security=None, - Worker=Worker, client=False, scheduler_kwargs={}, - worker_kwargs={}, client_kwargs={}, active_rpc_timeout=1, - config={}, check_new_threads=True): +def gen_cluster( + ncores=[("127.0.0.1", 1), ("127.0.0.1", 2)], + scheduler="127.0.0.1", + timeout=10, + security=None, + Worker=Worker, + client=False, + scheduler_kwargs={}, + worker_kwargs={}, + client_kwargs={}, + active_rpc_timeout=1, + config={}, + check_new_threads=True, +): from distributed import Client + """ Coroutine test with small cluster @gen_cluster() @@ -835,8 +902,9 @@ def test_foo(scheduler, worker1, worker2): start end """ - worker_kwargs = merge({'memory_limit': TOTAL_MEMORY, 'death_timeout': 5}, - worker_kwargs) + worker_kwargs = merge( + {"memory_limit": TOTAL_MEMORY, "death_timeout": 5}, worker_kwargs + ) def _(func): if not iscoroutinefunction(func): @@ -849,7 +917,7 @@ def test_func(): reset_config() - dask.config.set({'distributed.comm.timeouts.connect': '5s'}) + dask.config.set({"distributed.comm.timeouts.connect": "5s"}) # Restore default logging levels # XXX use pytest hooks/fixtures instead? for name, level in logging_levels.items(): @@ -860,6 +928,7 @@ def test_func(): with pristine_loop() as loop: with check_active_rpc(loop, active_rpc_timeout): + @gen.coroutine def coro(): with dask.config.set(config): @@ -867,11 +936,19 @@ def coro(): for i in range(5): try: s, ws = yield start_cluster( - ncores, scheduler, loop, security=security, - Worker=Worker, scheduler_kwargs=scheduler_kwargs, - worker_kwargs=worker_kwargs) + ncores, + scheduler, + loop, + security=security, + Worker=Worker, + scheduler_kwargs=scheduler_kwargs, + worker_kwargs=worker_kwargs, + ) except Exception as e: - logger.error("Failed to start gen_cluster, retrying", exc_info=True) + logger.error( + "Failed to start gen_cluster, retrying", + exc_info=True, + ) else: workers[:] = ws args = [s] + workers @@ -879,23 +956,30 @@ def coro(): if s is False: raise Exception("Could not start cluster") if client: - c = yield Client(s.address, loop=loop, security=security, - asynchronous=True, **client_kwargs) + c = yield Client( + s.address, + loop=loop, + security=security, + asynchronous=True, + **client_kwargs + ) args = [c] + args try: future = func(*args) if timeout: - future = gen.with_timeout(timedelta(seconds=timeout), - future) + future = gen.with_timeout( + timedelta(seconds=timeout), future + ) result = yield future if s.validate: s.validate_state() finally: if client: - yield c._close(fast=s.status == 'closed') + yield c._close(fast=s.status == "closed") yield end_cluster(s, workers) - yield gen.with_timeout(timedelta(seconds=1), - cleanup_global_workers()) + yield gen.with_timeout( + timedelta(seconds=1), cleanup_global_workers() + ) try: c = yield default_client() @@ -906,10 +990,12 @@ def coro(): raise gen.Return(result) - result = loop.run_sync(coro, timeout=timeout * 2 if timeout else timeout) + result = loop.run_sync( + coro, timeout=timeout * 2 if timeout else timeout + ) for w in workers: - if getattr(w, 'data', None): + if getattr(w, "data", None): try: w.data.clear() except EnvironmentError: @@ -921,24 +1007,28 @@ def coro(): for w in _global_workers: w = w() w._close(report=False, executor_wait=False) - if w.status == 'running': + if w.status == "running": w.close() del _global_workers[:] if PY3 and not WINDOWS and check_new_threads: start = time() while True: - bad = [t for t, v in threading._active.items() - if t not in active_threads_start and - "Threaded" not in v.name and - "watch message" not in v.name and - "TCP-Executor" not in v.name] + bad = [ + t + for t, v in threading._active.items() + if t not in active_threads_start + and "Threaded" not in v.name + and "watch message" not in v.name + and "TCP-Executor" not in v.name + ] if not bad: break else: sleep(0.01) if time() > start + 5: from distributed import profile + tid = bad[0] thread = threading._active[tid] call_stacks = profile.call_stack(sys._current_frames()[tid]) @@ -949,6 +1039,7 @@ def coro(): return result return test_func + return _ @@ -962,7 +1053,7 @@ def raises(func, exc=Exception): def terminate_process(proc): if proc.poll() is None: - if sys.platform.startswith('win'): + if sys.platform.startswith("win"): proc.send_signal(signal.CTRL_BREAK_EVENT) else: proc.send_signal(signal.SIGINT) @@ -981,18 +1072,20 @@ def terminate_process(proc): @contextmanager def popen(args, **kwargs): - kwargs['stdout'] = subprocess.PIPE - kwargs['stderr'] = subprocess.PIPE - if sys.platform.startswith('win'): + kwargs["stdout"] = subprocess.PIPE + kwargs["stderr"] = subprocess.PIPE + if sys.platform.startswith("win"): # Allow using CTRL_C_EVENT / CTRL_BREAK_EVENT - kwargs['creationflags'] = subprocess.CREATE_NEW_PROCESS_GROUP + kwargs["creationflags"] = subprocess.CREATE_NEW_PROCESS_GROUP dump_stdout = False args = list(args) - if sys.platform.startswith('win'): - args[0] = os.path.join(sys.prefix, 'Scripts', args[0]) + if sys.platform.startswith("win"): + args[0] = os.path.join(sys.prefix, "Scripts", args[0]) else: - args[0] = os.path.join(os.environ.get('DESTDIR', '') + sys.prefix, 'bin', args[0]) + args[0] = os.path.join( + os.environ.get("DESTDIR", "") + sys.prefix, "bin", args[0] + ) proc = subprocess.Popen(args, **kwargs) try: yield proc @@ -1007,10 +1100,10 @@ def popen(args, **kwargs): # XXX Also dump stdout if return code != 0 ? out, err = proc.communicate() if dump_stdout: - print('\n\nPrint from stderr\n %s\n=================\n' % args[0][0]) + print("\n\nPrint from stderr\n %s\n=================\n" % args[0][0]) print(err.decode()) - print('\n\nPrint from stdout\n=================\n') + print("\n\nPrint from stdout\n=================\n") print(out.decode()) @@ -1061,7 +1154,7 @@ def has_ipv6(): serv = cli = None try: serv = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) - serv.bind(('::', 0)) + serv.bind(("::", 0)) serv.listen(5) cli = socket.create_connection(serv.getsockname()[:2]) except EnvironmentError: @@ -1076,9 +1169,11 @@ def has_ipv6(): if has_ipv6(): + def requires_ipv6(test_func): return test_func + else: requires_ipv6 = pytest.mark.skip("ipv6 required") @@ -1091,13 +1186,14 @@ def assert_can_connect(addr, timeout=None, connection_args=None): """ if timeout is None: timeout = 0.5 - comm = yield connect(addr, timeout=timeout, - connection_args=connection_args) + comm = yield connect(addr, timeout=timeout, connection_args=connection_args) comm.abort() @gen.coroutine -def assert_cannot_connect(addr, timeout=None, connection_args=None, exception_class=EnvironmentError): +def assert_cannot_connect( + addr, timeout=None, connection_args=None, exception_class=EnvironmentError +): """ Check that it is impossible to connect to the distributed *addr* within the given *timeout*. @@ -1105,43 +1201,46 @@ def assert_cannot_connect(addr, timeout=None, connection_args=None, exception_cl if timeout is None: timeout = 0.5 with pytest.raises(exception_class): - comm = yield connect(addr, timeout=timeout, - connection_args=connection_args) + comm = yield connect(addr, timeout=timeout, connection_args=connection_args) comm.abort() @gen.coroutine -def assert_can_connect_from_everywhere_4_6(port, timeout=None, connection_args=None, protocol='tcp'): +def assert_can_connect_from_everywhere_4_6( + port, timeout=None, connection_args=None, protocol="tcp" +): """ Check that the local *port* is reachable from all IPv4 and IPv6 addresses. """ args = (timeout, connection_args) futures = [ - assert_can_connect('%s://127.0.0.1:%d' % (protocol, port), *args), - assert_can_connect('%s://%s:%d' % (protocol, get_ip(), port), *args), + assert_can_connect("%s://127.0.0.1:%d" % (protocol, port), *args), + assert_can_connect("%s://%s:%d" % (protocol, get_ip(), port), *args), ] if has_ipv6(): futures += [ - assert_can_connect('%s://[::1]:%d' % (protocol, port), *args), - assert_can_connect('%s://[%s]:%d' % (protocol, get_ipv6(), port), *args), + assert_can_connect("%s://[::1]:%d" % (protocol, port), *args), + assert_can_connect("%s://[%s]:%d" % (protocol, get_ipv6(), port), *args), ] yield futures @gen.coroutine -def assert_can_connect_from_everywhere_4(port, timeout=None, connection_args=None, protocol='tcp'): +def assert_can_connect_from_everywhere_4( + port, timeout=None, connection_args=None, protocol="tcp" +): """ Check that the local *port* is reachable from all IPv4 addresses. """ args = (timeout, connection_args) futures = [ - assert_can_connect('%s://127.0.0.1:%d' % (protocol, port), *args), - assert_can_connect('%s://%s:%d' % (protocol, get_ip(), port), *args), + assert_can_connect("%s://127.0.0.1:%d" % (protocol, port), *args), + assert_can_connect("%s://%s:%d" % (protocol, get_ip(), port), *args), ] if has_ipv6(): futures += [ - assert_cannot_connect('%s://[::1]:%d' % (protocol, port), *args), - assert_cannot_connect('%s://[%s]:%d' % (protocol, get_ipv6(), port), *args), + assert_cannot_connect("%s://[::1]:%d" % (protocol, port), *args), + assert_cannot_connect("%s://[%s]:%d" % (protocol, get_ipv6(), port), *args), ] yield futures @@ -1152,17 +1251,13 @@ def assert_can_connect_locally_4(port, timeout=None, connection_args=None): Check that the local *port* is only reachable from local IPv4 addresses. """ args = (timeout, connection_args) - futures = [ - assert_can_connect('tcp://127.0.0.1:%d' % port, *args), - ] - if get_ip() != '127.0.0.1': # No outside IPv4 connectivity? - futures += [ - assert_cannot_connect('tcp://%s:%d' % (get_ip(), port), *args), - ] + futures = [assert_can_connect("tcp://127.0.0.1:%d" % port, *args)] + if get_ip() != "127.0.0.1": # No outside IPv4 connectivity? + futures += [assert_cannot_connect("tcp://%s:%d" % (get_ip(), port), *args)] if has_ipv6(): futures += [ - assert_cannot_connect('tcp://[::1]:%d' % port, *args), - assert_cannot_connect('tcp://[%s]:%d' % (get_ipv6(), port), *args), + assert_cannot_connect("tcp://[::1]:%d" % port, *args), + assert_cannot_connect("tcp://[%s]:%d" % (get_ipv6(), port), *args), ] yield futures @@ -1175,10 +1270,10 @@ def assert_can_connect_from_everywhere_6(port, timeout=None, connection_args=Non assert has_ipv6() args = (timeout, connection_args) futures = [ - assert_cannot_connect('tcp://127.0.0.1:%d' % port, *args), - assert_cannot_connect('tcp://%s:%d' % (get_ip(), port), *args), - assert_can_connect('tcp://[::1]:%d' % port, *args), - assert_can_connect('tcp://[%s]:%d' % (get_ipv6(), port), *args), + assert_cannot_connect("tcp://127.0.0.1:%d" % port, *args), + assert_cannot_connect("tcp://%s:%d" % (get_ip(), port), *args), + assert_can_connect("tcp://[::1]:%d" % port, *args), + assert_can_connect("tcp://[%s]:%d" % (get_ipv6(), port), *args), ] yield futures @@ -1191,14 +1286,12 @@ def assert_can_connect_locally_6(port, timeout=None, connection_args=None): assert has_ipv6() args = (timeout, connection_args) futures = [ - assert_cannot_connect('tcp://127.0.0.1:%d' % port, *args), - assert_cannot_connect('tcp://%s:%d' % (get_ip(), port), *args), - assert_can_connect('tcp://[::1]:%d' % port, *args), + assert_cannot_connect("tcp://127.0.0.1:%d" % port, *args), + assert_cannot_connect("tcp://%s:%d" % (get_ip(), port), *args), + assert_can_connect("tcp://[::1]:%d" % port, *args), ] - if get_ipv6() != '::1': # No outside IPv6 connectivity? - futures += [ - assert_cannot_connect('tcp://[%s]:%d' % (get_ipv6(), port), *args), - ] + if get_ipv6() != "::1": # No outside IPv6 connectivity? + futures += [assert_cannot_connect("tcp://[%s]:%d" % (get_ipv6(), port), *args)] yield futures @@ -1244,6 +1337,7 @@ def new_config(new_config): Temporarily change configuration dictionary. """ from .config import defaults + config = dask.config.config orig_config = config.copy() try: @@ -1275,25 +1369,25 @@ def new_config_file(c): Temporarily change configuration file to match dictionary *c*. """ import yaml - old_file = os.environ.get('DASK_CONFIG') - fd, path = tempfile.mkstemp(prefix='dask-config') + + old_file = os.environ.get("DASK_CONFIG") + fd, path = tempfile.mkstemp(prefix="dask-config") try: - with os.fdopen(fd, 'w') as f: + with os.fdopen(fd, "w") as f: f.write(yaml.dump(c)) - os.environ['DASK_CONFIG'] = path + os.environ["DASK_CONFIG"] = path try: yield finally: if old_file: - os.environ['DASK_CONFIG'] = old_file + os.environ["DASK_CONFIG"] = old_file else: - del os.environ['DASK_CONFIG'] + del os.environ["DASK_CONFIG"] finally: os.remove(path) -certs_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), - 'tests')) +certs_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "tests")) def get_cert(filename): @@ -1309,22 +1403,16 @@ def tls_config(): """ A functional TLS configuration with our test certs. """ - ca_file = get_cert('tls-ca-cert.pem') - keycert = get_cert('tls-key-cert.pem') + ca_file = get_cert("tls-ca-cert.pem") + keycert = get_cert("tls-key-cert.pem") c = { - 'tls': { - 'ca-file': ca_file, - 'client': { - 'cert': keycert, - }, - 'scheduler': { - 'cert': keycert, - }, - 'worker': { - 'cert': keycert, - }, - }, + "tls": { + "ca-file": ca_file, + "client": {"cert": keycert}, + "scheduler": {"cert": keycert}, + "worker": {"cert": keycert}, + } } return c @@ -1335,7 +1423,7 @@ def tls_only_config(): plain TCP communications. """ c = tls_config() - c['require-encryption'] = True + c["require-encryption"] = True return c @@ -1359,20 +1447,20 @@ def tls_only_security(): return sec -def get_server_ssl_context(certfile='tls-cert.pem', keyfile='tls-key.pem', - ca_file='tls-ca-cert.pem'): - ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH, - cafile=get_cert(ca_file)) +def get_server_ssl_context( + certfile="tls-cert.pem", keyfile="tls-key.pem", ca_file="tls-ca-cert.pem" +): + ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH, cafile=get_cert(ca_file)) ctx.check_hostname = False ctx.verify_mode = ssl.CERT_REQUIRED ctx.load_cert_chain(get_cert(certfile), get_cert(keyfile)) return ctx -def get_client_ssl_context(certfile='tls-cert.pem', keyfile='tls-key.pem', - ca_file='tls-ca-cert.pem'): - ctx = ssl.create_default_context(ssl.Purpose.SERVER_AUTH, - cafile=get_cert(ca_file)) +def get_client_ssl_context( + certfile="tls-cert.pem", keyfile="tls-key.pem", ca_file="tls-ca-cert.pem" +): + ctx = ssl.create_default_context(ssl.Purpose.SERVER_AUTH, cafile=get_cert(ca_file)) ctx.check_hostname = False ctx.verify_mode = ssl.CERT_REQUIRED ctx.load_cert_chain(get_cert(certfile), get_cert(keyfile)) @@ -1380,18 +1468,17 @@ def get_client_ssl_context(certfile='tls-cert.pem', keyfile='tls-key.pem', def bump_rlimit(limit, desired): - resource = pytest.importorskip('resource') + resource = pytest.importorskip("resource") try: soft, hard = resource.getrlimit(limit) if soft < desired: - resource.setrlimit(limit, - (desired, max(hard, desired))) + resource.setrlimit(limit, (desired, max(hard, desired))) except Exception as e: - pytest.skip("rlimit too low (%s) and can't be increased: %s" - % (soft, e)) + pytest.skip("rlimit too low (%s) and can't be increased: %s" % (soft, e)) def gen_tls_cluster(**kwargs): - kwargs.setdefault('ncores', [('tls://127.0.0.1', 1), ('tls://127.0.0.1', 2)]) - return gen_cluster(scheduler='tls://127.0.0.1', - security=tls_only_security(), **kwargs) + kwargs.setdefault("ncores", [("tls://127.0.0.1", 1), ("tls://127.0.0.1", 2)]) + return gen_cluster( + scheduler="tls://127.0.0.1", security=tls_only_security(), **kwargs + ) diff --git a/distributed/variable.py b/distributed/variable.py index 5d905358a9e..7b775d3327a 100644 --- a/distributed/variable.py +++ b/distributed/variable.py @@ -37,27 +37,28 @@ def __init__(self, scheduler): self.waiting_conditions = defaultdict(tornado.locks.Condition) self.started = tornado.locks.Condition() - self.scheduler.handlers.update({'variable_set': self.set, - 'variable_get': self.get}) + self.scheduler.handlers.update( + {"variable_set": self.set, "variable_get": self.get} + ) - self.scheduler.stream_handlers['variable-future-release'] = self.future_release - self.scheduler.stream_handlers['variable_delete'] = self.delete + self.scheduler.stream_handlers["variable-future-release"] = self.future_release + self.scheduler.stream_handlers["variable_delete"] = self.delete - self.scheduler.extensions['variables'] = self + self.scheduler.extensions["variables"] = self def set(self, stream=None, name=None, key=None, data=None, client=None): if key is not None: - record = {'type': 'Future', 'value': key} - self.scheduler.client_desires_keys(keys=[key], client='variable-%s' % name) + record = {"type": "Future", "value": key} + self.scheduler.client_desires_keys(keys=[key], client="variable-%s" % name) else: - record = {'type': 'msgpack', 'value': data} + record = {"type": "msgpack", "value": data} try: old = self.variables[name] except KeyError: pass else: - if old['type'] == 'Future' and old['value'] != key: - self.release(old['value'], name) + if old["type"] == "Future" and old["value"] != key: + self.release(old["value"], name) if name not in self.variables: self.started.notify_all() self.variables[name] = record @@ -67,8 +68,7 @@ def release(self, key, name): while self.waiting[key, name]: yield self.waiting_conditions[name].wait() - self.scheduler.client_releases_keys(keys=[key], - client='variable-%s' % name) + self.scheduler.client_releases_keys(keys=[key], client="variable-%s" % name) del self.waiting[key, name] def future_release(self, name=None, key=None, token=None, client=None): @@ -88,15 +88,15 @@ def get(self, stream=None, name=None, client=None, timeout=None): raise gen.TimeoutError() yield self.started.wait(timeout=left) record = self.variables[name] - if record['type'] == 'Future': - key = record['value'] + if record["type"] == "Future": + key = record["value"] token = uuid.uuid4().hex ts = self.scheduler.tasks.get(key) - state = ts.state if ts is not None else 'lost' - msg = {'token': token, 'state': state} - if state == 'erred': - msg['exception'] = ts.exception_blame.exception - msg['traceback'] = ts.exception_blame.traceback + state = ts.state if ts is not None else "lost" + msg = {"token": token, "state": state} + if state == "erred": + msg["exception"] = ts.exception_blame.exception + msg["traceback"] = ts.exception_blame.traceback record = merge(record, msg) self.waiting[key, name].add(token) raise gen.Return(record) @@ -109,8 +109,8 @@ def delete(self, stream=None, name=None, client=None): except KeyError: pass else: - if old['type'] == 'Future': - yield self.release(old['value'], name) + if old["type"] == "Future": + yield self.release(old["value"], name) del self.waiting_conditions[name] del self.variables[name] @@ -149,16 +149,16 @@ class Variable(object): def __init__(self, name=None, client=None, maxsize=0): self.client = client or _get_global_client() - self.name = name or 'variable-' + uuid.uuid4().hex + self.name = name or "variable-" + uuid.uuid4().hex @gen.coroutine def _set(self, value): if isinstance(value, Future): - yield self.client.scheduler.variable_set(key=tokey(value.key), - name=self.name) + yield self.client.scheduler.variable_set( + key=tokey(value.key), name=self.name + ) else: - yield self.client.scheduler.variable_set(data=value, - name=self.name) + yield self.client.scheduler.variable_set(data=value, name=self.name) def set(self, value, **kwargs): """ Set the value of this variable @@ -172,19 +172,23 @@ def set(self, value, **kwargs): @gen.coroutine def _get(self, timeout=None): - d = yield self.client.scheduler.variable_get(timeout=timeout, - name=self.name, - client=self.client.id) - if d['type'] == 'Future': - value = Future(d['value'], self.client, inform=True, state=d['state']) - if d['state'] == 'erred': - value._state.set_error(d['exception'], d['traceback']) - self.client._send_to_scheduler({'op': 'variable-future-release', - 'name': self.name, - 'key': d['value'], - 'token': d['token']}) + d = yield self.client.scheduler.variable_get( + timeout=timeout, name=self.name, client=self.client.id + ) + if d["type"] == "Future": + value = Future(d["value"], self.client, inform=True, state=d["state"]) + if d["state"] == "erred": + value._state.set_error(d["exception"], d["traceback"]) + self.client._send_to_scheduler( + { + "op": "variable-future-release", + "name": self.name, + "key": d["value"], + "token": d["token"], + } + ) else: - value = d['value'] + value = d["value"] raise gen.Return(value) def get(self, timeout=None, **kwargs): @@ -196,9 +200,8 @@ def delete(self): Caution, this affects all clients currently pointing to this variable. """ - if self.client.status == 'running': # TODO: can leave zombie futures - self.client._send_to_scheduler({'op': 'variable_delete', - 'name': self.name}) + if self.client.status == "running": # TODO: can leave zombie futures + self.client._send_to_scheduler({"op": "variable_delete", "name": self.name}) def __getstate__(self): return (self.name, self.client.scheduler.address) diff --git a/distributed/versions.py b/distributed/versions.py index fa7bbc0835a..2baa47a1d8f 100644 --- a/distributed/versions.py +++ b/distributed/versions.py @@ -12,19 +12,23 @@ from .utils import ignoring -required_packages = [('dask', lambda p: p.__version__), - ('distributed', lambda p: p.__version__), - ('msgpack', lambda p: '.'.join([str(v) for v in p.version])), - ('cloudpickle', lambda p: p.__version__), - ('tornado', lambda p: p.version), - ('toolz', lambda p: p.__version__)] - -optional_packages = [('numpy', lambda p: p.__version__), - ('pandas', lambda p: p.__version__), - ('bokeh', lambda p: p.__version__), - ('lz4', lambda p: p.__version__), - ('dask_ml', lambda p: p.__version__), - ('blosc', lambda p: p.__version__)] +required_packages = [ + ("dask", lambda p: p.__version__), + ("distributed", lambda p: p.__version__), + ("msgpack", lambda p: ".".join([str(v) for v in p.version])), + ("cloudpickle", lambda p: p.__version__), + ("tornado", lambda p: p.version), + ("toolz", lambda p: p.__version__), +] + +optional_packages = [ + ("numpy", lambda p: p.__version__), + ("pandas", lambda p: p.__version__), + ("bokeh", lambda p: p.__version__), + ("lz4", lambda p: p.__version__), + ("dask_ml", lambda p: p.__version__), + ("blosc", lambda p: p.__version__), +] def get_versions(packages=None): @@ -34,27 +38,30 @@ def get_versions(packages=None): if packages is None: packages = [] - d = {'host': get_system_info(), - 'packages': {'required': get_package_info(required_packages), - 'optional': get_package_info(optional_packages + list(packages))} - } + d = { + "host": get_system_info(), + "packages": { + "required": get_package_info(required_packages), + "optional": get_package_info(optional_packages + list(packages)), + }, + } return d def get_system_info(): - (sysname, nodename, release, - version, machine, processor) = platform.uname() - host = [("python", "%d.%d.%d.%s.%s" % sys.version_info[:]), - ("python-bits", struct.calcsize("P") * 8), - ("OS", "%s" % (sysname)), - ("OS-release", "%s" % (release)), - ("machine", "%s" % (machine)), - ("processor", "%s" % (processor)), - ("byteorder", "%s" % sys.byteorder), - ("LC_ALL", "%s" % os.environ.get('LC_ALL', "None")), - ("LANG", "%s" % os.environ.get('LANG', "None")), - ("LOCALE", "%s.%s" % locale.getlocale()), - ] + (sysname, nodename, release, version, machine, processor) = platform.uname() + host = [ + ("python", "%d.%d.%d.%s.%s" % sys.version_info[:]), + ("python-bits", struct.calcsize("P") * 8), + ("OS", "%s" % (sysname)), + ("OS-release", "%s" % (release)), + ("machine", "%s" % (machine)), + ("processor", "%s" % (processor)), + ("byteorder", "%s" % sys.byteorder), + ("LC_ALL", "%s" % os.environ.get("LC_ALL", "None")), + ("LANG", "%s" % os.environ.get("LANG", "None")), + ("LOCALE", "%s.%s" % locale.getlocale()), + ] return host @@ -66,7 +73,7 @@ def version_of_package(pkg): with ignoring(AttributeError): return str(pkg.version) with ignoring(AttributeError): - return '.'.join(map(str, pkg.version_info)) + return ".".join(map(str, pkg.version_info)) return None diff --git a/distributed/worker.py b/distributed/worker.py index c7720888f4f..b9ed6c5a59d 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -16,6 +16,7 @@ import dask from dask.core import istask from dask.compatibility import apply + try: from cytoolz import pluck, partial, merge except ImportError: @@ -29,27 +30,39 @@ from .batched import BatchedSend from .comm import get_address_host, get_local_address_for, connect from .comm.utils import offload -from .compatibility import (unicode, get_thread_identity, finalize, - MutableMapping) -from .core import (error_message, CommClosedError, send_recv, - pingpong, coerce_to_address) +from .compatibility import unicode, get_thread_identity, finalize, MutableMapping +from .core import error_message, CommClosedError, send_recv, pingpong, coerce_to_address from .diskutils import WorkSpace from .metrics import time from .node import ServerNode from .preloading import preload_modules from .proctitle import setproctitle -from .protocol import (pickle, to_serialize, deserialize_bytes, - serialize_bytelist) +from .protocol import pickle, to_serialize, deserialize_bytes, serialize_bytelist from .pubsub import PubSubWorkerExtension from .security import Security from .sizeof import safe_sizeof as sizeof from .threadpoolexecutor import ThreadPoolExecutor, secede as tpe_secede -from .utils import (funcname, get_ip, has_arg, _maybe_complex, log_errors, - ignoring, mp_context, import_file, - silence_logging, thread_state, json_load_robust, key_split, - format_bytes, DequeHandler, PeriodicCallback, - parse_bytes, parse_timedelta, iscoroutinefunction, - warn_on_duration) +from .utils import ( + funcname, + get_ip, + has_arg, + _maybe_complex, + log_errors, + ignoring, + mp_context, + import_file, + silence_logging, + thread_state, + json_load_robust, + key_split, + format_bytes, + DequeHandler, + PeriodicCallback, + parse_bytes, + parse_timedelta, + iscoroutinefunction, + warn_on_duration, +) from .utils_comm import pack_data, gather_from_workers from .utils_perf import ThrottledGC, enable_gc_diagnosis, disable_gc_diagnosis @@ -57,12 +70,13 @@ logger = logging.getLogger(__name__) -LOG_PDB = dask.config.get('distributed.admin.pdb-on-err') +LOG_PDB = dask.config.get("distributed.admin.pdb-on-err") -no_value = '--no-value-sentinel--' +no_value = "--no-value-sentinel--" try: import psutil + TOTAL_MEMORY = psutil.virtual_memory().total except ImportError: logger.warning("Please install psutil to estimate worker memory use") @@ -70,15 +84,13 @@ psutil = None -IN_PLAY = ('waiting', 'ready', 'executing', 'long-running') -PENDING = ('waiting', 'ready', 'constrained') -PROCESSING = ('waiting', 'ready', 'constrained', 'executing', 'long-running') -READY = ('ready', 'constrained') +IN_PLAY = ("waiting", "ready", "executing", "long-running") +PENDING = ("waiting", "ready", "constrained") +PROCESSING = ("waiting", "ready", "constrained", "executing", "long-running") +READY = ("ready", "constrained") -DEFAULT_EXTENSIONS = [ - PubSubWorkerExtension, -] +DEFAULT_EXTENSIONS = [PubSubWorkerExtension] _global_workers = [] @@ -257,14 +269,33 @@ class Worker(ServerNode): distributed.nanny.Nanny """ - def __init__(self, scheduler_ip=None, scheduler_port=None, - scheduler_file=None, ncores=None, loop=None, local_dir='dask-worker-space', - services=None, service_ports=None, name=None, - reconnect=True, memory_limit='auto', - executor=None, resources=None, silence_logs=None, - death_timeout=None, preload=None, preload_argv=None, security=None, - contact_address=None, memory_monitor_interval='200ms', - extensions=None, metrics=None, data=None, **kwargs): + def __init__( + self, + scheduler_ip=None, + scheduler_port=None, + scheduler_file=None, + ncores=None, + loop=None, + local_dir="dask-worker-space", + services=None, + service_ports=None, + name=None, + reconnect=True, + memory_limit="auto", + executor=None, + resources=None, + silence_logs=None, + death_timeout=None, + preload=None, + preload_argv=None, + security=None, + contact_address=None, + memory_monitor_interval="200ms", + extensions=None, + metrics=None, + data=None, + **kwargs + ): self.tasks = dict() self.task_state = dict() self.dep_state = dict() @@ -280,8 +311,12 @@ def __init__(self, scheduler_ip=None, scheduler_port=None, self.in_flight_tasks = dict() self.in_flight_workers = dict() - self.total_out_connections = dask.config.get('distributed.worker.connections.outgoing') - self.total_in_connections = dask.config.get('distributed.worker.connections.incoming') + self.total_out_connections = dask.config.get( + "distributed.worker.connections.outgoing" + ) + self.total_in_connections = dask.config.get( + "distributed.worker.connections.incoming" + ) self.total_comm_nbytes = 10e6 self.comm_nbytes = 0 self.suspicious_deps = defaultdict(lambda: 0) @@ -313,33 +348,35 @@ def __init__(self, scheduler_ip=None, scheduler_port=None, self.long_running = set() self.batched_stream = None - self.recent_messages_log = deque(maxlen=dask.config.get('distributed.comm.recent-messages-log-length')) + self.recent_messages_log = deque( + maxlen=dask.config.get("distributed.comm.recent-messages-log-length") + ) self.target_message_size = 50e6 # 50 MB self.log = deque(maxlen=100000) - self.validate = kwargs.pop('validate', False) + self.validate = kwargs.pop("validate", False) self._transitions = { - ('waiting', 'ready'): self.transition_waiting_ready, - ('waiting', 'memory'): self.transition_waiting_done, - ('waiting', 'error'): self.transition_waiting_done, - ('ready', 'executing'): self.transition_ready_executing, - ('ready', 'memory'): self.transition_ready_memory, - ('constrained', 'executing'): self.transition_constrained_executing, - ('executing', 'memory'): self.transition_executing_done, - ('executing', 'error'): self.transition_executing_done, - ('executing', 'rescheduled'): self.transition_executing_done, - ('executing', 'long-running'): self.transition_executing_long_running, - ('long-running', 'error'): self.transition_executing_done, - ('long-running', 'memory'): self.transition_executing_done, - ('long-running', 'rescheduled'): self.transition_executing_done, + ("waiting", "ready"): self.transition_waiting_ready, + ("waiting", "memory"): self.transition_waiting_done, + ("waiting", "error"): self.transition_waiting_done, + ("ready", "executing"): self.transition_ready_executing, + ("ready", "memory"): self.transition_ready_memory, + ("constrained", "executing"): self.transition_constrained_executing, + ("executing", "memory"): self.transition_executing_done, + ("executing", "error"): self.transition_executing_done, + ("executing", "rescheduled"): self.transition_executing_done, + ("executing", "long-running"): self.transition_executing_long_running, + ("long-running", "error"): self.transition_executing_done, + ("long-running", "memory"): self.transition_executing_done, + ("long-running", "rescheduled"): self.transition_executing_done, } self._dep_transitions = { - ('waiting', 'flight'): self.transition_dep_waiting_flight, - ('waiting', 'memory'): self.transition_dep_waiting_memory, - ('flight', 'waiting'): self.transition_dep_flight_waiting, - ('flight', 'memory'): self.transition_dep_flight_memory, + ("waiting", "flight"): self.transition_dep_waiting_flight, + ("waiting", "memory"): self.transition_dep_waiting_memory, + ("flight", "waiting"): self.transition_dep_flight_waiting, + ("flight", "memory"): self.transition_dep_flight_memory, } self.incoming_transfer_log = deque(maxlen=(100000)) @@ -350,17 +387,19 @@ def __init__(self, scheduler_ip=None, scheduler_port=None, self.repetitively_busy = 0 self._client = None - profile_cycle_interval = kwargs.pop('profile_cycle_interval', - dask.config.get('distributed.worker.profile.cycle')) - profile_cycle_interval = parse_timedelta(profile_cycle_interval, default='ms') + profile_cycle_interval = kwargs.pop( + "profile_cycle_interval", + dask.config.get("distributed.worker.profile.cycle"), + ) + profile_cycle_interval = parse_timedelta(profile_cycle_interval, default="ms") self._setup_logging() if scheduler_file: cfg = json_load_robust(scheduler_file) - scheduler_addr = cfg['address'] - elif scheduler_ip is None and dask.config.get('scheduler-address', None): - scheduler_addr = dask.config.get('scheduler-address') + scheduler_addr = cfg["address"] + elif scheduler_ip is None and dask.config.get("scheduler-address", None): + scheduler_addr = dask.config.get("scheduler-address") elif scheduler_port is None: scheduler_addr = coerce_to_address(scheduler_ip) else: @@ -372,48 +411,56 @@ def __init__(self, scheduler_ip=None, scheduler_port=None, self.death_timeout = death_timeout self.preload = preload if self.preload is None: - self.preload = dask.config.get('distributed.worker.preload') + self.preload = dask.config.get("distributed.worker.preload") self.preload_argv = preload_argv if self.preload_argv is None: - self.preload_argv = dask.config.get('distributed.worker.preload-argv') + self.preload_argv = dask.config.get("distributed.worker.preload-argv") self.contact_address = contact_address - self.memory_monitor_interval = parse_timedelta(memory_monitor_interval, default='ms') + self.memory_monitor_interval = parse_timedelta( + memory_monitor_interval, default="ms" + ) self.extensions = dict() if silence_logs: silence_logging(level=silence_logs) with warn_on_duration( - '1s', + "1s", "Creating scratch directories is taking a surprisingly long time. " "This is often due to running workers on a network file system. " "Consider specifying a local-directory to point workers to write " - "scratch data to a local disk." + "scratch data to a local disk.", ): self._workspace = WorkSpace(os.path.abspath(local_dir)) - self._workdir = self._workspace.new_work_dir(prefix='worker-') + self._workdir = self._workspace.new_work_dir(prefix="worker-") self.local_dir = self._workdir.dir_path self.security = security or Security() assert isinstance(self.security, Security) - self.connection_args = self.security.get_connection_args('worker') - self.listen_args = self.security.get_listen_args('worker') + self.connection_args = self.security.get_connection_args("worker") + self.listen_args = self.security.get_listen_args("worker") self.memory_limit = parse_memory_limit(memory_limit, self.ncores) self.paused = False - if 'memory_target_fraction' in kwargs: - self.memory_target_fraction = kwargs.pop('memory_target_fraction') + if "memory_target_fraction" in kwargs: + self.memory_target_fraction = kwargs.pop("memory_target_fraction") else: - self.memory_target_fraction = dask.config.get('distributed.worker.memory.target') - if 'memory_spill_fraction' in kwargs: - self.memory_spill_fraction = kwargs.pop('memory_spill_fraction') + self.memory_target_fraction = dask.config.get( + "distributed.worker.memory.target" + ) + if "memory_spill_fraction" in kwargs: + self.memory_spill_fraction = kwargs.pop("memory_spill_fraction") else: - self.memory_spill_fraction = dask.config.get('distributed.worker.memory.spill') - if 'memory_pause_fraction' in kwargs: - self.memory_pause_fraction = kwargs.pop('memory_pause_fraction') + self.memory_spill_fraction = dask.config.get( + "distributed.worker.memory.spill" + ) + if "memory_pause_fraction" in kwargs: + self.memory_pause_fraction = kwargs.pop("memory_pause_fraction") else: - self.memory_pause_fraction = dask.config.get('distributed.worker.memory.pause') + self.memory_pause_fraction = dask.config.get( + "distributed.worker.memory.pause" + ) if isinstance(data, MutableMapping): self.data = data @@ -421,17 +468,19 @@ def __init__(self, scheduler_ip=None, scheduler_port=None, self.data = data() elif isinstance(data, tuple): self.data = data[0](**data[1]) - elif (self.memory_limit and - (self.memory_target_fraction or - self.memory_spill_fraction)): + elif self.memory_limit and ( + self.memory_target_fraction or self.memory_spill_fraction + ): try: from zict import Buffer, File, Func except ImportError: raise ImportError("Please `pip install zict` for spill-to-disk workers") - path = os.path.join(self.local_dir, 'storage') - storage = Func(partial(serialize_bytelist, on_error='raise'), - deserialize_bytes, - File(path)) + path = os.path.join(self.local_dir, "storage") + storage = Func( + partial(serialize_bytelist, on_error="raise"), + deserialize_bytes, + File(path), + ) target = int(float(self.memory_limit) * self.memory_target_fraction) self.data = Buffer({}, storage, target, weight) else: @@ -442,8 +491,12 @@ def __init__(self, scheduler_ip=None, scheduler_port=None, self.status = None self._closed = Event() self.reconnect = reconnect - self.executor = executor or ThreadPoolExecutor(self.ncores, thread_name_prefix="Dask-Worker-Threads'") - self.actor_executor = ThreadPoolExecutor(1, thread_name_prefix="Dask-Actor-Threads") + self.executor = executor or ThreadPoolExecutor( + self.ncores, thread_name_prefix="Dask-Worker-Threads'" + ) + self.actor_executor = ThreadPoolExecutor( + 1, thread_name_prefix="Dask-Actor-Threads" + ) self.name = name self.scheduler_delay = 0 self.stream_comms = dict() @@ -459,56 +512,61 @@ def __init__(self, scheduler_ip=None, scheduler_port=None, self.metrics = dict(metrics) if metrics else {} handlers = { - 'gather': self.gather, - 'run': self.run, - 'run_coroutine': self.run_coroutine, - 'get_data': self.get_data, - 'update_data': self.update_data, - 'delete_data': self.delete_data, - 'terminate': self.terminate, - 'ping': pingpong, - 'upload_file': self.upload_file, - 'start_ipython': self.start_ipython, - 'call_stack': self.get_call_stack, - 'profile': self.get_profile, - 'profile_metadata': self.get_profile_metadata, - 'get_logs': self.get_logs, - 'keys': self.keys, - 'versions': self.versions, - 'actor_execute': self.actor_execute, - 'actor_attribute': self.actor_attribute, + "gather": self.gather, + "run": self.run, + "run_coroutine": self.run_coroutine, + "get_data": self.get_data, + "update_data": self.update_data, + "delete_data": self.delete_data, + "terminate": self.terminate, + "ping": pingpong, + "upload_file": self.upload_file, + "start_ipython": self.start_ipython, + "call_stack": self.get_call_stack, + "profile": self.get_profile, + "profile_metadata": self.get_profile_metadata, + "get_logs": self.get_logs, + "keys": self.keys, + "versions": self.versions, + "actor_execute": self.actor_execute, + "actor_attribute": self.actor_attribute, } stream_handlers = { - 'close': self._close, - 'compute-task': self.add_task, - 'release-task': partial(self.release_key, report=False), - 'delete-data': self.delete_data, - 'steal-request': self.steal_request, + "close": self._close, + "compute-task": self.add_task, + "release-task": partial(self.release_key, report=False), + "delete-data": self.delete_data, + "steal-request": self.steal_request, } super(Worker, self).__init__( - handlers=handlers, - stream_handlers=stream_handlers, - io_loop=self.loop, - connection_args=self.connection_args, - **kwargs) + handlers=handlers, + stream_handlers=stream_handlers, + io_loop=self.loop, + connection_args=self.connection_args, + **kwargs + ) self.scheduler = self.rpc(scheduler_addr) - self.execution_state = {'scheduler': self.scheduler.address, - 'ioloop': self.loop, - 'worker': self} + self.execution_state = { + "scheduler": self.scheduler.address, + "ioloop": self.loop, + "worker": self, + } pc = PeriodicCallback(self.heartbeat, 1000, io_loop=self.io_loop) - self.periodic_callbacks['heartbeat'] = pc + self.periodic_callbacks["heartbeat"] = pc self._address = contact_address if self.memory_limit: self._memory_monitoring = False - pc = PeriodicCallback(self.memory_monitor, - self.memory_monitor_interval * 1000, - io_loop=self.io_loop) - self.periodic_callbacks['memory'] = pc + pc = PeriodicCallback( + self.memory_monitor, + self.memory_monitor_interval * 1000, + io_loop=self.io_loop, + ) + self.periodic_callbacks["memory"] = pc if extensions is None: extensions = DEFAULT_EXTENSIONS @@ -520,16 +578,19 @@ def __init__(self, scheduler_ip=None, scheduler_port=None, setproctitle("dask-worker [not started]") pc = PeriodicCallback( - self.trigger_profile, - parse_timedelta(dask.config.get('distributed.worker.profile.interval'), default='ms') * 1000, - io_loop=self.io_loop + self.trigger_profile, + parse_timedelta( + dask.config.get("distributed.worker.profile.interval"), default="ms" + ) + * 1000, + io_loop=self.io_loop, ) - self.periodic_callbacks['profile'] = pc + self.periodic_callbacks["profile"] = pc - pc = PeriodicCallback(self.cycle_profile, - profile_cycle_interval * 1000, - io_loop=self.io_loop) - self.periodic_callbacks['profile-cycle'] = pc + pc = PeriodicCallback( + self.cycle_profile, profile_cycle_interval * 1000, io_loop=self.io_loop + ) + self.periodic_callbacks["profile-cycle"] = pc _global_workers.append(weakref.ref(self)) @@ -538,15 +599,28 @@ def __init__(self, scheduler_ip=None, scheduler_port=None, ################## def __repr__(self): - return "<%s: %s, %s, stored: %d, running: %d/%d, ready: %d, comm: %d, waiting: %d>" % ( - self.__class__.__name__, self.address, self.status, - len(self.data), len(self.executing), self.ncores, - len(self.ready), len(self.in_flight_tasks), - len(self.waiting_for_data)) + return ( + "<%s: %s, %s, stored: %d, running: %d/%d, ready: %d, comm: %d, waiting: %d>" + % ( + self.__class__.__name__, + self.address, + self.status, + len(self.data), + len(self.executing), + self.ncores, + len(self.ready), + len(self.in_flight_tasks), + len(self.waiting_for_data), + ) + ) def _setup_logging(self): - self._deque_handler = DequeHandler(n=dask.config.get('distributed.admin.log-length')) - self._deque_handler.setFormatter(logging.Formatter(dask.config.get('distributed.admin.log-format'))) + self._deque_handler = DequeHandler( + n=dask.config.get("distributed.admin.log-length") + ) + self._deque_handler.setFormatter( + logging.Formatter(dask.config.get("distributed.admin.log-format")) + ) logger.addHandler(self._deque_handler) finalize(self, logger.removeHandler, self._deque_handler) @@ -556,20 +630,24 @@ def worker_address(self): return self.address def get_metrics(self): - core = dict(executing=len(self.executing), - in_memory=len(self.data), - ready=len(self.ready), - in_flight=len(self.in_flight_tasks)) + core = dict( + executing=len(self.executing), + in_memory=len(self.data), + ready=len(self.ready), + in_flight=len(self.in_flight_tasks), + ) custom = {k: metric(self) for k, metric in self.metrics.items()} return merge(custom, self.monitor.recent(), core) def identity(self, comm=None): - return {'type': type(self).__name__, - 'id': self.id, - 'scheduler': self.scheduler.address, - 'ncores': self.ncores, - 'memory_limit': self.memory_limit} + return { + "type": type(self).__name__, + "id": self.id, + "scheduler": self.scheduler.address, + "ncores": self.ncores, + "memory_limit": self.memory_limit, + } ##################### # External Services # @@ -577,37 +655,42 @@ def identity(self, comm=None): @gen.coroutine def _register_with_scheduler(self): - self.periodic_callbacks['heartbeat'].stop() + self.periodic_callbacks["heartbeat"].stop() start = time() if self.contact_address is None: self.contact_address = self.address - logger.info('-' * 49) + logger.info("-" * 49) while True: if self.death_timeout and time() > start + self.death_timeout: yield self._close(timeout=1) return - if self.status in ('closed', 'closing'): + if self.status in ("closed", "closing"): raise gen.Return try: _start = time() - comm = yield connect(self.scheduler.address, - connection_args=self.connection_args) - yield comm.write(dict(op='register-worker', - reply=False, - address=self.contact_address, - keys=list(self.data), - ncores=self.ncores, - name=self.name, - nbytes=self.nbytes, - now=time(), - resources=self.total_resources, - memory_limit=self.memory_limit, - local_directory=self.local_dir, - services=self.service_ports, - pid=os.getpid(), - metrics=self.get_metrics()), - serializers=['msgpack']) - future = comm.read(deserializers=['msgpack']) + comm = yield connect( + self.scheduler.address, connection_args=self.connection_args + ) + yield comm.write( + dict( + op="register-worker", + reply=False, + address=self.contact_address, + keys=list(self.data), + ncores=self.ncores, + name=self.name, + nbytes=self.nbytes, + now=time(), + resources=self.total_resources, + memory_limit=self.memory_limit, + local_directory=self.local_dir, + services=self.service_ports, + pid=os.getpid(), + metrics=self.get_metrics(), + ), + serializers=["msgpack"], + ) + future = comm.read(deserializers=["msgpack"]) if self.death_timeout: diff = self.death_timeout - (time() - start) if diff < 0: @@ -616,33 +699,34 @@ def _register_with_scheduler(self): response = yield future _end = time() middle = (_start + _end) / 2 - self.scheduler_delay = response['time'] - middle - self.status = 'running' + self.scheduler_delay = response["time"] - middle + self.status = "running" break except EnvironmentError: - logger.info('Waiting to connect to: %26s', self.scheduler.address) + logger.info("Waiting to connect to: %26s", self.scheduler.address) yield gen.sleep(0.1) except gen.TimeoutError: logger.info("Timed out when connecting to scheduler") - if response['status'] != 'OK': - raise ValueError("Unexpected response from register: %r" % - (response,)) + if response["status"] != "OK": + raise ValueError("Unexpected response from register: %r" % (response,)) else: # Retrieve eventual init functions and run them - for function_bytes in response['worker-setups']: + for function_bytes in response["worker-setups"]: setup_function = pickle.loads(function_bytes) - if has_arg(setup_function, 'dask_worker'): + if has_arg(setup_function, "dask_worker"): result = setup_function(dask_worker=self) else: result = setup_function() - logger.info('Init function %s ran: output=%s' % (setup_function, result)) + logger.info( + "Init function %s ran: output=%s" % (setup_function, result) + ) - logger.info(' Registered to: %26s', self.scheduler.address) - logger.info('-' * 49) + logger.info(" Registered to: %26s", self.scheduler.address) + logger.info("-" * 49) - self.batched_stream = BatchedSend(interval='2ms', loop=self.loop) + self.batched_stream = BatchedSend(interval="2ms", loop=self.loop) self.batched_stream.start(comm) - self.periodic_callbacks['heartbeat'].start() + self.periodic_callbacks["heartbeat"].start() self.loop.add_callback(self.handle_scheduler, comm) @gen.coroutine @@ -653,18 +737,18 @@ def heartbeat(self): try: start = time() response = yield self.scheduler.heartbeat_worker( - address=self.contact_address, - now=time(), - metrics=self.get_metrics() + address=self.contact_address, now=time(), metrics=self.get_metrics() ) end = time() middle = (start + end) / 2 - if response['status'] == 'missing': + if response["status"] == "missing": yield self._register_with_scheduler() return - self.scheduler_delay = response['time'] - middle - self.periodic_callbacks['heartbeat'].callback_time = response['heartbeat-interval'] * 1000 + self.scheduler_delay = response["time"] - middle + self.periodic_callbacks["heartbeat"].callback_time = ( + response["heartbeat-interval"] * 1000 + ) except CommClosedError: logger.warning("Heartbeat to scheduler failed") finally: @@ -675,8 +759,9 @@ def heartbeat(self): @gen.coroutine def handle_scheduler(self, comm): try: - yield self.handle_stream(comm, every_cycle=[self.ensure_communicating, - self.ensure_computing]) + yield self.handle_stream( + comm, every_cycle=[self.ensure_communicating, self.ensure_computing] + ) except Exception as e: logger.exception(e) raise @@ -693,11 +778,10 @@ def start_ipython(self, comm): Returns Jupyter connection info dictionary. """ from ._ipython_utils import start_ipython + if self._ipython_kernel is None: self._ipython_kernel = start_ipython( - ip=self.ip, - ns={'worker': self}, - log=logger, + ip=self.ip, ns={"worker": self}, log=logger ) return self._ipython_kernel.get_connection_info() @@ -708,7 +792,7 @@ def upload_file(self, comm, filename=None, data=None, load=True): def func(data): if isinstance(data, unicode): data = data.encode() - with open(out_filename, 'wb') as f: + with open(out_filename, "wb") as f: f.write(data) f.flush() return data @@ -723,29 +807,34 @@ def func(data): import_file(out_filename) except Exception as e: logger.exception(e) - raise gen.Return({'status': 'error', - 'exception': to_serialize(e)}) + raise gen.Return({"status": "error", "exception": to_serialize(e)}) - raise gen.Return({'status': 'OK', 'nbytes': len(data)}) + raise gen.Return({"status": "OK", "nbytes": len(data)}) def keys(self, comm=None): return list(self.data) @gen.coroutine def gather(self, comm=None, who_has=None): - who_has = {k: [coerce_to_address(addr) for addr in v] - for k, v in who_has.items() - if k not in self.data} + who_has = { + k: [coerce_to_address(addr) for addr in v] + for k, v in who_has.items() + if k not in self.data + } result, missing_keys, missing_workers = yield gather_from_workers( - who_has, rpc=self.rpc, who=self.address) + who_has, rpc=self.rpc, who=self.address + ) if missing_keys: - logger.warning("Could not find data: %s on workers: %s (who_has: %s)", - missing_keys, missing_workers, who_has) - raise Return({'status': 'missing-data', - 'keys': missing_keys}) + logger.warning( + "Could not find data: %s on workers: %s (who_has: %s)", + missing_keys, + missing_workers, + who_has, + ) + raise Return({"status": "missing-data", "keys": missing_keys}) else: self.update_data(data=result, report=False) - raise Return({'status': 'OK'}) + raise Return({"status": "OK"}) def get_logs(self, comm=None, n=None): deque_handler = self._deque_handler @@ -761,8 +850,8 @@ def get_logs(self, comm=None, n=None): ############# def start_services(self, default_listen_ip): - if default_listen_ip == '0.0.0.0': - default_listen_ip = '' # for IPV6 + if default_listen_ip == "0.0.0.0": + default_listen_ip = "" # for IPV6 for k, v in self.service_specs.items(): listen_ip = None @@ -772,7 +861,7 @@ def start_services(self, default_listen_ip): port = 0 if isinstance(port, (str, unicode)): - port = port.split(':') + port = port.split(":") if isinstance(port, (tuple, list)): listen_ip, port = (port[0], int(port[1])) @@ -783,7 +872,9 @@ def start_services(self, default_listen_ip): kwargs = {} self.services[k] = v(self, io_loop=self.loop, **kwargs) - self.services[k].listen((listen_ip if listen_ip is not None else default_listen_ip, port)) + self.services[k].listen( + (listen_ip if listen_ip is not None else default_listen_ip, port) + ) self.service_ports[k] = self.services[k].port @gen.coroutine @@ -797,16 +888,15 @@ def _start(self, addr_or_port=0): if not addr_or_port: # Default address is the required one to reach the scheduler listen_host = get_address_host(self.scheduler.address) - self.listen(get_local_address_for(self.scheduler.address), - listen_args=self.listen_args) + self.listen( + get_local_address_for(self.scheduler.address), + listen_args=self.listen_args, + ) self.ip = get_address_host(self.address) elif isinstance(addr_or_port, int): # addr_or_port is an integer => assume TCP - listen_host = self.ip = get_ip( - get_address_host(self.scheduler.address) - ) - self.listen((listen_host, addr_or_port), - listen_args=self.listen_args) + listen_host = self.ip = get_ip(get_address_host(self.scheduler.address)) + self.listen((listen_host, addr_or_port), listen_args=self.listen_args) else: self.listen(addr_or_port, listen_args=self.listen_args) self.ip = get_address_host(self.address) @@ -815,31 +905,40 @@ def _start(self, addr_or_port=0): except ValueError: listen_host = addr_or_port - if '://' in listen_host: - protocol, listen_host = listen_host.split('://') + if "://" in listen_host: + protocol, listen_host = listen_host.split("://") self.name = self.name or self.address - preload_modules(self.preload, parameter=self, file_dir=self.local_dir, argv=self.preload_argv) + preload_modules( + self.preload, + parameter=self, + file_dir=self.local_dir, + argv=self.preload_argv, + ) # Services listen on all addresses # Note Nanny is not a "real" service, just some metadata # passed in service_ports... self.start_services(listen_host) try: - listening_address = '%s%s:%d' % (self.listener.prefix, listen_host, self.port) + listening_address = "%s%s:%d" % ( + self.listener.prefix, + listen_host, + self.port, + ) except Exception: - listening_address = '%s%s' % (self.listener.prefix, listen_host) + listening_address = "%s%s" % (self.listener.prefix, listen_host) - logger.info(' Start worker at: %26s', self.address) - logger.info(' Listening to: %26s', listening_address) + logger.info(" Start worker at: %26s", self.address) + logger.info(" Listening to: %26s", listening_address) for k, v in self.service_ports.items(): - logger.info(' %16s at: %26s' % (k, listen_host + ':' + str(v))) - logger.info('Waiting to connect to: %26s', self.scheduler.address) - logger.info('-' * 49) - logger.info(' Threads: %26d', self.ncores) + logger.info(" %16s at: %26s" % (k, listen_host + ":" + str(v))) + logger.info("Waiting to connect to: %26s", self.scheduler.address) + logger.info("-" * 49) + logger.info(" Threads: %26d", self.ncores) if self.memory_limit: - logger.info(' Memory: %26s', format_bytes(self.memory_limit)) - logger.info(' Local Directory: %26s', self.local_dir) + logger.info(" Memory: %26s", format_bytes(self.memory_limit)) + logger.info(" Local Directory: %26s", self.local_dir) setproctitle("dask-worker [%s]" % self.address) @@ -857,7 +956,7 @@ def start(self, port=0): @gen.coroutine def _close(self, report=True, timeout=10, nanny=True, executor_wait=True): with log_errors(): - if self.status in ('closed', 'closing'): + if self.status in ("closed", "closing"): return disable_gc_diagnosis() @@ -866,7 +965,7 @@ def _close(self, report=True, timeout=10, nanny=True, executor_wait=True): logger.info("Stopping worker at %s", self.address) except ValueError: # address not available if already closed logger.info("Stopping worker") - self.status = 'closing' + self.status = "closing" setproctitle("dask-worker [closing]") self.stop() @@ -874,8 +973,10 @@ def _close(self, report=True, timeout=10, nanny=True, executor_wait=True): pc.stop() with ignoring(EnvironmentError, gen.TimeoutError): if report: - yield gen.with_timeout(timedelta(seconds=timeout), - self.scheduler.unregister(address=self.contact_address)) + yield gen.with_timeout( + timedelta(seconds=timeout), + self.scheduler.unregister(address=self.contact_address), + ) self.scheduler.close_rpc() self.actor_executor._work_queue.queue.clear() if isinstance(self.executor, ThreadPoolExecutor): @@ -889,14 +990,14 @@ def _close(self, report=True, timeout=10, nanny=True, executor_wait=True): for k, v in self.services.items(): v.stop() - self.status = 'closed' + self.status = "closed" - if nanny and 'nanny' in self.service_ports: - with self.rpc((self.ip, self.service_ports['nanny'])) as r: + if nanny and "nanny" in self.service_ports: + with self.rpc((self.ip, self.service_ports["nanny"])) as r: yield r.terminate() if self.batched_stream and not self.batched_stream.comm.closed(): - self.batched_stream.send({'op': 'close-stream'}) + self.batched_stream.send({"op": "close-stream"}) if self.batched_stream: self.batched_stream.close() @@ -921,12 +1022,12 @@ def _remove_from_global_workers(self): @gen.coroutine def terminate(self, comm, report=True): yield self._close(report=report) - raise Return('OK') + raise Return("OK") @gen.coroutine def wait_until_closed(self): yield self._closed.wait() - assert self.status == 'closed' + assert self.status == "closed" ################ # Worker Peers # @@ -934,14 +1035,15 @@ def wait_until_closed(self): def send_to_worker(self, address, msg): if address not in self.stream_comms: - bcomm = BatchedSend(interval='1ms', loop=self.loop) + bcomm = BatchedSend(interval="1ms", loop=self.loop) self.stream_comms[address] = bcomm @gen.coroutine def batched_send_connect(): - comm = yield connect(address, # TODO, serialization - connection_args=self.connection_args) - yield comm.write({'op': 'connection_stream'}) + comm = yield connect( + address, connection_args=self.connection_args # TODO, serialization + ) + yield comm.write({"op": "connection_stream"}) bcomm.start(comm) @@ -950,19 +1052,27 @@ def batched_send_connect(): self.stream_comms[address].send(msg) @gen.coroutine - def get_data(self, comm, keys=None, who=None, serializers=None, - max_connections=None): + def get_data( + self, comm, keys=None, who=None, serializers=None, max_connections=None + ): start = time() if max_connections is None: max_connections = self.total_in_connections # Allow same-host connections more liberally - if max_connections and comm and get_address_host(comm.peer_address) == get_address_host(self.address): + if ( + max_connections + and comm + and get_address_host(comm.peer_address) == get_address_host(self.address) + ): max_connections = max_connections * 2 - if max_connections is not False and self.outgoing_current_count > max_connections: - raise gen.Return({'status': 'busy'}) + if ( + max_connections is not False + and self.outgoing_current_count > max_connections + ): + raise gen.Return({"status": "busy"}) self.outgoing_current_count += 1 data = {k: self.data[k] for k in keys if k in self.data} @@ -971,48 +1081,51 @@ def get_data(self, comm, keys=None, who=None, serializers=None, for k in set(keys) - set(data): if k in self.actors: from .actor import Actor + data[k] = Actor(type(self.actors[k]), self.address, k) - msg = {'status': 'OK', - 'data': {k: to_serialize(v) for k, v in data.items()}} + msg = {"status": "OK", "data": {k: to_serialize(v) for k, v in data.items()}} nbytes = {k: self.nbytes.get(k) for k in data} stop = time() if self.digests is not None: - self.digests['get-data-load-duration'].add(stop - start) + self.digests["get-data-load-duration"].add(stop - start) start = time() try: compressed = yield comm.write(msg, serializers=serializers) response = yield comm.read(deserializers=serializers) - assert response == 'OK', response + assert response == "OK", response except EnvironmentError: - logger.exception('failed during get data with %s -> %s', - self.address, who, exc_info=True) + logger.exception( + "failed during get data with %s -> %s", self.address, who, exc_info=True + ) comm.abort() raise finally: self.outgoing_current_count -= 1 stop = time() if self.digests is not None: - self.digests['get-data-send-duration'].add(stop - start) + self.digests["get-data-send-duration"].add(stop - start) total_bytes = sum(filter(None, nbytes.values())) self.outgoing_count += 1 duration = (stop - start) or 0.5 # windows - self.outgoing_transfer_log.append({ - 'start': start + self.scheduler_delay, - 'stop': stop + self.scheduler_delay, - 'middle': (start + stop) / 2, - 'duration': duration, - 'who': who, - 'keys': nbytes, - 'total': total_bytes, - 'compressed': compressed, - 'bandwidth': total_bytes / duration - }) - - raise gen.Return('dont-reply') + self.outgoing_transfer_log.append( + { + "start": start + self.scheduler_delay, + "stop": stop + self.scheduler_delay, + "middle": (start + stop) / 2, + "duration": duration, + "who": who, + "keys": nbytes, + "total": total_bytes, + "compressed": compressed, + "bandwidth": total_bytes / duration, + } + ) + + raise gen.Return("dont-reply") ################### # Local Execution # @@ -1021,32 +1134,30 @@ def get_data(self, comm, keys=None, who=None, serializers=None, def update_data(self, comm=None, data=None, report=True, serializers=None): for key, value in data.items(): if key in self.task_state: - self.transition(key, 'memory', value=value) + self.transition(key, "memory", value=value) else: self.put_key_in_memory(key, value) - self.task_state[key] = 'memory' + self.task_state[key] = "memory" self.tasks[key] = None self.priorities[key] = None self.durations[key] = None self.dependencies[key] = set() if key in self.dep_state: - self.transition_dep(key, 'memory', value=value) + self.transition_dep(key, "memory", value=value) - self.log.append((key, 'receive-from-scatter')) + self.log.append((key, "receive-from-scatter")) if report: - self.batched_stream.send({'op': 'add-keys', - 'keys': list(data)}) - info = {'nbytes': {k: sizeof(v) for k, v in data.items()}, - 'status': 'OK'} + self.batched_stream.send({"op": "add-keys", "keys": list(data)}) + info = {"nbytes": {k: sizeof(v) for k, v in data.items()}, "status": "OK"} return info @gen.coroutine def delete_data(self, comm=None, keys=None, report=True): if keys: for key in list(keys): - self.log.append((key, 'delete')) + self.log.append((key, "delete")) if key in self.task_state: self.release_key(key) @@ -1057,9 +1168,10 @@ def delete_data(self, comm=None, keys=None, report=True): if report: logger.debug("Reporting loss of keys to scheduler") # TODO: this route seems to not exist? - yield self.scheduler.remove_keys(address=self.contact_address, - keys=list(keys)) - raise Return('OK') + yield self.scheduler.remove_keys( + address=self.contact_address, keys=list(keys) + ) + raise Return("OK") @gen.coroutine def set_resources(self, **resources): @@ -1070,28 +1182,42 @@ def set_resources(self, **resources): self.available_resources[r] = quantity self.total_resources[r] = quantity - yield self.scheduler.set_resources(resources=self.total_resources, - worker=self.contact_address) + yield self.scheduler.set_resources( + resources=self.total_resources, worker=self.contact_address + ) ################### # Task Management # ################### - def add_task(self, key, function=None, args=None, kwargs=None, task=None, - who_has=None, nbytes=None, priority=None, duration=None, - resource_restrictions=None, actor=False, **kwargs2): + def add_task( + self, + key, + function=None, + args=None, + kwargs=None, + task=None, + who_has=None, + nbytes=None, + priority=None, + duration=None, + resource_restrictions=None, + actor=False, + **kwargs2 + ): try: if key in self.tasks: state = self.task_state[key] - if state == 'memory': + if state == "memory": assert key in self.data or key in self.actors - logger.debug("Asked to compute pre-existing result: %s: %s", - key, state) + logger.debug( + "Asked to compute pre-existing result: %s: %s", key, state + ) self.send_task_state_to_scheduler(key) return if state in IN_PLAY: return - if state == 'erred': + if state == "erred": del self.exceptions[key] del self.tracebacks[key] @@ -1099,16 +1225,16 @@ def add_task(self, key, function=None, args=None, kwargs=None, task=None, priority = tuple(priority) + (self.generation,) self.generation -= 1 - if self.dep_state.get(key) == 'memory': - self.task_state[key] = 'memory' + if self.dep_state.get(key) == "memory": + self.task_state[key] = "memory" self.send_task_state_to_scheduler(key) self.tasks[key] = None - self.log.append((key, 'new-task-already-in-memory')) + self.log.append((key, "new-task-already-in-memory")) self.priorities[key] = priority self.durations[key] = duration return - self.log.append((key, 'new')) + self.log.append((key, "new")) try: start = time() self.tasks[key] = _deserialize(function, args, kwargs, task) @@ -1117,21 +1243,21 @@ def add_task(self, key, function=None, args=None, kwargs=None, task=None, stop = time() if stop - start > 0.010: - self.startstops[key].append(('deserialize', start, stop)) + self.startstops[key].append(("deserialize", start, stop)) except Exception as e: logger.warning("Could not deserialize task", exc_info=True) emsg = error_message(e) - emsg['key'] = key - emsg['op'] = 'task-erred' + emsg["key"] = key + emsg["op"] = "task-erred" self.batched_stream.send(emsg) - self.log.append((key, 'deserialize-error')) + self.log.append((key, "deserialize-error")) return self.priorities[key] = priority self.durations[key] = duration if resource_restrictions: self.resource_restrictions[key] = resource_restrictions - self.task_state[key] = 'waiting' + self.task_state[key] = "waiting" if nbytes is not None: self.nbytes.update(nbytes) @@ -1146,14 +1272,14 @@ def add_task(self, key, function=None, args=None, kwargs=None, task=None, self.dependents[dep].add(key) if dep not in self.dep_state: - if self.task_state.get(dep) == 'memory': - state = 'memory' + if self.task_state.get(dep) == "memory": + state = "memory" else: - state = 'waiting' + state = "waiting" self.dep_state[dep] = state - self.log.append((dep, 'new-dep', state)) + self.log.append((dep, "new-dep", state)) - if self.dep_state[dep] != 'memory': + if self.dep_state[dep] != "memory": self.waiting_for_data[key].add(dep) for dep, workers in who_has.items(): @@ -1164,13 +1290,13 @@ def add_task(self, key, function=None, args=None, kwargs=None, task=None, for worker in workers: self.has_what[worker].add(dep) - if self.dep_state[dep] != 'memory': + if self.dep_state[dep] != "memory": self.pending_data_per_worker[worker].append(dep) if self.waiting_for_data[key]: self.data_needed.append(key) else: - self.transition(key, 'ready') + self.transition(key, "ready") if self.validate: if who_has: assert all(dep in self.dep_state for dep in who_has) @@ -1182,6 +1308,7 @@ def add_task(self, key, function=None, args=None, kwargs=None, task=None, logger.exception(e) if LOG_PDB: import pdb + pdb.set_trace() raise @@ -1194,7 +1321,7 @@ def transition_dep(self, dep, finish, **kwargs): return func = self._dep_transitions[start, finish] state = func(dep, **kwargs) - self.log.append(('dep', dep, start, state or finish)) + self.log.append(("dep", dep, start, state or finish)) if dep in self.dep_state: self.dep_state[dep] = state or finish if self.validate: @@ -1211,6 +1338,7 @@ def transition_dep_waiting_flight(self, dep, worker=None): logger.exception(e) if LOG_PDB: import pdb + pdb.set_trace() raise @@ -1235,7 +1363,7 @@ def transition_dep_flight_waiting(self, dep, worker=None, remove=True): self._missing_dep_flight.add(dep) self.loop.add_callback(self.handle_missing_dep, dep) for key in self.dependents.get(dep, ()): - if self.task_state[key] == 'waiting': + if self.task_state[key] == "waiting": if remove: # try a new worker immediately self.data_needed.appendleft(key) else: # worker was probably busy, wait a while @@ -1247,6 +1375,7 @@ def transition_dep_flight_waiting(self, dep, worker=None, remove=True): logger.exception(e) if LOG_PDB: import pdb + pdb.set_trace() raise @@ -1257,10 +1386,9 @@ def transition_dep_flight_memory(self, dep, value=None): del self.in_flight_tasks[dep] if self.dependents[dep]: - self.dep_state[dep] = 'memory' + self.dep_state[dep] = "memory" self.put_key_in_memory(dep, value) - self.batched_stream.send({'op': 'add-keys', - 'keys': [dep]}) + self.batched_stream.send({"op": "add-keys", "keys": [dep]}) else: self.release_dep(dep) @@ -1268,6 +1396,7 @@ def transition_dep_flight_memory(self, dep, value=None): logger.exception(e) if LOG_PDB: import pdb + pdb.set_trace() raise @@ -1277,11 +1406,12 @@ def transition_dep_waiting_memory(self, dep, value=None): assert dep in self.data assert dep in self.nbytes assert dep in self.types - assert self.task_state[dep] == 'memory' + assert self.task_state[dep] == "memory" except Exception as e: logger.exception(e) if LOG_PDB: import pdb + pdb.set_trace() raise if value is not no_value and dep not in self.data: @@ -1301,10 +1431,13 @@ def transition(self, key, finish, **kwargs): def transition_waiting_ready(self, key): try: if self.validate: - assert self.task_state[key] == 'waiting' + assert self.task_state[key] == "waiting" assert key in self.waiting_for_data assert not self.waiting_for_data[key] - assert all(dep in self.data or dep in self.actors for dep in self.dependencies[key]) + assert all( + dep in self.data or dep in self.actors + for dep in self.dependencies[key] + ) assert key not in self.executing assert key not in self.ready @@ -1312,20 +1445,21 @@ def transition_waiting_ready(self, key): if key in self.resource_restrictions: self.constrained.append(key) - return 'constrained' + return "constrained" else: heapq.heappush(self.ready, (self.priorities[key], key)) except Exception as e: logger.exception(e) if LOG_PDB: import pdb + pdb.set_trace() raise def transition_waiting_done(self, key, value=None): try: if self.validate: - assert self.task_state[key] == 'waiting' + assert self.task_state[key] == "waiting" assert key in self.waiting_for_data assert key not in self.executing assert key not in self.ready @@ -1336,6 +1470,7 @@ def transition_waiting_done(self, key, value=None): logger.exception(e) if LOG_PDB: import pdb + pdb.set_trace() raise @@ -1346,7 +1481,10 @@ def transition_ready_executing(self, key): # assert key not in self.data assert self.task_state[key] in READY assert key not in self.ready - assert all(dep in self.data or dep in self.actors for dep in self.dependencies[key]) + assert all( + dep in self.data or dep in self.actors + for dep in self.dependencies[key] + ) self.executing.add(key) self.loop.add_callback(self.execute, key) @@ -1354,6 +1492,7 @@ def transition_ready_executing(self, key): logger.exception(e) if LOG_PDB: import pdb + pdb.set_trace() raise @@ -1380,26 +1519,26 @@ def transition_executing_done(self, key, value=no_value, report=True): for resource, quantity in self.resource_restrictions[key].items(): self.available_resources[resource] += quantity - if self.task_state[key] == 'executing': + if self.task_state[key] == "executing": self.executing.remove(key) self.executed_count += 1 - elif self.task_state[key] == 'long-running': + elif self.task_state[key] == "long-running": self.long_running.remove(key) if value is not no_value: try: - self.task_state[key] = 'memory' + self.task_state[key] = "memory" self.put_key_in_memory(key, value, transition=False) except Exception as e: logger.info("Failed to put key in memory", exc_info=True) msg = error_message(e) - self.exceptions[key] = msg['exception'] - self.tracebacks[key] = msg['traceback'] - self.task_state[key] = 'error' - out = 'error' + self.exceptions[key] = msg["exception"] + self.tracebacks[key] = msg["traceback"] + self.task_state[key] = "error" + out = "error" if key in self.dep_state: - self.transition_dep(key, 'memory') + self.transition_dep(key, "memory") if report and self.batched_stream: self.send_task_state_to_scheduler(key) @@ -1414,6 +1553,7 @@ def transition_executing_done(self, key, value=no_value, report=True): logger.exception(e) if LOG_PDB: import pdb + pdb.set_trace() raise @@ -1424,45 +1564,59 @@ def transition_executing_long_running(self, key, compute_duration=None): self.executing.remove(key) self.long_running.add(key) - self.batched_stream.send({'op': 'long-running', - 'key': key, - 'compute_duration': compute_duration}) + self.batched_stream.send( + {"op": "long-running", "key": key, "compute_duration": compute_duration} + ) self.ensure_computing() except Exception as e: logger.exception(e) if LOG_PDB: import pdb + pdb.set_trace() raise def maybe_transition_long_running(self, key, compute_duration=None): - if self.task_state.get(key) == 'executing': - self.transition(key, 'long-running', compute_duration=compute_duration) + if self.task_state.get(key) == "executing": + self.transition(key, "long-running", compute_duration=compute_duration) def stateof(self, key): - return {'executing': key in self.executing, - 'waiting_for_data': key in self.waiting_for_data, - 'heap': key in pluck(1, self.ready), - 'data': key in self.data} + return { + "executing": key in self.executing, + "waiting_for_data": key in self.waiting_for_data, + "heap": key in pluck(1, self.ready), + "data": key in self.data, + } def story(self, *keys): - return [msg for msg in self.log - if any(key in msg for key in keys) - or any(key in c - for key in keys - for c in msg - if isinstance(c, (tuple, list, set)))] + return [ + msg + for msg in self.log + if any(key in msg for key in keys) + or any( + key in c + for key in keys + for c in msg + if isinstance(c, (tuple, list, set)) + ) + ] def ensure_communicating(self): changed = True try: - while changed and self.data_needed and len(self.in_flight_workers) < self.total_out_connections: + while ( + changed + and self.data_needed + and len(self.in_flight_workers) < self.total_out_connections + ): changed = False - logger.debug("Ensure communicating. Pending: %d. Connections: %d/%d", - len(self.data_needed), - len(self.in_flight_workers), - self.total_out_connections) + logger.debug( + "Ensure communicating. Pending: %d. Connections: %d/%d", + len(self.data_needed), + len(self.in_flight_workers), + self.total_out_connections, + ) key = self.data_needed[0] @@ -1471,8 +1625,8 @@ def ensure_communicating(self): changed = True continue - if self.task_state.get(key) != 'waiting': - self.log.append((key, 'communication pass')) + if self.task_state.get(key) != "waiting": + self.log.append((key, "communication pass")) self.data_needed.popleft() changed = True continue @@ -1481,33 +1635,38 @@ def ensure_communicating(self): if self.validate: assert all(dep in self.dep_state for dep in deps) - deps = [dep for dep in deps if self.dep_state[dep] == 'waiting'] + deps = [dep for dep in deps if self.dep_state[dep] == "waiting"] missing_deps = {dep for dep in deps if not self.who_has.get(dep)} if missing_deps: logger.info("Can't find dependencies for key %s", key) - missing_deps2 = {dep for dep in missing_deps - if dep not in self._missing_dep_flight} + missing_deps2 = { + dep + for dep in missing_deps + if dep not in self._missing_dep_flight + } for dep in missing_deps2: self._missing_dep_flight.add(dep) - self.loop.add_callback(self.handle_missing_dep, - *missing_deps2) + self.loop.add_callback(self.handle_missing_dep, *missing_deps2) deps = [dep for dep in deps if dep not in missing_deps] - self.log.append(('gather-dependencies', key, deps)) + self.log.append(("gather-dependencies", key, deps)) in_flight = False - while deps and (len(self.in_flight_workers) < self.total_out_connections - or self.comm_nbytes < self.total_comm_nbytes): + while deps and ( + len(self.in_flight_workers) < self.total_out_connections + or self.comm_nbytes < self.total_comm_nbytes + ): dep = deps.pop() - if self.dep_state[dep] != 'waiting': + if self.dep_state[dep] != "waiting": continue if dep not in self.who_has: continue - workers = [w for w in self.who_has[dep] - if w not in self.in_flight_workers] + workers = [ + w for w in self.who_has[dep] if w not in self.in_flight_workers + ] if not workers: in_flight = True continue @@ -1521,9 +1680,10 @@ def ensure_communicating(self): self.comm_nbytes += total_nbytes self.in_flight_workers[worker] = to_gather for d in to_gather: - self.transition_dep(d, 'flight', worker=worker) - self.loop.add_callback(self.gather_dep, worker, dep, - to_gather, total_nbytes, cause=key) + self.transition_dep(d, "flight", worker=worker) + self.loop.add_callback( + self.gather_dep, worker, dep, to_gather, total_nbytes, cause=key + ) changed = True if not deps and not in_flight: @@ -1532,6 +1692,7 @@ def ensure_communicating(self): logger.exception(e) if LOG_PDB: import pdb + pdb.set_trace() raise @@ -1550,26 +1711,31 @@ def send_task_state_to_scheduler(self, key): # Some types fail pickling (example: _thread.lock objects), # send their name as a best effort. typ = pickle.dumps(typ.__name__) - d = {'op': 'task-finished', - 'status': 'OK', - 'key': key, - 'nbytes': nbytes, - 'thread': self.threads.get(key), - 'type': typ} + d = { + "op": "task-finished", + "status": "OK", + "key": key, + "nbytes": nbytes, + "thread": self.threads.get(key), + "type": typ, + } elif key in self.exceptions: - d = {'op': 'task-erred', - 'status': 'error', - 'key': key, - 'thread': self.threads.get(key), - 'exception': self.exceptions[key], - 'traceback': self.tracebacks[key]} + d = { + "op": "task-erred", + "status": "error", + "key": key, + "thread": self.threads.get(key), + "exception": self.exceptions[key], + "traceback": self.tracebacks[key], + } else: - logger.error("Key not ready to send to worker, %s: %s", - key, self.task_state[key]) + logger.error( + "Key not ready to send to worker, %s: %s", key, self.task_state[key] + ) return if key in self.startstops: - d['startstops'] = self.startstops[key] + d["startstops"] = self.startstops[key] self.batched_stream.send(d) def put_key_in_memory(self, key, value, transition=True): @@ -1584,7 +1750,7 @@ def put_key_in_memory(self, key, value, transition=True): self.data[key] = value stop = time() if stop - start > 0.020: - self.startstops[key].append(('disk-write', start, stop)) + self.startstops[key].append(("disk-write", start, stop)) if key not in self.nbytes: self.nbytes[key] = sizeof(value) @@ -1596,12 +1762,12 @@ def put_key_in_memory(self, key, value, transition=True): if key in self.waiting_for_data[dep]: self.waiting_for_data[dep].remove(key) if not self.waiting_for_data[dep]: - self.transition(dep, 'ready') + self.transition(dep, "ready") if transition and key in self.task_state: - self.transition(key, 'memory') + self.transition(key, "memory") - self.log.append((key, 'put-in-memory')) + self.log.append((key, "put-in-memory")) def select_keys_for_gather(self, worker, dep): deps = {dep} @@ -1611,7 +1777,7 @@ def select_keys_for_gather(self, worker, dep): while L: d = L.popleft() - if self.dep_state.get(d) != 'waiting': + if self.dep_state.get(d) != "waiting": continue if total_bytes + self.nbytes[d] > self.target_message_size: break @@ -1622,7 +1788,7 @@ def select_keys_for_gather(self, worker, dep): @gen.coroutine def gather_dep(self, worker, dep, deps, total_nbytes, cause=None): - if self.status != 'running': + if self.status != "running": return with log_errors(): response = {} @@ -1632,58 +1798,64 @@ def gather_dep(self, worker, dep, deps, total_nbytes, cause=None): # dep states may have changed before gather_dep runs # if a dep is no longer in-flight then don't fetch it - deps = tuple(dep for dep in deps - if self.dep_state.get(dep) == 'flight') + deps = tuple(dep for dep in deps if self.dep_state.get(dep) == "flight") - self.log.append(('request-dep', dep, worker, deps)) + self.log.append(("request-dep", dep, worker, deps)) logger.debug("Request %d keys", len(deps)) start = time() - response = yield get_data_from_worker(self.rpc, deps, worker, - who=self.address) + response = yield get_data_from_worker( + self.rpc, deps, worker, who=self.address + ) stop = time() - if response['status'] == 'busy': - self.log.append(('busy-gather', worker, deps)) + if response["status"] == "busy": + self.log.append(("busy-gather", worker, deps)) for dep in deps: - if self.dep_state.get(dep, None) == 'flight': - self.transition_dep(dep, 'waiting') + if self.dep_state.get(dep, None) == "flight": + self.transition_dep(dep, "waiting") return if cause: - self.startstops[cause].append(( - 'transfer', - start + self.scheduler_delay, - stop + self.scheduler_delay - )) - - total_bytes = sum(self.nbytes.get(dep, 0) for dep in response['data']) + self.startstops[cause].append( + ( + "transfer", + start + self.scheduler_delay, + stop + self.scheduler_delay, + ) + ) + + total_bytes = sum(self.nbytes.get(dep, 0) for dep in response["data"]) duration = (stop - start) or 0.5 - self.incoming_transfer_log.append({ - 'start': start + self.scheduler_delay, - 'stop': stop + self.scheduler_delay, - 'middle': (start + stop) / 2.0 + self.scheduler_delay, - 'duration': duration, - 'keys': {dep: self.nbytes.get(dep, None) for dep in response['data']}, - 'total': total_bytes, - 'bandwidth': total_bytes / duration, - 'who': worker - }) + self.incoming_transfer_log.append( + { + "start": start + self.scheduler_delay, + "stop": stop + self.scheduler_delay, + "middle": (start + stop) / 2.0 + self.scheduler_delay, + "duration": duration, + "keys": { + dep: self.nbytes.get(dep, None) for dep in response["data"] + }, + "total": total_bytes, + "bandwidth": total_bytes / duration, + "who": worker, + } + ) if self.digests is not None: - self.digests['transfer-bandwidth'].add(total_bytes / duration) - self.digests['transfer-duration'].add(duration) - self.counters['transfer-count'].add(len(response['data'])) + self.digests["transfer-bandwidth"].add(total_bytes / duration) + self.digests["transfer-duration"].add(duration) + self.counters["transfer-count"].add(len(response["data"])) self.incoming_count += 1 - self.log.append(('receive-dep', worker, list(response['data']))) + self.log.append(("receive-dep", worker, list(response["data"]))) - if response['data']: - self.batched_stream.send({'op': 'add-keys', - 'keys': list(response['data'])}) + if response["data"]: + self.batched_stream.send( + {"op": "add-keys", "keys": list(response["data"])} + ) except EnvironmentError as e: - logger.exception("Worker stream died during communication: %s", - worker) - self.log.append(('receive-dep-failed', worker)) + logger.exception("Worker stream died during communication: %s", worker) + self.log.append(("receive-dep-failed", worker)) for d in self.has_what.pop(worker): self.who_has[d].remove(worker) if not self.who_has[d]: @@ -1693,25 +1865,27 @@ def gather_dep(self, worker, dep, deps, total_nbytes, cause=None): logger.exception(e) if self.batched_stream and LOG_PDB: import pdb + pdb.set_trace() raise finally: self.comm_nbytes -= total_nbytes - busy = response.get('status', '') == 'busy' - data = response.get('data', {}) + busy = response.get("status", "") == "busy" + data = response.get("data", {}) for d in self.in_flight_workers.pop(worker): if not busy and d in data: - self.transition_dep(d, 'memory', value=data[d]) - elif self.dep_state.get(d) != 'memory': - self.transition_dep(d, 'waiting', worker=worker, - remove=not busy) + self.transition_dep(d, "memory", value=data[d]) + elif self.dep_state.get(d) != "memory": + self.transition_dep( + d, "waiting", worker=worker, remove=not busy + ) if not busy and d not in data and d in self.dependents: - self.log.append(('missing-dep', d)) - self.batched_stream.send({'op': 'missing-data', - 'errant_worker': worker, - 'key': d}) + self.log.append(("missing-dep", d)) + self.batched_stream.send( + {"op": "missing-data", "errant_worker": worker, "key": d} + ) if self.validate: self.validate_state() @@ -1734,15 +1908,15 @@ def bad_dep(self, dep): exc = ValueError("Could not find dependent %s. Check worker logs" % str(dep)) for key in self.dependents[dep]: msg = error_message(exc) - self.exceptions[key] = msg['exception'] - self.tracebacks[key] = msg['traceback'] - self.transition(key, 'error') + self.exceptions[key] = msg["exception"] + self.tracebacks[key] = msg["traceback"] + self.transition(key, "error") self.release_dep(dep) @gen.coroutine def handle_missing_dep(self, *deps, **kwargs): original_deps = list(deps) - self.log.append(('handle-missing', deps)) + self.log.append(("handle-missing", deps)) try: deps = {dep for dep in deps if dep in self.dependents} if not deps: @@ -1757,8 +1931,11 @@ def handle_missing_dep(self, *deps, **kwargs): return for dep in deps: - logger.info("Dependent not found: %s %s . Asking scheduler", - dep, self.suspicious_deps[dep]) + logger.info( + "Dependent not found: %s %s . Asking scheduler", + dep, + self.suspicious_deps[dep], + ) who_has = yield self.scheduler.who_has(keys=list(deps)) who_has = {k: v for k, v in who_has.items() if v} @@ -1767,19 +1944,18 @@ def handle_missing_dep(self, *deps, **kwargs): self.suspicious_deps[dep] += 1 if not who_has.get(dep): - self.log.append((dep, 'no workers found', - self.dependents.get(dep))) + self.log.append((dep, "no workers found", self.dependents.get(dep))) self.release_dep(dep) else: - self.log.append((dep, 'new workers found')) + self.log.append((dep, "new workers found")) for key in self.dependents.get(dep, ()): if key in self.waiting_for_data: self.data_needed.append(key) except Exception: logger.error("Handle missing dep failed, retrying", exc_info=True) - retries = kwargs.get('retries', 5) - self.log.append(('handle-missing-failed', retries, deps)) + retries = kwargs.get("retries", 5) + self.log.append(("handle-missing-failed", retries, deps)) if retries > 0: yield self.handle_missing_dep(self, *deps, retries=retries - 1) else: @@ -1816,18 +1992,17 @@ def update_who_has(self, who_has): logger.exception(e) if LOG_PDB: import pdb + pdb.set_trace() raise def steal_request(self, key): state = self.task_state.get(key, None) - response = {'op': 'steal-response', - 'key': key, - 'state': state} + response = {"op": "steal-response", "key": key, "state": state} self.batched_stream.send(response) - if state in ('ready', 'waiting'): + if state in ("ready", "waiting"): self.release_key(key) def release_key(self, key, cause=None, reason=None, report=True): @@ -1836,16 +2011,15 @@ def release_key(self, key, cause=None, reason=None, report=True): return state = self.task_state.pop(key) if cause: - self.log.append((key, 'release-key', {'cause': cause})) + self.log.append((key, "release-key", {"cause": cause})) else: - self.log.append((key, 'release-key')) + self.log.append((key, "release-key")) del self.tasks[key] if key in self.data and key not in self.dep_state: try: del self.data[key] except FileNotFoundError: - logger.error("Tried to delete %s but no file found", - exc_info=True) + logger.error("Tried to delete %s but no file found", exc_info=True) del self.nbytes[key] del self.types[key] if key in self.actors and key not in self.dep_state: @@ -1859,7 +2033,10 @@ def release_key(self, key, cause=None, reason=None, report=True): for dep in self.dependencies.pop(key, ()): if dep in self.dependents: self.dependents[dep].discard(key) - if not self.dependents[dep] and self.dep_state[dep] in ('waiting', 'flight'): + if not self.dependents[dep] and self.dep_state[dep] in ( + "waiting", + "flight", + ): self.release_dep(dep) if key in self.threads: @@ -1879,21 +2056,20 @@ def release_key(self, key, cause=None, reason=None, report=True): self.executing.remove(key) if key in self.resource_restrictions: - if state == 'executing': + if state == "executing": for resource, quantity in self.resource_restrictions[key].items(): self.available_resources[resource] += quantity del self.resource_restrictions[key] if report and state in PROCESSING: # not finished - self.batched_stream.send({'op': 'release', - 'key': key, - 'cause': cause}) + self.batched_stream.send({"op": "release", "key": key, "cause": cause}) except CommClosedError: pass except Exception as e: logger.exception(e) if LOG_PDB: import pdb + pdb.set_trace() raise @@ -1901,7 +2077,7 @@ def release_dep(self, dep, report=False): try: if dep not in self.dep_state: return - self.log.append((dep, 'release-dep')) + self.log.append((dep, "release-dep")) state = self.dep_state.pop(dep) if dep in self.suspicious_deps: @@ -1925,16 +2101,16 @@ def release_dep(self, dep, report=False): self.in_flight_workers[worker].remove(dep) for key in self.dependents.pop(dep, ()): - if self.task_state[key] != 'memory': + if self.task_state[key] != "memory": self.release_key(key, cause=dep) - if report and state == 'memory': - self.batched_stream.send({'op': 'release-worker-data', - 'keys': [dep]}) + if report and state == "memory": + self.batched_stream.send({"op": "release-worker-data", "keys": [dep]}) except Exception as e: logger.exception(e) if LOG_PDB: import pdb + pdb.set_trace() raise @@ -1963,6 +2139,7 @@ def rescind_key(self, key): logger.exception(e) if LOG_PDB: import pdb + pdb.set_trace() raise @@ -1971,8 +2148,7 @@ def rescind_key(self, key): ################ @gen.coroutine - def executor_submit(self, key, function, args=(), kwargs=None, - executor=None): + def executor_submit(self, key, function, args=(), kwargs=None, executor=None): """ Safely run function in thread pool executor We've run into issues running concurrent.future futures within @@ -1985,8 +2161,9 @@ def executor_submit(self, key, function, args=(), kwargs=None, # logger.info("%s:%d Starts job %d, %s", self.ip, self.port, i, key) kwargs = kwargs or {} future = executor.submit(function, *args, **kwargs) - pc = PeriodicCallback(lambda: logger.debug("future state: %s - %s", - key, future._state), 1000) + pc = PeriodicCallback( + lambda: logger.debug("future state: %s - %s", key, future._state), 1000 + ) pc.start() try: yield future @@ -2000,39 +2177,43 @@ def executor_submit(self, key, function, args=(), kwargs=None, def run(self, comm, function, args=(), wait=True, kwargs=None): kwargs = kwargs or {} - return run(self, comm, function=function, args=args, kwargs=kwargs, - wait=wait) + return run(self, comm, function=function, args=args, kwargs=kwargs, wait=wait) def run_coroutine(self, comm, function, args=(), kwargs=None, wait=True): - return run(self, comm, function=function, args=args, kwargs=kwargs, - wait=wait) + return run(self, comm, function=function, args=args, kwargs=kwargs, wait=wait) @gen.coroutine def actor_execute(self, comm=None, actor=None, function=None, args=(), kwargs={}): - separate_thread = kwargs.pop('separate_thread', True) + separate_thread = kwargs.pop("separate_thread", True) key = actor actor = self.actors[key] func = getattr(actor, function) - name = key_split(key) + '.' + function + name = key_split(key) + "." + function if iscoroutinefunction(func): result = yield func(*args, **kwargs) elif separate_thread: - result = yield self.executor_submit(name, - apply_function_actor, - args=(func, args, kwargs, - self.execution_state, - name, - self.active_threads, - self.active_threads_lock), - executor=self.actor_executor) + result = yield self.executor_submit( + name, + apply_function_actor, + args=( + func, + args, + kwargs, + self.execution_state, + name, + self.active_threads, + self.active_threads_lock, + ), + executor=self.actor_executor, + ) else: result = func(*args, **kwargs) - raise gen.Return({'status': 'OK', 'result': to_serialize(result)}) + raise gen.Return({"status": "OK", "result": to_serialize(result)}) def actor_attribute(self, comm=None, actor=None, attribute=None): value = getattr(self.actors[actor], attribute) - return {'status': 'OK', 'result': to_serialize(value)} + return {"status": "OK", "result": to_serialize(value)} def meets_resource_constraints(self, key): if key not in self.resource_restrictions: @@ -2049,36 +2230,37 @@ def ensure_computing(self): try: while self.constrained and len(self.executing) < self.ncores: key = self.constrained[0] - if self.task_state.get(key) != 'constrained': + if self.task_state.get(key) != "constrained": self.constrained.popleft() continue if self.meets_resource_constraints(key): self.constrained.popleft() - self.transition(key, 'executing') + self.transition(key, "executing") else: break while self.ready and len(self.executing) < self.ncores: _, key = heapq.heappop(self.ready) if self.task_state.get(key) in READY: - self.transition(key, 'executing') + self.transition(key, "executing") except Exception as e: logger.exception(e) if LOG_PDB: import pdb + pdb.set_trace() raise @gen.coroutine def execute(self, key, report=False): executor_error = None - if self.status in ('closing', 'closed'): + if self.status in ("closing", "closed"): return try: if key not in self.executing or key not in self.task_state: return if self.validate: assert key not in self.waiting_for_data - assert self.task_state[key] == 'executing' + assert self.task_state[key] == "executing" function, args, kwargs = self.tasks[key] @@ -2089,64 +2271,74 @@ def execute(self, key, report=False): data[k] = self.data[k] except KeyError: from .actor import Actor # TODO: create local actor + data[k] = Actor(type(self.actors[k]), self.address, k, self) args2 = pack_data(args, data, key_types=(bytes, unicode)) kwargs2 = pack_data(kwargs, data, key_types=(bytes, unicode)) stop = time() if stop - start > 0.005: - self.startstops[key].append(('disk-read', start, stop)) + self.startstops[key].append(("disk-read", start, stop)) if self.digests is not None: - self.digests['disk-load-duration'].add(stop - start) + self.digests["disk-load-duration"].add(stop - start) - logger.debug("Execute key: %s worker: %s", key, self.address) # TODO: comment out? + logger.debug( + "Execute key: %s worker: %s", key, self.address + ) # TODO: comment out? try: - result = yield self.executor_submit(key, apply_function, - args=(function, args2, kwargs2, - self.execution_state, key, - self.active_threads, - self.active_threads_lock, - self.scheduler_delay)) + result = yield self.executor_submit( + key, + apply_function, + args=( + function, + args2, + kwargs2, + self.execution_state, + key, + self.active_threads, + self.active_threads_lock, + self.scheduler_delay, + ), + ) except RuntimeError as e: executor_error = e raise - if self.task_state.get(key) not in ('executing', 'long-running'): + if self.task_state.get(key) not in ("executing", "long-running"): return - result['key'] = key - value = result.pop('result', None) - self.startstops[key].append(('compute', result['start'], - result['stop'])) - self.threads[key] = result['thread'] + result["key"] = key + value = result.pop("result", None) + self.startstops[key].append(("compute", result["start"], result["stop"])) + self.threads[key] = result["thread"] - if result['op'] == 'task-finished': - self.nbytes[key] = result['nbytes'] - self.types[key] = result['type'] - self.transition(key, 'memory', value=value) + if result["op"] == "task-finished": + self.nbytes[key] = result["nbytes"] + self.types[key] = result["type"] + self.transition(key, "memory", value=value) if self.digests is not None: - self.digests['task-duration'].add(result['stop'] - - result['start']) + self.digests["task-duration"].add(result["stop"] - result["start"]) else: - if isinstance(result.pop('actual-exception'), Reschedule): - self.batched_stream.send({'op': 'reschedule', 'key': key}) - self.transition(key, 'rescheduled', report=False) + if isinstance(result.pop("actual-exception"), Reschedule): + self.batched_stream.send({"op": "reschedule", "key": key}) + self.transition(key, "rescheduled", report=False) self.release_key(key, report=False) else: - self.exceptions[key] = result['exception'] - self.tracebacks[key] = result['traceback'] - logger.warning(" Compute Failed\n" - "Function: %s\n" - "args: %s\n" - "kwargs: %s\n" - "Exception: %s\n", - str(funcname(function))[:1000], - convert_args_to_str(args2, max_len=1000), - convert_kwargs_to_str(kwargs2, max_len=1000), - repr(result['exception'].data)) - self.transition(key, 'error') - - logger.debug("Send compute response to scheduler: %s, %s", key, - result) + self.exceptions[key] = result["exception"] + self.tracebacks[key] = result["traceback"] + logger.warning( + " Compute Failed\n" + "Function: %s\n" + "args: %s\n" + "kwargs: %s\n" + "Exception: %s\n", + str(funcname(function))[:1000], + convert_args_to_str(args2, max_len=1000), + convert_kwargs_to_str(kwargs2, max_len=1000), + repr(result["exception"].data), + ) + self.transition(key, "error") + + logger.debug("Send compute response to scheduler: %s, %s", key, result) if self.validate: assert key not in self.executing @@ -2161,6 +2353,7 @@ def execute(self, key, report=False): logger.exception(e) if LOG_PDB: import pdb + pdb.set_trace() raise finally: @@ -2193,18 +2386,22 @@ def memory_monitor(self): # Try to free some memory while in paused state self._throttled_gc.collect() if not self.paused: - logger.warning("Worker is at %d%% memory usage. Pausing worker. " - "Process memory: %s -- Worker memory limit: %s", - int(frac * 100), - format_bytes(proc.memory_info().rss), - format_bytes(self.memory_limit)) + logger.warning( + "Worker is at %d%% memory usage. Pausing worker. " + "Process memory: %s -- Worker memory limit: %s", + int(frac * 100), + format_bytes(proc.memory_info().rss), + format_bytes(self.memory_limit), + ) self.paused = True elif self.paused: - logger.warning("Worker is at %d%% memory usage. Resuming worker. " - "Process memory: %s -- Worker memory limit: %s", - int(frac * 100), - format_bytes(proc.memory_info().rss), - format_bytes(self.memory_limit)) + logger.warning( + "Worker is at %d%% memory usage. Resuming worker. " + "Process memory: %s -- Worker memory limit: %s", + int(frac * 100), + format_bytes(proc.memory_info().rss), + format_bytes(self.memory_limit), + ) self.paused = False self.ensure_computing() @@ -2215,12 +2412,14 @@ def memory_monitor(self): need = memory - target while memory > target: if not self.data.fast: - logger.warning("Memory use is high but worker has no data " - "to store to disk. Perhaps some other process " - "is leaking memory? Process memory: %s -- " - "Worker memory limit: %s", - format_bytes(proc.memory_info().rss), - format_bytes(self.memory_limit)) + logger.warning( + "Memory use is high but worker has no data " + "to store to disk. Perhaps some other process " + "is leaking memory? Process memory: %s -- " + "Worker memory limit: %s", + format_bytes(proc.memory_info().rss), + format_bytes(self.memory_limit), + ) break k, v, weight = self.data.fast.evict() del k, v @@ -2235,8 +2434,11 @@ def memory_monitor(self): self._throttled_gc.collect() memory = proc.memory_info().rss if count: - logger.debug("Moved %d pieces of data data and %s to disk", - count, format_bytes(total)) + logger.debug( + "Moved %d pieces of data data and %s to disk", + count, + format_bytes(total), + ) self._memory_monitoring = False raise gen.Return(total) @@ -2265,21 +2467,22 @@ def trigger_profile(self): for ident, frame in frames.items(): if frame is not None: key = key_split(active_threads[ident]) - profile.process(frame, None, self.profile_recent, - stop='distributed/worker.py') - profile.process(frame, None, self.profile_keys[key], - stop='distributed/worker.py') + profile.process( + frame, None, self.profile_recent, stop="distributed/worker.py" + ) + profile.process( + frame, None, self.profile_keys[key], stop="distributed/worker.py" + ) stop = time() if self.digests is not None: - self.digests['profile-duration'].add(stop - start) + self.digests["profile-duration"].add(stop - start) def get_profile(self, comm=None, start=None, stop=None, key=None): now = time() + self.scheduler_delay if key is None: history = self.profile_history else: - history = [(t, d[key]) for t, d in self.profile_keys_history - if key in d] + history = [(t, d[key]) for t, d in self.profile_keys_history if key in d] if start is None: istart = 0 else: @@ -2318,23 +2521,28 @@ def get_profile_metadata(self, comm=None, start=0, stop=None): now = time() + self.scheduler_delay stop = stop or now start = start or 0 - result = {'counts': [(t, d['count']) for t, d in self.profile_history - if start < t < stop], - 'keys': [(t, {k: d['count'] for k, d in v.items()}) - for t, v in self.profile_keys_history - if start < t < stop]} + result = { + "counts": [ + (t, d["count"]) for t, d in self.profile_history if start < t < stop + ], + "keys": [ + (t, {k: d["count"] for k, d in v.items()}) + for t, v in self.profile_keys_history + if start < t < stop + ], + } if add_recent: - result['counts'].append((now, self.profile_recent['count'])) - result['keys'].append((now, {k: v['count'] - for k, v in self.profile_keys.items()})) + result["counts"].append((now, self.profile_recent["count"])) + result["keys"].append( + (now, {k: v["count"] for k, v in self.profile_keys.items()}) + ) return result def get_call_stack(self, comm=None, keys=None): with self.active_threads_lock: frames = sys._current_frames() active_threads = self.active_threads.copy() - frames = {k: frames[ident] - for ident, k in active_threads.items()} + frames = {k: frames[ident] for ident, k in active_threads.items()} if keys is not None: frames = {k: frame for k, frame in frames.items() if k in keys} @@ -2352,20 +2560,24 @@ def validate_key_memory(self, key): assert key not in self.executing assert key not in self.ready if key in self.dep_state: - assert self.dep_state[key] == 'memory' + assert self.dep_state[key] == "memory" def validate_key_executing(self, key): assert key in self.executing assert key not in self.data assert key not in self.waiting_for_data - assert all(dep in self.data or dep in self.actors for dep in self.dependencies[key]) + assert all( + dep in self.data or dep in self.actors for dep in self.dependencies[key] + ) def validate_key_ready(self, key): assert key in pluck(1, self.ready) assert key not in self.data assert key not in self.executing assert key not in self.waiting_for_data - assert all(dep in self.data or dep in self.actors for dep in self.dependencies[key]) + assert all( + dep in self.data or dep in self.actors for dep in self.dependencies[key] + ) def validate_key_waiting(self, key): assert key not in self.data @@ -2374,18 +2586,19 @@ def validate_key_waiting(self, key): def validate_key(self, key): try: state = self.task_state[key] - if state == 'memory': + if state == "memory": self.validate_key_memory(key) - elif state == 'waiting': + elif state == "waiting": self.validate_key_waiting(key) - elif state == 'ready': + elif state == "ready": self.validate_key_ready(key) - elif state == 'executing': + elif state == "executing": self.validate_key_executing(key) except Exception as e: logger.exception(e) if LOG_PDB: import pdb + pdb.set_trace() raise @@ -2407,16 +2620,16 @@ def validate_dep_memory(self, dep): assert dep in self.nbytes assert dep in self.types if dep in self.task_state: - assert self.task_state[dep] == 'memory' + assert self.task_state[dep] == "memory" def validate_dep(self, dep): try: state = self.dep_state[dep] - if state == 'waiting': + if state == "waiting": self.validate_dep_waiting(dep) - elif state == 'flight': + elif state == "flight": self.validate_dep_flight(dep) - elif state == 'memory': + elif state == "memory": self.validate_dep_memory(dep) else: raise ValueError("Unknown dependent state", state) @@ -2424,11 +2637,12 @@ def validate_dep(self, dep): logger.exception(e) if LOG_PDB: import pdb + pdb.set_trace() raise def validate_state(self): - if self.status != 'running': + if self.status != "running": return try: for key, workers in self.who_has.items(): @@ -2448,12 +2662,14 @@ def validate_state(self): for key, deps in self.waiting_for_data.items(): if key not in self.data_needed: for dep in deps: - assert (dep in self.in_flight_tasks or - dep in self._missing_dep_flight or - self.who_has[dep].issubset(self.in_flight_workers)) + assert ( + dep in self.in_flight_tasks + or dep in self._missing_dep_flight + or self.who_has[dep].issubset(self.in_flight_workers) + ) for key in self.tasks: - if self.task_state[key] == 'memory': + if self.task_state[key] == "memory": assert isinstance(self.nbytes[key], int) assert key not in self.waiting_for_data assert key in self.data or key in self.actors @@ -2462,6 +2678,7 @@ def validate_state(self): logger.exception(e) if LOG_PDB: import pdb + pdb.set_trace() raise @@ -2488,26 +2705,34 @@ def _get_client(self, timeout=3): """ try: from .client import default_client + client = default_client() except ValueError: # no clients found, need to make a new one pass else: - if (client.scheduler and client.scheduler.address == self.scheduler.address - or client._start_arg == self.scheduler.address): + if ( + client.scheduler + and client.scheduler.address == self.scheduler.address + or client._start_arg == self.scheduler.address + ): self._client = client if not self._client: from .client import Client + asynchronous = self.loop is IOLoop.current() - self._client = Client(self.scheduler, loop=self.loop, - security=self.security, - set_as_default=True, - asynchronous=asynchronous, - direct_to_workers=True, - name='worker', - timeout=timeout) + self._client = Client( + self.scheduler, + loop=self.loop, + security=self.security, + set_as_default=True, + asynchronous=asynchronous, + direct_to_workers=True, + name="worker", + timeout=timeout, + ) if not asynchronous: - assert self._client.status == 'running' + assert self._client.status == "running" return self._client def get_current_task(self): @@ -2551,7 +2776,7 @@ def get_worker(): worker_client """ try: - return thread_state.execution_state['worker'] + return thread_state.execution_state["worker"] except AttributeError: for ref in _global_workers[::-1]: worker = ref() @@ -2608,11 +2833,13 @@ def get_client(address=None, timeout=3, resolve_address=True): return worker._get_client(timeout=timeout) from .client import _get_global_client + client = _get_global_client() # TODO: assumes the same scheduler if client and (not address or client.scheduler.address == address): return client elif address: from .client import Client + return Client(address, timeout=timeout) else: raise ValueError("No global client found and no address provided") @@ -2644,8 +2871,11 @@ def secede(): worker = get_worker() tpe_secede() # have this thread secede from the thread pool duration = time() - thread_state.start_time - worker.loop.add_callback(worker.maybe_transition_long_running, - thread_state.key, compute_duration=duration) + worker.loop.add_callback( + worker.maybe_transition_long_running, + thread_state.key, + compute_duration=duration, + ) class Reschedule(Exception): @@ -2660,13 +2890,14 @@ class Reschedule(Exception): load across the cluster has significantly changed since first scheduling the task. """ + pass def parse_memory_limit(memory_limit, ncores, total_cores=_ncores): if memory_limit is None: return None - if memory_limit == 'auto': + if memory_limit == "auto": memory_limit = int(TOTAL_MEMORY * min(1, ncores / total_cores)) with ignoring(ValueError, TypeError): x = float(memory_limit) @@ -2680,8 +2911,15 @@ def parse_memory_limit(memory_limit, ncores, total_cores=_ncores): @gen.coroutine -def get_data_from_worker(rpc, keys, worker, who=None, max_connections=None, - serializers=None, deserializers=None): +def get_data_from_worker( + rpc, + keys, + worker, + who=None, + max_connections=None, + serializers=None, + deserializers=None, +): """ Get keys from worker The worker has a two step handshake to acknowledge when data has been fully @@ -2700,18 +2938,22 @@ def get_data_from_worker(rpc, keys, worker, who=None, max_connections=None, comm = yield rpc.connect(worker) try: - response = yield send_recv(comm, - serializers=serializers, - deserializers=deserializers, - op='get_data', keys=keys, who=who, - max_connections=max_connections) + response = yield send_recv( + comm, + serializers=serializers, + deserializers=deserializers, + op="get_data", + keys=keys, + who=who, + max_connections=max_connections, + ) try: - status = response['status'] + status = response["status"] except KeyError: raise ValueError("Unexpected response", response) else: - if status == 'OK': - yield comm.write('OK') + if status == "OK": + yield comm.write("OK") finally: rpc.reuse(worker, comm) @@ -2794,14 +3036,12 @@ def dumps_task(task): """ if istask(task): if task[0] is apply and not any(map(_maybe_complex, task[2:])): - d = {'function': dumps_function(task[1]), - 'args': warn_dumps(task[2])} + d = {"function": dumps_function(task[1]), "args": warn_dumps(task[2])} if len(task) == 4: - d['kwargs'] = warn_dumps(task[3]) + d["kwargs"] = warn_dumps(task[3]) return d elif not any(map(_maybe_complex, task[1:])): - return {'function': dumps_function(task[0]), - 'args': warn_dumps(task[1:])} + return {"function": dumps_function(task[0]), "args": warn_dumps(task[1:])} return to_serialize(task) @@ -2815,21 +3055,31 @@ def warn_dumps(obj, dumps=pickle.dumps, limit=1e6): _warn_dumps_warned[0] = True s = str(obj) if len(s) > 70: - s = s[:50] + ' ... ' + s[-15:] - warnings.warn("Large object of size %s detected in task graph: \n" - " %s\n" - "Consider scattering large objects ahead of time\n" - "with client.scatter to reduce scheduler burden and \n" - "keep data on workers\n\n" - " future = client.submit(func, big_data) # bad\n\n" - " big_future = client.scatter(big_data) # good\n" - " future = client.submit(func, big_future) # good" - % (format_bytes(len(b)), s)) + s = s[:50] + " ... " + s[-15:] + warnings.warn( + "Large object of size %s detected in task graph: \n" + " %s\n" + "Consider scattering large objects ahead of time\n" + "with client.scatter to reduce scheduler burden and \n" + "keep data on workers\n\n" + " future = client.submit(func, big_data) # bad\n\n" + " big_future = client.scatter(big_data) # good\n" + " future = client.submit(func, big_future) # good" + % (format_bytes(len(b)), s) + ) return b -def apply_function(function, args, kwargs, execution_state, key, - active_threads, active_threads_lock, time_delay): +def apply_function( + function, + args, + kwargs, + execution_state, + key, + active_threads, + active_threads_lock, + time_delay, +): """ Run a function, collect information Returns @@ -2847,26 +3097,29 @@ def apply_function(function, args, kwargs, execution_state, key, result = function(*args, **kwargs) except Exception as e: msg = error_message(e) - msg['op'] = 'task-erred' - msg['actual-exception'] = e + msg["op"] = "task-erred" + msg["actual-exception"] = e else: - msg = {'op': 'task-finished', - 'status': 'OK', - 'result': result, - 'nbytes': sizeof(result), - 'type': type(result) if result is not None else None} + msg = { + "op": "task-finished", + "status": "OK", + "result": result, + "nbytes": sizeof(result), + "type": type(result) if result is not None else None, + } finally: end = time() - msg['start'] = start + time_delay - msg['stop'] = end + time_delay - msg['thread'] = ident + msg["start"] = start + time_delay + msg["stop"] = end + time_delay + msg["thread"] = ident with active_threads_lock: del active_threads[ident] return msg -def apply_function_actor(function, args, kwargs, execution_state, key, - active_threads, active_threads_lock): +def apply_function_actor( + function, args, kwargs, execution_state, key, active_threads, active_threads_lock +): """ Run a function, collect information Returns @@ -2894,6 +3147,7 @@ def get_msg_safe_str(msg): allowing for some arguments to raise exceptions during conversion and ignoring them. """ + class Repr(object): def __init__(self, f, val): self._f = f @@ -2924,7 +3178,7 @@ def convert_args_to_str(args, max_len=None): strs[i] = sarg length += len(sarg) + 2 if max_len is not None and length > max_len: - return "({}".format(", ".join(strs[:i + 1]))[:max_len] + return "({}".format(", ".join(strs[: i + 1]))[:max_len] else: return "({})".format(", ".join(strs)) @@ -2944,7 +3198,7 @@ def convert_kwargs_to_str(kwargs, max_len=None): strs[i] = skwarg length += len(skwarg) + 2 if max_len is not None and length > max_len: - return "{{{}".format(", ".join(strs[:i + 1]))[:max_len] + return "{{{}".format(", ".join(strs[: i + 1]))[:max_len] else: return "{{{}}}".format(", ".join(strs)) @@ -2959,17 +3213,19 @@ def run(server, comm, function, args=(), kwargs={}, is_coro=None, wait=True): if is_coro is None: is_coro = iscoroutinefunction(function) else: - warnings.warn("The is_coro= parameter is deprecated. " - "We now automatically detect coroutines/async functions") + warnings.warn( + "The is_coro= parameter is deprecated. " + "We now automatically detect coroutines/async functions" + ) assert wait or is_coro, "Combination not supported" if args: args = pickle.loads(args) if kwargs: kwargs = pickle.loads(kwargs) - if has_arg(function, 'dask_worker'): - kwargs['dask_worker'] = server - if has_arg(function, 'dask_scheduler'): - kwargs['dask_scheduler'] = server + if has_arg(function, "dask_worker"): + kwargs["dask_worker"] = server + if has_arg(function, "dask_scheduler"): + kwargs["dask_scheduler"] = server logger.info("Run out-of-band function %r", funcname(function)) try: if not is_coro: @@ -2982,18 +3238,15 @@ def run(server, comm, function, args=(), kwargs={}, is_coro=None, wait=True): result = None except Exception as e: - logger.warning(" Run Failed\n" - "Function: %s\n" - "args: %s\n" - "kwargs: %s\n", - str(funcname(function))[:1000], - convert_args_to_str(args, max_len=1000), - convert_kwargs_to_str(kwargs, max_len=1000), exc_info=True) + logger.warning( + " Run Failed\n" "Function: %s\n" "args: %s\n" "kwargs: %s\n", + str(funcname(function))[:1000], + convert_args_to_str(args, max_len=1000), + convert_kwargs_to_str(kwargs, max_len=1000), + exc_info=True, + ) response = error_message(e) else: - response = { - 'status': 'OK', - 'result': to_serialize(result), - } + response = {"status": "OK", "result": to_serialize(result)} raise Return(response) diff --git a/distributed/worker_client.py b/distributed/worker_client.py index 50cf6be25a5..ff6294430b5 100644 --- a/distributed/worker_client.py +++ b/distributed/worker_client.py @@ -44,7 +44,7 @@ def worker_client(timeout=3, separate_thread=True): client = get_client(timeout=timeout) if separate_thread: secede() # have this thread secede from the thread pool - worker.loop.add_callback(worker.transition, thread_state.key, 'long-running') + worker.loop.add_callback(worker.transition, thread_state.key, "long-running") yield client From cb6ed57573ef171988b372843489e12c8e5c5b6b Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Tue, 16 Apr 2019 19:32:51 +0200 Subject: [PATCH 0237/1550] Integrate stacktrace for low-level profiling (#2575) --- continuous_integration/travis/install.sh | 5 + distributed/bokeh/tests/test_components.py | 4 +- distributed/distributed.yaml | 2 + distributed/profile.py | 113 ++++++++++++++++++++- distributed/tests/test_profile.py | 46 ++++++++- distributed/worker.py | 16 ++- 6 files changed, 178 insertions(+), 8 deletions(-) diff --git a/continuous_integration/travis/install.sh b/continuous_integration/travis/install.sh index 8e9d1ff7f12..bba69dd3ac8 100644 --- a/continuous_integration/travis/install.sh +++ b/continuous_integration/travis/install.sh @@ -54,6 +54,11 @@ conda install -q \ tornado=$TORNADO \ $PACKAGES +# For low-level profiler, install libunwind and stacktrace from conda-forge +# For stacktrace we use --no-deps to avoid upgrade of python +conda install -c defaults -c conda-forge libunwind +conda install --no-deps -c defaults -c numba -c conda-forge stacktrace + pip install -q pytest-repeat pytest-faulthandler pip install -q git+https://github.com/dask/dask.git --upgrade --no-deps diff --git a/distributed/bokeh/tests/test_components.py b/distributed/bokeh/tests/test_components.py index 4f4df92f6cd..028f209b41a 100644 --- a/distributed/bokeh/tests/test_components.py +++ b/distributed/bokeh/tests/test_components.py @@ -30,10 +30,10 @@ def test_basic(Component): @gen_cluster(client=True, check_new_threads=False) def test_profile_plot(c, s, a, b): p = ProfilePlot() - assert len(p.source.data["left"]) <= 1 + assert not p.source.data["left"] yield c.map(slowinc, range(10), delay=0.05) p.update(a.profile_recent) - assert len(p.source.data["left"]) > 1 + assert len(p.source.data["left"]) >= 1 @gen_cluster(client=True, check_new_threads=False) diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index 156c90a127e..d625a103fe8 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -35,6 +35,8 @@ distributed: profile: interval: 10ms # Time between statistical profiling queries cycle: 1000ms # Time between starting new profile + low-level: False # Whether or not to include low-level functions + # Requires https://github.com/numba/stacktrace # Fractions of worker memory at which we take action to avoid memory blowup # Set any of the lower three values to False to turn off the behavior entirely diff --git a/distributed/profile.py b/distributed/profile.py index 385c7449e75..e240a872fb4 100644 --- a/distributed/profile.py +++ b/distributed/profile.py @@ -178,6 +178,7 @@ def plot_data(state, profile_interval=0.010): See Also -------- + plot_figure distributed.bokeh.components.ProfilePlot """ starts = [] @@ -213,9 +214,14 @@ def traverse(state, start, stop, height): ident = state["identifier"] try: - colors.append(color_of(desc["filename"])) + fn = desc["filename"] except IndexError: colors.append("gray") + else: + if fn == "": + colors.append("lightgray") + else: + colors.append(color_of(fn)) delta = (stop - start) / state["count"] @@ -274,6 +280,29 @@ def watch( omit=None, stop=lambda: False, ): + """ Gather profile information on a particular thread + + This starts a new thread to watch a particular thread and returns a deque + that holds periodic profile information. + + Parameters + ---------- + thread_id: int + interval: str + Time per sample + cycle: str + Time per refreshing to a new profile state + maxlen: int + Passed onto deque, maximum number of periods + omit: str + Don't include entries that start with this filename + stop: callable + Function to call to see if we should stop + + Returns + ------- + deque + """ if thread_id is None: thread_id = get_thread_identity() @@ -298,6 +327,17 @@ def watch( def get_profile(history, recent=None, start=None, stop=None, key=None): + """ Collect profile information from a sequence of profile states + + Parameters + ---------- + history: Sequence[Tuple[time, Dict]] + A list or deque of profile states + recent: dict + The most recent accumulating state + start: time + stop: time + """ now = time() if start is None: istart = 0 @@ -329,6 +369,15 @@ def get_profile(history, recent=None, start=None, stop=None, key=None): def plot_figure(data, **kwargs): + """ Plot profile data using Bokeh + + This takes the output from the function ``plot_data`` and produces a Bokeh + figure + + See Also + -------- + plot_data + """ from bokeh.plotting import ColumnDataSource, figure from bokeh.models import HoverTool @@ -388,3 +437,65 @@ def plot_figure(data, **kwargs): fig.grid.visible = False return fig, source + + +def _remove_py_stack(frames): + for entry in frames: + if entry.is_python: + break + yield entry + + +def llprocess(frames, child, state): + """ Add counts from low level profile information onto existing state + + This uses the ``stacktrace`` module to collect low level stack trace + information and place it onto the given sttate. + + It is configured with the ``distributed.worker.profile.low-level`` config + entry. + + See Also + -------- + process + ll_get_stack + """ + if not frames: + return + frame = frames.pop() + if frames: + state = llprocess(frames, frame, state) + + addr = hex(frame.addr - frame.offset) + ident = ";".join(map(str, (frame.name, "", addr))) + try: + d = state["children"][ident] + except KeyError: + d = { + "count": 0, + "description": { + "filename": "", + "name": frame.name, + "line_number": 0, + "line": str(frame), + }, + "children": {}, + "identifier": ident, + } + state["children"][ident] = d + + state["count"] += 1 + + if child is not None: + return d + else: + d["count"] += 1 + + +def ll_get_stack(tid): + """ Collect low level stack information from thread id """ + from stacktrace import get_thread_stack + + frames = get_thread_stack(tid, show_python=False) + llframes = list(_remove_py_stack(frames))[::-1] + return llframes diff --git a/distributed/tests/test_profile.py b/distributed/tests/test_profile.py index 57a7ca657e4..ee49f130027 100644 --- a/distributed/tests/test_profile.py +++ b/distributed/tests/test_profile.py @@ -1,11 +1,22 @@ +import pytest import sys import time from toolz import first import threading -from distributed.compatibility import get_thread_identity +from distributed.compatibility import get_thread_identity, WINDOWS from distributed import metrics -from distributed.profile import process, merge, create, call_stack, identifier, watch +from distributed.profile import ( + process, + merge, + create, + call_stack, + identifier, + watch, + llprocess, + ll_get_stack, + plot_data, +) def test_basic(): @@ -44,6 +55,37 @@ def test_f(): assert g["count"] < h["count"] assert 95 < g["count"] + h["count"] <= 100 + pd = plot_data(state) + assert len(set(map(len, pd.values()))) == 1 # all same length + assert len(set(pd["color"])) > 1 # different colors + + +@pytest.mark.skipif( + WINDOWS, reason="no low-level profiler support for Windows available" +) +def test_basic_low_level(): + pytest.importorskip("stacktrace") + + state = create() + + for i in range(100): + time.sleep(0.02) + frame = sys._current_frames()[threading.get_ident()] + llframes = {threading.get_ident(): ll_get_stack(threading.get_ident())} + for f in llframes.values(): + if f is not None: + llprocess(f, None, state) + + assert state["count"] == 100 + children = state.get("children") + assert children + expected = "" + for k, v in zip(children.keys(), children.values()): + desc = v.get("description") + assert desc + filename = desc.get("filename") + assert expected in k and filename == expected + def test_merge(): a1 = { diff --git a/distributed/worker.py b/distributed/worker.py index b9ed6c5a59d..9f940f02f93 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -294,6 +294,7 @@ def __init__( extensions=None, metrics=None, data=None, + low_level_profiler=dask.config.get("distributed.worker.profile.low-level"), **kwargs ): self.tasks = dict() @@ -511,6 +512,8 @@ def __init__( self.service_specs = services or {} self.metrics = dict(metrics) if metrics else {} + self.low_level_profiler = low_level_profiler + handlers = { "gather": self.gather, "run": self.run, @@ -2464,15 +2467,22 @@ def trigger_profile(self): active_threads = self.active_threads.copy() frames = sys._current_frames() frames = {ident: frames[ident] for ident in active_threads} + llframes = {} + if self.low_level_profiler: + llframes = {ident: profile.ll_get_stack(ident) for ident in active_threads} for ident, frame in frames.items(): if frame is not None: key = key_split(active_threads[ident]) - profile.process( - frame, None, self.profile_recent, stop="distributed/worker.py" + llframe = llframes.get(ident) + + state = profile.process( + frame, True, self.profile_recent, stop="distributed/worker.py" ) + profile.llprocess(llframe, None, state) profile.process( - frame, None, self.profile_keys[key], stop="distributed/worker.py" + frame, True, self.profile_keys[key], stop="distributed/worker.py" ) + stop = time() if self.digests is not None: self.digests["profile-duration"].add(stop - start) From 33df62d5e6369e50f1c9e86381ea9f1f346ca931 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 18 Apr 2019 16:10:40 -0500 Subject: [PATCH 0238/1550] Allow Python 2 testing failures in Travis CI (#2615) * Allow Python 2 build to fail in travis.ci * Use the --check flag for black --- .travis.yml | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/.travis.yml b/.travis.yml index 96331468b55..23daef096ca 100644 --- a/.travis.yml +++ b/.travis.yml @@ -21,10 +21,9 @@ matrix: #- os: osx #env: PYTHON=3.6 RUNSLOW=false # Together with fast_finish, allow build to be marked successful before the OS X job finishes - #allow_failures: - #- os: osx + allow_failures: ## This needs to be the exact same line as above - #env: PYTHON=3.6 RUNSLOW=false + env: PYTHON=2.7 TESTS=true PACKAGES="python-blosc futures faulthandler lz4" install: - if [[ $TESTS == true ]]; then source continuous_integration/travis/install.sh ; fi @@ -32,7 +31,7 @@ install: script: - if [[ $TESTS == true ]]; then source continuous_integration/travis/run_tests.sh ; fi - if [[ $LINT == true ]]; then pip install flake8 ; flake8 distributed ; fi - - if [[ $LINT == true ]]; then pip install black; black distributed ; fi + - if [[ $LINT == true ]]; then pip install black; black distributed --check; fi after_success: - if [[ $COVERAGE == true ]]; then coverage report; pip install -q coveralls ; coveralls ; fi From c355744f988a23b7dbc388d93c0fd09de67e75cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Est=C3=A8ve?= Date: Fri, 19 Apr 2019 15:39:01 +0200 Subject: [PATCH 0239/1550] Fix parameter name in LocalCluster docstring (#2626) --- distributed/deploy/local.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/deploy/local.py b/distributed/deploy/local.py index 22a796bbb78..92de1b1c799 100644 --- a/distributed/deploy/local.py +++ b/distributed/deploy/local.py @@ -62,7 +62,7 @@ class LocalCluster(Cluster): asynchronous: bool (False by default) Set to True if using this cluster within async/await functions or within Tornado gen.coroutines. This should remain False for normal use. - kwargs: dict + worker_kwargs: dict Extra worker arguments, will be passed to the Worker constructor. blocked_handlers: List[str] A list of strings specifying a blacklist of handlers to disallow on the Scheduler, From 3b84c3e350f4b72d428c2bc44e5ce95652c63e5e Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sat, 20 Apr 2019 10:43:24 -0700 Subject: [PATCH 0240/1550] Add number of trials to diskutils test (#2630) --- distributed/tests/test_diskutils.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/distributed/tests/test_diskutils.py b/distributed/tests/test_diskutils.py index 8bf4000178e..d5abf5c1dee 100644 --- a/distributed/tests/test_diskutils.py +++ b/distributed/tests/test_diskutils.py @@ -19,14 +19,20 @@ from distributed.utils_test import captured_logger, slow -def assert_directory_contents(dir_path, expected): +def assert_directory_contents(dir_path, expected, trials=2): expected = [os.path.join(dir_path, p) for p in expected] - actual = [ - os.path.join(dir_path, p) - for p in os.listdir(dir_path) - if p not in ("global.lock", "purge.lock") - ] - assert sorted(actual) == sorted(expected) + for i in range(trials): + actual = [ + os.path.join(dir_path, p) + for p in os.listdir(dir_path) + if p not in ("global.lock", "purge.lock") + ] + if sorted(actual) == sorted(expected): + break + else: + sleep(0.5) + else: + assert sorted(actual) == sorted(expected) def test_workdir_simple(tmpdir): @@ -82,10 +88,10 @@ def test_two_workspaces_in_same_directory(tmpdir): del ws del b gc.collect() - assert_contents(["aa", "aa.dirlock"]) + assert_contents(["aa", "aa.dirlock"], trials=5) del a gc.collect() - assert_contents([]) + assert_contents([], trials=5) def test_workspace_process_crash(tmpdir): From 291a140787168ac19625cd142f5189f21e179d00 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sat, 20 Apr 2019 10:44:08 -0700 Subject: [PATCH 0241/1550] Remove Python 2.7 from testing matrix (#2631) --- .travis.yml | 8 -------- 1 file changed, 8 deletions(-) diff --git a/.travis.yml b/.travis.yml index 23daef096ca..bcc09351eff 100644 --- a/.travis.yml +++ b/.travis.yml @@ -6,7 +6,6 @@ dist: trusty env: matrix: - - PYTHON=2.7 TESTS=true PACKAGES="python-blosc futures faulthandler lz4" - PYTHON=3.5.4 TESTS=true COVERAGE=true PACKAGES="python-blosc lz4" CRICK=true - PYTHON=3.6 TESTS=true PACKAGES="scikit-learn lz4" TORNADO=5 - PYTHON=3.7 TESTS=true PACKAGES="scikit-learn python-snappy python-blosc" TORNADO=6 @@ -15,15 +14,8 @@ matrix: fast_finish: true include: - os: linux - # Using Travis-CI's python makes job faster by not downloading miniconda python: 3.6 env: LINT=true - #- os: osx - #env: PYTHON=3.6 RUNSLOW=false - # Together with fast_finish, allow build to be marked successful before the OS X job finishes - allow_failures: - ## This needs to be the exact same line as above - env: PYTHON=2.7 TESTS=true PACKAGES="python-blosc futures faulthandler lz4" install: - if [[ $TESTS == true ]]; then source continuous_integration/travis/install.sh ; fi From f459af1637c1f1cb3bf75e6d710b99bcf3886190 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Sat, 20 Apr 2019 19:49:34 +0200 Subject: [PATCH 0242/1550] Add worker_class argument to LocalCluster (#2625) --- distributed/deploy/local.py | 12 +++++++---- distributed/deploy/tests/test_local.py | 30 ++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/distributed/deploy/local.py b/distributed/deploy/local.py index 92de1b1c799..4e63646ba27 100644 --- a/distributed/deploy/local.py +++ b/distributed/deploy/local.py @@ -74,6 +74,8 @@ class LocalCluster(Cluster): Protocol to use like ``tcp://``, ``tls://``, ``inproc://`` This defaults to sensible choice given other keyword arguments like ``processes`` and ``security`` + worker_class: Worker + Worker class used to instantiate workers from. Examples -------- @@ -115,6 +117,7 @@ def __init__( security=None, protocol=None, blocked_handlers=None, + worker_class=None, **worker_kwargs ): if start is not None: @@ -203,6 +206,10 @@ def __init__( if security: self.worker_kwargs["security"] = security + if not worker_class: + worker_class = Worker if not processes else Nanny + self.worker_class = worker_class + self.start(ip=ip, n_workers=n_workers) clusters_to_close.add(self) @@ -279,12 +286,9 @@ def _start_worker(self, death_timeout=60, **kwargs): return if self.processes: - W = Nanny kwargs["quiet"] = True - else: - W = Worker - w = yield W( + w = yield self.worker_class( self.scheduler.address, loop=self.loop, death_timeout=death_timeout, diff --git a/distributed/deploy/tests/test_local.py b/distributed/deploy/tests/test_local.py index ee2d48c2df3..ab378ba0e8f 100644 --- a/distributed/deploy/tests/test_local.py +++ b/distributed/deploy/tests/test_local.py @@ -742,5 +742,35 @@ def test_protocol_ip(loop): assert cluster.scheduler.address.startswith("tcp://127.0.0.2") +class MyWorker(Worker): + pass + + +def test_worker_class_worker(loop): + with LocalCluster( + n_workers=2, + loop=loop, + worker_class=MyWorker, + processes=False, + scheduler_port=0, + dashboard_address=None, + ) as cluster: + assert all(isinstance(w, MyWorker) for w in cluster.workers) + + +def test_worker_class_nanny(loop): + class MyNanny(Nanny): + pass + + with LocalCluster( + n_workers=2, + loop=loop, + worker_class=MyNanny, + scheduler_port=0, + dashboard_address=None, + ) as cluster: + assert all(isinstance(w, MyNanny) for w in cluster.workers) + + if sys.version_info >= (3, 5): from distributed.deploy.tests.py3_test_deploy import * # noqa F401 From 7de97bdbdc97599ae7c4bd8d9f3851463d7eedc1 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sun, 21 Apr 2019 08:25:40 -0700 Subject: [PATCH 0243/1550] Add interface= keyword to LocalCluster (#2629) This is useful when you want to use a particular network interface, like infiniband. Fixes https://github.com/dask/distributed/issues/2618 --- distributed/deploy/local.py | 10 +++++++++- distributed/tests/test_client.py | 5 +++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/distributed/deploy/local.py b/distributed/deploy/local.py index 4e63646ba27..73ddde8bdc1 100644 --- a/distributed/deploy/local.py +++ b/distributed/deploy/local.py @@ -15,6 +15,7 @@ from ..compatibility import get_thread_identity from ..core import CommClosedError from ..utils import ( + get_ip_interface, sync, ignoring, All, @@ -74,6 +75,8 @@ class LocalCluster(Cluster): Protocol to use like ``tcp://``, ``tls://``, ``inproc://`` This defaults to sensible choice given other keyword arguments like ``processes`` and ``security`` + interface: str (optional) + Network interface to use. Defaults to lo/localhost worker_class: Worker Worker class used to instantiate workers from. @@ -117,6 +120,7 @@ def __init__( security=None, protocol=None, blocked_handlers=None, + interface=None, worker_class=None, **worker_kwargs ): @@ -155,6 +159,7 @@ def __init__( self.silence_logs = silence_logs self._asynchronous = asynchronous self.security = security + self.interface = interface services = services or {} worker_services = worker_services or {} if silence_logs: @@ -262,7 +267,10 @@ def _start(self, ip=None, n_workers=0): address = self.protocol else: if ip is None: - ip = "127.0.0.1" + if self.interface: + ip = get_ip_interface(self.interface) + else: + ip = "127.0.0.1" if "://" in ip: address = ip diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 634834bf671..a9d4f21fd81 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -5061,8 +5061,9 @@ def test_call_stack_future(c, s, a, b): @gen_cluster([("127.0.0.1", 4)] * 2, client=True) def test_call_stack_all(c, s, a, b): - future = c.submit(slowinc, 1, delay=0.5) - yield gen.sleep(0.1) + future = c.submit(slowinc, 1, delay=0.8) + while not a.executing and not b.executing: + yield gen.sleep(0.01) result = yield c.call_stack() w = a if a.executing else b assert list(result) == [w.address] From a3d2016a4fab9ef14ab5be0b0b722b59365b59ce Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sun, 21 Apr 2019 08:25:57 -0700 Subject: [PATCH 0244/1550] Increase GC thresholds (#2624) Fixes https://github.com/dask/distributed/issues/1653 --- distributed/cli/dask_scheduler.py | 3 +++ distributed/cli/dask_worker.py | 4 ++++ 2 files changed, 7 insertions(+) diff --git a/distributed/cli/dask_scheduler.py b/distributed/cli/dask_scheduler.py index 0e8415ac132..3b0aa5b4c70 100755 --- a/distributed/cli/dask_scheduler.py +++ b/distributed/cli/dask_scheduler.py @@ -3,6 +3,7 @@ import atexit import dask import logging +import gc import os import shutil import sys @@ -137,6 +138,8 @@ def main( tls_key, dashboard_address, ): + g0, g1, g2 = gc.get_threshold() # https://github.com/dask/distributed/issues/1653 + gc.set_threshold(g0 * 3, g1 * 3, g2 * 3) enable_proctitle_on_current() enable_proctitle_on_children() diff --git a/distributed/cli/dask_worker.py b/distributed/cli/dask_worker.py index 0eb5a7973fb..73cb9970924 100755 --- a/distributed/cli/dask_worker.py +++ b/distributed/cli/dask_worker.py @@ -2,6 +2,7 @@ import atexit import logging +import gc import os from sys import exit import warnings @@ -207,6 +208,9 @@ def main( tls_key, dashboard_address, ): + g0, g1, g2 = gc.get_threshold() # https://github.com/dask/distributed/issues/1653 + gc.set_threshold(g0 * 3, g1 * 3, g2 * 3) + enable_proctitle_on_current() enable_proctitle_on_children() From 0c8918b5e6ff63857f89c76468a5ec2d2aa005ae Mon Sep 17 00:00:00 2001 From: Michael Delgado Date: Wed, 24 Apr 2019 14:26:19 -0700 Subject: [PATCH 0245/1550] Adaptive: recommend close workers when any are idle (#2330) * adaptive: recommend close workers if idle * adaptive: check for idle workers before recommending scale up * adaptive: check for waiting tasks in should_scale_up * revert to changing needs_cpu only * performance bump in adaptive.needs_cpu by looping through workers * switch to checking number of cores, not workers * apply black * remove xfail * flake --- distributed/deploy/adaptive.py | 24 ++++++++++++++++++----- distributed/deploy/tests/test_adaptive.py | 2 -- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/distributed/deploy/adaptive.py b/distributed/deploy/adaptive.py index 890e30c027f..8c260609638 100644 --- a/distributed/deploy/adaptive.py +++ b/distributed/deploy/adaptive.py @@ -135,10 +135,11 @@ def needs_cpu(self): Notes ----- Returns ``True`` if the occupancy per core is some factor larger - than ``startup_cost``. + than ``startup_cost`` and the number of tasks exceeds the number of + cores """ total_occupancy = self.scheduler.total_occupancy - total_cores = sum([ws.ncores for ws in self.scheduler.workers.values()]) + total_cores = self.scheduler.total_ncores if total_occupancy / (total_cores + 1e-9) > self.startup_cost * 2: logger.info( @@ -146,9 +147,22 @@ def needs_cpu(self): total_occupancy, total_cores, ) - return True - else: - return False + + tasks_processing = 0 + + for w in self.scheduler.workers.values(): + tasks_processing += len(w.processing) + + if tasks_processing > total_cores: + logger.info( + "pending tasks exceed number of cores " "[%d tasks / %d cores]", + tasks_processing, + total_cores, + ) + + return True + + return False def needs_memory(self): """ diff --git a/distributed/deploy/tests/test_adaptive.py b/distributed/deploy/tests/test_adaptive.py index 1d8a48bf7fc..50c4f0a45a3 100644 --- a/distributed/deploy/tests/test_adaptive.py +++ b/distributed/deploy/tests/test_adaptive.py @@ -2,7 +2,6 @@ from time import sleep -import pytest from toolz import frequencies, pluck from tornado import gen from tornado.ioloop import IOLoop @@ -331,7 +330,6 @@ def test_adapt_down(): yield cluster.close() -@pytest.mark.xfail(reason="we currently only judge occupancy, not ntasks") @gen_test(timeout=30) def test_no_more_workers_than_tasks(): loop = IOLoop.current() From 7461488d6ecf9226870b159314565ea2ca477d28 Mon Sep 17 00:00:00 2001 From: Brett Randall Date: Mon, 29 Apr 2019 23:20:15 +1000 Subject: [PATCH 0246/1550] Updated logging module doc links from docs.python.org/2 to docs.python.org/3. (#2635) Signed-off-by: Brett Randall --- distributed/config.py | 4 ++-- docs/source/configuration.rst | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/distributed/config.py b/distributed/config.py index 4b7b589d58f..5c71cf570c8 100644 --- a/distributed/config.py +++ b/distributed/config.py @@ -103,7 +103,7 @@ def _initialize_logging_old_style(config): def _initialize_logging_new_style(config): """ Initialize logging using logging's "Configuration dictionary schema". - (ref.: https://docs.python.org/2/library/logging.config.html#logging-config-dictschema) + (ref.: https://docs.python.org/3/library/logging.config.html#configuration-dictionary-schema) """ logging.config.dictConfig(config.get("logging")) @@ -111,7 +111,7 @@ def _initialize_logging_new_style(config): def _initialize_logging_file_config(config): """ Initialize logging using logging's "Configuration file format". - (ref.: https://docs.python.org/2/library/logging.config.html#configuration-file-format) + (ref.: https://docs.python.org/3/howto/logging.html#configuring-logging) """ logging.config.fileConfig( config.get("logging-file-config"), disable_existing_loggers=False diff --git a/docs/source/configuration.rst b/docs/source/configuration.rst index 86070de96f9..8967255f526 100644 --- a/docs/source/configuration.rst +++ b/docs/source/configuration.rst @@ -156,7 +156,7 @@ for each logger. It also sets default values for several loggers such as ``distributed`` unless explicitly configured. A more extended format is possible following the :mod:`logging` module's -`Configuration dictionary schema `_. +`Configuration dictionary schema `_. To enable this extended format, there must be a ``version`` sub-key as mandated by the schema. The extended format does not set any default values. @@ -173,7 +173,7 @@ mandated by the schema. The extended format does not set any default values. As an alternative to the two logging settings formats discussed above, you can specify a logging config file. Its format adheres to the :mod:`logging` module's -`Configuration file format `_. +`Configuration file format `_. .. note:: The configuration options `logging-file-config` and `logging` are mutually exclusive. \ No newline at end of file From f62d6310827c387cb004fec8ee7202c62cf46a69 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 29 Apr 2019 09:39:55 -0500 Subject: [PATCH 0247/1550] bump version to 1.27.1 --- docs/source/changelog.rst | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index affad66a759..34fe078fe18 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -1,10 +1,26 @@ Changelog ========= +1.27.1 - 2019-04-29 +------------------- + +- Adaptive: recommend close workers when any are idle (:pr:`2330`) `Michael Delgado`_ +- Increase GC thresholds (:pr:`2624`) `Matthew Rocklin`_ +- Add interface= keyword to LocalCluster (:pr:`2629`) `Matthew Rocklin`_ +- Add worker_class argument to LocalCluster (:pr:`2625`) `Matthew Rocklin`_ +- Remove Python 2.7 from testing matrix (:pr:`2631`) `Matthew Rocklin`_ +- Add number of trials to diskutils test (:pr:`2630`) `Matthew Rocklin`_ +- Fix parameter name in LocalCluster docstring (:pr:`2626`) `Loïc Estève`_ +- Integrate stacktrace for low-level profiling (:pr:`2575`) `Peter Andreas Entschev`_ +- Apply Black to standardize code styling (:pr:`2614`) `Matthew Rocklin`_ +- added missing whitespace to start_worker cmd (:pr:`2613`) `condoratberlin`_ +- Updated logging module doc links from docs.python.org/2 to docs.python.org/3. (:pr:`2635`) `Brett Randall`_ + + 1.27.0 - 2019-04-12 ------------------- - Add basic health endpoints to scheduler and worker bokeh. (#2607) `amerkel2`_ + Add basic health endpoints to scheduler and worker bokeh. (:pr:`2607) `amerkel2`_ - Improved description accuracy of --memory-limit option. (:pr:`2601`) `Brett Randall`_ - Check self.dependencies when looking at dependent tasks in memory (:pr:`2606`) `deepthirajagopalan7`_ - Add RabbitMQ SchedulerPlugin example (:pr:`2604`) `Matt Nicolls`_ @@ -986,3 +1002,6 @@ significantly without many new features. .. _`Brian Chu`: https://github.com/bchu .. _`James Bourbeau`: https://github.com/jrbourbeau .. _`amerkel2`: https://github.com/amerkel2 +.. _`Michael Delgado`: https://github.com/delgadom +.. _`Peter Andreas Entschev`: https://github.com/pentschev +.. _`condoratberlin`: https://github.com/condoratberlin From e0cf7e7300c9dcf10b7440abb1e3efc6cea3a91a Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Mon, 29 Apr 2019 21:33:36 +0200 Subject: [PATCH 0248/1550] Fix deserialization of bytes chunks larger than 64MB (#2637) --- distributed/protocol/serialize.py | 2 +- distributed/protocol/tests/test_serialize.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index 3b0a45c8a6f..4ff0fb47a65 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -473,7 +473,7 @@ def _serialize_bytes(obj): @dask_deserialize.register((bytes, bytearray)) def _deserialize_bytes(header, frames): - return frames[0] + return b"".join(frames) ######################### diff --git a/distributed/protocol/tests/test_serialize.py b/distributed/protocol/tests/test_serialize.py index da43021d550..4f72ec9a538 100644 --- a/distributed/protocol/tests/test_serialize.py +++ b/distributed/protocol/tests/test_serialize.py @@ -192,7 +192,7 @@ def test_empty_loads_deep(): def test_serialize_bytes(): - for x in [1, "abc", np.arange(5)]: + for x in [1, "abc", np.arange(5), b"ab" * int(40e6)]: b = serialize_bytes(x) assert isinstance(b, bytes) y = deserialize_bytes(b) From 8af282651d701569acda55e742f2db741700de9a Mon Sep 17 00:00:00 2001 From: Jim Crist Date: Tue, 30 Apr 2019 09:21:41 -0500 Subject: [PATCH 0249/1550] Use proper address in worker -> nanny comms (#2640) When a worker is shutdown explicitly it notifies the nanny that it should also shutdown. Previously the address it used for this assumed tcp in all cases, this changes that to use the same protocol as the worker (which is currently always the same as the nanny's). This allows `retire_workers` to properly work over TLS. --- distributed/tests/test_tls_functional.py | 14 ++++++++++++++ distributed/worker.py | 7 ++++++- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/distributed/tests/test_tls_functional.py b/distributed/tests/test_tls_functional.py index 6b71941257c..74a9cf3cbd4 100644 --- a/distributed/tests/test_tls_functional.py +++ b/distributed/tests/test_tls_functional.py @@ -10,6 +10,8 @@ from distributed import Nanny, worker_client, Queue from distributed.client import wait +from distributed.metrics import time +from distributed.nanny import Nanny from distributed.utils_test import gen_tls_cluster, inc, double, slowinc, slowadd @@ -157,3 +159,15 @@ def mysum(): future = c.submit(mysum) result = yield future assert result == 30 * 29 + + +@gen_tls_cluster(client=True, Worker=Nanny) +def test_retire_workers(c, s, a, b): + assert set(s.workers) == {a.worker_address, b.worker_address} + yield c.retire_workers(workers=[a.worker_address], close_workers=True) + assert set(s.workers) == {b.worker_address} + + start = time() + while a.status != "closed": + yield gen.sleep(0.01) + assert time() < start + 5 diff --git a/distributed/worker.py b/distributed/worker.py index 9f940f02f93..33010836fdb 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -996,7 +996,12 @@ def _close(self, report=True, timeout=10, nanny=True, executor_wait=True): self.status = "closed" if nanny and "nanny" in self.service_ports: - with self.rpc((self.ip, self.service_ports["nanny"])) as r: + nanny_address = "%s%s:%d" % ( + self.listener.prefix, + self.ip, + self.service_ports["nanny"], + ) + with self.rpc(nanny_address) as r: yield r.terminate() if self.batched_stream and not self.batched_stream.comm.closed(): From 38afa51ca58aa2ea721caa95062fdb9024b0450a Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Tue, 30 Apr 2019 16:46:15 +0200 Subject: [PATCH 0250/1550] Limit test_spill_by_default memory, reenable it (#2633) --- distributed/tests/test_worker.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 05b61a997f4..388414a4448 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -25,7 +25,7 @@ from distributed.client import wait from distributed.scheduler import Scheduler from distributed.metrics import time -from distributed.worker import Worker, error_message, logger, TOTAL_MEMORY +from distributed.worker import Worker, error_message, logger from distributed.utils import tmpfile, format_bytes from distributed.utils_test import ( inc, @@ -446,11 +446,15 @@ def test_Executor(c, s): yield w._close() -@pytest.mark.skip(reason="Leaks a large amount of memory") -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)], timeout=30) +@gen_cluster( + client=True, + ncores=[("127.0.0.1", 1)], + timeout=30, + worker_kwargs={"memory_limit": 10e6}, +) def test_spill_by_default(c, s, w): da = pytest.importorskip("dask.array") - x = da.ones(int(TOTAL_MEMORY * 0.7), chunks=10000000, dtype="u1") + x = da.ones(int(10e6 * 0.7), chunks=1e6, dtype="u1") y = c.persist(x) yield wait(y) assert len(w.data.slow) # something is on disk From 0d115acada91fcfe4a685c7369a8b9736c9364ff Mon Sep 17 00:00:00 2001 From: Jim Crist Date: Tue, 30 Apr 2019 13:11:38 -0500 Subject: [PATCH 0251/1550] Add timeout to Client._reconnect (#2639) Previously if a client lost connection to the scheduler, it would try to reconnect forever. We now use the same timeout as the initial connect. On failure a nice message is logged and the client is shutdown. --- distributed/client.py | 20 ++++++++++++++++---- distributed/tests/test_client.py | 13 +++++++++++++ distributed/utils_test.py | 2 +- 3 files changed, 30 insertions(+), 5 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 96d20a7ece2..b8bb42c3115 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -954,9 +954,10 @@ def _start(self, timeout=no_default, **kwargs): raise gen.Return(self) @gen.coroutine - def _reconnect(self, timeout=0.1): + def _reconnect(self): with log_errors(): assert self.scheduler_comm.comm.closed() + self.status = "connecting" self.scheduler_comm = None @@ -964,12 +965,23 @@ def _reconnect(self, timeout=0.1): st.cancel() self.futures.clear() - while self.status == "connecting": + timeout = self._timeout + deadline = self.loop.time() + timeout + while timeout > 0 and self.status == "connecting": try: - yield self._ensure_connected() + yield self._ensure_connected(timeout=timeout) break except EnvironmentError: - yield gen.sleep(timeout) + # Wait a bit before retrying + yield gen.sleep(0.1) + timeout = deadline - self.loop.time() + else: + logger.error( + "Failed to reconnect to scheduler after %.2f " + "seconds, closing client", + self._timeout, + ) + yield self._close() @gen.coroutine def _ensure_connected(self, timeout=None): diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index a9d4f21fd81..545d3af67a9 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -3704,6 +3704,19 @@ def test_reconnect(loop): c.close() +@gen_cluster(client=True, ncores=[], client_kwargs={"timeout": 0.5}) +def test_reconnect_timeout(c, s): + with captured_logger(logging.getLogger("distributed.client")) as logger: + yield s.close() + start = time() + while c.status != "closed": + yield c._update_scheduler_info() + yield gen.sleep(0.05) + assert time() < start + 5, "Timeout waiting for reconnect to fail" + text = logger.getvalue() + assert "Failed to reconnect" in text + + @slow @pytest.mark.skipif( sys.platform.startswith("win"), reason="num_fds not supported on windows" diff --git a/distributed/utils_test.py b/distributed/utils_test.py index a3f76e4c477..fcda695e2ec 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -974,7 +974,7 @@ def coro(): if s.validate: s.validate_state() finally: - if client: + if client and c.status not in ("closing", "closed"): yield c._close(fast=s.status == "closed") yield end_cluster(s, workers) yield gen.with_timeout( From 2d431399ee32aef85f4ccd386623d599fa1ca50c Mon Sep 17 00:00:00 2001 From: Jim Crist Date: Tue, 30 Apr 2019 13:36:20 -0500 Subject: [PATCH 0252/1550] Add as_completed methods to docs (#2642) Adds ``as_completed`` methods to the docs, and cleans up the existing docstrings slightly. Also adds a new method ``has_ready`` for checking if there are any completed futures ready for processing. --- distributed/client.py | 12 ++++++++---- distributed/tests/test_client.py | 5 +++++ docs/source/api.rst | 4 +++- 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index b8bb42c3115..332a6681508 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -4122,7 +4122,7 @@ def _notify(self): self.thread_condition.notify() @gen.coroutine - def track_future(self, future): + def _track_future(self, future): try: yield _wait(future) except CancelledError: @@ -4148,7 +4148,7 @@ def update(self, futures): if not isinstance(f, Future): raise TypeError("Input must be a future, got %s" % f) self.futures[f] += 1 - self.loop.add_callback(self.track_future, f) + self.loop.add_callback(self._track_future, f) def add(self, future): """ Add a future to the collection @@ -4158,9 +4158,13 @@ def add(self, future): self.update((future,)) def is_empty(self): - """Return True if there no waiting futures, False otherwise""" + """Returns True if there no completed or computing futures""" return not self.count() + def has_ready(self): + """Returns True if there are completed futures available.""" + return not self.queue.empty() + def count(self): """ Return the number of futures yet to be returned @@ -4207,7 +4211,7 @@ def __anext__(self): next = __next__ def next_batch(self, block=True): - """ Get next batch of futures from as_completed iterator + """ Get the next batch of completed futures. Parameters ---------- diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 545d3af67a9..d98d9039686 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -3590,8 +3590,13 @@ def test_as_completed_batches(c, with_results): def test_as_completed_next_batch(c): futures = c.map(slowinc, range(2), delay=0.1) ac = as_completed(futures) + assert not ac.is_empty() assert ac.next_batch(block=False) == [] assert set(ac.next_batch(block=True)).issubset(futures) + while not ac.is_empty(): + assert set(ac.next_batch(block=True)).issubset(futures) + assert ac.is_empty() + assert not ac.has_ready() @gen_test() diff --git a/docs/source/api.rst b/docs/source/api.rst index 33634bff6c7..47933be06d4 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -152,7 +152,9 @@ Future Other ----- -.. autofunction:: as_completed +.. autoclass:: as_completed + :members: + .. autofunction:: distributed.diagnostics.progress .. autofunction:: wait .. autofunction:: fire_and_forget From 6fca31a13525d42ce7b2aa786e3e8a3a8f98cae5 Mon Sep 17 00:00:00 2001 From: Martin Durant Date: Wed, 1 May 2019 09:59:16 -0400 Subject: [PATCH 0253/1550] Set working worker class for dask-ssh (#2646) Fixes #2645 Suggestions welcome on how to test --- distributed/cli/dask_ssh.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/cli/dask_ssh.py b/distributed/cli/dask_ssh.py index df2b1c6fe94..2d98992d969 100755 --- a/distributed/cli/dask_ssh.py +++ b/distributed/cli/dask_ssh.py @@ -96,7 +96,7 @@ ) @click.option( "--remote-dask-worker", - default=None, + default="distributed.cli.dask_worker", type=str, help="Worker to run. Defaults to distributed.cli.dask_worker", ) From 7b470c4cbedcd1b98d271983a2ac1c5a909e1230 Mon Sep 17 00:00:00 2001 From: plbertrand Date: Wed, 1 May 2019 12:39:46 -0400 Subject: [PATCH 0254/1550] Add last worker into KilledWorker exception to help debug (#2610) Fixes #2549 --- distributed/core.py | 2 +- distributed/scheduler.py | 23 +++++++++++++++++++++-- distributed/tests/test_client.py | 4 +++- distributed/tests/test_scheduler.py | 9 +++++++++ 4 files changed, 34 insertions(+), 4 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index e074fa68148..bb8a47c8525 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -959,7 +959,7 @@ def clean_exception(exception, traceback, **kwargs): -------- error_message: create and serialize errors into message """ - if isinstance(exception, bytes): + if isinstance(exception, bytes) or isinstance(exception, bytearray): try: exception = protocol.pickle.loads(exception) except Exception: diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 5e5e2843c2c..9aebcf11b4f 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -260,6 +260,20 @@ def __init__( def host(self): return get_address_host(self.address) + def clean(self): + """ Return a version of this object that is appropriate for serialization """ + ws = WorkerState( + address=self.address, + pid=self.pid, + name=self.name, + ncores=self.ncores, + memory_limit=self.memory_limit, + local_directory=self.local_directory, + services=self.services, + ) + ws.processing = {ts.key for ts in self.processing} + return ws + def __repr__(self): return "" % ( self.address, @@ -1872,7 +1886,9 @@ def remove_worker(self, comm=None, address=None, safe=False, close=True): ts.suspicious += 1 if ts.suspicious > self.allowed_failures: del recommendations[k] - e = pickle.dumps(KilledWorker(k, address)) + e = pickle.dumps( + KilledWorker(task=k, last_worker=ws.clean()), -1 + ) r = self.transition(k, "erred", exception=e, cause=k) recommendations.update(r) @@ -4827,4 +4843,7 @@ def heartbeat_interval(n): class KilledWorker(Exception): - pass + def __init__(self, task, last_worker): + super(KilledWorker, self).__init__(task, last_worker) + self.task = task + self.last_worker = last_worker diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index d98d9039686..ff94fba7787 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -3479,9 +3479,11 @@ def test_get_foo_lost_keys(c, s, u, v, w): @gen_cluster(client=True, Worker=Nanny, check_new_threads=False) def test_bad_tasks_fail(c, s, a, b): f = c.submit(sys.exit, 1) - with pytest.raises(KilledWorker): + with pytest.raises(KilledWorker) as info: yield f + assert info.value.last_worker.services["nanny"] in {a.port, b.port} + def test_get_processing_sync(c, s, a, b): processing = c.processing() diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 02f15e1e1a2..5a68f287a1b 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -1,6 +1,7 @@ from __future__ import print_function, division, absolute_import import cloudpickle +import pickle from collections import defaultdict from datetime import timedelta import json @@ -1496,3 +1497,11 @@ def qux(x): yield gen.sleep(0.1) f = c.submit(bar, x, key="y") yield f + + +@gen_cluster() +def test_workerstate_clean(s, a, b): + ws = s.workers[a.address].clean() + assert ws.address == a.address + b = pickle.dumps(ws) + assert len(b) < 1000 From 1082e3c91ad6576dc8956b86469268b52d4b7938 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 3 May 2019 11:00:12 -0700 Subject: [PATCH 0255/1550] Explain LocalCluster behavior in Client docstring (#2647) --- distributed/client.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/distributed/client.py b/distributed/client.py index 332a6681508..24cf9ce5cf5 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -538,6 +538,9 @@ class resembles executors in ``concurrent.futures`` but also allows the scheduler to serve as intermediary. heartbeat_interval: int Time in milliseconds between heartbeats to scheduler + **kwargs: + If you do not pass a scheduler address, Client will create a + ``LocalCluster`` object, passing any extra keyword arguments. Examples -------- @@ -559,9 +562,19 @@ class resembles executors in ``concurrent.futures`` but also allows >>> client.gather(c) # doctest: +SKIP 33 + You can also call Client with no arguments in order to create your own + local cluster. + + >>> client = Client() # makes your own local "cluster" # doctest: +SKIP + + Extra keywords will be passed directly to LocalCluster + + >>> client = Client(processes=False, threads_per_worker=1) # doctest: +SKIP + See Also -------- distributed.scheduler.Scheduler: Internal scheduler + distributed.deploy.local.LocalCluster: """ def __init__( From fade817e361a6102b013f600e6947b8d2f7939df Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 3 May 2019 16:04:20 -0700 Subject: [PATCH 0256/1550] Add Comm closed bookkeeping (#2648) We now track open Comms at every test in order to identify cases where Comms may leak out and not be closed up. This involves adding weak references and names to many objects, which should hopefully help debugging in the future. --- distributed/client.py | 11 ++++------- distributed/comm/core.py | 10 +++++++++- distributed/comm/inproc.py | 1 + distributed/comm/tcp.py | 1 + distributed/core.py | 24 +++++++++++++++++++++++- distributed/node.py | 5 +++++ distributed/scheduler.py | 3 +++ distributed/tests/test_scheduler.py | 2 +- distributed/utils_test.py | 29 +++++++++++++++++++++++++++-- distributed/worker.py | 8 ++++++-- 10 files changed, 80 insertions(+), 14 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 24cf9ce5cf5..84728c62fb7 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -703,6 +703,7 @@ def __init__( io_loop=self.loop, serializers=serializers, deserializers=deserializers, + timeout=timeout, ) for ext in extensions: @@ -947,13 +948,7 @@ def _start(self, timeout=no_default, **kwargs): address = self.cluster.scheduler_address if self.scheduler is None: - self.scheduler = rpc( - address, - timeout=timeout, - connection_args=self.connection_args, - serializers=self._serializers, - deserializers=self._deserializers, - ) + self.scheduler = self.rpc(address) self.scheduler_comm = None yield self._ensure_connected(timeout=timeout) @@ -1014,6 +1009,7 @@ def _ensure_connected(self, timeout=None): timeout=timeout, connection_args=self.connection_args, ) + comm.name = "Client->Scheduler" if timeout is not None: yield gen.with_timeout( timedelta(seconds=timeout), self._update_scheduler_info() @@ -1238,6 +1234,7 @@ def _close(self, fast=False): if self._start_arg is None: with ignoring(AttributeError): yield self.cluster._close() + self.rpc.close() self.status = "closed" if _get_global_client() is self: _set_global_client(None) diff --git a/distributed/comm/core.py b/distributed/comm/core.py index b66be0b6dc4..e0b236e7b96 100644 --- a/distributed/comm/core.py +++ b/distributed/comm/core.py @@ -3,6 +3,7 @@ from abc import ABCMeta, abstractmethod, abstractproperty from datetime import timedelta import logging +import weakref import dask from six import with_metaclass @@ -37,6 +38,12 @@ class Comm(with_metaclass(ABCMeta)): depending on the underlying transport's characteristics. """ + _instances = weakref.WeakSet() + + def __init__(self): + self._instances.add(self) + self.name = None + # XXX add set_close_callback()? @abstractmethod @@ -116,8 +123,9 @@ def __repr__(self): if self.closed(): return "" % (clsname,) else: - return "<%s local=%s remote=%s>" % ( + return "<%s %s local=%s remote=%s>" % ( clsname, + self.name or "", self.local_address, self.peer_address, ) diff --git a/distributed/comm/inproc.py b/distributed/comm/inproc.py index 8721a3df8ac..7f267978d51 100644 --- a/distributed/comm/inproc.py +++ b/distributed/comm/inproc.py @@ -152,6 +152,7 @@ class InProc(Comm): def __init__( self, local_addr, peer_addr, read_q, write_q, write_loop, deserialize=True ): + Comm.__init__(self) self._local_addr = local_addr self._peer_addr = peer_addr self.deserialize = deserialize diff --git a/distributed/comm/tcp.py b/distributed/comm/tcp.py index 6d90a7bc9c7..85dbe2ce278 100644 --- a/distributed/comm/tcp.py +++ b/distributed/comm/tcp.py @@ -152,6 +152,7 @@ class TCP(Comm): _iostream_has_read_into = hasattr(IOStream, "read_into") def __init__(self, stream, local_addr, peer_addr, deserialize=True): + Comm.__init__(self) self._local_addr = local_addr self._peer_addr = peer_addr self.stream = stream diff --git a/distributed/core.py b/distributed/core.py index bb8a47c8525..3cf3f9b5bb2 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -593,6 +593,7 @@ def __init__( self.serializers = serializers self.deserializers = deserializers if deserializers is not None else serializers self.connection_args = connection_args + self._created = weakref.WeakSet() rpc.active.add(self) @gen.coroutine @@ -632,6 +633,7 @@ def live_comm(self): deserialize=self.deserialize, connection_args=self.connection_args, ) + comm.name = "rpc" self.comms[comm] = False # mark as taken raise gen.Return(comm) @@ -648,6 +650,9 @@ def _close_comm(comm): for comm in list(self.comms): if comm and not comm.closed(): _close_comm(comm) + for comm in list(self._created): + if comm and not comm.closed(): + _close_comm(comm) self.comms.clear() def __getattr__(self, key): @@ -659,6 +664,7 @@ def send_recv_from_rpc(**kwargs): kwargs["deserializers"] = self.deserializers try: comm = yield self.live_comm() + comm.name = "rpc." + key result = yield send_recv(comm=comm, op=key, **kwargs) except (RPCClosed, CommClosedError) as e: raise e.__class__( @@ -723,10 +729,12 @@ def send_recv_from_rpc(**kwargs): if self.deserializers is not None and kwargs.get("deserializers") is None: kwargs["deserializers"] = self.deserializers comm = yield self.pool.connect(self.addr) + name, comm.name = comm.name, "ConnectionPool." + key try: result = yield send_recv(comm=comm, op=key, **kwargs) finally: self.pool.reuse(self.addr, comm) + comm.name = name raise gen.Return(result) @@ -780,6 +788,8 @@ class ConnectionPool(object): Whether or not to deserialize data by default or pass it through """ + _instances = weakref.WeakSet() + def __init__( self, limit=512, @@ -787,6 +797,8 @@ def __init__( serializers=None, deserializers=None, connection_args=None, + timeout=None, + server=None, ): self.limit = limit # Max number of open comms # Invariant: len(available) == open - active @@ -797,7 +809,11 @@ def __init__( self.serializers = serializers self.deserializers = deserializers if deserializers is not None else serializers self.connection_args = connection_args + self.timeout = timeout self.event = Event() + self.server = weakref.ref(server) if server else None + self._created = weakref.WeakSet() + self._instances.add(self) @property def active(self): @@ -838,10 +854,13 @@ def connect(self, addr, timeout=None): try: comm = yield connect( addr, - timeout=timeout, + timeout=timeout or self.timeout, deserialize=self.deserialize, connection_args=self.connection_args, ) + comm.name = "ConnectionPool" + comm._pool = weakref.ref(self) + self._created.add(comm) except Exception: raise occupied.add(comm) @@ -907,6 +926,9 @@ def close(self): for comm in comms: comm.abort() + for comm in self._created: + IOLoop.current().add_callback(comm.abort) + def coerce_to_address(o): if isinstance(o, (list, tuple)): diff --git a/distributed/node.py b/distributed/node.py index 8a0b8c12195..8134546fa0b 100644 --- a/distributed/node.py +++ b/distributed/node.py @@ -19,6 +19,7 @@ def __init__( io_loop=None, serializers=None, deserializers=None, + timeout=None, ): self.io_loop = io_loop or IOLoop.current() self.rpc = ConnectionPool( @@ -27,6 +28,8 @@ def __init__( serializers=serializers, deserializers=deserializers, connection_args=connection_args, + timeout=timeout, + server=self, ) @@ -51,6 +54,7 @@ def __init__( io_loop=None, serializers=None, deserializers=None, + timeout=None, ): Node.__init__( self, @@ -60,6 +64,7 @@ def __init__( io_loop=io_loop, serializers=serializers, deserializers=deserializers, + timeout=timeout, ) Server.__init__( self, diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 9aebcf11b4f..af87960f6b6 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -2176,6 +2176,7 @@ def add_client(self, comm, client=None): We listen to all future messages from this Comm. """ assert client is not None + comm.name = "Scheduler->Client" logger.info("Receive client connection: %s", client) self.log_event(["all", client], {"action": "add-client", "client": client}) self.clients[client] = ClientState(client) @@ -2373,6 +2374,7 @@ def handle_worker(self, comm=None, worker=None): -------- Scheduler.handle_client: Equivalent coroutine for clients """ + comm.name = "Scheduler connection to worker" worker_comm = self.stream_comms[worker] worker_comm.start(comm) logger.info("Starting worker compute stream, %s", worker) @@ -2633,6 +2635,7 @@ def send_message(addr): comm = yield connect( addr, deserialize=self.deserialize, connection_args=self.connection_args ) + comm.name = "Scheduler Broadcast" resp = yield send_recv(comm, close=True, serializers=serializers, **msg) raise gen.Return(resp) diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 5a68f287a1b..7bb114a4cbc 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -824,7 +824,7 @@ def test_file_descriptors(c, s): yield [n._close() for n in nannies] assert not s.rpc.open - assert not c.rpc.open + assert not c.rpc.active assert not s.stream_comms start = time() diff --git a/distributed/utils_test.py b/distributed/utils_test.py index fcda695e2ec..ba5567b1a80 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -41,6 +41,7 @@ from .client import default_client, _global_clients, Client from .compatibility import PY3, Empty, WINDOWS, PY2 +from .comm import Comm from .comm.utils import offload from .config import initialize_logging from .core import connect, rpc, CommClosedError @@ -637,6 +638,7 @@ def cluster( ws = weakref.WeakSet() reset_config() + Comm._instances.clear() for name, level in logging_levels.items(): logging.getLogger(name).setLevel(level) @@ -761,6 +763,17 @@ def cluster( sleep(0.01) assert time() < start + 1, "Workers still around after one second" + for i in range(5): + if all(c.closed() for c in Comm._instances): + break + else: + sleep(0.1) + else: + L = [c for c in Comm._instances if not c.closed()] + Comm._instances.clear() + print("Unclosed Comms", L) + # raise ValueError("Unclosed Comms", L) + @gen.coroutine def disconnect(addr, timeout=3, rpc_kwargs=None): @@ -845,8 +858,8 @@ def start_cluster( ) for i, ncore in enumerate(ncores) ] - for w in workers: - w.rpc = workers[0].rpc + # for w in workers: + # w.rpc = workers[0].rpc yield [w._start(ncore[0]) for ncore, w in zip(ncores, workers)] @@ -913,6 +926,7 @@ def _(func): def test_func(): del _global_workers[:] _global_clients.clear() + Comm._instances.clear() active_threads_start = set(threading._active) reset_config() @@ -988,6 +1002,17 @@ def coro(): else: yield c._close(fast=True) + for i in range(5): + if all(c.closed() for c in Comm._instances): + break + else: + yield gen.sleep(0.05) + else: + L = [c for c in Comm._instances if not c.closed()] + Comm._instances.clear() + # raise ValueError("Unclosed Comms", L) + print("Unclosed Comms", L) + raise gen.Return(result) result = loop.run_sync( diff --git a/distributed/worker.py b/distributed/worker.py index 33010836fdb..74ad395cc4d 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -674,6 +674,8 @@ def _register_with_scheduler(self): comm = yield connect( self.scheduler.address, connection_args=self.connection_args ) + comm.name = "Worker->Scheduler" + comm._server = weakref.ref(self) yield comm.write( dict( op="register-worker", @@ -993,8 +995,6 @@ def _close(self, report=True, timeout=10, nanny=True, executor_wait=True): for k, v in self.services.items(): v.stop() - self.status = "closed" - if nanny and "nanny" in self.service_ports: nanny_address = "%s%s:%d" % ( self.listener.prefix, @@ -1013,6 +1013,8 @@ def _close(self, report=True, timeout=10, nanny=True, executor_wait=True): self.rpc.close() self._closed.set() self._remove_from_global_workers() + + self.status = "closed" yield self.close() setproctitle("dask-worker [closed]") @@ -1051,6 +1053,7 @@ def batched_send_connect(): comm = yield connect( address, connection_args=self.connection_args # TODO, serialization ) + comm.name = "Worker->Worker" yield comm.write({"op": "connection_stream"}) bcomm.start(comm) @@ -2952,6 +2955,7 @@ def get_data_from_worker( deserializers = rpc.deserializers comm = yield rpc.connect(worker) + comm.name = "Ephemeral Worker->Worker for gather" try: response = yield send_recv( comm, From 2783024663c136318c3eca359951ac0cd2e2529e Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 3 May 2019 17:05:22 -0700 Subject: [PATCH 0257/1550] Rename Worker._close to Worker.close (#2650) --- distributed/deploy/local.py | 2 +- distributed/diagnostics/tests/test_plugin.py | 6 ++-- .../diagnostics/tests/test_progress.py | 4 +-- .../diagnostics/tests/test_progressbar.py | 2 +- distributed/nanny.py | 17 ++++++---- distributed/protocol/core.py | 20 ++--------- distributed/protocol/serialize.py | 17 +++++----- distributed/protocol/utils.py | 13 +++++++ distributed/tests/test_client.py | 34 +++++++++---------- distributed/tests/test_nanny.py | 24 ++++++------- distributed/tests/test_scheduler.py | 24 ++++++------- distributed/tests/test_stress.py | 4 +-- distributed/tests/test_worker.py | 24 ++++++------- distributed/utils_test.py | 10 +++--- distributed/worker.py | 16 +++++---- 15 files changed, 111 insertions(+), 106 deletions(-) diff --git a/distributed/deploy/local.py b/distributed/deploy/local.py index 73ddde8bdc1..68a47c85d48 100644 --- a/distributed/deploy/local.py +++ b/distributed/deploy/local.py @@ -338,7 +338,7 @@ def start_worker(self, **kwargs): @gen.coroutine def _stop_worker(self, w): - yield w._close() + yield w.close() if w in self.workers: self.workers.remove(w) diff --git a/distributed/diagnostics/tests/test_plugin.py b/distributed/diagnostics/tests/test_plugin.py index b1d5406e052..fa4449c74b7 100644 --- a/distributed/diagnostics/tests/test_plugin.py +++ b/distributed/diagnostics/tests/test_plugin.py @@ -55,8 +55,8 @@ def remove_worker(self, worker, scheduler): b = Worker(s.address) yield a yield b - yield a._close() - yield b._close() + yield a.close() + yield b.close() assert events == [ ("add_worker", a.address), @@ -68,5 +68,5 @@ def remove_worker(self, worker, scheduler): events[:] = [] s.remove_plugin(plugin) a = yield Worker(s.address) - yield a._close() + yield a.close() assert events == [] diff --git a/distributed/diagnostics/tests/test_progress.py b/distributed/diagnostics/tests/test_progress.py index d8435cc7ff0..097b2670247 100644 --- a/distributed/diagnostics/tests/test_progress.py +++ b/distributed/diagnostics/tests/test_progress.py @@ -185,8 +185,8 @@ def test_AllProgress_lost_key(c, s, a, b, timeout=None): yield wait(futures) assert len(p.state["memory"]["inc"]) == 5 - yield a._close() - yield b._close() + yield a.close() + yield b.close() start = time() while len(p.state["memory"]["inc"]) > 0: diff --git a/distributed/diagnostics/tests/test_progressbar.py b/distributed/diagnostics/tests/test_progressbar.py index 8738cb60e22..d5a01410f5e 100644 --- a/distributed/diagnostics/tests/test_progressbar.py +++ b/distributed/diagnostics/tests/test_progressbar.py @@ -63,7 +63,7 @@ def f(): assert progress.status == "finished" check_bar_completed(capsys) - yield [a._close(), b._close()] + yield [a.close(), b.close()] s.close() yield done diff --git a/distributed/nanny.py b/distributed/nanny.py index 356ebc3168d..4b81bec4646 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -8,6 +8,7 @@ import shutil import threading import uuid +import warnings import dask from tornado import gen @@ -122,7 +123,7 @@ def __init__( "kill": self.kill, "restart": self.restart, # cannot call it 'close' on the rpc side for naming conflict - "terminate": self._close, + "terminate": self.close, "run": self.run, } @@ -197,7 +198,7 @@ def _start(self, addr_or_port=0): assert self.worker_address self.status = "running" else: - yield self._close() + yield self.close() self.start_periodic_callbacks() @@ -275,7 +276,7 @@ def instantiate(self, comm=None): timedelta(seconds=self.death_timeout), self.process.start() ) except gen.TimeoutError: - yield self._close(timeout=self.death_timeout) + yield self.close(timeout=self.death_timeout) raise gen.Return("timed out") else: result = yield self.process.start() @@ -332,7 +333,7 @@ def _on_exit(self, exitcode): yield self.scheduler.unregister(address=self.worker_address) except (EnvironmentError, CommClosedError): if not self.reconnect: - yield self._close() + yield self.close() return try: @@ -349,8 +350,12 @@ def _on_exit(self, exitcode): def pid(self): return self.process and self.process.pid + def _close(self, *args, **kwargs): + warnings.warn("Worker._close has moved to Worker.close") + return self.close(*args, **kwargs) + @gen.coroutine - def _close(self, comm=None, timeout=5, report=None): + def close(self, comm=None, timeout=5, report=None): """ Close the worker process, stop all comms. """ @@ -584,7 +589,7 @@ def _run( @gen.coroutine def do_stop(timeout=5, executor_wait=True): try: - yield worker._close( + yield worker.close( report=False, nanny=False, executor_wait=executor_wait, diff --git a/distributed/protocol/core.py b/distributed/protocol/core.py index 0b5f7eb0fea..c1b62b2491e 100644 --- a/distributed/protocol/core.py +++ b/distributed/protocol/core.py @@ -11,29 +11,13 @@ from toolz import reduce from .compression import compressions, maybe_compress, decompress -from .serialize import ( - serialize, - deserialize, - Serialize, - Serialized, - extract_serialize, - msgpack_len_opts, -) -from .utils import frame_split_size, merge_frames +from .serialize import serialize, deserialize, Serialize, Serialized, extract_serialize +from .utils import frame_split_size, merge_frames, msgpack_opts from ..utils import nbytes _deserialize = deserialize -try: - msgpack.loads(msgpack.dumps(""), raw=False, **msgpack_len_opts) - msgpack_opts = {"raw": False} - msgpack_opts.update(msgpack_len_opts) -except TypeError: - # Backward compat with old msgpack (prior to 0.5.2) - msgpack_opts = {"encoding": "utf-8"} - - logger = logging.getLogger(__name__) diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index 4ff0fb47a65..f47ea7388af 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -16,7 +16,13 @@ from ..compatibility import PY2 from ..utils import has_keyword from .compression import maybe_compress, decompress -from .utils import unpack_frames, pack_frames_prelude, frame_split_size, ensure_bytes +from .utils import ( + unpack_frames, + pack_frames_prelude, + frame_split_size, + ensure_bytes, + msgpack_opts, +) lazy_registrations = {} @@ -58,11 +64,6 @@ def pickle_loads(header, frames): return pickle.loads(b"".join(frames)) -msgpack_len_opts = { - ("max_%s_len" % x): 2 ** 31 - 1 for x in ["str", "bin", "array", "map", "ext"] -} - - def msgpack_dumps(x): try: frame = msgpack.dumps(x, use_bin_type=True) @@ -73,9 +74,7 @@ def msgpack_dumps(x): def msgpack_loads(header, frames): - return msgpack.loads( - b"".join(frames), encoding="utf8", use_list=False, **msgpack_len_opts - ) + return msgpack.loads(b"".join(frames), use_list=False, **msgpack_opts) def serialization_error_loads(header, frames): diff --git a/distributed/protocol/utils.py b/distributed/protocol/utils.py index 90d30342951..208caebb926 100644 --- a/distributed/protocol/utils.py +++ b/distributed/protocol/utils.py @@ -1,12 +1,25 @@ from __future__ import print_function, division, absolute_import import struct +import msgpack from ..utils import ensure_bytes, nbytes BIG_BYTES_SHARD_SIZE = 2 ** 26 +msgpack_opts = { + ("max_%s_len" % x): 2 ** 31 - 1 for x in ["str", "bin", "array", "map", "ext"] +} + +try: + msgpack.loads(msgpack.dumps(""), raw=False, **msgpack_opts) + msgpack_opts["raw"] = False +except TypeError: + # Backward compat with old msgpack (prior to 0.5.2) + msgpack_opts["encoding"] = "utf-8" + + def frame_split_size(frames, n=BIG_BYTES_SHARD_SIZE): """ Split a list of frames into a list of frames of maximum size diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index ff94fba7787..6f1fb5c8a03 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -528,7 +528,7 @@ def test_gather_lost(c, s, a, b): [x] = yield c.scatter([1], workers=a.address) y = c.submit(inc, 1, workers=b.address) - yield a._close() + yield a.close() with pytest.raises(Exception): res = yield c.gather([x, y]) @@ -641,7 +641,7 @@ def g(a, b): with pytest.raises(AttributeError): yield c.gather(future_g) - yield a._close() + yield a.close() @gen_cluster(client=True) @@ -946,7 +946,7 @@ def test_remove_worker(c, s, a, b): L = c.map(inc, range(20)) yield wait(L) - yield b._close() + yield b.close() assert b.address not in s.workers @@ -2845,7 +2845,7 @@ def test_worker_aliases(): assert result == i + 1 yield c.close() - yield [a._close(), b._close(), w._close()] + yield [a.close(), b.close(), w.close()] yield s.close() @@ -3020,7 +3020,7 @@ def test_rebalance_unprepared(c, s, a, b): def test_receive_lost_key(c, s, a, b): x = c.submit(inc, 1, workers=[a.address]) result = yield x - yield a._close() + yield a.close() start = time() while x.status == "finished": @@ -3036,7 +3036,7 @@ def test_unrunnable_task_runs(c, s, a, b): x = c.submit(inc, 1, workers=[a.ip]) result = yield x - yield a._close() + yield a.close() start = time() while x.status == "finished": assert time() < start + 5 @@ -3055,7 +3055,7 @@ def test_unrunnable_task_runs(c, s, a, b): assert s.tasks[x.key] not in s.unrunnable result = yield x assert result == 2 - yield w._close() + yield w.close() @gen_cluster(client=True, ncores=[]) @@ -3067,7 +3067,7 @@ def test_add_worker_after_tasks(c, s): result = yield c.gather(futures) - yield n._close() + yield n.close() @pytest.mark.skipif( @@ -3460,8 +3460,8 @@ def test_get_foo_lost_keys(c, s, u, v, w): d = yield c.scheduler.who_has(keys=[x.key, y.key]) assert_dict_key_equal(d, {x.key: [ua], y.key: [va]}) - yield u._close() - yield v._close() + yield u.close() + yield v.close() d = yield c.scheduler.has_what() assert_dict_key_equal(d, {wa: []}) @@ -3707,7 +3707,7 @@ def test_reconnect(loop): assert time() < start + 5 sleep(0.1) - sync(loop, w._close) + sync(loop, w.close) c.close() @@ -3753,7 +3753,7 @@ def start_worker(sleep, duration, repeat=1): addr = w.worker_address running[w] = addr yield gen.sleep(duration) - yield w._close() + yield w.close() del w yield gen.moment done.release() @@ -3882,7 +3882,7 @@ def f(): def test_lose_scattered_data(c, s, a, b): [x] = yield c.scatter([1], workers=a.address) - yield a._close() + yield a.close() yield gen.sleep(0.1) assert x.status == "cancelled" @@ -3894,7 +3894,7 @@ def test_partially_lose_scattered_data(e, s, a, b, c): [x] = yield e.scatter([1], workers=a.address) yield e.replicate(x, n=2) - yield a._close() + yield a.close() yield gen.sleep(0.1) assert x.status == "finished" @@ -3909,7 +3909,7 @@ def test_scatter_compute_lose(c, s, a, b): z = c.submit(slowadd, x, y, delay=0.2) yield gen.sleep(0.1) - yield a._close() + yield a.close() with pytest.raises(CancelledError): yield wait(z) @@ -3935,7 +3935,7 @@ def test_scatter_compute_store_lose(c, s, a, b): z = c.submit(slowadd, xx, y, delay=0.2, workers=b.address) yield wait(z) - yield a._close() + yield a.close() start = time() while x.status == "finished": @@ -3980,7 +3980,7 @@ def test_scatter_compute_store_lose_processing(c, s, a, b): y = c.submit(slowinc, x, delay=0.2) z = c.submit(inc, y) yield gen.sleep(0.1) - yield a._close() + yield a.close() start = time() while x.status == "finished": diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index 932419015f3..08cd49fb3c9 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -48,7 +48,7 @@ def test_nanny(s): yield nn.terminate() assert not n.is_alive() - yield n._close() + yield n.close() @gen_cluster(ncores=[]) @@ -57,7 +57,7 @@ def test_many_kills(s): assert n.is_alive() yield [n.kill() for i in range(5)] yield [n.kill() for i in range(5)] - yield n._close() + yield n.close() @gen_cluster(Worker=Nanny) @@ -102,7 +102,7 @@ def test_nanny_process_failure(c, s): second_dir = n.worker_dir - yield n._close() + yield n.close() assert not os.path.exists(second_dir) assert not os.path.exists(first_dir) assert first_dir != n.worker_dir @@ -124,7 +124,7 @@ def test_run(s): assert response["status"] == "OK" assert response["result"] == 1 - yield n._close() + yield n.close() @slow @@ -194,7 +194,7 @@ def test_num_fds(s): # Warm up w = yield Nanny(s.address) - yield w._close() + yield w.close() del w gc.collect() @@ -203,7 +203,7 @@ def test_num_fds(s): for i in range(3): w = yield Nanny(s.address) yield gen.sleep(0.1) - yield w._close() + yield w.close() start = time() while proc.num_fds() > before: @@ -226,7 +226,7 @@ def func(dask_worker): result = yield c.run(func) assert host in first(result.values()) - yield n._close() + yield n.close() @gen_test() @@ -236,7 +236,7 @@ def test_scheduler_file(): s.start(8008) w = yield Nanny(scheduler_file=fn) assert set(s.workers) == {w.worker_address} - yield w._close() + yield w.close() s.stop() @@ -301,7 +301,7 @@ def test_avoid_memory_monitor_if_zero_limit(c, s): yield c.submit(inc, 2) # worker doesn't pause - yield nanny._close() + yield nanny.close() @gen_cluster(ncores=[], client=True) @@ -315,7 +315,7 @@ def test_scheduler_address_config(c, s): yield gen.sleep(0.1) assert time() < start + 10 - yield nanny._close() + yield nanny.close() @slow @@ -338,7 +338,7 @@ def test_environment_variable(c, s): yield [a, b] results = yield c.run(lambda: os.environ["FOO"]) assert results == {a.worker_address: "123", b.worker_address: "456"} - yield [a._close(), b._close()] + yield [a.close(), b.close()] @gen_cluster(ncores=[], client=True) @@ -346,4 +346,4 @@ def test_data_types(c, s): w = yield Nanny(s.address, data=dict) r = yield c.run(lambda dask_worker: type(dask_worker.data)) assert r[w.worker_address] == dict - yield w._close() + yield w.close() diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 7bb114a4cbc..8280c3cd120 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -275,7 +275,7 @@ def test_add_worker(s, a, b): assert w.ip in s.host_info assert s.host_info[w.ip]["addresses"] == {a.address, b.address, w.address} - yield w._close() + yield w.close() @gen_cluster(scheduler_kwargs={"blocked_handlers": ["feed"]}) @@ -541,10 +541,10 @@ def test_worker_name(): with pytest.raises(ValueError): w2 = yield Worker(s.ip, s.port, name="alice") - yield w2._close() + yield w2.close() yield s.close() - yield w._close() + yield w.close() @gen_test() @@ -585,7 +585,7 @@ def test_coerce_address(): assert s.coerce_address("zzzt:8000", resolve=False) == "tcp://zzzt:8000" yield s.close() - yield [w._close() for w in [a, b, c]] + yield [w.close() for w in [a, b, c]] @pytest.mark.skipif( @@ -598,7 +598,7 @@ def test_file_descriptors_dont_leak(s): before = proc.num_fds() w = yield Worker(s.ip, s.port) - yield w._close() + yield w.close() during = proc.num_fds() @@ -668,7 +668,7 @@ def test_scatter_no_workers(c, s): yield [c.scatter(data={"y": 2}, timeout=5), w._start()] assert w.data["y"] == 2 - yield w._close() + yield w.close() @gen_cluster(ncores=[]) @@ -676,7 +676,7 @@ def test_scheduler_sees_memory_limits(s): w = yield Worker(s.ip, s.port, ncores=3, memory_limit=12345) assert s.workers[w.address].memory_limit == 12345 - yield w._close() + yield w.close() @gen_cluster(client=True, timeout=1000) @@ -821,7 +821,7 @@ def test_file_descriptors(c, s): num_fds_6 = proc.num_fds() assert num_fds_6 < num_fds_5 + N - yield [n._close() for n in nannies] + yield [n.close() for n in nannies] assert not s.rpc.open assert not c.rpc.active @@ -945,7 +945,7 @@ def test_worker_arrives_with_processing_data(c, s, a, b): z.key: "processing", } - yield w._close() + yield w.close() @gen_cluster(client=True, ncores=[("127.0.0.1", 1)]) @@ -996,7 +996,7 @@ def test_no_workers_to_memory(c, s): z.key: "processing", } - yield w._close() + yield w.close() @gen_cluster(client=True) @@ -1025,7 +1025,7 @@ def test_no_worker_to_memory_restrictions(c, s, a, b): z.key: "processing", } - yield w._close() + yield w.close() def test_run_on_scheduler_sync(loop): @@ -1333,7 +1333,7 @@ def test_mising_data_errant_worker(c, s, w1, w2, w3): y = c.submit(len, x, workers=w3.address) while not w3.tasks: yield gen.sleep(0.001) - w1._close() + w1.close() yield wait(y) diff --git a/distributed/tests/test_stress.py b/distributed/tests/test_stress.py index 8a36b8b3b94..f145a11b053 100644 --- a/distributed/tests/test_stress.py +++ b/distributed/tests/test_stress.py @@ -114,7 +114,7 @@ def create_and_destroy_worker(delay): yield gen.sleep(delay) - yield n._close() + yield n.close() print("Killed nanny") yield gen.with_timeout( @@ -167,7 +167,7 @@ def test_stress_scatter_death(c, s, *workers): else: raise w = random.choice(alive) - yield w._close() + yield w.close() alive.remove(w) try: diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 388414a4448..b51de0b7c0e 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -193,8 +193,8 @@ def g(): result = yield future assert result == 123 - yield a._close() - yield b._close() + yield a.close() + yield b.close() aa.close_rpc() bb.close_rpc() assert not os.path.exists(os.path.join(a.local_dir, "foobar.py")) @@ -997,7 +997,7 @@ def test_start_services(s): yield w._start() assert w.services["bokeh"].server.port == 1234 - yield w._close() + yield w.close() @gen_test() @@ -1007,7 +1007,7 @@ def test_scheduler_file(): s.start(8009) w = yield Worker(scheduler_file=fn) assert set(s.workers) == {w.address} - yield w._close() + yield w.close() s.stop() @@ -1186,7 +1186,7 @@ def test_avoid_memory_monitor_if_zero_limit(c, s): yield c.submit(inc, 2) # worker doesn't pause - yield worker._close() + yield worker.close() @gen_cluster( @@ -1225,7 +1225,7 @@ def test_scheduler_address_config(c, s): with dask.config.set({"scheduler-address": s.address}): worker = yield Worker(loop=s.loop) assert worker.scheduler.address == s.address - yield worker._close() + yield worker.close() @slow @@ -1321,7 +1321,7 @@ def test_startup2(): worker = yield Worker(s.address, loop=s.loop) result = yield c.run(test_import, workers=[worker.address]) assert list(result.values()) == [False] - yield worker._close() + yield worker.close() # Add a preload function response = yield c.register_worker_callbacks(setup=mystartup) @@ -1336,7 +1336,7 @@ def test_startup2(): worker = yield Worker(s.address, loop=s.loop) result = yield c.run(test_import, workers=[worker.address]) assert list(result.values()) == [True] - yield worker._close() + yield worker.close() # Register another preload function response = yield c.register_worker_callbacks(setup=mystartup2) @@ -1353,7 +1353,7 @@ def test_startup2(): assert list(result.values()) == [True] result = yield c.run(test_startup2, workers=[worker.address]) assert list(result.values()) == [True] - yield worker._close() + yield worker.close() # Final exception test with pytest.raises(ZeroDivisionError): @@ -1364,12 +1364,12 @@ def test_startup2(): def test_data_types(s): w = yield Worker(s.address, data=dict) assert isinstance(w.data, dict) - yield w._close() + yield w.close() data = dict() w = yield Worker(s.address, data=data) assert w.data is data - yield w._close() + yield w.close() class Data(dict): def __init__(self, x, y): @@ -1379,4 +1379,4 @@ def __init__(self, x, y): w = yield Worker(s.address, data=(Data, {"x": 123, "y": 456})) assert w.data.x == 123 assert w.data.y == 456 - yield w._close() + yield w.close() diff --git a/distributed/utils_test.py b/distributed/utils_test.py index ba5567b1a80..b0c0d2d48cc 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -112,7 +112,7 @@ def invalid_python_script(tmpdir_factory): def cleanup_global_workers(): for w in _global_workers: w = w() - w._close(report=False, executor_wait=False) + w.close(report=False, executor_wait=False) @pytest.fixture @@ -526,7 +526,7 @@ def run_nanny(q, scheduler_q, **kwargs): try: loop.start() finally: - loop.run_sync(worker._close) + loop.run_sync(worker.close) loop.close(all_fds=True) @@ -869,7 +869,7 @@ def start_cluster( ): yield gen.sleep(0.01) if time() - start > 5: - yield [w._close(timeout=1) for w in workers] + yield [w.close(timeout=1) for w in workers] yield s.close(fast=True) raise Exception("Cluster creation timeout") raise gen.Return((s, workers)) @@ -882,7 +882,7 @@ def end_cluster(s, workers): @gen.coroutine def end_worker(w): with ignoring(TimeoutError, CommClosedError, EnvironmentError): - yield w._close(report=False) + yield w.close(report=False) yield [end_worker(w) for w in workers] yield s.close() # wait until scheduler stops completely @@ -1031,7 +1031,7 @@ def coro(): DequeHandler.clear_all_instances() for w in _global_workers: w = w() - w._close(report=False, executor_wait=False) + w.close(report=False, executor_wait=False) if w.status == "running": w.close() del _global_workers[:] diff --git a/distributed/worker.py b/distributed/worker.py index 74ad395cc4d..8d81e2781a0 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -536,7 +536,7 @@ def __init__( } stream_handlers = { - "close": self._close, + "close": self.close, "compute-task": self.add_task, "release-task": partial(self.release_key, report=False), "delete-data": self.delete_data, @@ -665,7 +665,7 @@ def _register_with_scheduler(self): logger.info("-" * 49) while True: if self.death_timeout and time() > start + self.death_timeout: - yield self._close(timeout=1) + yield self.close(timeout=1) return if self.status in ("closed", "closing"): raise gen.Return @@ -775,7 +775,7 @@ def handle_scheduler(self, comm): logger.info("Connection to scheduler broken. Reconnecting...") self.loop.add_callback(self._register_with_scheduler) else: - yield self._close(report=False) + yield self.close(report=False) def start_ipython(self, comm): """Start an IPython kernel @@ -958,8 +958,12 @@ def __await__(self): def start(self, port=0): self.loop.add_callback(self._start, port) + def _close(self, *args, **kwargs): + warnings.warn("Worker._close has moved to Worker.close") + return self.close(*args, **kwargs) + @gen.coroutine - def _close(self, report=True, timeout=10, nanny=True, executor_wait=True): + def close(self, report=True, timeout=10, nanny=True, executor_wait=True): with log_errors(): if self.status in ("closed", "closing"): return @@ -1015,7 +1019,7 @@ def _close(self, report=True, timeout=10, nanny=True, executor_wait=True): self._remove_from_global_workers() self.status = "closed" - yield self.close() + yield ServerNode.close(self) setproctitle("dask-worker [closed]") @@ -1031,7 +1035,7 @@ def _remove_from_global_workers(self): @gen.coroutine def terminate(self, comm, report=True): - yield self._close(report=report) + yield self.close(report=report) raise Return("OK") @gen.coroutine From 5c43091c2f61d25652d681898b7aade96e9ee811 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 3 May 2019 19:28:07 -0700 Subject: [PATCH 0258/1550] Use an LRU cache for deserialized functions (#2623) Fixes https://github.com/dask/distributed/issues/2621 --- distributed/worker.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/distributed/worker.py b/distributed/worker.py index 8d81e2781a0..4cf0585b1de 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -3021,7 +3021,13 @@ def execute_task(task): return task -cache = dict() +try: + # a 10 MB cache of deserialized functions and their bytes + from zict import LRU + + cache = LRU(10000000, dict(), weight=lambda k, v: len(v)) +except ImportError: + cache = dict() def dumps_function(func): From 528c59b2b660e218c116823ad30671846cfa2530 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sat, 4 May 2019 08:33:50 -0700 Subject: [PATCH 0259/1550] Avoid deprecation warnings (#2653) --- distributed/client.py | 3 ++- distributed/deploy/local.py | 4 +++- distributed/tests/py3_test_asyncio.py | 6 +++--- distributed/tests/py3_test_pubsub.py | 2 +- 4 files changed, 9 insertions(+), 6 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 84728c62fb7..fef8b12f1fc 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -2391,7 +2391,8 @@ def run_coroutine(self, function, *args, **kwargs): warnings.warn( "This method has been deprecated. " "Instead use Client.run which detects async functions " - "automatically" + "automatically", + stacklevel=2, ) return self.run(function, *args, **kwargs) diff --git a/distributed/deploy/local.py b/distributed/deploy/local.py index 68a47c85d48..4cf67af150e 100644 --- a/distributed/deploy/local.py +++ b/distributed/deploy/local.py @@ -290,7 +290,9 @@ def _start(self, ip=None, n_workers=0): @gen.coroutine def _start_worker(self, death_timeout=60, **kwargs): if self.status and self.status.startswith("clos"): - warnings.warn("Tried to start a worker while status=='%s'" % self.status) + warnings.warn( + "Tried to start a worker while status=='%s'" % self.status, stacklevel=2 + ) return if self.processes: diff --git a/distributed/tests/py3_test_asyncio.py b/distributed/tests/py3_test_asyncio.py index 90e20268617..3754b282813 100644 --- a/distributed/tests/py3_test_asyncio.py +++ b/distributed/tests/py3_test_asyncio.py @@ -299,15 +299,15 @@ async def aiothrows(x, delay=0.02): raise RuntimeError("hello") async with AioClient(processes=False) as c: - results = await c.run_coroutine(aioinc, 1, delay=0.05) + results = await c.run(aioinc, 1, delay=0.05) assert len(results) > 0 assert [value == 2 for value in results.values()] - results = await c.run_coroutine(aioinc, 1, workers=[]) + results = await c.run(aioinc, 1, workers=[]) assert results == {} with pytest.raises(RuntimeError) as exc_info: - await c.run_coroutine(aiothrows, 1) + await c.run(aiothrows, 1) assert "hello" in str(exc_info) diff --git a/distributed/tests/py3_test_pubsub.py b/distributed/tests/py3_test_pubsub.py index 172c8734819..0cedbb3bd31 100644 --- a/distributed/tests/py3_test_pubsub.py +++ b/distributed/tests/py3_test_pubsub.py @@ -22,7 +22,7 @@ def f(_): sub = Sub("a") return list(toolz.take(5, sub)) - c.run_coroutine(publish, workers=[a.address]) + c.run(publish, workers=[a.address]) tasks = [c.submit(f, i) for i in range(4)] results = yield c.gather(tasks) From ddaf73bea0cb18aedd2025b28db6247cc984aaba Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sat, 4 May 2019 10:00:54 -0700 Subject: [PATCH 0260/1550] Add idle timeout to scheduler (#2652) Schedulers that haven't been touched in a while can choose to shut themselves down. This is useful as a stop-gap to clean up costly forgotten resources. * Avoid allocating works unnecessarily in LocalCluster tests --- distributed/deploy/tests/test_local.py | 27 ++++++++++++++++++------- distributed/distributed.yaml | 1 + distributed/scheduler.py | 28 ++++++++++++++++++++++++++ distributed/tests/test_scheduler.py | 16 +++++++++++++++ 4 files changed, 65 insertions(+), 7 deletions(-) diff --git a/distributed/deploy/tests/test_local.py b/distributed/deploy/tests/test_local.py index ab378ba0e8f..6e1e71e83b2 100644 --- a/distributed/deploy/tests/test_local.py +++ b/distributed/deploy/tests/test_local.py @@ -52,7 +52,7 @@ def test_simple(loop): def test_local_cluster_supports_blocked_handlers(loop): - with LocalCluster(blocked_handlers=["run_function"], loop=loop) as c: + with LocalCluster(blocked_handlers=["run_function"], n_workers=0, loop=loop) as c: with Client(c) as client: with pytest.raises(ValueError) as exc: client.run_on_scheduler(lambda x: x, 42) @@ -309,11 +309,11 @@ def test_cleanup(): def test_repeated(): with LocalCluster( - scheduler_port=8448, silence_logs=False, dashboard_address=None + 0, scheduler_port=8448, silence_logs=False, dashboard_address=None ) as c: pass with LocalCluster( - scheduler_port=8448, silence_logs=False, dashboard_address=None + 0, scheduler_port=8448, silence_logs=False, dashboard_address=None ) as c: pass @@ -323,6 +323,7 @@ def test_bokeh(loop, processes): pytest.importorskip("bokeh") requests = pytest.importorskip("requests") with LocalCluster( + n_workers=0, scheduler_port=0, silence_logs=False, loop=loop, @@ -405,14 +406,19 @@ def test_silent_startup(): def test_only_local_access(loop): with LocalCluster( - scheduler_port=0, silence_logs=False, dashboard_address=None, loop=loop + 0, scheduler_port=0, silence_logs=False, dashboard_address=None, loop=loop ) as c: sync(loop, assert_can_connect_locally_4, c.scheduler.port) def test_remote_access(loop): with LocalCluster( - scheduler_port=0, silence_logs=False, dashboard_address=None, ip="", loop=loop + 0, + scheduler_port=0, + silence_logs=False, + dashboard_address=None, + ip="", + loop=loop, ) as c: sync(loop, assert_can_connect_from_everywhere_4_6, c.scheduler.port) @@ -463,6 +469,7 @@ def test_death_timeout_raises(loop): def test_bokeh_kwargs(loop): pytest.importorskip("bokeh") with LocalCluster( + n_workers=0, scheduler_port=0, silence_logs=False, loop=loop, @@ -496,6 +503,7 @@ def test_logging(): def test_ipywidgets(loop): ipywidgets = pytest.importorskip("ipywidgets") with LocalCluster( + n_workers=0, scheduler_port=0, silence_logs=False, loop=loop, @@ -607,6 +615,7 @@ def test_local_tls(loop): security = tls_only_security() with LocalCluster( + n_workers=0, scheduler_port=8786, silence_logs=False, security=security, @@ -730,7 +739,9 @@ def test_protocol_inproc(loop): def test_protocol_tcp(loop): - with LocalCluster(protocol="tcp", loop=loop, processes=False) as cluster: + with LocalCluster( + protocol="tcp", loop=loop, n_workers=0, processes=False + ) as cluster: assert cluster.scheduler.address.startswith("tcp://") @@ -738,7 +749,9 @@ def test_protocol_tcp(loop): not sys.platform.startswith("linux"), reason="Need 127.0.0.2 to mean localhost" ) def test_protocol_ip(loop): - with LocalCluster(ip="tcp://127.0.0.2", loop=loop, processes=False) as cluster: + with LocalCluster( + ip="tcp://127.0.0.2", loop=loop, n_workers=0, processes=False + ) as cluster: assert cluster.scheduler.address.startswith("tcp://127.0.0.2") diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index d625a103fe8..3ae9b7ee690 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -16,6 +16,7 @@ distributed: # Number of seconds to wait until workers or clients are removed from the events log # after they have been removed from the scheduler events-cleanup-delay: 1h + idle-timeout: null # Shut down after this duration, like "1h" or "30 minutes" transition-log-length: 100000 work-stealing: True # workers should steal tasks from each other worker-ttl: null # like '60s'. Time to live for workers. They must heartbeat faster than this diff --git a/distributed/scheduler.py b/distributed/scheduler.py index af87960f6b6..f9cbbb8c783 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -816,6 +816,7 @@ def __init__( scheduler_file=None, security=None, worker_ttl=None, + idle_timeout=None, **kwargs ): @@ -836,6 +837,14 @@ def __init__( self.scheduler_file = scheduler_file worker_ttl = worker_ttl or dask.config.get("distributed.scheduler.worker-ttl") self.worker_ttl = parse_timedelta(worker_ttl) if worker_ttl else None + idle_timeout = idle_timeout or dask.config.get( + "distributed.scheduler.idle-timeout" + ) + if idle_timeout: + self.idle_timeout = parse_timedelta(idle_timeout) + else: + self.idle_timeout = None + self.time_started = time() self.security = security or Security() assert isinstance(self.security, Security) @@ -1054,6 +1063,10 @@ def __init__( pc = PeriodicCallback(self.check_worker_ttl, self.worker_ttl, io_loop=loop) self.periodic_callbacks["worker-ttl"] = pc + if self.idle_timeout: + pc = PeriodicCallback(self.check_idle, self.idle_timeout / 4, io_loop=loop) + self.periodic_callbacks["idle-timeout"] = pc + if extensions is None: extensions = DEFAULT_EXTENSIONS for ext in extensions: @@ -4651,6 +4664,21 @@ def check_worker_ttl(self): ) self.remove_worker(address=ws.address) + def check_idle(self): + if any(ws.processing for ws in self.workers.values()): + return + if self.unrunnable: + return + + if not self.transition_log: + close = time() > self.time_started + self.idle_timeout + else: + last_task = self.transition_log[-1][-1] + close = time() > last_task + self.idle_timeout + + if close: + self.loop.add_callback(self.close) + def decide_worker(ts, all_workers, valid_workers, objective): """ diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 8280c3cd120..f4c13bfc852 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -1499,6 +1499,22 @@ def qux(x): yield f +@gen_cluster(client=True, config={"distributed.scheduler.idle-timeout": "200ms"}) +def test_idle_timeout(c, s, a, b): + future = c.submit(slowinc, 1) + yield future + + assert s.status != "closed" + + start = time() + while s.status != "closed": + yield gen.sleep(0.01) + assert time() < start + 3 + + assert a.status == "closed" + assert b.status == "closed" + + @gen_cluster() def test_workerstate_clean(s, a, b): ws = s.workers[a.address].clean() From d42173be416326aac4a48e73d4d169674a41c6e8 Mon Sep 17 00:00:00 2001 From: Brett Randall Date: Sun, 5 May 2019 23:18:05 +1000 Subject: [PATCH 0261/1550] Fixed comment regarding keeping existing level if less verbose (#2655) This behaviour changed in commit 6cc529979 . --- distributed/utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/distributed/utils.py b/distributed/utils.py index 5259e567358..28ce2364190 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -727,8 +727,7 @@ def log_errors(pdb=False): def silence_logging(level, root="distributed"): """ - Force all existing loggers below *root* to the given level at least - (or keep the existing level if less verbose). + Change all StreamHandlers for the given logger to the given level """ if isinstance(level, str): level = getattr(logging, level.upper()) From 09b959a5667a51a2dc073510c784e01d44827457 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sun, 5 May 2019 11:42:04 -0700 Subject: [PATCH 0262/1550] Check direct_to_workers before using get_worker in Client (#2656) Otherwise we would ignore direct_to_workers=True when there wasn't a local worker --- distributed/client.py | 10 +++++----- distributed/tests/test_client.py | 8 ++++++++ 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index fef8b12f1fc..97728929f33 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -591,7 +591,7 @@ def __init__( serializers=None, deserializers=None, extensions=DEFAULT_EXTENSIONS, - direct_to_workers=False, + direct_to_workers=None, **kwargs ): if timeout == no_default: @@ -1607,6 +1607,8 @@ def _gather(self, futures, errors="raise", direct=None, local_worker=None): bad_data = dict() data = {} + if direct is None: + direct = self.direct_to_workers if direct is None: try: w = get_worker() @@ -1615,8 +1617,6 @@ def _gather(self, futures, errors="raise", direct=None, local_worker=None): else: if w.scheduler.address == self.scheduler.address: direct = True - if direct is None: - direct = self.direct_to_workers @gen.coroutine def wait(k): @@ -1866,6 +1866,8 @@ def _scatter( types = valmap(type, data) + if direct is None: + direct = self.direct_to_workers if direct is None: try: w = get_worker() @@ -1874,8 +1876,6 @@ def _scatter( else: if w.scheduler.address == self.scheduler.address: direct = True - if direct is None: - direct = self.direct_to_workers if local_worker: # running within task local_worker.update_data(data=data, report=False) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 6f1fb5c8a03..429b92c5194 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -5707,5 +5707,13 @@ def test_get_mix_futures_and_SubgraphCallable_dask_dataframe(c, s, a, b): assert result.equals(df.astype("f8")) +def test_direct_to_workers(s, loop): + with Client(s["address"], loop=loop, direct_to_workers=True) as client: + future = client.scatter(1) + future.result() + resp = client.run_on_scheduler(lambda dask_scheduler: dask_scheduler.events) + assert "gather" not in str(resp) + + if sys.version_info >= (3, 5): from distributed.tests.py3_test_client import * # noqa F401 From 7a76a77428e64df486f34666bf2eb8869fecbce4 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 6 May 2019 09:13:19 -0500 Subject: [PATCH 0263/1550] Allow scheduler to politely close workers as part of shutdown (#2651) --- distributed/cli/dask_worker.py | 2 +- distributed/core.py | 5 +++ distributed/deploy/local.py | 14 +++---- distributed/nanny.py | 2 +- distributed/protocol/tests/test_pickle.py | 10 +++-- distributed/scheduler.py | 11 ++++- distributed/tests/test_actor.py | 2 +- distributed/tests/test_client.py | 22 ++++------ distributed/tests/test_failed_workers.py | 8 ++-- distributed/tests/test_resources.py | 6 +-- distributed/tests/test_scheduler.py | 9 +++- distributed/tests/test_steal.py | 6 +-- distributed/tests/test_worker.py | 50 ++++++++++++----------- distributed/tests/test_worker_client.py | 2 +- distributed/worker.py | 14 +++---- 15 files changed, 92 insertions(+), 71 deletions(-) diff --git a/distributed/cli/dask_worker.py b/distributed/cli/dask_worker.py index 73cb9970924..a0bc801a960 100755 --- a/distributed/cli/dask_worker.py +++ b/distributed/cli/dask_worker.py @@ -369,7 +369,7 @@ def del_pid_file(): def close_all(): # Unregister all workers from scheduler if nanny: - yield [n._close(timeout=2) for n in nannies] + yield [n.close(timeout=2) for n in nannies] def on_signal(signum): logger.info("Exiting on signal %d", signum) diff --git a/distributed/core.py b/distributed/core.py index 3cf3f9b5bb2..9b1d408a038 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -490,6 +490,11 @@ def handle_stream(self, comm, extra=None, every_cycle=[]): @gen.coroutine def close(self): self.listener.stop() + for i in range(20): # let comms close naturally for a second + if not self._comms: + break + else: + yield gen.sleep(0.05) for comm in self._comms: comm.close() for cb in self._ongoing_coroutines: diff --git a/distributed/deploy/local.py b/distributed/deploy/local.py index 4cf67af150e..ad8b36f6dd7 100644 --- a/distributed/deploy/local.py +++ b/distributed/deploy/local.py @@ -362,7 +362,11 @@ def _close(self, timeout="2s"): return self.status = "closing" - self.scheduler.clear_task_state() + with ignoring(gen.TimeoutError, CommClosedError, OSError): + yield gen.with_timeout( + timedelta(seconds=parse_timedelta(timeout)), + self.scheduler.close(close_workers=True), + ) with ignoring(gen.TimeoutError): yield gen.with_timeout( @@ -370,13 +374,7 @@ def _close(self, timeout="2s"): All([self._stop_worker(w) for w in self.workers]), ) del self.workers[:] - - try: - with ignoring(gen.TimeoutError, CommClosedError, OSError): - yield self.scheduler.close(fast=True) - del self.workers[:] - finally: - self.status = "closed" + self.status = "closed" def close(self, timeout=20): """ Close the cluster """ diff --git a/distributed/nanny.py b/distributed/nanny.py index 4b81bec4646..60e83e86da7 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -351,7 +351,7 @@ def pid(self): return self.process and self.process.pid def _close(self, *args, **kwargs): - warnings.warn("Worker._close has moved to Worker.close") + warnings.warn("Worker._close has moved to Worker.close", stacklevel=2) return self.close(*args, **kwargs) @gen.coroutine diff --git a/distributed/protocol/tests/test_pickle.py b/distributed/protocol/tests/test_pickle.py index 70d9cdaff22..0ba776e2758 100644 --- a/distributed/protocol/tests/test_pickle.py +++ b/distributed/protocol/tests/test_pickle.py @@ -1,10 +1,11 @@ -from distributed.protocol.pickle import dumps, loads +from functools import partial +import gc +from operator import add +import weakref import pytest -import weakref -from operator import add -from functools import partial +from distributed.protocol.pickle import dumps, loads def test_pickle_data(): @@ -42,5 +43,6 @@ def funcs(): wr2 = weakref.ref(func2) assert func2(1) == func(1) del func, func2 + gc.collect() assert wr() is None assert wr2() is None diff --git a/distributed/scheduler.py b/distributed/scheduler.py index f9cbbb8c783..cc6f39ff396 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1234,7 +1234,7 @@ def finished(self): yield All(self.coroutines) @gen.coroutine - def close(self, comm=None, fast=False): + def close(self, comm=None, fast=False, close_workers=False): """ Send cleanup signal to all coroutines then wait until finished See Also @@ -1248,6 +1248,15 @@ def close(self, comm=None, fast=False): logger.info("Scheduler closing...") setproctitle("dask-scheduler [closing]") + if close_workers: + for worker in self.workers: + self.worker_send(worker, {"op": "close"}) + for i in range(20): # wait a second for send signals to clear + if self.workers: + yield gen.sleep(0.05) + else: + break + for pc in self.periodic_callbacks.values(): pc.stop() self.periodic_callbacks.clear() diff --git a/distributed/tests/test_actor.py b/distributed/tests/test_actor.py index fba0f50cbfe..ec2636ccd50 100644 --- a/distributed/tests/test_actor.py +++ b/distributed/tests/test_actor.py @@ -261,7 +261,7 @@ def test_failed_worker(c, s, a, b): yield wait(future) counter = yield future - yield a._close() + yield a.close() with pytest.raises(Exception) as info: yield counter.increment() diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 429b92c5194..8582c2abc83 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -3891,7 +3891,7 @@ def test_lose_scattered_data(c, s, a, b): @gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 3) def test_partially_lose_scattered_data(e, s, a, b, c): - [x] = yield e.scatter([1], workers=a.address) + x = yield e.scatter(1, workers=a.address) yield e.replicate(x, n=2) yield a.close() @@ -4713,20 +4713,16 @@ def test_quiet_client_close(loop): ), line +@slow def test_quiet_client_close_when_cluster_is_closed_before_client(loop): - n_attempts = 5 - # Trying a few times to reduce the flakiness of the test. Without the bug - # fix in #2477 and with 5 attempts, this test passes by chance in about 10% - # of the cases. - for _ in range(n_attempts): - with captured_logger(logging.getLogger("tornado.application")) as logger: - cluster = LocalCluster(loop=loop) - client = Client(cluster, loop=loop) - cluster.close() - client.close() + with captured_logger(logging.getLogger("tornado.application")) as logger: + cluster = LocalCluster(loop=loop, n_workers=1) + client = Client(cluster, loop=loop) + cluster.close() + client.close() - out = logger.getvalue() - assert "CancelledError" not in out + out = logger.getvalue() + assert "CancelledError" not in out @gen_cluster() diff --git a/distributed/tests/test_failed_workers.py b/distributed/tests/test_failed_workers.py index 5bb1c61fb5b..0772ea52c32 100644 --- a/distributed/tests/test_failed_workers.py +++ b/distributed/tests/test_failed_workers.py @@ -53,14 +53,14 @@ def test_submit_after_failed_worker_async(c, s, a, b): result = yield total assert result == sum(map(inc, range(10))) - yield n._close() + yield n.close() @gen_cluster(client=True, timeout=60) def test_submit_after_failed_worker(c, s, a, b): L = c.map(inc, range(10)) yield wait(L) - yield a._close() + yield a.close() total = c.submit(sum, L) result = yield total @@ -353,7 +353,7 @@ def test_broken_worker_during_computation(c, s, a, b): assert isinstance(result, int) assert result == expected_result - yield n._close() + yield n.close() @gen_cluster(client=True, Worker=Nanny, timeout=60) @@ -403,7 +403,7 @@ def test_worker_who_has_clears_after_failed_connection(c, s, a, b): assert not a.has_what.get(n_worker_address) assert not any(n_worker_address in s for s in a.who_has.values()) - yield n._close() + yield n.close() @slow diff --git a/distributed/tests/test_resources.py b/distributed/tests/test_resources.py index 35f5e160969..429bbc2bb56 100644 --- a/distributed/tests/test_resources.py +++ b/distributed/tests/test_resources.py @@ -27,12 +27,12 @@ def test_resources(c, s): assert s.resources == {"GPU": {a.address: 2, b.address: 1}, "DB": {b.address: 1}} assert s.worker_resources == {a.address: {"GPU": 2}, b.address: {"GPU": 1, "DB": 1}} - yield b._close() + yield b.close() assert s.resources == {"GPU": {a.address: 2}, "DB": {}} assert s.worker_resources == {a.address: {"GPU": 2}} - yield a._close() + yield a.close() @gen_cluster( @@ -60,7 +60,7 @@ def test_resource_submit(c, s, a, b): yield wait(z) assert z.key in d.data - yield d._close() + yield d.close() @gen_cluster( diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index f4c13bfc852..6a631a13498 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -543,8 +543,8 @@ def test_worker_name(): w2 = yield Worker(s.ip, s.port, name="alice") yield w2.close() - yield s.close() yield w.close() + yield s.close() @gen_test() @@ -1521,3 +1521,10 @@ def test_workerstate_clean(s, a, b): assert ws.address == a.address b = pickle.dumps(ws) assert len(b) < 1000 + + +@gen_cluster() +def test_close_workers(s, a, b): + yield s.close(close_workers=True) + assert a.status == "closed" + assert b.status == "closed" diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index f93022e6d81..cb56fc0f263 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -182,7 +182,7 @@ def test_new_worker_steals(c, s, a): assert b.data - yield b._close() + yield b.close() @gen_cluster(client=True, timeout=20) @@ -287,7 +287,7 @@ def test_steal_resource_restrictions(c, s, a): assert len(b.task_state) > 0 assert len(a.task_state) < 101 - yield b._close() + yield b.close() @gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 5, timeout=20) @@ -550,7 +550,7 @@ def test_steal_twice(c, s, a, b): assert max(map(len, has_what.values())) < 30 yield c._close() - yield [w._close() for w in workers] + yield [w.close() for w in workers] @gen_cluster(client=True) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index b51de0b7c0e..07864ab4b64 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -174,12 +174,11 @@ def test_upload_file(c, s, a, b): assert not os.path.exists(os.path.join(b.local_dir, "foobar.py")) assert a.local_dir != b.local_dir - aa = rpc(a.address) - bb = rpc(b.address) - yield [ - aa.upload_file(filename="foobar.py", data=b"x = 123"), - bb.upload_file(filename="foobar.py", data="x = 123"), - ] + with rpc(a.address) as aa, rpc(b.address) as bb: + yield [ + aa.upload_file(filename="foobar.py", data=b"x = 123"), + bb.upload_file(filename="foobar.py", data="x = 123"), + ] assert os.path.exists(os.path.join(a.local_dir, "foobar.py")) assert os.path.exists(os.path.join(b.local_dir, "foobar.py")) @@ -193,10 +192,8 @@ def g(): result = yield future assert result == 123 - yield a.close() - yield b.close() - aa.close_rpc() - bb.close_rpc() + yield c.close() + yield s.close(close_workers=True) assert not os.path.exists(os.path.join(a.local_dir, "foobar.py")) @@ -251,8 +248,10 @@ def g(x): result = yield future assert result == 10 + 1 - yield a._close() - yield b._close() + yield c.close() + yield s.close() + yield a.close() + yield b.close() assert not os.path.exists(os.path.join(a.local_dir, eggname)) @@ -278,8 +277,10 @@ def g(x): result = yield future assert result == 10 + 1 - yield a._close() - yield b._close() + yield c.close() + yield s.close() + yield a.close() + yield b.close() assert not os.path.exists(os.path.join(a.local_dir, pyzname)) @@ -309,7 +310,7 @@ def test_worker_with_port_zero(): assert isinstance(w.port, int) assert w.port > 1024 - yield w._close() + yield w.close() @slow @@ -392,7 +393,7 @@ def test_spill_to_disk(c, s): yield x assert set(w.data.fast) == {x.key, z.key} assert set(w.data.slow) == {y.key} or set(w.data.slow) == {x.key, y.key} - yield w._close() + yield w.close() @gen_cluster(client=True) @@ -443,9 +444,12 @@ def test_Executor(c, s): assert e._threads # had to do some work - yield w._close() + yield w.close() +@pytest.mark.skip( + reason="Other tests leak memory, so process-level checks" "trigger immediately" +) @gen_cluster( client=True, ncores=[("127.0.0.1", 1)], @@ -932,8 +936,8 @@ def test_global_workers(s, a, b): n = len(_global_workers) w = _global_workers[-1]() assert w is a or w is b - yield a._close() - yield b._close() + yield a.close() + yield b.close() assert len(_global_workers) == n - 2 @@ -952,7 +956,7 @@ def test_worker_fds(s): yield gen.sleep(0.01) assert time() < start + 1 - yield worker._close() + yield worker.close() start = time() while psutil.Process().num_fds() > start: @@ -971,19 +975,19 @@ def test_service_hosts_match_worker(s): yield w._start("tcp://0.0.0.0") sock = first(w.services["bokeh"].server._http._sockets.values()) assert sock.getsockname()[0] in ("::", "0.0.0.0") - yield w._close() + yield w.close() w = Worker(s.address, services={("bokeh", ":0"): BokehWorker}) yield w._start("tcp://127.0.0.1") sock = first(w.services["bokeh"].server._http._sockets.values()) assert sock.getsockname()[0] in ("::", "0.0.0.0") - yield w._close() + yield w.close() w = Worker(s.address, services={("bokeh", 0): BokehWorker}) yield w._start("tcp://127.0.0.1") sock = first(w.services["bokeh"].server._http._sockets.values()) assert sock.getsockname()[0] == "127.0.0.1" - yield w._close() + yield w.close() @gen_cluster(ncores=[]) diff --git a/distributed/tests/test_worker_client.py b/distributed/tests/test_worker_client.py index b0dd338153f..2d4632b0b54 100644 --- a/distributed/tests/test_worker_client.py +++ b/distributed/tests/test_worker_client.py @@ -234,7 +234,7 @@ def func(x): return yield wait(c.map(func, range(10))) - yield a._close() + yield a.close() assert c.status == "running" diff --git a/distributed/worker.py b/distributed/worker.py index 4cf0585b1de..10258d5c28f 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -959,7 +959,7 @@ def start(self, port=0): self.loop.add_callback(self._start, port) def _close(self, *args, **kwargs): - warnings.warn("Worker._close has moved to Worker.close") + warnings.warn("Worker._close has moved to Worker.close", stacklevel=2) return self.close(*args, **kwargs) @gen.coroutine @@ -999,6 +999,12 @@ def close(self, report=True, timeout=10, nanny=True, executor_wait=True): for k, v in self.services.items(): v.stop() + if self.batched_stream and not self.batched_stream.comm.closed(): + self.batched_stream.send({"op": "close-stream"}) + + if self.batched_stream: + self.batched_stream.close() + if nanny and "nanny" in self.service_ports: nanny_address = "%s%s:%d" % ( self.listener.prefix, @@ -1008,12 +1014,6 @@ def close(self, report=True, timeout=10, nanny=True, executor_wait=True): with self.rpc(nanny_address) as r: yield r.terminate() - if self.batched_stream and not self.batched_stream.comm.closed(): - self.batched_stream.send({"op": "close-stream"}) - - if self.batched_stream: - self.batched_stream.close() - self.rpc.close() self._closed.set() self._remove_from_global_workers() From e5d2488310ef52ccfdd2f9b70d9ae20b89874eac Mon Sep 17 00:00:00 2001 From: "K.-Michael Aye" Date: Tue, 7 May 2019 07:24:22 -0600 Subject: [PATCH 0264/1550] DOC: Clean up reference to cluster object (#2664) The current docs would overwrite the cluster objects, so it needs to get its own variable name. As the Client is usually named `c`, I changed the cluster object to the name `cluster` --- distributed/deploy/local.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/distributed/deploy/local.py b/distributed/deploy/local.py index ad8b36f6dd7..3431210a645 100644 --- a/distributed/deploy/local.py +++ b/distributed/deploy/local.py @@ -82,19 +82,19 @@ class LocalCluster(Cluster): Examples -------- - >>> c = LocalCluster() # Create a local cluster with as many workers as cores # doctest: +SKIP - >>> c # doctest: +SKIP + >>> cluster = LocalCluster() # Create a local cluster with as many workers as cores # doctest: +SKIP + >>> cluster # doctest: +SKIP LocalCluster("127.0.0.1:8786", workers=8, ncores=8) - >>> c = Client(c) # connect to local cluster # doctest: +SKIP + >>> c = Client(cluster) # connect to local cluster # doctest: +SKIP Add a new worker to the cluster - >>> w = c.start_worker(ncores=2) # doctest: +SKIP + >>> w = cluster.start_worker(ncores=2) # doctest: +SKIP Shut down the extra worker - >>> c.stop_worker(w) # doctest: +SKIP + >>> cluster.stop_worker(w) # doctest: +SKIP Pass extra keyword arguments to Bokeh From a61df1f54a7ae38c26d8d40de4e1e944067b8dea Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Tue, 7 May 2019 09:54:56 -0500 Subject: [PATCH 0265/1550] Add waiting task count to progress title bar (#2663) --- distributed/bokeh/scheduler.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/distributed/bokeh/scheduler.py b/distributed/bokeh/scheduler.py index a6f07351730..2dd60f0690f 100644 --- a/distributed/bokeh/scheduler.py +++ b/distributed/bokeh/scheduler.py @@ -1049,7 +1049,7 @@ def __init__(self, scheduler, **kwargs): def update(self): with log_errors(): state = {"all": valmap(len, self.plugin.all), "nbytes": self.plugin.nbytes} - for k in ["memory", "erred", "released", "processing"]: + for k in ["memory", "erred", "released", "processing", "waiting"]: state[k] = valmap(len, self.plugin.state[k]) if not state["all"] and not len(self.source.data["all"]): return @@ -1060,7 +1060,7 @@ def update(self): totals = { k: sum(state[k].values()) - for k in ["all", "memory", "erred", "released"] + for k in ["all", "memory", "erred", "released", "waiting"] } totals["processing"] = totals["all"] - sum( v for k, v in totals.items() if k != "all" @@ -1069,6 +1069,7 @@ def update(self): self.root.title.text = ( "Progress -- total: %(all)s, " "in-memory: %(memory)s, processing: %(processing)s, " + "waiting: %(waiting)s, " "erred: %(erred)s" % totals ) From f75ceb90ff294fd26428d3c8c2f54d6b524e0c0a Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 7 May 2019 18:29:08 -0500 Subject: [PATCH 0266/1550] Add Type Attribute to TaskState (#2657) --- distributed/bokeh/templates/task.html | 6 ++++ .../bokeh/tests/test_scheduler_bokeh_html.py | 25 +++++++++++++- distributed/protocol/serialize.py | 14 +------- distributed/scheduler.py | 33 ++++++++++++++++--- distributed/tests/test_collections.py | 5 +-- distributed/tests/test_scheduler.py | 8 +++++ distributed/utils.py | 15 +++++++++ distributed/worker.py | 10 ++++-- 8 files changed, 93 insertions(+), 23 deletions(-) diff --git a/distributed/bokeh/templates/task.html b/distributed/bokeh/templates/task.html index 9f3bb0f78f3..f396a4cba8f 100644 --- a/distributed/bokeh/templates/task.html +++ b/distributed/bokeh/templates/task.html @@ -19,6 +19,12 @@

        Task: {{ ts.key }}

        Call Stack {% end %} + {% if ts.type %} + + Type + {{ ts.type }} + + {% end %} {% if ts.nbytes %} Bytes diff --git a/distributed/bokeh/tests/test_scheduler_bokeh_html.py b/distributed/bokeh/tests/test_scheduler_bokeh_html.py index d5ca1ee7f05..96fe3c2f5d2 100644 --- a/distributed/bokeh/tests/test_scheduler_bokeh_html.py +++ b/distributed/bokeh/tests/test_scheduler_bokeh_html.py @@ -11,7 +11,8 @@ from tornado.escape import url_escape from tornado.httpclient import AsyncHTTPClient -from distributed.utils_test import gen_cluster, slowinc +from dask.sizeof import sizeof +from distributed.utils_test import gen_cluster, slowinc, inc from distributed.bokeh.scheduler import BokehScheduler @@ -105,3 +106,25 @@ def test_health(c, s, a, b): txt = response.body.decode("utf8") assert txt == "ok" + + +@gen_cluster(client=True, scheduler_kwargs={"services": {("bokeh", 0): BokehScheduler}}) +def test_task_page(c, s, a, b): + future = c.submit(lambda x: x + 1, 1, workers=a.address) + x = c.submit(inc, 1) + yield future + http_client = AsyncHTTPClient() + + "info/task/" + url_escape(future.key) + ".html", + response = yield http_client.fetch( + "http://localhost:%d/info/task/" % s.services["bokeh"].port + + url_escape(future.key) + + ".html" + ) + assert response.code == 200 + body = response.body.decode() + + assert str(sizeof(1)) in body + assert "int" in body + assert a.address in body + assert "memory" in body diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index f47ea7388af..26129f4e1c5 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -14,7 +14,7 @@ from . import pickle from ..compatibility import PY2 -from ..utils import has_keyword +from ..utils import has_keyword, typename from .compression import maybe_compress, decompress from .utils import ( unpack_frames, @@ -445,18 +445,6 @@ def register_serialization_lazy(toplevel, func): raise Exception("Serialization registration has changed. See documentation") -def typename(typ): - """ Return name of type - - Examples - -------- - >>> from distributed import Scheduler - >>> typename(Scheduler) - 'distributed.scheduler.Scheduler' - """ - return typ.__module__ + "." + typ.__name__ - - @partial(normalize_token.register, Serialized) def normalize_Serialized(o): return [o.header] + o.frames # for dask.base.tokenize diff --git a/distributed/scheduler.py b/distributed/scheduler.py index cc6f39ff396..8442b6ddcea 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -453,6 +453,11 @@ class TaskState(object): of a finished task. This number is used for diagnostics and to help prioritize work. + .. attribute:: type: str + + The type of the object as a string. Only present for tasks that have + been computed. + .. attribute:: exception: object If this task failed executing, the exception object is stored here. @@ -566,6 +571,7 @@ class TaskState(object): "suspicious", "retries", "nbytes", + "type", ) def __init__(self, key, run_spec): @@ -590,6 +596,7 @@ def __init__(self, key, run_spec): self.resource_restrictions = None self.loose_restrictions = False self.actor = None + self.type = None def get_nbytes(self): nbytes = self.nbytes @@ -1382,6 +1389,7 @@ def add_worker( name=None, resolve_address=True, nbytes=None, + types=None, now=None, resources=None, host_info=None, @@ -1460,7 +1468,11 @@ def add_worker( ts = self.tasks.get(key) if ts is not None and ts.state in ("processing", "waiting"): recommendations = self.transition( - key, "memory", worker=address, nbytes=nbytes[key] + key, + "memory", + worker=address, + nbytes=nbytes[key], + typename=types[key], ) self.transitions(recommendations) @@ -3418,7 +3430,9 @@ def _remove_from_processing(self, ts, send_worker_msg=None): if send_worker_msg: self.worker_send(w, send_worker_msg) - def _add_to_memory(self, ts, ws, recommendations, type=None, **kwargs): + def _add_to_memory( + self, ts, ws, recommendations, type=None, typename=None, **kwargs + ): """ Add *ts* to the set of in-memory tasks. """ @@ -3454,6 +3468,7 @@ def _add_to_memory(self, ts, ws, recommendations, type=None, **kwargs): self.report(msg) ts.state = "memory" + ts.type = typename cs = self.clients["fire-and-forget"] if ts in cs.wants_what: @@ -3676,7 +3691,14 @@ def transition_waiting_memory(self, key, nbytes=None, worker=None, **kwargs): raise def transition_processing_memory( - self, key, nbytes=None, type=None, worker=None, startstops=None, **kwargs + self, + key, + nbytes=None, + type=None, + typename=None, + worker=None, + startstops=None, + **kwargs ): try: ts = self.tasks[key] @@ -3749,7 +3771,7 @@ def transition_processing_memory( self._remove_from_processing(ts) - self._add_to_memory(ts, ws, recommendations, type=type) + self._add_to_memory(ts, ws, recommendations, type=type, typename=typename) if self.validate: assert not ts.processing_on @@ -4801,6 +4823,9 @@ def validate_task_state(ts): str(ts), str(ts.who_has), ) + if ts.run_spec: # was computed + assert ts.type + assert isinstance(ts.type, str) assert not any(ts in dts.waiting_on for dts in ts.dependents) for ws in ts.who_has: assert ts in ws.has_what, ( diff --git a/distributed/tests/test_collections.py b/distributed/tests/test_collections.py index f640d2d21e0..985b6f78fe9 100644 --- a/distributed/tests/test_collections.py +++ b/distributed/tests/test_collections.py @@ -73,10 +73,11 @@ def test_dataframes(c, s, a, b): def test__dask_array_collections(c, s, a, b): import dask.array as da + s.validate = False x_dsk = {("x", i, j): np.random.random((3, 3)) for i in range(3) for j in range(2)} y_dsk = {("y", i, j): np.random.random((3, 3)) for i in range(2) for j in range(3)} - x_futures = yield c._scatter(x_dsk) - y_futures = yield c._scatter(y_dsk) + x_futures = yield c.scatter(x_dsk) + y_futures = yield c.scatter(y_dsk) dt = np.random.random(0).dtype x_local = da.Array(x_dsk, "x", ((3, 3, 3), (3, 3)), dt) diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 6a631a13498..9224ed69030 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -1523,6 +1523,14 @@ def test_workerstate_clean(s, a, b): assert len(b) < 1000 +@gen_cluster(client=True) +def test_result_type(c, s, a, b): + x = c.submit(lambda: 1) + yield x + + assert "int" in s.tasks[x.key].type + + @gen_cluster() def test_close_workers(s, a, b): yield s.close(close_workers=True) diff --git a/distributed/utils.py b/distributed/utils.py index 28ce2364190..d6cc5ba62cf 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -1529,3 +1529,18 @@ def warn_on_duration(duration, msg): stop = time() if stop - start > parse_timedelta(duration): warnings.warn(msg, stacklevel=2) + + +def typename(typ): + """ Return name of type + + Examples + -------- + >>> from distributed import Scheduler + >>> typename(Scheduler) + 'distributed.scheduler.Scheduler' + """ + try: + return typ.__module__ + "." + typ.__name__ + except AttributeError: + return str(typ) diff --git a/distributed/worker.py b/distributed/worker.py index 10258d5c28f..915784edd88 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -44,6 +44,7 @@ from .threadpoolexecutor import ThreadPoolExecutor, secede as tpe_secede from .utils import ( funcname, + typename, get_ip, has_arg, _maybe_complex, @@ -671,6 +672,7 @@ def _register_with_scheduler(self): raise gen.Return try: _start = time() + types = {k: typename(v) for k, v in self.data.items()} comm = yield connect( self.scheduler.address, connection_args=self.connection_args ) @@ -685,6 +687,7 @@ def _register_with_scheduler(self): ncores=self.ncores, name=self.name, nbytes=self.nbytes, + types=types, now=time(), resources=self.total_resources, memory_limit=self.memory_limit, @@ -1721,18 +1724,19 @@ def send_task_state_to_scheduler(self, key): typ = self.types.get(key) or type(value) del value try: - typ = dumps_function(typ) + typ_serialized = dumps_function(typ) except PicklingError: # Some types fail pickling (example: _thread.lock objects), # send their name as a best effort. - typ = pickle.dumps(typ.__name__) + typ_serialized = pickle.dumps(typ.__name__) d = { "op": "task-finished", "status": "OK", "key": key, "nbytes": nbytes, "thread": self.threads.get(key), - "type": typ, + "type": typ_serialized, + "typename": typename(typ), } elif key in self.exceptions: d = { From a0d57710c9c836f8104b55c448eb2cec5f38e959 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Wed, 8 May 2019 14:12:01 -0500 Subject: [PATCH 0267/1550] bump version to 1.28.0 --- docs/source/changelog.rst | 30 ++++++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index 34fe078fe18..1b584c02219 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -1,6 +1,30 @@ Changelog ========= +1.28.0 - 2019-05-08 +------------------- + +- Add Type Attribute to TaskState (:pr:`2657`) `Matthew Rocklin`_ +- Add waiting task count to progress title bar (:pr:`2663`) `James Bourbeau`_ +- DOC: Clean up reference to cluster object (:pr:`2664`) `K.-Michael Aye`_ +- Allow scheduler to politely close workers as part of shutdown (:pr:`2651`) `Matthew Rocklin`_ +- Check direct_to_workers before using get_worker in Client (:pr:`2656`) `Matthew Rocklin`_ +- Fixed comment regarding keeping existing level if less verbose (:pr:`2655`) `Brett Randall`_ +- Add idle timeout to scheduler (:pr:`2652`) `Matthew Rocklin`_ +- Avoid deprecation warnings (:pr:`2653`) `Matthew Rocklin`_ +- Use an LRU cache for deserialized functions (:pr:`2623`) `Matthew Rocklin`_ +- Rename Worker._close to Worker.close (:pr:`2650`) `Matthew Rocklin`_ +- Add Comm closed bookkeeping (:pr:`2648`) `Matthew Rocklin`_ +- Explain LocalCluster behavior in Client docstring (:pr:`2647`) `Matthew Rocklin`_ +- Add last worker into KilledWorker exception to help debug (:pr:`2610`) `@plbertrand`_ +- Set working worker class for dask-ssh (:pr:`2646`) `Martin Durant`_ +- Add as_completed methods to docs (:pr:`2642`) `Jim Crist`_ +- Add timeout to Client._reconnect (:pr:`2639`) `Jim Crist`_ +- Limit test_spill_by_default memory, reenable it (:pr:`2633`) `Peter Andreas Entschev`_ +- Use proper address in worker -> nanny comms (:pr:`2640`) `Jim Crist`_ +- Fix deserialization of bytes chunks larger than 64MB (:pr:`2637`) `Peter Andreas Entschev`_ + + 1.27.1 - 2019-04-29 ------------------- @@ -20,7 +44,7 @@ Changelog 1.27.0 - 2019-04-12 ------------------- - Add basic health endpoints to scheduler and worker bokeh. (:pr:`2607) `amerkel2`_ +- Add basic health endpoints to scheduler and worker bokeh. (:pr:`2607`) `amerkel2`_ - Improved description accuracy of --memory-limit option. (:pr:`2601`) `Brett Randall`_ - Check self.dependencies when looking at dependent tasks in memory (:pr:`2606`) `deepthirajagopalan7`_ - Add RabbitMQ SchedulerPlugin example (:pr:`2604`) `Matt Nicolls`_ @@ -28,7 +52,7 @@ Changelog - Use ensure_bytes in serialize_error (:pr:`2588`) `Matthew Rocklin`_ - Specify data storage explicitly from Worker constructor (:pr:`2600`) `Matthew Rocklin`_ - Change bokeh port keywords to dashboard_address (:pr:`2589`) `Matthew Rocklin`_ -- .detach_(`) pytorch tensor to serialize data as numpy array. (:pr:`2586`) `Muammar El Khatib`_ +- .detach_() pytorch tensor to serialize data as numpy array. (:pr:`2586`) `Muammar El Khatib`_ - Add warning if creating scratch directories takes a long time (:pr:`2561`) `Matthew Rocklin`_ - Fix typo in pub-sub doc. (:pr:`2599`) `Loïc Estève`_ - Allow return_when='FIRST_COMPLETED' in wait (:pr:`2598`) `Nikos Tsaousis`_ @@ -1005,3 +1029,5 @@ significantly without many new features. .. _`Michael Delgado`: https://github.com/delgadom .. _`Peter Andreas Entschev`: https://github.com/pentschev .. _`condoratberlin`: https://github.com/condoratberlin +.. _`K.-Michael Aye`: https://github.com/michaelaye +.. _`@plbertrand`: https://github.com/plbertrand \ No newline at end of file From 7b526c0f436955d58144a8a76b64b5dea3e0b174 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Wed, 8 May 2019 14:32:51 -0500 Subject: [PATCH 0268/1550] Add release procedure doc (#2672) --- docs/release-procedure.md | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 docs/release-procedure.md diff --git a/docs/release-procedure.md b/docs/release-procedure.md new file mode 100644 index 00000000000..f9efd6a0ab1 --- /dev/null +++ b/docs/release-procedure.md @@ -0,0 +1,3 @@ +Distributed follows a similar procedure for releasing as the core Dask project. + +See https://github.com/dask/dask/blob/master/docs/release-procedure.md for instructions. \ No newline at end of file From ff6d3565b761b93f06af94cfb0d999af189a50e8 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 8 May 2019 17:06:47 -0500 Subject: [PATCH 0269/1550] Remove AioClient (#2668) See also https://github.com/dask/dask-examples/pull/71 Fixes https://github.com/dask/distributed/issues/1164 Fixes https://github.com/dask/distributed/issues/1762 Fixes https://github.com/dask/distributed/issues/2537 Fixes https://github.com/dask/distributed/issues/2661 --- distributed/asyncio.py | 130 +-------- distributed/tests/py3_test_asyncio.py | 365 -------------------------- distributed/tests/test_asyncio.py | 5 - docs/source/api.rst | 8 - docs/source/asynchronous.rst | 46 +--- 5 files changed, 17 insertions(+), 537 deletions(-) delete mode 100644 distributed/tests/py3_test_asyncio.py delete mode 100644 distributed/tests/test_asyncio.py diff --git a/distributed/asyncio.py b/distributed/asyncio.py index b75bf2a1130..7a7225e60fd 100644 --- a/distributed/asyncio.py +++ b/distributed/asyncio.py @@ -1,126 +1,14 @@ -"""Experimental interface for asyncio, may disappear without warning""" - -# flake8: noqa - -import asyncio -from functools import wraps - -from toolz import merge - -from tornado.platform.asyncio import BaseAsyncIOLoop -from tornado.platform.asyncio import to_asyncio_future - -from . import client -from .client import Client, Future -from .variable import Variable -from .utils import ignoring - - -def to_asyncio(fn, **default_kwargs): - """Converts Tornado gen.coroutines and futures to asyncio ones""" - - @wraps(fn) - def convert(*args, **kwargs): - if default_kwargs: - kwargs = merge(default_kwargs, kwargs) - return to_asyncio_future(fn(*args, **kwargs)) - - return convert - - -class AioClient(Client): - """ Connect to and drive computation on a distributed Dask cluster - - This class provides an asyncio compatible async/await interface for - dask.distributed. - - The Client connects users to a dask.distributed compute cluster. It - provides an asynchronous user interface around functions and futures. - This class resembles executors in ``concurrent.futures`` but also - allows ``Future`` objects within ``submit/map`` calls. - - AioClient is an **experimental** interface for distributed and may - disappear without warning! - - Parameters - ---------- - address: string, or Cluster - This can be the address of a ``Scheduler`` server like a string - ``'127.0.0.1:8786'`` or a cluster object like ``LocalCluster()`` - - Examples - -------- - Provide cluster's scheduler address on initialization:: - - client = AioClient('127.0.0.1:8786') - - Start the client:: - - async def start_the_client(): - client = await AioClient() - - # Use the client.... - - await client.close() - - An ``async with`` statement is a more convenient way to start and shut down - the client:: - - async def start_the_client(): - async with AioClient() as client: - # Use the client within this block. - pass - - Use the ``submit`` method to send individual computations to the cluster, - and await the returned future to retrieve the result:: - - async def add_two_numbers(): - async with AioClient() as client: - a = client.submit(add, 1, 2) - result = await a - - Continue using submit or map on results to build up larger computations, - and gather results with the ``gather`` method:: - - async def gather_some_results(): - async with AioClient() as client: - a = client.submit(add, 1, 2) - b = client.submit(add, 10, 20) - c = client.submit(add, a, b) - result = await client.gather([c]) - - See Also - -------- - distributed.client.Client: Blocking Client - distributed.scheduler.Scheduler: Internal scheduler +raise ImportError( """ - def __init__(self, *args, **kwargs): - loop = asyncio.get_event_loop() - ioloop = BaseAsyncIOLoop(loop) - super().__init__(*args, loop=ioloop, asynchronous=True, **kwargs) - - def __enter__(self): - raise RuntimeError("Use AioClient in an 'async with' block, not 'with'") - - async def __aenter__(self): - await to_asyncio_future(self._started) - return self - - async def __aexit__(self, type, value, traceback): - await to_asyncio_future(self._close()) - - def __await__(self): - return to_asyncio_future(self._started).__await__() - - get = to_asyncio(Client.get, sync=False) - sync = to_asyncio(Client.sync) - close = to_asyncio(Client.close) - shutdown = to_asyncio(Client.shutdown) - +The dask.distributed.AioClient object has been removed. +We recommend using the normal client with asynchonrous=True -class as_completed(client.as_completed): - __anext__ = to_asyncio(client.as_completed.__anext__) + client = await Client(..., asynchronous=True) +and a version of Tornado >= 5. -wait = to_asyncio(client._wait) +Documentation: https://distributed.dask.org/en/latest/asynchronous.html +Example: https://examples.dask.org/applications/async-await.html +""" +) diff --git a/distributed/tests/py3_test_asyncio.py b/distributed/tests/py3_test_asyncio.py deleted file mode 100644 index 3754b282813..00000000000 --- a/distributed/tests/py3_test_asyncio.py +++ /dev/null @@ -1,365 +0,0 @@ -# flake8: noqa -import pytest - -asyncio = pytest.importorskip("asyncio") - -import functools -from time import time -from operator import add -from toolz import isdistinct -from concurrent.futures import CancelledError -from distributed.utils_test import slow -from distributed.utils_test import slowinc - -from tornado.ioloop import IOLoop -from tornado.platform.asyncio import BaseAsyncIOLoop - -from distributed.client import Future -from distributed.variable import Variable -from distributed.asyncio import AioClient -from distributed.asyncio import as_completed, wait -from distributed.utils_test import inc, div - - -def coro_test(fn): - assert asyncio.iscoroutinefunction(fn) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - loop = None - try: - IOLoop.clear_current() - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop.run_until_complete(fn(*args, **kwargs)) - finally: - if loop is not None: - loop.close() - - IOLoop.clear_current() - asyncio.set_event_loop(None) - - return wrapper - - -@coro_test -async def test_coro_test(): - assert asyncio.get_event_loop().is_running() - - -@coro_test -async def test_asyncio_start_close(): - async with AioClient(processes=False, dashboard_address=False) as c: - assert c.status == "running" - # AioClient has installed its AioLoop shim. - assert isinstance(IOLoop.current(instance=False), BaseAsyncIOLoop) - - result = await c.submit(inc, 10) - assert result == 11 - - await c.close() - assert c.status == "closed" - # assert IOLoop.current(instance=False) is None - - -@coro_test -async def test_asyncio_submit(): - async with AioClient(processes=False) as c: - x = c.submit(inc, 10) - assert not x.done() - - assert isinstance(x, Future) - assert x.client is c - - result = await x.result() - assert result == 11 - assert x.done() - - y = c.submit(inc, 20) - z = c.submit(add, x, y) - - result = await z.result() - assert result == 11 + 21 - - -@coro_test -async def test_asyncio_future_await(): - async with AioClient(processes=False) as c: - x = c.submit(inc, 10) - assert not x.done() - - assert isinstance(x, Future) - assert x.client is c - - result = await x - assert result == 11 - assert x.done() - - y = c.submit(inc, 20) - z = c.submit(add, x, y) - - result = await z - assert result == 11 + 21 - - -@coro_test -async def test_asyncio_map(): - async with AioClient(processes=False) as c: - L1 = c.map(inc, range(5)) - assert len(L1) == 5 - assert isdistinct(x.key for x in L1) - assert all(isinstance(x, Future) for x in L1) - - result = await L1[0] - assert result == inc(0) - - L2 = c.map(inc, L1) - - result = await L2[1] - assert result == inc(inc(1)) - - total = c.submit(sum, L2) - result = await total - assert result == sum(map(inc, map(inc, range(5)))) - - L3 = c.map(add, L1, L2) - result = await L3[1] - assert result == inc(1) + inc(inc(1)) - - L4 = c.map(add, range(3), range(4)) - results = await c.gather(L4) - assert results == list(map(add, range(3), range(4))) - - def f(x, y=10): - return x + y - - L5 = c.map(f, range(5), y=5) - results = await c.gather(L5) - assert results == list(range(5, 10)) - - y = c.submit(f, 10) - L6 = c.map(f, range(5), y=y) - results = await c.gather(L6) - assert results == list(range(20, 25)) - - -@coro_test -async def test_asyncio_gather(): - async with AioClient(processes=False) as c: - x = c.submit(inc, 10) - y = c.submit(inc, x) - - result = await c.gather(x) - assert result == 11 - result = await c.gather([x]) - assert result == [11] - result = await c.gather({"x": x, "y": [y]}) - assert result == {"x": 11, "y": [12]} - - -@coro_test -async def test_asyncio_get(): - async with AioClient(processes=False) as c: - result = await c.get({"x": (inc, 1)}, "x") - assert result == 2 - - result = await c.get({"x": (inc, 1)}, ["x"]) - assert result == [2] - - result = await c.get({}, []) - assert result == [] - - result = await c.get({("x", 1): (inc, 1), ("x", 2): (inc, ("x", 1))}, ("x", 2)) - assert result == 3 - - -@coro_test -async def test_asyncio_exceptions(): - async with AioClient(processes=False) as c: - result = await c.submit(div, 1, 2) - assert result == 1 / 2 - - with pytest.raises(ZeroDivisionError): - result = await c.submit(div, 1, 0) - - result = await c.submit(div, 10, 2) # continues to operate - assert result == 10 / 2 - - -@coro_test -async def test_asyncio_exception_on_exception(): - async with AioClient(processes=False) as c: - x = c.submit(lambda: 1 / 0) - y = c.submit(inc, x) - - with pytest.raises(ZeroDivisionError): - await y - - z = c.submit(inc, y) - with pytest.raises(ZeroDivisionError): - await z - - -@coro_test -async def test_asyncio_as_completed(): - async with AioClient(processes=False) as c: - futures = c.map(inc, range(10)) - - results = [] - async for future in as_completed(futures): - results.append(await future) - - assert set(results) == set(range(1, 11)) - - -@coro_test -async def test_asyncio_cancel(): - async with AioClient(processes=False) as c: - s = c.cluster.scheduler - - x = c.submit(slowinc, 1) - y = c.submit(slowinc, x) - - while y.key not in s.tasks: - await asyncio.sleep(0.01) - - await c.cancel([x]) - - assert x.cancelled() - assert "cancel" in str(x) - s.validate_state() - - start = time() - while not y.cancelled(): - await asyncio.sleep(0.01) - assert time() < start + 5 - - assert not s.tasks - assert not s.who_has - s.validate_state() - - -@coro_test -async def test_asyncio_cancel_tuple_key(): - async with AioClient(processes=False) as c: - x = c.submit(inc, 1, key=("x", 0, 1)) - await x - await c.cancel(x) - with pytest.raises(CancelledError): - await x - - -@coro_test -async def test_asyncio_wait(): - async with AioClient(processes=False) as c: - x = c.submit(inc, 1) - y = c.submit(inc, 2) - z = c.submit(inc, 3) - - await wait(x) - assert x.done() is True - - await wait([y, z]) - assert y.done() is True - assert z.done() is True - - -@coro_test -async def test_asyncio_run(): - async with AioClient(processes=False) as c: - results = await c.run(inc, 1) - assert len(results) > 0 - assert [value == 2 for value in results.values()] - - results = await c.run(inc, 1, workers=[]) - assert results == {} - - -@coro_test -async def test_asyncio_run_on_scheduler(): - def f(dask_scheduler=None): - return dask_scheduler.address - - async with AioClient(processes=False) as c: - address = await c.run_on_scheduler(f) - assert address == c.cluster.scheduler.address - - with pytest.raises(ZeroDivisionError): - await c.run_on_scheduler(div, 1, 0) - - -@coro_test -async def test_asyncio_run_coroutine(): - async def aioinc(x, delay=0.02): - await asyncio.sleep(delay) - return x + 1 - - async def aiothrows(x, delay=0.02): - await asyncio.sleep(delay) - raise RuntimeError("hello") - - async with AioClient(processes=False) as c: - results = await c.run(aioinc, 1, delay=0.05) - assert len(results) > 0 - assert [value == 2 for value in results.values()] - - results = await c.run(aioinc, 1, workers=[]) - assert results == {} - - with pytest.raises(RuntimeError) as exc_info: - await c.run(aiothrows, 1) - assert "hello" in str(exc_info) - - -@slow -@coro_test -async def test_asyncio_restart(): - async with AioClient(processes=False) as c: - assert c.status == "running" - x = c.submit(inc, 1) - assert x.key in c.refcount - - await c.restart() - assert x.key not in c.refcount - - key = x.key - del x - import gc - - gc.collect() - - assert key not in c.refcount - - -@coro_test -async def test_asyncio_nanny_workers(): - async with AioClient(n_workers=2) as c: - assert await c.submit(inc, 1) == 2 - - -@coro_test -async def test_asyncio_variable(): - async with AioClient(processes=False) as c: - s = c.cluster.scheduler - - x = Variable("x") - xx = Variable("x") - assert x.client is c - - future = c.submit(inc, 1) - - await x.set(future) - future2 = await xx.get() - assert future.key == future2.key - - del future, future2 - - await asyncio.sleep(0.1) - assert s.tasks # future still present - - x.delete() - - start = time() - while s.tasks: - await asyncio.sleep(0.01) - assert time() < start + 5 diff --git a/distributed/tests/test_asyncio.py b/distributed/tests/test_asyncio.py deleted file mode 100644 index 4eab91a5a81..00000000000 --- a/distributed/tests/test_asyncio.py +++ /dev/null @@ -1,5 +0,0 @@ -import sys - - -if sys.version_info >= (3, 5): - from distributed.tests.py3_test_asyncio import * # noqa: F401, F403 diff --git a/docs/source/api.rst b/docs/source/api.rst index 47933be06d4..e91c4ee6ac1 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -177,14 +177,6 @@ Other :members: -Asyncio Client --------------- - -.. currentmodule:: distributed.asyncio -.. autoclass:: AioClient - :members: - - Adaptive -------- diff --git a/docs/source/asynchronous.rst b/docs/source/asynchronous.rst index 1c079035722..9d38a8b04fe 100644 --- a/docs/source/asynchronous.rst +++ b/docs/source/asynchronous.rst @@ -57,25 +57,6 @@ call. results = await client.gather(futures, asynchronous=True) return results -AsyncIO -------- - -If you prefer to use the Asyncio event loop over the Tornado event loop you -should use the ``AioClient``. - -.. code-block:: python - - from distributed.asyncio import AioClient - client = await AioClient() - -All other operations remain the same: - -.. code-block:: python - - future = client.submit(lambda x: x + 1, 10) - result = await future - # or - result = await client.gather(future) Python 2 Compatibility ---------------------- @@ -90,8 +71,8 @@ This self-contained example starts an asynchronous client, submits a trivial job, waits on the result, and then shuts down the client. You can see implementations for Python 2 and 3 and for Asyncio and Tornado. -Python 3 with Tornado -+++++++++++++++++++++ +Python 3 with Tornado or Asyncio +++++++++++++++++++++++++++++++++ .. code-block:: python @@ -104,9 +85,15 @@ Python 3 with Tornado await client.close() return result + # Either use Tornado from tornado.ioloop import IOLoop IOLoop().run_sync(f) + # Or use asyncio + import asyncio + asyncio.get_event_loop().run_until_complete(f()) + + Python 2/3 with Tornado +++++++++++++++++++++++ @@ -126,23 +113,6 @@ Python 2/3 with Tornado from tornado.ioloop import IOLoop IOLoop().run_sync(f) -Python 3 with Asyncio -+++++++++++++++++++++ - -.. code-block:: python - - from distributed.asyncio import AioClient - - async def f(): - client = await AioClient() - future = client.submit(lambda x: x + 1, 10) - result = await future - await client.close() - return result - - from asyncio import get_event_loop - get_event_loop().run_until_complete(f()) - Use Cases --------- From 8dee90c3b093538188def4513d274886cae1c842 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 8 May 2019 19:16:01 -0500 Subject: [PATCH 0270/1550] Move interface/host/port handling from CLI to classes (#2667) This should allow other systems to benefit from this logic, but does make the class constructors a bit more complicated. Overall I think it's a win though. * Improve debug information around test_file_descriptors * Move port and host information into constructor in tests * Pull out interface/host/port logic * Test no leaked processes --- distributed/cli/dask_scheduler.py | 31 +++++----- distributed/cli/dask_worker.py | 25 ++------ distributed/cli/tests/test_cli_utils.py | 54 ------------------ distributed/cli/tests/test_dask_worker.py | 3 +- distributed/cli/utils.py | 40 ------------- distributed/comm/addressing.py | 69 +++++++++++++++++++++++ distributed/nanny.py | 15 +++++ distributed/scheduler.py | 27 ++++++++- distributed/tests/test_client.py | 9 +-- distributed/tests/test_scheduler.py | 21 ++++--- distributed/tests/test_worker.py | 21 +++++-- distributed/utils_test.py | 38 +++++++++---- distributed/worker.py | 18 +++++- 13 files changed, 206 insertions(+), 165 deletions(-) delete mode 100644 distributed/cli/tests/test_cli_utils.py diff --git a/distributed/cli/dask_scheduler.py b/distributed/cli/dask_scheduler.py index 3b0aa5b4c70..57a7168a3a2 100755 --- a/distributed/cli/dask_scheduler.py +++ b/distributed/cli/dask_scheduler.py @@ -5,6 +5,7 @@ import logging import gc import os +import re import shutil import sys import tempfile @@ -16,12 +17,7 @@ from distributed import Scheduler from distributed.security import Security -from distributed.utils import get_ip_interface -from distributed.cli.utils import ( - check_python_3, - install_signal_handlers, - uri_from_host_port, -) +from distributed.cli.utils import check_python_3, install_signal_handlers from distributed.preloading import preload_modules, validate_preload_argv from distributed.proctitle import ( enable_proctitle_on_children, @@ -151,6 +147,9 @@ def main( ) dashboard_address = bokeh_port + if port is None and (not host or not re.search(r":\d", host)): + port = 8786 + sec = Security( tls_ca_file=tls_ca_file, tls_scheduler_cert=tls_cert, tls_scheduler_key=tls_key ) @@ -186,14 +185,6 @@ def del_pid_file(): limit = max(soft, hard // 2) resource.setrlimit(resource.RLIMIT_NOFILE, (limit, hard)) - if interface: - if host: - raise ValueError("Can not specify both interface and host") - else: - host = get_ip_interface(interface) - - addr = uri_from_host_port(host, port, 8786) - loop = IOLoop.current() logger.info("-" * 47) @@ -213,9 +204,15 @@ def del_pid_file(): logger.info("Unable to import bokeh: %s" % str(error)) scheduler = Scheduler( - loop=loop, services=services, scheduler_file=scheduler_file, security=sec + loop=loop, + services=services, + scheduler_file=scheduler_file, + security=sec, + host=host, + port=port, + interface=interface, ) - scheduler.start(addr) + scheduler.start() if not preload: preload = dask.config.get("distributed.scheduler.preload") if not preload_argv: @@ -237,7 +234,7 @@ def del_pid_file(): if local_directory_created: shutil.rmtree(local_directory) - logger.info("End scheduler at %r", addr) + logger.info("End scheduler at %r", scheduler.address) def go(): diff --git a/distributed/cli/dask_worker.py b/distributed/cli/dask_worker.py index a0bc801a960..6315939005d 100755 --- a/distributed/cli/dask_worker.py +++ b/distributed/cli/dask_worker.py @@ -10,14 +10,10 @@ import click from distributed import Nanny, Worker from distributed.config import config -from distributed.utils import get_ip_interface, parse_timedelta +from distributed.utils import parse_timedelta from distributed.worker import _ncores from distributed.security import Security -from distributed.cli.utils import ( - check_python_3, - uri_from_host_port, - install_signal_handlers, -) +from distributed.cli.utils import check_python_3, install_signal_handlers from distributed.comm import get_address_host_port from distributed.preloading import validate_preload_argv from distributed.proctitle import ( @@ -328,18 +324,6 @@ def del_pid_file(): "dask-worker SCHEDULER_ADDRESS:8786" ) - if interface: - if host: - raise ValueError("Can not specify both interface and host") - else: - host = get_ip_interface(interface) - - if host or port: - addr = uri_from_host_port(host, port, 0) - else: - # Choose appropriate address for scheduler - addr = None - if death_timeout is not None: death_timeout = parse_timedelta(death_timeout, "s") @@ -359,6 +343,9 @@ def del_pid_file(): preload_argv=preload_argv, security=sec, contact_address=contact_address, + interface=interface, + host=host, + port=port, name=name if nprocs == 1 or not name else name + "-" + str(i), **kwargs ) @@ -377,7 +364,7 @@ def on_signal(signum): @gen.coroutine def run(): - yield [n._start(addr) for n in nannies] + yield nannies while all(n.status != "closed" for n in nannies): yield gen.sleep(0.2) diff --git a/distributed/cli/tests/test_cli_utils.py b/distributed/cli/tests/test_cli_utils.py deleted file mode 100644 index 4f07f699de5..00000000000 --- a/distributed/cli/tests/test_cli_utils.py +++ /dev/null @@ -1,54 +0,0 @@ -from __future__ import print_function, division, absolute_import - -import pytest - -pytest.importorskip("requests") - -from distributed.cli.utils import uri_from_host_port -from distributed.utils import get_ip - - -external_ip = get_ip() - - -def test_uri_from_host_port(): - f = uri_from_host_port - - assert f("", 456, None) == "tcp://:456" - assert f("", 456, 123) == "tcp://:456" - assert f("", None, 123) == "tcp://:123" - assert f("", None, 0) == "tcp://" - assert f("", 0, 123) == "tcp://" - - assert f("localhost", 456, None) == "tcp://localhost:456" - assert f("localhost", 456, 123) == "tcp://localhost:456" - assert f("localhost", None, 123) == "tcp://localhost:123" - assert f("localhost", None, 0) == "tcp://localhost" - - assert f("192.168.1.2", 456, None) == "tcp://192.168.1.2:456" - assert f("192.168.1.2", 456, 123) == "tcp://192.168.1.2:456" - assert f("192.168.1.2", None, 123) == "tcp://192.168.1.2:123" - assert f("192.168.1.2", None, 0) == "tcp://192.168.1.2" - - assert f("tcp://192.168.1.2", 456, None) == "tcp://192.168.1.2:456" - assert f("tcp://192.168.1.2", 456, 123) == "tcp://192.168.1.2:456" - assert f("tcp://192.168.1.2", None, 123) == "tcp://192.168.1.2:123" - assert f("tcp://192.168.1.2", None, 0) == "tcp://192.168.1.2" - - assert f("tcp://192.168.1.2:456", None, None) == "tcp://192.168.1.2:456" - assert f("tcp://192.168.1.2:456", 0, 0) == "tcp://192.168.1.2:456" - assert f("tcp://192.168.1.2:456", 0, 123) == "tcp://192.168.1.2:456" - assert f("tcp://192.168.1.2:456", 456, 123) == "tcp://192.168.1.2:456" - - with pytest.raises(ValueError): - # Two incompatible port values - f("tcp://192.168.1.2:456", 123, None) - - assert f("tls://192.168.1.2:456", None, None) == "tls://192.168.1.2:456" - assert f("tls://192.168.1.2:456", 0, 0) == "tls://192.168.1.2:456" - assert f("tls://192.168.1.2:456", 0, 123) == "tls://192.168.1.2:456" - assert f("tls://192.168.1.2:456", 456, 123) == "tls://192.168.1.2:456" - - assert f("tcp://[::1]:456", None, None) == "tcp://[::1]:456" - - assert f("tls://[::1]:456", None, None) == "tls://[::1]:456" diff --git a/distributed/cli/tests/test_dask_worker.py b/distributed/cli/tests/test_dask_worker.py index 72084e53141..2fa3779d9b4 100644 --- a/distributed/cli/tests/test_dask_worker.py +++ b/distributed/cli/tests/test_dask_worker.py @@ -7,7 +7,6 @@ import requests import sys from time import sleep -from toolz import first from distributed import Client from distributed.metrics import time @@ -52,7 +51,7 @@ def test_memory_limit(loop): while not c.ncores(): sleep(0.1) info = c.scheduler_info() - d = first(info["workers"].values()) + [d] = info["workers"].values() assert isinstance(d["memory_limit"], int) assert d["memory_limit"] == 2e9 diff --git a/distributed/cli/utils.py b/distributed/cli/utils.py index 4ce1d845821..2c2088a7556 100644 --- a/distributed/cli/utils.py +++ b/distributed/cli/utils.py @@ -3,13 +3,6 @@ from tornado import gen from tornado.ioloop import IOLoop -from distributed.comm import ( - parse_address, - unparse_address, - parse_host_port, - unparse_host_port, -) - py3_err_msg = """ Warning: Your terminal does not set locales. @@ -75,36 +68,3 @@ def cleanup_and_stop(): for sig in [signal.SIGINT, signal.SIGTERM]: old_handlers[sig] = signal.signal(sig, handle_signal) - - -def uri_from_host_port(host_arg, port_arg, default_port): - """ - Process the *host* and *port* CLI options. - Return a URI. - """ - # Much of distributed depends on a well-known IP being assigned to - # each entity (Worker, Scheduler, etc.), so avoid "universal" addresses - # like '' which would listen on all registered IPs and interfaces. - scheme, loc = parse_address(host_arg or "") - - host, port = parse_host_port( - loc, port_arg if port_arg is not None else default_port - ) - - if port is None and port_arg is None: - port_arg = default_port - - if port and port_arg and port != port_arg: - raise ValueError( - "port number given twice in options: " - "host %r and port %r" % (host_arg, port_arg) - ) - if port is None and port_arg is not None: - port = port_arg - # Note `port = 0` means "choose a random port" - if port is None: - port = default_port - loc = unparse_host_port(host, port) - addr = unparse_address(scheme, loc) - - return addr diff --git a/distributed/comm/addressing.py b/distributed/comm/addressing.py index 20ddb2c863f..3d79befe0f1 100644 --- a/distributed/comm/addressing.py +++ b/distributed/comm/addressing.py @@ -5,6 +5,7 @@ import dask from . import registry +from ..utils import get_ip_interface DEFAULT_SCHEME = dask.config.get("distributed.comm.default-scheme") @@ -172,3 +173,71 @@ def resolve_address(addr): scheme, loc = parse_address(addr) backend = registry.get_backend(scheme) return unparse_address(scheme, backend.resolve_address(loc)) + + +def uri_from_host_port(host_arg, port_arg, default_port): + """ + Process the *host* and *port* CLI options. + Return a URI. + """ + # Much of distributed depends on a well-known IP being assigned to + # each entity (Worker, Scheduler, etc.), so avoid "universal" addresses + # like '' which would listen on all registered IPs and interfaces. + scheme, loc = parse_address(host_arg or "") + + host, port = parse_host_port( + loc, port_arg if port_arg is not None else default_port + ) + + if port is None and port_arg is None: + port_arg = default_port + + if port and port_arg and port != port_arg: + raise ValueError( + "port number given twice in options: " + "host %r and port %r" % (host_arg, port_arg) + ) + if port is None and port_arg is not None: + port = port_arg + # Note `port = 0` means "choose a random port" + if port is None: + port = default_port + loc = unparse_host_port(host, port) + addr = unparse_address(scheme, loc) + + return addr + + +def address_from_user_args( + host=None, port=None, interface=None, protocol=None, peer=None, security=None +): + """ Get an address to listen on from common user provided arguments """ + if security and security.require_encryption and not protocol: + protocol = "tls" + + if protocol and protocol.rstrip("://") == "inplace": + if host or port or interface: + raise ValueError( + "Can not specify inproc protocol and host or port or interface" + ) + else: + return "inproc://" + + if interface: + if host: + raise ValueError("Can not specify both interface and host", interface, host) + else: + host = get_ip_interface(interface) + + if protocol and host and "://" not in host: + host = protocol.rstrip("://") + "://" + host + + if host or port: + addr = uri_from_host_port(host, port, 0) + else: + addr = "" + + if protocol and "://" not in addr: + addr = protocol.rstrip("://") + "://" + addr + + return addr diff --git a/distributed/nanny.py b/distributed/nanny.py index 60e83e86da7..ef0e0a38f0e 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -16,6 +16,7 @@ from tornado.locks import Event from .comm import get_address_host, get_local_address_for, unparse_host_port +from .comm.addressing import address_from_user_args from .core import rpc, RPCClosed, CommClosedError, coerce_to_address from .metrics import time from .node import ServerNode @@ -69,6 +70,10 @@ def __init__( listen_address=None, worker_class=None, env=None, + interface=None, + host=None, + port=None, + protocol=None, **worker_kwargs ): @@ -135,6 +140,14 @@ def __init__( pc = PeriodicCallback(self.memory_monitor, 100, io_loop=self.loop) self.periodic_callbacks["memory"] = pc + self._start_address = address_from_user_args( + host=host, + port=port, + interface=interface, + protocol=protocol, + security=security, + ) + self._listen_address = listen_address self.status = "init" @@ -175,6 +188,7 @@ def worker_dir(self): @gen.coroutine def _start(self, addr_or_port=0): """ Start nanny, start local process, start watching """ + addr_or_port = addr_or_port or self._start_address # XXX Factor this out if not addr_or_port: @@ -419,6 +433,7 @@ def start(self): self.process = AsyncProcess( target=self._run, + name="Dask Worker process (from Nanny)", kwargs=dict( worker_args=self.worker_args, worker_kwargs=self.worker_kwargs, diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 8442b6ddcea..8ba4cedf468 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -35,6 +35,7 @@ get_address_host, unparse_host_port, ) +from .comm.addressing import address_from_user_args from .compatibility import finalize, unicode, Mapping, Set from .core import rpc, connect, send_recv, clean_exception, CommClosedError from . import profile @@ -824,9 +825,12 @@ def __init__( security=None, worker_ttl=None, idle_timeout=None, + interface=None, + host=None, + port=8786, + protocol=None, **kwargs ): - self._setup_logging() # Attributes @@ -1056,6 +1060,14 @@ def __init__( connection_limit = get_fileno_limit() / 2 + self._start_address = address_from_user_args( + host=host, + port=port, + interface=interface, + protocol=protocol, + security=security, + ) + super(Scheduler, self).__init__( handlers=self.handlers, stream_handlers=merge(worker_handlers, client_handlers), @@ -1172,10 +1184,12 @@ def stop_services(self): for service in self.services.values(): service.stop() - def start(self, addr_or_port=8786, start_queues=True): + def start(self, addr_or_port=None, start_queues=True): """ Clear out old state and restart all running coroutines """ enable_gc_diagnosis() + addr_or_port = addr_or_port or self._start_address + self.clear_task_state() with ignoring(AttributeError): @@ -1234,6 +1248,15 @@ def del_scheduler_file(): return self.finished() + def __await__(self): + self.start() + + @gen.coroutine + def _(): + return self + + return _().__await__() + @gen.coroutine def finished(self): """ Wait until all coroutines have ceased """ diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 8582c2abc83..3515be9ebcb 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -2824,8 +2824,7 @@ def test_diagnostic_nbytes(c, s, a, b): @gen_test() def test_worker_aliases(): - s = Scheduler(validate=True) - s.start(0) + s = yield Scheduler(validate=True, port=0) a = Worker(s.ip, s.port, name="alice") b = Worker(s.ip, s.port, name="bob") w = Worker(s.ip, s.port, name=3) @@ -3062,8 +3061,7 @@ def test_unrunnable_task_runs(c, s, a, b): def test_add_worker_after_tasks(c, s): futures = c.map(inc, range(10)) - n = Nanny(s.ip, s.port, ncores=2, loop=s.loop) - n.start(0) + n = yield Nanny(s.ip, s.port, ncores=2, loop=s.loop, port=0) result = yield c.gather(futures) @@ -3603,8 +3601,7 @@ def test_as_completed_next_batch(c): @gen_test() def test_status(): - s = Scheduler() - s.start(0) + s = yield Scheduler(port=0) c = yield Client((s.ip, s.port), asynchronous=True) assert c.status == "running" diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 9224ed69030..805b0e06ed0 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -533,8 +533,7 @@ def test_broadcast_nanny(s, a, b): @gen_test() def test_worker_name(): - s = Scheduler(validate=True) - s.start(0) + s = yield Scheduler(validate=True, port=0) w = yield Worker(s.ip, s.port, name="alice") assert s.workers[w.address].name == "alice" assert s.aliases["alice"] == w.address @@ -550,8 +549,7 @@ def test_worker_name(): @gen_test() def test_coerce_address(): with dask.config.set({"distributed.comm.timeouts.connect": "100ms"}): - s = Scheduler(validate=True) - s.start(0) + s = yield Scheduler(validate=True, port=0) print("scheduler:", s.address, s.listen_address) a = Worker(s.ip, s.port, name="alice") b = Worker(s.ip, s.port, name=123) @@ -824,7 +822,7 @@ def test_file_descriptors(c, s): yield [n.close() for n in nannies] assert not s.rpc.open - assert not c.rpc.active + assert not c.rpc.active, list(c.rpc._created) assert not s.stream_comms start = time() @@ -1133,8 +1131,7 @@ def test_fifo_submission(c, s, w): @gen_test() def test_scheduler_file(): with tmpfile() as fn: - s = Scheduler(scheduler_file=fn) - s.start(0) + s = yield Scheduler(scheduler_file=fn, port=0) with open(fn) as f: data = json.load(f) assert data["address"] == s.address @@ -1536,3 +1533,13 @@ def test_close_workers(s, a, b): yield s.close(close_workers=True) assert a.status == "closed" assert b.status == "closed" + + +@pytest.mark.skipif( + not sys.platform.startswith("linux"), reason="Need 127.0.0.2 to mean localhost" +) +@gen_test() +def test_host_address(): + s = yield Scheduler(host="127.0.0.2") + assert "127.0.0.2" in s.address + yield s.close() diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 07864ab4b64..8ca3c5d9682 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -304,8 +304,7 @@ def test_broadcast(s, a, b): @gen_test() def test_worker_with_port_zero(): - s = Scheduler() - s.start(8007) + s = yield Scheduler(port=8007) w = yield Worker(s.address) assert isinstance(w.port, int) assert w.port > 1024 @@ -1007,8 +1006,7 @@ def test_start_services(s): @gen_test() def test_scheduler_file(): with tmpfile() as fn: - s = Scheduler(scheduler_file=fn) - s.start(8009) + s = yield Scheduler(scheduler_file=fn, port=8009) w = yield Worker(scheduler_file=fn) assert set(s.workers) == {w.address} yield w.close() @@ -1384,3 +1382,18 @@ def __init__(self, x, y): assert w.data.x == 123 assert w.data.y == 456 yield w.close() + + +@pytest.mark.skipif( + not sys.platform.startswith("linux"), reason="Need 127.0.0.2 to mean localhost" +) +@gen_cluster(ncores=[], client=True) +def test_host_address(c, s): + w = yield Worker(s.address, host="127.0.0.2") + assert "127.0.0.2" in w.address + yield w.close() + + n = yield Nanny(s.address, host="127.0.0.3") + assert "127.0.0.3" in n.address + assert "127.0.0.3" in n.worker_address + yield n.close() diff --git a/distributed/utils_test.py b/distributed/utils_test.py index b0c0d2d48cc..7aaa5b1ed0d 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -11,7 +11,6 @@ import logging import logging.config import os -import psutil import re import shutil import signal @@ -40,7 +39,7 @@ from tornado.ioloop import IOLoop from .client import default_client, _global_clients, Client -from .compatibility import PY3, Empty, WINDOWS, PY2 +from .compatibility import PY3, Empty, WINDOWS from .comm import Comm from .comm.utils import offload from .config import initialize_logging @@ -156,10 +155,7 @@ def start(): _cleanup_dangling() - if PY2: # no forkserver, so no extra procs - for child in psutil.Process().children(recursive=True): - with ignoring(psutil.NoSuchProcess): - child.terminate() + assert_no_leaked_processes() _global_clients.clear() @@ -482,8 +478,8 @@ def run_scheduler(q, nputs, **kwargs): # On Python 2.7 and Unix, fork() is used to spawn child processes, # so avoid inheriting the parent's IO loop. with pristine_loop() as loop: - scheduler = Scheduler(validate=True, **kwargs) - done = scheduler.start("127.0.0.1") + scheduler = Scheduler(validate=True, host="127.0.0.1", **kwargs) + done = scheduler.start() for i in range(nputs): q.put(scheduler.address) @@ -501,7 +497,7 @@ def run_worker(q, scheduler_q, **kwargs): with pristine_loop() as loop: scheduler_addr = scheduler_q.get() worker = Worker(scheduler_addr, validate=True, **kwargs) - loop.run_sync(lambda: worker._start(0)) + loop.run_sync(lambda: worker._start()) q.put(worker.address) try: @@ -521,7 +517,7 @@ def run_nanny(q, scheduler_q, **kwargs): with pristine_loop() as loop: scheduler_addr = scheduler_q.get() worker = Nanny(scheduler_addr, validate=True, **kwargs) - loop.run_sync(lambda: worker._start(0)) + loop.run_sync(lambda: worker._start()) q.put(worker.address) try: loop.start() @@ -657,6 +653,7 @@ def cluster( # Launch scheduler scheduler = mp_context.Process( + name="Dask cluster test: Scheduler", target=run_scheduler, args=(scheduler_q, nworkers + 1), kwargs=scheduler_kwargs, @@ -675,7 +672,10 @@ def cluster( worker_kwargs, ) proc = mp_context.Process( - target=_run_worker, args=(q, scheduler_q), kwargs=kwargs + name="Dask cluster test: Worker", + target=_run_worker, + args=(q, scheduler_q), + kwargs=kwargs, ) ws.add(proc) workers.append({"proc": proc, "queue": q, "dir": fn}) @@ -774,6 +774,16 @@ def cluster( print("Unclosed Comms", L) # raise ValueError("Unclosed Comms", L) + assert_no_leaked_processes() + + +def assert_no_leaked_processes(): + for i in range(20): + if mp_context.active_children(): + sleep(0.1) + else: + assert not mp_context.active_children() + @gen.coroutine def disconnect(addr, timeout=3, rpc_kwargs=None): @@ -854,6 +864,7 @@ def start_cluster( security=security, loop=loop, validate=True, + host=ncore[0], **(merge(worker_kwargs, ncore[2]) if len(ncore) > 2 else worker_kwargs) ) for i, ncore in enumerate(ncores) @@ -861,7 +872,7 @@ def start_cluster( # for w in workers: # w.rpc = workers[0].rpc - yield [w._start(ncore[0]) for ncore, w in zip(ncores, workers)] + yield workers start = time() while len(s.workers) < len(ncores) or any( @@ -1061,6 +1072,9 @@ def coro(): _cleanup_dangling() with ignoring(AttributeError): del thread_state.on_event_loop_thread + + assert_no_leaked_processes() + return result return test_func diff --git a/distributed/worker.py b/distributed/worker.py index 915784edd88..abbd2376c42 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -30,6 +30,7 @@ from .batched import BatchedSend from .comm import get_address_host, get_local_address_for, connect from .comm.utils import offload +from .comm.addressing import address_from_user_args from .compatibility import unicode, get_thread_identity, finalize, MutableMapping from .core import error_message, CommClosedError, send_recv, pingpong, coerce_to_address from .diskutils import WorkSpace @@ -295,6 +296,10 @@ def __init__( extensions=None, metrics=None, data=None, + interface=None, + host=None, + port=None, + protocol=None, low_level_profiler=dask.config.get("distributed.worker.profile.low-level"), **kwargs ): @@ -406,7 +411,16 @@ def __init__( scheduler_addr = coerce_to_address(scheduler_ip) else: scheduler_addr = coerce_to_address((scheduler_ip, scheduler_port)) - self._port = 0 + self.contact_address = contact_address + + self._start_address = address_from_user_args( + host=host, + port=port, + interface=interface, + protocol=protocol, + security=security, + ) + self.ncores = ncores or _ncores self.total_resources = resources or {} self.available_resources = (resources or {}).copy() @@ -417,7 +431,6 @@ def __init__( self.preload_argv = preload_argv if self.preload_argv is None: self.preload_argv = dask.config.get("distributed.worker.preload-argv") - self.contact_address = contact_address self.memory_monitor_interval = parse_timedelta( memory_monitor_interval, default="ms" ) @@ -888,6 +901,7 @@ def start_services(self, default_listen_ip): @gen.coroutine def _start(self, addr_or_port=0): assert self.status is None + addr_or_port = addr_or_port or self._start_address enable_gc_diagnosis() thread_state.on_event_loop_thread = True From ffe08384d4eb98c7cd8e9891943b662c8301ea32 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Thu, 9 May 2019 02:55:52 +0200 Subject: [PATCH 0271/1550] Add memory and disk aliases to Worker.data (#2670) --- distributed/tests/test_worker.py | 19 ++++++++++++------- distributed/worker.py | 13 ++++++++++++- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 8ca3c5d9682..0d8169fd6e5 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -381,17 +381,22 @@ def test_spill_to_disk(c, s): yield wait(y) assert set(w.data) == {x.key, y.key} - assert set(w.data.fast) == {x.key, y.key} + assert set(w.data.memory) == {x.key, y.key} + assert set(w.data.fast) == set(w.data.memory) z = c.submit(np.random.randint, 0, 255, size=500, dtype="u1", key="z") yield wait(z) assert set(w.data) == {x.key, y.key, z.key} - assert set(w.data.fast) == {y.key, z.key} - assert set(w.data.slow) == {x.key} or set(w.data.slow) == {x.key, y.key} + assert set(w.data.memory) == {y.key, z.key} + assert set(w.data.disk) == {x.key} or set(w.data.slow) == {x.key, y.key} + assert set(w.data.fast) == set(w.data.memory) + assert set(w.data.slow) == set(w.data.disk) yield x - assert set(w.data.fast) == {x.key, z.key} - assert set(w.data.slow) == {y.key} or set(w.data.slow) == {x.key, y.key} + assert set(w.data.memory) == {x.key, z.key} + assert set(w.data.disk) == {y.key} or set(w.data.slow) == {x.key, y.key} + assert set(w.data.fast) == set(w.data.memory) + assert set(w.data.slow) == set(w.data.disk) yield w.close() @@ -460,7 +465,7 @@ def test_spill_by_default(c, s, w): x = da.ones(int(10e6 * 0.7), chunks=1e6, dtype="u1") y = c.persist(x) yield wait(y) - assert len(w.data.slow) # something is on disk + assert len(w.data.disk) # something is on disk del x, y @@ -1069,7 +1074,7 @@ def f(n): futures = c.map(f, [100e6] * 8, pure=False) start = time() - while not a.data.slow: + while not a.data.disk: yield gen.sleep(0.1) assert time() < start + 5 diff --git a/distributed/worker.py b/distributed/worker.py index abbd2376c42..a1846d85539 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -156,7 +156,16 @@ class Worker(ServerNode): that we want to collect from others. * **data:** ``{key: object}``: - Dictionary mapping keys to actual values + Prefer using the **host** attribute instead of this, unless + memory_limit and at least one of memory_target_fraction or + memory_spill_fraction values are defined, in that case, this attribute + is a zict.Buffer, from which information on LRU cache can be queried. + * **data.memory:** ``{key: object}``: + Dictionary mapping keys to actual values stored in memory. Only + available if condition for **data** being a zict.Buffer is met. + * **data.disk:** ``{key: object}``: + Dictionary mapping keys to actual values stored on disk. Only + available if condition for **data** being a zict.Buffer is met. * **task_state**: ``{key: string}``: The state of all tasks that the scheduler has asked us to compute. Valid states include waiting, constrained, executing, memory, erred @@ -498,6 +507,8 @@ def __init__( ) target = int(float(self.memory_limit) * self.memory_target_fraction) self.data = Buffer({}, storage, target, weight) + self.data.memory = self.data.fast + self.data.disk = self.data.slow else: self.data = dict() From 0be60cccfb6b8739b119f7126eef94dd3deff609 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Thu, 9 May 2019 08:01:35 -0500 Subject: [PATCH 0272/1550] Use config accessor method for "scheduler-address" (#2676) --- distributed/cli/dask_worker.py | 8 ++++++-- distributed/cli/tests/test_dask_worker.py | 12 ++++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/distributed/cli/dask_worker.py b/distributed/cli/dask_worker.py index 6315939005d..1448395d109 100755 --- a/distributed/cli/dask_worker.py +++ b/distributed/cli/dask_worker.py @@ -8,8 +8,8 @@ import warnings import click +import dask from distributed import Nanny, Worker -from distributed.config import config from distributed.utils import parse_timedelta from distributed.worker import _ncores from distributed.security import Security @@ -318,7 +318,11 @@ def del_pid_file(): kwargs["service_ports"] = {"nanny": nanny_port} t = Worker - if not scheduler and not scheduler_file and "scheduler-address" not in config: + if ( + not scheduler + and not scheduler_file + and dask.config.get("scheduler-address", None) is None + ): raise ValueError( "Need to provide scheduler address like\n" "dask-worker SCHEDULER_ADDRESS:8786" diff --git a/distributed/cli/tests/test_dask_worker.py b/distributed/cli/tests/test_dask_worker.py index 2fa3779d9b4..72f8327375a 100644 --- a/distributed/cli/tests/test_dask_worker.py +++ b/distributed/cli/tests/test_dask_worker.py @@ -6,6 +6,7 @@ import requests import sys +import os from time import sleep from distributed import Client @@ -141,6 +142,17 @@ def test_scheduler_file(loop, nanny): assert time() < start + 10 +def test_scheduler_address_env(loop, monkeypatch): + monkeypatch.setenv("DASK_SCHEDULER_ADDRESS", "tcp://127.0.0.1:8786") + with popen(["dask-scheduler", "--no-bokeh"]) as sched: + with popen(["dask-worker", "--no-bokeh"]): + with Client(os.environ["DASK_SCHEDULER_ADDRESS"], loop=loop) as c: + start = time() + while not c.scheduler_info()["workers"]: + sleep(0.1) + assert time() < start + 10 + + def test_nprocs_requires_nanny(loop): with popen(["dask-scheduler", "--no-bokeh"]) as sched: with popen( From 22c733ef852af72bb281af8f1534e6a942c96f8b Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Thu, 9 May 2019 14:44:43 -0500 Subject: [PATCH 0273/1550] Fix pytest.config deprecation warning (#2677) --- conftest.py | 16 +++++++++--- distributed/cli/tests/test_dask_worker.py | 4 +-- distributed/protocol/tests/test_numpy.py | 4 +-- distributed/protocol/tests/test_protocol.py | 3 +-- distributed/tests/test_batched.py | 6 ++--- distributed/tests/test_client.py | 27 ++++++++++----------- distributed/tests/test_core.py | 3 +-- distributed/tests/test_diskutils.py | 4 +-- distributed/tests/test_failed_workers.py | 3 +-- distributed/tests/test_nanny.py | 8 +++--- distributed/tests/test_queues.py | 4 +-- distributed/tests/test_scheduler.py | 9 +++---- distributed/tests/test_stress.py | 5 ++-- distributed/tests/test_variable.py | 4 +-- distributed/tests/test_worker.py | 9 +++---- distributed/utils_test.py | 11 --------- 16 files changed, 56 insertions(+), 64 deletions(-) diff --git a/conftest.py b/conftest.py index cba68bddec1..b5db36f59d8 100644 --- a/conftest.py +++ b/conftest.py @@ -1,12 +1,11 @@ # https://pytest.org/latest/example/simple.html#control-skipping-of-tests-according-to-command-line-option -import os import pytest # Uncomment to enable more logging and checks # (https://docs.python.org/3/library/asyncio-dev.html) # Note this makes things slower and might consume much memory. -#os.environ["PYTHONASYNCIODEBUG"] = "1" +# os.environ["PYTHONASYNCIODEBUG"] = "1" try: import faulthandler @@ -19,4 +18,15 @@ def pytest_addoption(parser): parser.addoption("--runslow", action="store_true", help="run slow tests") -pytest_plugins = ['distributed.pytest_resourceleaks'] + +def pytest_collection_modifyitems(config, items): + if config.getoption("--runslow"): + # --runslow given in cli: do not skip slow tests + return + skip_slow = pytest.mark.skip(reason="need --runslow option to run") + for item in items: + if "slow" in item.keywords: + item.add_marker(skip_slow) + + +pytest_plugins = ["distributed.pytest_resourceleaks"] diff --git a/distributed/cli/tests/test_dask_worker.py b/distributed/cli/tests/test_dask_worker.py index 72f8327375a..eec038ba9d6 100644 --- a/distributed/cli/tests/test_dask_worker.py +++ b/distributed/cli/tests/test_dask_worker.py @@ -12,7 +12,7 @@ from distributed import Client from distributed.metrics import time from distributed.utils import sync, tmpfile -from distributed.utils_test import popen, slow, terminate_process, wait_for_port +from distributed.utils_test import popen, terminate_process, wait_for_port from distributed.utils_test import loop # noqa: F401 @@ -65,7 +65,7 @@ def test_no_nanny(loop): assert any(b"Registered" in worker.stderr.readline() for i in range(15)) -@slow +@pytest.mark.slow @pytest.mark.parametrize("nanny", ["--nanny", "--no-nanny"]) def test_no_reconnect(nanny, loop): with popen(["dask-scheduler", "--no-bokeh"]) as sched: diff --git a/distributed/protocol/tests/test_numpy.py b/distributed/protocol/tests/test_numpy.py index 849c2964fd6..ede0eded3cf 100644 --- a/distributed/protocol/tests/test_numpy.py +++ b/distributed/protocol/tests/test_numpy.py @@ -18,7 +18,7 @@ ) from distributed.protocol.utils import BIG_BYTES_SHARD_SIZE from distributed.utils import tmpfile, nbytes -from distributed.utils_test import slow, gen_cluster +from distributed.utils_test import gen_cluster from distributed.protocol.numpy import itemsize from distributed.protocol.compression import maybe_compress @@ -152,7 +152,7 @@ def test_memmap(): np.testing.assert_equal(x, y) -@slow +@pytest.mark.slow def test_dumps_serialize_numpy_large(): psutil = pytest.importorskip("psutil") if psutil.virtual_memory().total < 2e9: diff --git a/distributed/protocol/tests/test_protocol.py b/distributed/protocol/tests/test_protocol.py index f0dc1dc6c2f..2415e01b5f1 100644 --- a/distributed/protocol/tests/test_protocol.py +++ b/distributed/protocol/tests/test_protocol.py @@ -9,7 +9,6 @@ from distributed.protocol.compression import compressions from distributed.protocol.serialize import Serialize, Serialized, serialize, deserialize from distributed.utils import nbytes -from distributed.utils_test import slow def test_protocol(): @@ -110,7 +109,7 @@ def test_large_bytes(): assert loads(frames, deserialize=False) == msg -@slow +@pytest.mark.slow def test_large_messages(): np = pytest.importorskip("numpy") psutil = pytest.importorskip("psutil") diff --git a/distributed/tests/test_batched.py b/distributed/tests/test_batched.py index 2f22134f7ae..23d8e677774 100644 --- a/distributed/tests/test_batched.py +++ b/distributed/tests/test_batched.py @@ -10,7 +10,7 @@ from distributed.core import listen, connect, CommClosedError from distributed.metrics import time from distributed.utils import All -from distributed.utils_test import gen_test, slow, captured_logger +from distributed.utils_test import gen_test, captured_logger from distributed.protocol import to_serialize @@ -158,7 +158,7 @@ def test_close_twice(): yield b.close() -@slow +@pytest.mark.slow @gen_test(timeout=50) def test_stress(): with echo_server() as e: @@ -231,7 +231,7 @@ def test_sending_traffic_jam(): yield run_traffic_jam(50, 300000) -@slow +@pytest.mark.slow @gen_test() def test_large_traffic_jam(): yield run_traffic_jam(500, 1500000) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 3515be9ebcb..b77efc9f51d 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -60,7 +60,6 @@ from distributed.utils import ignoring, mp_context, sync, tmp_text, tokey, tmpfile from distributed.utils_test import ( cluster, - slow, slowinc, slowadd, slowdec, @@ -771,7 +770,7 @@ def test_recompute_released_key(c, s, a, b): assert result1 == result2 -@slow +@pytest.mark.slow @gen_cluster(client=True) def test_long_tasks_dont_trigger_timeout(c, s, a, b): from time import sleep @@ -3473,7 +3472,7 @@ def test_get_foo_lost_keys(c, s, u, v, w): assert_dict_key_equal(d, {x.key: [], y.key: []}) -@slow +@pytest.mark.slow @gen_cluster(client=True, Worker=Nanny, check_new_threads=False) def test_bad_tasks_fail(c, s, a, b): f = c.submit(sys.exit, 1) @@ -3528,7 +3527,7 @@ def test_get_returns_early(c): assert x.key in c.futures -@slow +@pytest.mark.slow @gen_cluster(Worker=Nanny, client=True) def test_Client_clears_references_after_restart(c, s, a, b): x = c.submit(inc, 1) @@ -3644,7 +3643,7 @@ def test_scatter_raises_if_no_workers(c, s): yield c.scatter(1, timeout=0.5) -@slow +@pytest.mark.slow def test_reconnect(loop): w = Worker("127.0.0.1", 9393, loop=loop) w.start() @@ -3721,7 +3720,7 @@ def test_reconnect_timeout(c, s): assert "Failed to reconnect" in text -@slow +@pytest.mark.slow @pytest.mark.skipif( sys.platform.startswith("win"), reason="num_fds not supported on windows" ) @@ -4251,7 +4250,7 @@ def test_normalize_collection_dask_array(c, s, a, b): assert result1 == result2 -@slow +@pytest.mark.slow def test_normalize_collection_with_released_futures(c): da = pytest.importorskip("dask.array") @@ -4382,7 +4381,7 @@ def test_scatter_dict_workers(c, s, a, b): assert "a" in a.data or "a" in b.data -@slow +@pytest.mark.slow @gen_test() def test_client_timeout(): loop = IOLoop.current() @@ -4710,7 +4709,7 @@ def test_quiet_client_close(loop): ), line -@slow +@pytest.mark.slow def test_quiet_client_close_when_cluster_is_closed_before_client(loop): with captured_logger(logging.getLogger("tornado.application")) as logger: cluster = LocalCluster(loop=loop, n_workers=1) @@ -4755,7 +4754,7 @@ def f(_): del results -@slow +@pytest.mark.slow def test_threadsafe_get(c): da = pytest.importorskip("dask.array") x = da.arange(100, chunks=(10,)) @@ -4774,7 +4773,7 @@ def f(_): assert results and all(results) -@slow +@pytest.mark.slow def test_threadsafe_compute(c): da = pytest.importorskip("dask.array") x = da.arange(100, chunks=(10,)) @@ -4864,7 +4863,7 @@ def f(): assert result == 2 -@slow +@pytest.mark.slow @gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 2, timeout=60) def test_secede_balances(c, s, a, b): count = threading.active_count() @@ -4970,7 +4969,7 @@ def test_dynamic_workloads_sync(c): _test_dynamic_workloads_sync(c, delay=0.02) -@slow +@pytest.mark.slow def test_dynamic_workloads_sync_random(c): _test_dynamic_workloads_sync(c, delay="random") @@ -5190,7 +5189,7 @@ def test_client_async_before_loop_starts(): client.close() -@slow +@pytest.mark.slow @gen_cluster( client=True, Worker=Nanny if PY3 else Worker, diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index 2c8f63de6ff..4b3c0ac0ade 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -25,7 +25,6 @@ from distributed.protocol import to_serialize from distributed.utils import get_ip, get_ipv6 from distributed.utils_test import ( - slow, gen_test, gen_cluster, has_ipv6, @@ -409,7 +408,7 @@ def check_large_packets(listen_arg): server.stop() -@slow +@pytest.mark.slow @gen_test() def test_large_packets_tcp(): yield check_large_packets("tcp://") diff --git a/distributed/tests/test_diskutils.py b/distributed/tests/test_diskutils.py index d5abf5c1dee..1bededf84ab 100644 --- a/distributed/tests/test_diskutils.py +++ b/distributed/tests/test_diskutils.py @@ -16,7 +16,7 @@ from distributed.diskutils import WorkSpace from distributed.metrics import time from distributed.utils import mp_context -from distributed.utils_test import captured_logger, slow +from distributed.utils_test import captured_logger def assert_directory_contents(dir_path, expected, trials=2): @@ -279,7 +279,7 @@ def test_workspace_concurrency(tmpdir): _test_workspace_concurrency(tmpdir, 2.0, 6) -@slow +@pytest.mark.slow def test_workspace_concurrency_intense(tmpdir): n_created, n_purged = _test_workspace_concurrency(tmpdir, 8.0, 16) assert n_created >= 100 diff --git a/distributed/tests/test_failed_workers.py b/distributed/tests/test_failed_workers.py index 0772ea52c32..bae2e141ee2 100644 --- a/distributed/tests/test_failed_workers.py +++ b/distributed/tests/test_failed_workers.py @@ -19,7 +19,6 @@ gen_cluster, cluster, inc, - slow, div, slowinc, slowadd, @@ -406,7 +405,7 @@ def test_worker_who_has_clears_after_failed_connection(c, s, a, b): yield n.close() -@slow +@pytest.mark.slow @gen_cluster(client=True, timeout=60, Worker=Nanny, ncores=[("127.0.0.1", 1)]) def test_restart_timeout_on_long_running_task(c, s, a): with captured_logger("distributed.scheduler") as sio: diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index 08cd49fb3c9..60de12dce4b 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -18,7 +18,7 @@ from distributed.metrics import time from distributed.protocol.pickle import dumps from distributed.utils import ignoring, tmpfile -from distributed.utils_test import gen_cluster, gen_test, slow, inc, captured_logger +from distributed.utils_test import gen_cluster, gen_test, inc, captured_logger @gen_cluster(ncores=[]) @@ -127,7 +127,7 @@ def test_run(s): yield n.close() -@slow +@pytest.mark.slow @gen_cluster( Worker=Nanny, ncores=[("127.0.0.1", 1)], worker_kwargs={"reconnect": False} ) @@ -159,7 +159,7 @@ def test_nanny_alt_worker_class(c, s, w1, w2): assert w1.Worker is Something -@slow +@pytest.mark.slow @gen_cluster(client=False, ncores=[]) def test_nanny_death_timeout(s): yield s.close() @@ -318,7 +318,7 @@ def test_scheduler_address_config(c, s): yield nanny.close() -@slow +@pytest.mark.slow @gen_test() def test_wait_for_scheduler(): with captured_logger("distributed") as log: diff --git a/distributed/tests/test_queues.py b/distributed/tests/test_queues.py index e82b893989b..e40d3cd492c 100644 --- a/distributed/tests/test_queues.py +++ b/distributed/tests/test_queues.py @@ -9,7 +9,7 @@ from distributed import Client, Queue, Nanny, worker_client, wait from distributed.metrics import time -from distributed.utils_test import gen_cluster, inc, slow, div +from distributed.utils_test import gen_cluster, inc, div from distributed.utils_test import client, cluster_fixture, loop # noqa: F401 @@ -114,7 +114,7 @@ def f(x): @pytest.mark.skipif(sys.version_info[0] == 2, reason="Multi-client issues") -@slow +@pytest.mark.slow @gen_cluster(client=True, ncores=[("127.0.0.1", 2)] * 5, Worker=Nanny, timeout=None) def test_race(c, s, *workers): def f(i): diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 805b0e06ed0..07caff09869 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -35,7 +35,6 @@ cluster, div, varying, - slow, ) from distributed.utils_test import loop, nodebug # noqa: F401 from dask.compatibility import apply @@ -775,7 +774,7 @@ def test_retire_workers_no_suspicious_tasks(c, s, a, b): assert all(ts.suspicious == 0 for ts in s.tasks.values()) -@slow +@pytest.mark.slow @pytest.mark.skipif( sys.platform.startswith("win"), reason="file descriptors not really a thing" ) @@ -831,7 +830,7 @@ def test_file_descriptors(c, s): assert time() < start + 3 -@slow +@pytest.mark.slow @nodebug @gen_cluster(client=True) def test_learn_occupancy(c, s, a, b): @@ -844,7 +843,7 @@ def test_learn_occupancy(c, s, a, b): assert 50 < s.workers[w.address].occupancy < 700 -@slow +@pytest.mark.slow @nodebug @gen_cluster(client=True) def test_learn_occupancy_2(c, s, a, b): @@ -1062,7 +1061,7 @@ def test_close_worker(c, s, a, b): assert len(s.workers) == 1 -@slow +@pytest.mark.slow @gen_cluster(client=True, Worker=Nanny, timeout=20) def test_close_nanny(c, s, a, b): assert len(s.workers) == 2 diff --git a/distributed/tests/test_stress.py b/distributed/tests/test_stress.py index f145a11b053..8c37f5a82fb 100644 --- a/distributed/tests/test_stress.py +++ b/distributed/tests/test_stress.py @@ -21,7 +21,6 @@ inc, slowinc, slowadd, - slow, slowsum, bump_rlimit, ) @@ -198,7 +197,7 @@ def vsum(*args): @pytest.mark.avoid_travis -@slow +@pytest.mark.slow @gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 80, timeout=1000) def test_stress_communication(c, s, *workers): s.validate = False # very slow otherwise @@ -244,7 +243,7 @@ def test_stress_steal(c, s, *workers): break -@slow +@pytest.mark.slow @gen_cluster(ncores=[("127.0.0.1", 1)] * 10, client=True, timeout=120) def test_close_connections(c, s, *workers): da = pytest.importorskip("dask.array") diff --git a/distributed/tests/test_variable.py b/distributed/tests/test_variable.py index 5ae94d037c5..4d8851668f9 100644 --- a/distributed/tests/test_variable.py +++ b/distributed/tests/test_variable.py @@ -9,7 +9,7 @@ from distributed import Client, Variable, worker_client, Nanny, wait from distributed.metrics import time -from distributed.utils_test import gen_cluster, inc, slow, div +from distributed.utils_test import gen_cluster, inc, div from distributed.utils_test import client, cluster_fixture, loop # noqa: F401 @@ -147,7 +147,7 @@ def test_timeout_get(c, s, a, b): @pytest.mark.skipif(sys.version_info[0] == 2, reason="Multi-client issues") -@slow +@pytest.mark.slow @gen_cluster(client=True, ncores=[("127.0.0.1", 2)] * 5, Worker=Nanny, timeout=None) def test_race(c, s, *workers): NITERS = 50 diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 0d8169fd6e5..5fad86b2665 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -33,7 +33,6 @@ gen_cluster, div, dec, - slow, slowinc, gen_test, captured_logger, @@ -147,7 +146,7 @@ def reset(self): assert tuple(results) == (3, 7) -@slow +@pytest.mark.slow @gen_cluster() def dont_test_delete_data_with_missing_worker(c, a, b): bad = "127.0.0.1:9001" # this worker doesn't exist @@ -312,7 +311,7 @@ def test_worker_with_port_zero(): yield w.close() -@slow +@pytest.mark.slow def test_worker_waits_for_center_to_come_up(loop): @gen.coroutine def f(): @@ -726,7 +725,7 @@ def test_hold_onto_dependents(c, s, a, b): assert x.key in b.data -@slow +@pytest.mark.slow @gen_cluster(client=False, ncores=[]) def test_worker_death_timeout(s): with dask.config.set({"distributed.comm.timeouts.connect": "1s"}): @@ -1235,7 +1234,7 @@ def test_scheduler_address_config(c, s): yield worker.close() -@slow +@pytest.mark.slow @gen_cluster(client=True) def test_wait_for_outgoing(c, s, a, b): np = pytest.importorskip("numpy") diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 7aaa5b1ed0d..0a1cf447cfd 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -804,17 +804,6 @@ def disconnect_all(addresses, timeout=3, rpc_kwargs=None): yield [disconnect(addr, timeout, rpc_kwargs) for addr in addresses] -def slow(func): - try: - if not pytest.config.getoption("--runslow"): - func = pytest.mark.skip("need --runslow option to run")(func) - except AttributeError: - # AttributeError: module 'pytest' has no attribute 'config' - pass - - return nodebug(func) - - def gen_test(timeout=10): """ Coroutine test From 9ea2dc3fdf0484339d13be3fb4485a782c6f4696 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 9 May 2019 14:53:57 -0500 Subject: [PATCH 0274/1550] Move dashboard_address logic into Scheduler/Worker (#2678) This removes repetitive logic from the LocalCluster and dask-scheduler/dask-worker CLI and moves it into the classes. This also makes it easier to make other Cluster objects without depending on LocalCluster * fix test_file_descriptors --- distributed/cli/dask_scheduler.py | 18 ++---------------- distributed/cli/dask_worker.py | 16 +++------------- distributed/deploy/local.py | 25 ++++++++++--------------- distributed/scheduler.py | 13 +++++++++++++ distributed/tests/test_scheduler.py | 2 +- distributed/worker.py | 14 ++++++++++++++ 6 files changed, 43 insertions(+), 45 deletions(-) diff --git a/distributed/cli/dask_scheduler.py b/distributed/cli/dask_scheduler.py index 57a7168a3a2..3668be684d0 100755 --- a/distributed/cli/dask_scheduler.py +++ b/distributed/cli/dask_scheduler.py @@ -188,29 +188,15 @@ def del_pid_file(): loop = IOLoop.current() logger.info("-" * 47) - services = {} - if _bokeh: - try: - from distributed.bokeh.scheduler import BokehScheduler - - services[("bokeh", dashboard_address)] = ( - BokehScheduler, - {"prefix": bokeh_prefix}, - ) - except ImportError as error: - if str(error).startswith("No module named"): - logger.info("Web dashboard not loaded. Unable to import bokeh") - else: - logger.info("Unable to import bokeh: %s" % str(error)) - scheduler = Scheduler( loop=loop, - services=services, scheduler_file=scheduler_file, security=sec, host=host, port=port, interface=interface, + dashboard_address=dashboard_address if _bokeh else None, + service_kwargs={"bokeh": {"prefix": bokeh_prefix}}, ) scheduler.start() if not preload: diff --git a/distributed/cli/dask_worker.py b/distributed/cli/dask_worker.py index 1448395d109..e383095b382 100755 --- a/distributed/cli/dask_worker.py +++ b/distributed/cli/dask_worker.py @@ -163,7 +163,7 @@ default=None, help="Seconds to wait for a scheduler before closing", ) -@click.option("--bokeh-prefix", type=str, default=None, help="Prefix for the bokeh app") +@click.option("--bokeh-prefix", type=str, default="", help="Prefix for the bokeh app") @click.option( "--preload", type=str, @@ -288,18 +288,6 @@ def del_pid_file(): services = {} - if bokeh: - try: - from distributed.bokeh.worker import BokehWorker - except ImportError: - pass - else: - if bokeh_prefix: - result = (BokehWorker, {"prefix": bokeh_prefix}) - else: - result = BokehWorker - services[("bokeh", dashboard_address)] = result - if resources: resources = resources.replace(",", " ").split() resources = dict(pair.split("=") for pair in resources) @@ -350,6 +338,8 @@ def del_pid_file(): interface=interface, host=host, port=port, + dashboard_address=dashboard_address if bokeh else None, + service_kwargs={"bokhe": {"prefix": bokeh_prefix}}, name=name if nprocs == 1 or not name else name + "-" + str(i), **kwargs ) diff --git a/distributed/deploy/local.py b/distributed/deploy/local.py index 3431210a645..fb8793d0840 100644 --- a/distributed/deploy/local.py +++ b/distributed/deploy/local.py @@ -57,7 +57,7 @@ class LocalCluster(Cluster): Address on which to listen for the Bokeh diagnostics server like 'localhost:8787' or '0.0.0.0:8787'. Defaults to ':8787'. Set to ``None`` to disable the dashboard. - Use port 0 for a random port. + Use ':0' for a random port. diagnostics_port: int Deprecated. See dashboard_address. asynchronous: bool (False by default) @@ -112,6 +112,7 @@ def __init__( scheduler_port=0, silence_logs=logging.WARN, dashboard_address=":8787", + worker_dashboard_address=None, diagnostics_port=None, services=None, worker_services=None, @@ -179,29 +180,23 @@ def __init__( worker_kwargs["memory_limit"] = parse_memory_limit("auto", 1, n_workers) worker_kwargs.update( - {"ncores": threads_per_worker, "services": worker_services} + { + "ncores": threads_per_worker, + "services": worker_services, + "dashboard_address": worker_dashboard_address, + } ) self._loop_runner = LoopRunner(loop=loop, asynchronous=asynchronous) self.loop = self._loop_runner.loop - if dashboard_address is not False and dashboard_address is not None: - try: - from distributed.bokeh.scheduler import BokehScheduler - from distributed.bokeh.worker import BokehWorker - except ImportError: - logger.debug("To start diagnostics web server please install Bokeh") - else: - services[("bokeh", dashboard_address)] = ( - BokehScheduler, - (service_kwargs or {}).get("bokeh", {}), - ) - worker_services[("bokeh", 0)] = BokehWorker - self.scheduler = Scheduler( loop=self.loop, services=services, + service_kwargs=service_kwargs, security=security, + interface=interface, + dashboard_address=dashboard_address, blocked_handlers=blocked_handlers, ) self.scheduler_port = scheduler_port diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 8ba4cedf468..8500150204e 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -818,6 +818,7 @@ def __init__( delete_interval="500ms", synchronize_worker_interval="60s", services=None, + service_kwargs=None, allowed_failures=ALLOWED_FAILURES, extensions=None, validate=False, @@ -829,6 +830,7 @@ def __init__( host=None, port=8786, protocol=None, + dashboard_address=None, **kwargs ): self._setup_logging() @@ -862,6 +864,17 @@ def __init__( self.connection_args = self.security.get_connection_args("scheduler") self.listen_args = self.security.get_listen_args("scheduler") + if dashboard_address is not None: + try: + from distributed.bokeh.scheduler import BokehScheduler + except ImportError: + logger.debug("To start diagnostics web server please install Bokeh") + else: + self.service_specs[("bokeh", dashboard_address)] = ( + BokehScheduler, + (service_kwargs or {}).get("bokeh", {}), + ) + # Communication state self.loop = loop or IOLoop.current() self.client_comms = dict() diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 07caff09869..3cc6579ed29 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -821,7 +821,7 @@ def test_file_descriptors(c, s): yield [n.close() for n in nannies] assert not s.rpc.open - assert not c.rpc.active, list(c.rpc._created) + assert not any(occ for addr, occ in c.rpc.occupied.items() if occ != s.address) assert not s.stream_comms start = time() diff --git a/distributed/worker.py b/distributed/worker.py index a1846d85539..1b103fe144a 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -290,6 +290,7 @@ def __init__( local_dir="dask-worker-space", services=None, service_ports=None, + service_kwargs=None, name=None, reconnect=True, memory_limit="auto", @@ -309,6 +310,7 @@ def __init__( host=None, port=None, protocol=None, + dashboard_address=None, low_level_profiler=dask.config.get("distributed.worker.profile.low-level"), **kwargs ): @@ -535,6 +537,18 @@ def __init__( self.services = {} self.service_ports = service_ports or {} self.service_specs = services or {} + + if dashboard_address is not None: + try: + from distributed.bokeh.worker import BokehWorker + except ImportError: + logger.debug("To start diagnostics web server please install Bokeh") + else: + self.service_specs[("bokeh", dashboard_address)] = ( + BokehWorker, + (service_kwargs or {}).get("bokeh", {}), + ) + self.metrics = dict(metrics) if metrics else {} self.low_level_profiler = low_level_profiler From 94dd92ebc2345a326550a6b2e3f2de776727712f Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Fri, 10 May 2019 08:29:31 -0500 Subject: [PATCH 0275/1550] Fix uri_from_host_port import in dask-mpi (#2683) --- distributed/cli/dask_mpi.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/distributed/cli/dask_mpi.py b/distributed/cli/dask_mpi.py index ef7dd0c59fa..398596508a3 100644 --- a/distributed/cli/dask_mpi.py +++ b/distributed/cli/dask_mpi.py @@ -8,7 +8,8 @@ from distributed import Scheduler, Nanny, Worker from distributed.bokeh.worker import BokehWorker -from distributed.cli.utils import check_python_3, uri_from_host_port +from distributed.cli.utils import check_python_3 +from distributed.comm.addressing import uri_from_host_port from distributed.utils import get_ip_interface From 14998926603310416eed91d0f6181bf73a6b1fb8 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 10 May 2019 08:43:21 -0500 Subject: [PATCH 0276/1550] Consolidate logic around services (#2679) We move shared logic for services down to the Server class. Additionally we remove special casing of the nanny in service_ports, and instead tack on a nanny attribute to the Worker and WorkerState directly. --- .../bokeh/tests/test_scheduler_bokeh.py | 10 +-- distributed/cli/tests/test_dask_worker.py | 5 +- distributed/nanny.py | 2 +- distributed/node.py | 52 +++++++++++++++ distributed/scheduler.py | 63 +++++-------------- distributed/tests/test_client.py | 2 +- distributed/tests/test_nanny.py | 4 +- distributed/tests/test_scheduler.py | 12 ++++ distributed/worker.py | 43 +++---------- 9 files changed, 99 insertions(+), 94 deletions(-) diff --git a/distributed/bokeh/tests/test_scheduler_bokeh.py b/distributed/bokeh/tests/test_scheduler_bokeh.py index 380dff104e2..f3a57586c72 100644 --- a/distributed/bokeh/tests/test_scheduler_bokeh.py +++ b/distributed/bokeh/tests/test_scheduler_bokeh.py @@ -565,10 +565,12 @@ def test_GraphPlot_order(c, s, a, b): ) def test_profile_server(c, s, a, b): ptp = ProfileServer(s) - ptp.trigger_update() - yield gen.sleep(0.200) - ptp.trigger_update() - assert 2 < len(ptp.ts_source.data["time"]) < 20 + start = time() + yield gen.sleep(0.100) + while len(ptp.ts_source.data["time"]) < 2: + yield gen.sleep(0.100) + ptp.trigger_update() + assert time() < start + 2 @gen_cluster(client=True, scheduler_kwargs={"services": {("bokeh", 0): BokehScheduler}}) diff --git a/distributed/cli/tests/test_dask_worker.py b/distributed/cli/tests/test_dask_worker.py index eec038ba9d6..c26c99f2350 100644 --- a/distributed/cli/tests/test_dask_worker.py +++ b/distributed/cli/tests/test_dask_worker.py @@ -40,7 +40,10 @@ def test_nanny_worker_ports(loop): else: assert time() - start < 5 sleep(0.1) - assert d["workers"]["tcp://127.0.0.1:9684"]["services"]["nanny"] == 5273 + assert ( + d["workers"]["tcp://127.0.0.1:9684"]["nanny"] + == "tcp://127.0.0.1:5273" + ) def test_memory_limit(loop): diff --git a/distributed/nanny.py b/distributed/nanny.py index ef0e0a38f0e..4bc0eeef6f9 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -259,7 +259,7 @@ def instantiate(self, comm=None): ncores=self.ncores, local_dir=self.local_dir, services=self.services, - service_ports={"nanny": self.port}, + nanny=self.address, name=self.name, memory_limit=self.memory_limit, reconnect=self.reconnect, diff --git a/distributed/node.py b/distributed/node.py index 8134546fa0b..ff95a621877 100644 --- a/distributed/node.py +++ b/distributed/node.py @@ -1,7 +1,10 @@ from __future__ import print_function, division, absolute_import +import warnings + from tornado.ioloop import IOLoop +from .compatibility import unicode from .core import Server, ConnectionPool from .versions import get_versions @@ -78,3 +81,52 @@ def __init__( def versions(self, comm=None, packages=None): return get_versions(packages=packages) + + def start_services(self, default_listen_ip): + if default_listen_ip == "0.0.0.0": + default_listen_ip = "" # for IPV6 + + for k, v in self.service_specs.items(): + listen_ip = None + if isinstance(k, tuple): + k, port = k + else: + port = 0 + + if isinstance(port, (str, unicode)): + port = port.split(":") + + if isinstance(port, (tuple, list)): + if len(port) == 2: + listen_ip, port = (port[0], int(port[1])) + elif len(port) == 1: + [listen_ip], port = port, 0 + else: + raise ValueError(port) + + if isinstance(v, tuple): + v, kwargs = v + else: + kwargs = {} + + try: + service = v(self, io_loop=self.loop, **kwargs) + service.listen( + (listen_ip if listen_ip is not None else default_listen_ip, port) + ) + self.services[k] = service + except Exception as e: + warnings.warn( + "\nCould not launch service '%s' on port %s. " % (k, port) + + "Got the following message:\n\n" + + str(e), + stacklevel=3, + ) + + def stop_services(self): + for service in self.services.values(): + service.stop() + + @property + def service_ports(self): + return {k: v.port for k, v in self.services.items()} diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 8500150204e..87cc4fda8a4 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -12,7 +12,6 @@ import pickle import random import six -import warnings import psutil import sortedcontainers @@ -190,6 +189,10 @@ class WorkerState(object): The current status of the worker, either ``'running'`` or ``'closed'`` + .. attribute:: nanny: str + + Address of the associated Nanny, if present + .. attribute:: last_seen: Number The last time we received a heartbeat from this worker, in local @@ -214,6 +217,7 @@ class WorkerState(object): "memory_limit", "metrics", "name", + "nanny", "nbytes", "ncores", "occupancy", @@ -235,6 +239,7 @@ def __init__( memory_limit=0, local_directory=None, services=None, + nanny=None, ): self.address = address self.pid = pid @@ -243,6 +248,7 @@ def __init__( self.memory_limit = memory_limit self.local_directory = local_directory self.services = services or {} + self.nanny = nanny self.status = "running" self.nbytes = 0 @@ -271,6 +277,7 @@ def clean(self): memory_limit=self.memory_limit, local_directory=self.local_directory, services=self.services, + nanny=self.nanny, ) ws.processing = {ts.key for ts in self.processing} return ws @@ -298,6 +305,7 @@ def identity(self): "last_seen": self.last_seen, "services": self.services, "metrics": self.metrics, + "nanny": self.nanny, } @@ -1157,46 +1165,6 @@ def get_worker_service_addr(self, worker, service_name, protocol=False): else: return ws.host, port - def start_services(self, default_listen_ip): - if default_listen_ip == "0.0.0.0": - default_listen_ip = "" # for IPV6 - - for k, v in self.service_specs.items(): - listen_ip = None - if isinstance(k, tuple): - k, port = k - else: - port = 0 - - if isinstance(port, (str, unicode)): - port = port.split(":") - - if isinstance(port, (tuple, list)): - listen_ip, port = (port[0], int(port[1])) - - if isinstance(v, tuple): - v, kwargs = v - else: - kwargs = {} - - try: - service = v(self, io_loop=self.loop, **kwargs) - service.listen( - (listen_ip if listen_ip is not None else default_listen_ip, port) - ) - self.services[k] = service - except Exception as e: - warnings.warn( - "\nCould not launch service '%s' on port %s. " % (k, port) - + "Got the following message:\n\n" - + str(e), - stacklevel=3, - ) - - def stop_services(self): - for service in self.services.values(): - service.stop() - def start(self, addr_or_port=None, start_queues=True): """ Clear out old state and restart all running coroutines """ enable_gc_diagnosis() @@ -1347,7 +1315,7 @@ def close_worker(self, stream=None, worker=None, safe=None): logger.info("Closing worker %s", worker) with log_errors(): self.log_event(worker, {"action": "close-worker"}) - nanny_addr = self.get_worker_service_addr(worker, "nanny", protocol=True) + nanny_addr = self.workers[worker].nanny address = nanny_addr or worker self.worker_send(worker, {"op": "close", "report": False}) @@ -1434,6 +1402,7 @@ def add_worker( pid=0, services=None, local_directory=None, + nanny=None, ): """ Add a new worker to the cluster """ with log_errors(): @@ -1453,6 +1422,7 @@ def add_worker( name=name, local_directory=local_directory, services=services, + nanny=nanny, ) if name in self.aliases: @@ -2608,10 +2578,7 @@ def restart(self, client=None, timeout=3): keys=[ts.key for ts in cs.wants_what], client=cs.client_key ) - nannies = { - addr: self.get_worker_service_addr(addr, "nanny", protocol=True) - for addr in self.workers - } + nannies = {addr: ws.nanny for addr, ws in self.workers.items()} for addr in list(self.workers): try: @@ -2694,9 +2661,7 @@ def broadcast( # TODO replace with worker_list if nanny: - addresses = [ - self.get_worker_service_addr(w, "nanny", protocol=True) for w in workers - ] + addresses = [self.workers[w].nanny for w in workers] else: addresses = workers diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index b77efc9f51d..cf18e6b10e8 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -3479,7 +3479,7 @@ def test_bad_tasks_fail(c, s, a, b): with pytest.raises(KilledWorker) as info: yield f - assert info.value.last_worker.services["nanny"] in {a.port, b.port} + assert info.value.last_worker.nanny in {a.address, b.address} def test_get_processing_sync(c, s, a, b): diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index 60de12dce4b..bf9f91b6371 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -28,7 +28,7 @@ def test_nanny(s): with rpc(n.address) as nn: assert n.is_alive() assert s.ncores[n.worker_address] == 2 - assert s.workers[n.worker_address].services["nanny"] > 1024 + assert s.workers[n.worker_address].nanny == n.address yield nn.kill() assert not n.is_alive() @@ -43,7 +43,7 @@ def test_nanny(s): yield nn.instantiate() assert n.is_alive() assert s.ncores[n.worker_address] == 2 - assert s.workers[n.worker_address].services["nanny"] > 1024 + assert s.workers[n.worker_address].nanny == n.address yield nn.terminate() assert not n.is_alive() diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 3cc6579ed29..73d10ab8a55 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -1542,3 +1542,15 @@ def test_host_address(): s = yield Scheduler(host="127.0.0.2") assert "127.0.0.2" in s.address yield s.close() + + +@gen_test() +def test_dashboard_address(): + pytest.importorskip("bokeh") + s = yield Scheduler(dashboard_address="127.0.0.1:8901") + assert s.services["bokeh"].port == 8901 + yield s.close() + + s = yield Scheduler(dashboard_address="127.0.0.1") + assert s.services["bokeh"].port + yield s.close() diff --git a/distributed/worker.py b/distributed/worker.py index 1b103fe144a..64eca0cc770 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -261,6 +261,8 @@ class Worker(ServerNode): executor: concurrent.futures.Executor resources: dict Resources that this worker has like ``{'GPU': 2}`` + nanny: str + Address on which to contact nanny, if it exists Examples -------- @@ -311,6 +313,7 @@ def __init__( port=None, protocol=None, dashboard_address=None, + nanny=None, low_level_profiler=dask.config.get("distributed.worker.profile.low-level"), **kwargs ): @@ -323,6 +326,7 @@ def __init__( self.who_has = dict() self.has_what = defaultdict(set) self.pending_data_per_worker = defaultdict(deque) + self.nanny = nanny self._lock = threading.Lock() self.data_needed = deque() # TODO: replace with heap? @@ -535,7 +539,6 @@ def __init__( sys.path.insert(0, self.local_dir) self.services = {} - self.service_ports = service_ports or {} self.service_specs = services or {} if dashboard_address is not None: @@ -731,6 +734,7 @@ def _register_with_scheduler(self): memory_limit=self.memory_limit, local_directory=self.local_dir, services=self.service_ports, + nanny=self.nanny, pid=os.getpid(), metrics=self.get_metrics(), ), @@ -895,34 +899,6 @@ def get_logs(self, comm=None, n=None): # Lifecycle # ############# - def start_services(self, default_listen_ip): - if default_listen_ip == "0.0.0.0": - default_listen_ip = "" # for IPV6 - - for k, v in self.service_specs.items(): - listen_ip = None - if isinstance(k, tuple): - k, port = k - else: - port = 0 - - if isinstance(port, (str, unicode)): - port = port.split(":") - - if isinstance(port, (tuple, list)): - listen_ip, port = (port[0], int(port[1])) - - if isinstance(v, tuple): - v, kwargs = v - else: - kwargs = {} - - self.services[k] = v(self, io_loop=self.loop, **kwargs) - self.services[k].listen( - (listen_ip if listen_ip is not None else default_listen_ip, port) - ) - self.service_ports[k] = self.services[k].port - @gen.coroutine def _start(self, addr_or_port=0): assert self.status is None @@ -1047,13 +1023,8 @@ def close(self, report=True, timeout=10, nanny=True, executor_wait=True): if self.batched_stream: self.batched_stream.close() - if nanny and "nanny" in self.service_ports: - nanny_address = "%s%s:%d" % ( - self.listener.prefix, - self.ip, - self.service_ports["nanny"], - ) - with self.rpc(nanny_address) as r: + if nanny and self.nanny: + with self.rpc(self.nanny) as r: yield r.terminate() self.rpc.close() From 3e87f34f6fe110188e03c18c14bc27893ad0da59 Mon Sep 17 00:00:00 2001 From: Scott Sievert Date: Fri, 10 May 2019 16:20:40 +0000 Subject: [PATCH 0277/1550] Add CONTRIBUTING.md (#2680) --- CONTRIBUTING.md | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 CONTRIBUTING.md diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 00000000000..cd35ad7c572 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,29 @@ +For more information, see https://docs.dask.org/en/latest/develop.html#contributing-to-code + + +## Style +Distributed conforms with the [flake8] and [black] styles. To make sure your +code conforms with these styles, run + +``` shell +$ pip install black flake8 +$ cd path/to/distributed +$ black distributed +$ flake8 distributed +``` + +[flake8]:http://flake8.pycqa.org/en/latest/ +[black]:https://github.com/python/black + +## Docstrings + +Dask Distributed roughly follows the [numpydoc] standard. More information is +available at https://docs.dask.org/en/latest/develop.html#docstrings. + +[numpydoc]:https://github.com/numpy/numpy/blob/master/doc/HOWTO_DOCUMENT.rst.txt + +## Tests + +Dask employs extensive unit tests to ensure correctness of code both for today +and for the future. Test coverage is expected for all code contributions. More +detail is at https://docs.dask.org/en/latest/develop.html#test From a8fa4c19f6da43d8605bb8fcd8d83b9ff8b214ee Mon Sep 17 00:00:00 2001 From: Muammar El Khatib Date: Fri, 10 May 2019 09:28:45 -0700 Subject: [PATCH 0278/1550] Catch RuntimeError to avoid serialization fail when using pytorch (#2619) * Fix `Failed to Serialize` error with pytorch tensors. - When a tensor requires_grad then we have to t.detach().numpy() otherwise a .numpy() is used. This fixes the failed to serialized problem present in latest distributed version. - Improved test_grad() test as suggested by @stsievert. - The whole PR is included in a single commit. * More improvements to test_torch - Verify that t.requires_grad is not modified by serialization. - Use `np.allclose()` instead of `==`. --- distributed/protocol/tests/test_torch.py | 18 +++++++++++++----- distributed/protocol/torch.py | 7 ++++++- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/distributed/protocol/tests/test_torch.py b/distributed/protocol/tests/test_torch.py index 6cc8bb20986..efb5fa6610a 100644 --- a/distributed/protocol/tests/test_torch.py +++ b/distributed/protocol/tests/test_torch.py @@ -14,14 +14,22 @@ def test_tensor(): assert (x == t2.numpy()).all() -def test_grad(): +@pytest.mark.parametrize("requires_grad", [True, False]) +def test_grad(requires_grad): x = np.arange(10) - t = torch.Tensor(x) - t.grad = torch.zeros_like(t) + 1 + t = torch.tensor(x, dtype=torch.float, requires_grad=requires_grad) + + if requires_grad: + t.grad = torch.zeros_like(t) + 1 t2 = deserialize(*serialize(t)) - assert (t2.numpy() == x).all() - assert (t2.grad.numpy() == 1).all() + + assert t2.requires_grad is requires_grad + assert t.requires_grad is requires_grad + assert np.allclose(t2.detach().numpy(), x) + + if requires_grad: + assert np.allclose(t2.grad.numpy(), 1) def test_resnet(): diff --git a/distributed/protocol/torch.py b/distributed/protocol/torch.py index e69be68b0c1..3b4c6d19c8d 100644 --- a/distributed/protocol/torch.py +++ b/distributed/protocol/torch.py @@ -7,7 +7,12 @@ @dask_serialize.register(torch.Tensor) def serialize_torch_Tensor(t): requires_grad_ = t.requires_grad - header, frames = serialize(t.detach_().numpy()) + + if requires_grad_: + header, frames = serialize(t.detach().numpy()) + else: + header, frames = serialize(t.numpy()) + if t.grad is not None: grad_header, grad_frames = serialize(t.grad.numpy()) header["grad"] = {"header": grad_header, "start": len(frames)} From edc001441841c748258a6f875f414f671e2dcf2a Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 10 May 2019 13:23:48 -0500 Subject: [PATCH 0279/1550] Switch from (ip, port) to address in tests (#2684) This helps resolve issues on Windows CI --- distributed/batched.py | 2 +- .../diagnostics/tests/test_progressbar.py | 33 ++++-------- distributed/tests/test_client.py | 52 +++++++++---------- distributed/tests/test_failed_workers.py | 10 ++-- distributed/tests/test_nanny.py | 6 +-- distributed/tests/test_publish.py | 16 +++--- distributed/tests/test_resources.py | 6 +-- distributed/tests/test_scheduler.py | 24 ++++----- distributed/tests/test_steal.py | 6 +-- distributed/tests/test_worker.py | 2 +- 10 files changed, 73 insertions(+), 84 deletions(-) diff --git a/distributed/batched.py b/distributed/batched.py index bc77cc7fda2..e17d7b1f1bd 100644 --- a/distributed/batched.py +++ b/distributed/batched.py @@ -26,7 +26,7 @@ class BatchedSend(object): Example ------- - >>> stream = yield connect(ip, port) + >>> stream = yield connect(address) >>> bstream = BatchedSend(interval='10 ms') >>> bstream.start(stream) >>> bstream.send('Hello,') diff --git a/distributed/diagnostics/tests/test_progressbar.py b/distributed/diagnostics/tests/test_progressbar.py index d5a01410f5e..ac21f1637bc 100644 --- a/distributed/diagnostics/tests/test_progressbar.py +++ b/distributed/diagnostics/tests/test_progressbar.py @@ -2,12 +2,10 @@ from time import sleep -from tornado import gen - from distributed import Scheduler, Worker from distributed.diagnostics.progressbar import TextProgressBar, progress from distributed.metrics import time -from distributed.utils_test import inc, div, gen_cluster +from distributed.utils_test import inc, div, gen_cluster, gen_test from distributed.utils_test import client, loop, cluster_fixture # noqa: F401 @@ -30,34 +28,25 @@ def test_text_progressbar(capsys, client): def test_TextProgressBar_error(c, s, a, b): x = c.submit(div, 1, 0) - progress = TextProgressBar( - [x.key], scheduler=(s.ip, s.port), start=False, interval=0.01 - ) + progress = TextProgressBar([x.key], scheduler=s.address, start=False, interval=0.01) yield progress.listen() assert progress.status == "error" assert progress.comm.closed() - progress = TextProgressBar( - [x.key], scheduler=(s.ip, s.port), start=False, interval=0.01 - ) + progress = TextProgressBar([x.key], scheduler=s.address, start=False, interval=0.01) yield progress.listen() assert progress.status == "error" assert progress.comm.closed() -def test_TextProgressBar_empty(loop, capsys): - @gen.coroutine +def test_TextProgressBar_empty(capsys): + @gen_test() def f(): - s = Scheduler(loop=loop) - done = s.start(0) - a = Worker(s.ip, s.port, loop=loop, ncores=1) - b = Worker(s.ip, s.port, loop=loop, ncores=1) - yield [a._start(0), b._start(0)] - - progress = TextProgressBar( - [], scheduler=(s.ip, s.port), start=False, interval=0.01 - ) + s = yield Scheduler(port=0) + a, b = yield [Worker(s.address, ncores=1), Worker(s.address, ncores=1)] + + progress = TextProgressBar([], scheduler=s.address, start=False, interval=0.01) yield progress.listen() assert progress.status == "finished" @@ -65,9 +54,9 @@ def f(): yield [a.close(), b.close()] s.close() - yield done + yield s.finished() - loop.run_sync(f) + f() def check_bar_completed(capsys, width=40): diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index cf18e6b10e8..bce1066c0d2 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -471,7 +471,7 @@ def test_exceptions(c, s, a, b): @gen_cluster() def test_gc(s, a, b): - c = yield Client((s.ip, s.port), asynchronous=True) + c = yield Client(s.address, asynchronous=True) x = c.submit(inc, 10) yield x @@ -1006,12 +1006,12 @@ def assert_list(x, z=[]): @gen_cluster() def test_two_consecutive_clients_share_results(s, a, b): - c = yield Client((s.ip, s.port), asynchronous=True) + c = yield Client(s.address, asynchronous=True) x = c.submit(random.randint, 0, 1000, pure=True) xx = yield x - f = yield Client((s.ip, s.port), asynchronous=True) + f = yield Client(s.address, asynchronous=True) y = f.submit(random.randint, 0, 1000, pure=True) yy = yield y @@ -1680,8 +1680,8 @@ def test_upload_file_exception_sync(c): @pytest.mark.skip @gen_cluster() def test_multiple_clients(s, a, b): - a = yield Client((s.ip, s.port), asynchronous=True) - b = yield Client((s.ip, s.port), asynchronous=True) + a = yield Client(s.address, asynchronous=True) + b = yield Client(s.address, asynchronous=True) x = a.submit(inc, 1) y = b.submit(inc, 2) @@ -2102,8 +2102,8 @@ def test_waiting_data(c, s, a, b): @gen_cluster() def test_multi_client(s, a, b): - c = yield Client((s.ip, s.port), asynchronous=True) - f = yield Client((s.ip, s.port), asynchronous=True) + c = yield Client(s.address, asynchronous=True) + f = yield Client(s.address, asynchronous=True) assert set(s.client_comms) == {c.id, f.id} @@ -2170,9 +2170,9 @@ def test_cleanup_after_broken_client_connection(s, a, b): @gen_cluster() def test_multi_garbage_collection(s, a, b): - c = yield Client((s.ip, s.port), asynchronous=True) + c = yield Client(s.address, asynchronous=True) - f = yield Client((s.ip, s.port), asynchronous=True) + f = yield Client(s.address, asynchronous=True) x = c.submit(inc, 1) y = f.submit(inc, 2) @@ -2294,8 +2294,8 @@ def test__cancel_tuple_key(c, s, a, b): @gen_cluster() def test__cancel_multi_client(s, a, b): - c = yield Client((s.ip, s.port), asynchronous=True) - f = yield Client((s.ip, s.port), asynchronous=True) + c = yield Client(s.address, asynchronous=True) + f = yield Client(s.address, asynchronous=True) x = c.submit(slowinc, 1) y = f.submit(slowinc, 1) @@ -2824,12 +2824,12 @@ def test_diagnostic_nbytes(c, s, a, b): @gen_test() def test_worker_aliases(): s = yield Scheduler(validate=True, port=0) - a = Worker(s.ip, s.port, name="alice") - b = Worker(s.ip, s.port, name="bob") - w = Worker(s.ip, s.port, name=3) + a = Worker(s.address, name="alice") + b = Worker(s.address, name="bob") + w = Worker(s.address, name=3) yield [a, b, w] - c = yield Client((s.ip, s.port), asynchronous=True) + c = yield Client(s.address, asynchronous=True) L = c.map(inc, range(10), workers="alice") future = yield c.scatter(123, workers=3) @@ -2905,10 +2905,10 @@ def test_client_num_fds(loop): @gen_cluster() def test_startup_close_startup(s, a, b): - c = yield Client((s.ip, s.port), asynchronous=True) + c = yield Client(s.address, asynchronous=True) yield c.close() - c = yield Client((s.ip, s.port), asynchronous=True) + c = yield Client(s.address, asynchronous=True) yield c.close() @@ -3043,7 +3043,7 @@ def test_unrunnable_task_runs(c, s, a, b): assert s.tasks[x.key] in s.unrunnable assert s.get_task_status(keys=[x.key]) == {x.key: "no-worker"} - w = yield Worker(s.ip, s.port, loop=s.loop) + w = yield Worker(s.address, loop=s.loop) start = time() while x.status != "finished": @@ -3060,7 +3060,7 @@ def test_unrunnable_task_runs(c, s, a, b): def test_add_worker_after_tasks(c, s): futures = c.map(inc, range(10)) - n = yield Nanny(s.ip, s.port, ncores=2, loop=s.loop, port=0) + n = yield Nanny(s.address, ncores=2, loop=s.loop, port=0) result = yield c.gather(futures) @@ -3602,7 +3602,7 @@ def test_as_completed_next_batch(c): def test_status(): s = yield Scheduler(port=0) - c = yield Client((s.ip, s.port), asynchronous=True) + c = yield Client(s.address, asynchronous=True) assert c.status == "running" x = c.submit(inc, 1) @@ -3782,8 +3782,8 @@ def start_worker(sleep, duration, repeat=1): @gen_cluster(client=False, timeout=None) def test_idempotence(s, a, b): - c = yield Client((s.ip, s.port), asynchronous=True) - f = yield Client((s.ip, s.port), asynchronous=True) + c = yield Client(s.address, asynchronous=True) + f = yield Client(s.address, asynchronous=True) # Submit x = c.submit(inc, 1) @@ -3989,8 +3989,8 @@ def test_scatter_compute_store_lose_processing(c, s, a, b): @gen_cluster(client=False) def test_serialize_future(s, a, b): - c = yield Client((s.ip, s.port), asynchronous=True) - f = yield Client((s.ip, s.port), asynchronous=True) + c = yield Client(s.address, asynchronous=True) + f = yield Client(s.address, asynchronous=True) future = c.submit(lambda: 1) result = yield future @@ -4008,8 +4008,8 @@ def test_serialize_future(s, a, b): @gen_cluster(client=False) def test_temp_client(s, a, b): - c = yield Client((s.ip, s.port), asynchronous=True) - f = yield Client((s.ip, s.port), asynchronous=True) + c = yield Client(s.address, asynchronous=True) + f = yield Client(s.address, asynchronous=True) with temp_default_client(c): assert default_client() is c diff --git a/distributed/tests/test_failed_workers.py b/distributed/tests/test_failed_workers.py index bae2e141ee2..dde92c6d24c 100644 --- a/distributed/tests/test_failed_workers.py +++ b/distributed/tests/test_failed_workers.py @@ -39,7 +39,7 @@ def test_submit_after_failed_worker_sync(loop): @gen_cluster(client=True, timeout=60, active_rpc_timeout=10) def test_submit_after_failed_worker_async(c, s, a, b): - n = Nanny(s.ip, s.port, ncores=2, loop=s.loop) + n = Nanny(s.address, ncores=2, loop=s.loop) n.start(0) while len(s.workers) < 3: yield gen.sleep(0.1) @@ -267,8 +267,8 @@ def test_fast_kill(c, s, a, b): @gen_cluster(Worker=Nanny, timeout=60) def test_multiple_clients_restart(s, a, b): - e1 = yield Client((s.ip, s.port), asynchronous=True) - e2 = yield Client((s.ip, s.port), asynchronous=True) + e1 = yield Client(s.address, asynchronous=True) + e2 = yield Client(s.address, asynchronous=True) x = e1.submit(inc, 1) y = e2.submit(inc, 2) @@ -315,7 +315,7 @@ def test_forgotten_futures_dont_clean_up_new_futures(c, s, a, b): @gen_cluster(client=True, timeout=60, active_rpc_timeout=10) def test_broken_worker_during_computation(c, s, a, b): s.allowed_failures = 100 - n = Nanny(s.ip, s.port, ncores=2, loop=s.loop) + n = Nanny(s.address, ncores=2, loop=s.loop) n.start(0) start = time() @@ -374,7 +374,7 @@ def test_restart_during_computation(c, s, a, b): @gen_cluster(client=True, timeout=60) def test_worker_who_has_clears_after_failed_connection(c, s, a, b): - n = Nanny(s.ip, s.port, ncores=2, loop=s.loop) + n = Nanny(s.address, ncores=2, loop=s.loop) n.start(0) start = time() diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index bf9f91b6371..4c18b5242a3 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -23,7 +23,7 @@ @gen_cluster(ncores=[]) def test_nanny(s): - n = yield Nanny(s.ip, s.port, ncores=2, loop=s.loop) + n = yield Nanny(s.address, ncores=2, loop=s.loop) with rpc(n.address) as nn: assert n.is_alive() @@ -70,7 +70,7 @@ def test_str(s, a, b): @gen_cluster(ncores=[], timeout=20, client=True) def test_nanny_process_failure(c, s): - n = yield Nanny(s.ip, s.port, ncores=2, loop=s.loop) + n = yield Nanny(s.address, ncores=2, loop=s.loop) first_dir = n.worker_dir assert os.path.exists(first_dir) @@ -117,7 +117,7 @@ def test_nanny_no_port(): @gen_cluster(ncores=[]) def test_run(s): pytest.importorskip("psutil") - n = yield Nanny(s.ip, s.port, ncores=2, loop=s.loop) + n = yield Nanny(s.address, ncores=2, loop=s.loop) with rpc(n.address) as nn: response = yield nn.run(function=dumps(lambda: 1)) diff --git a/distributed/tests/test_publish.py b/distributed/tests/test_publish.py index e4789589c48..7c0fd0db6d2 100644 --- a/distributed/tests/test_publish.py +++ b/distributed/tests/test_publish.py @@ -12,8 +12,8 @@ @gen_cluster(client=False) def test_publish_simple(s, a, b): - c = yield Client((s.ip, s.port), asynchronous=True) - f = yield Client((s.ip, s.port), asynchronous=True) + c = yield Client(s.address, asynchronous=True) + f = yield Client(s.address, asynchronous=True) data = yield c.scatter(range(3)) out = yield c.publish_dataset(data=data) @@ -38,8 +38,8 @@ def test_publish_simple(s, a, b): @gen_cluster(client=False) def test_publish_non_string_key(s, a, b): - c = yield Client((s.ip, s.port), asynchronous=True) - f = yield Client((s.ip, s.port), asynchronous=True) + c = yield Client(s.address, asynchronous=True) + f = yield Client(s.address, asynchronous=True) try: for name in [("a", "b"), 9.0, 8]: @@ -60,8 +60,8 @@ def test_publish_non_string_key(s, a, b): @gen_cluster(client=False) def test_publish_roundtrip(s, a, b): - c = yield Client((s.ip, s.port), asynchronous=True) - f = yield Client((s.ip, s.port), asynchronous=True) + c = yield Client(s.address, asynchronous=True) + f = yield Client(s.address, asynchronous=True) data = yield c.scatter([0, 1, 2]) yield c.publish_dataset(data=data) @@ -156,8 +156,8 @@ def test_unpublish_multiple_datasets_sync(client): @gen_cluster(client=False) def test_publish_bag(s, a, b): db = pytest.importorskip("dask.bag") - c = yield Client((s.ip, s.port), asynchronous=True) - f = yield Client((s.ip, s.port), asynchronous=True) + c = yield Client(s.address, asynchronous=True) + f = yield Client(s.address, asynchronous=True) bag = db.from_sequence([0, 1, 2]) bagp = c.persist(bag) diff --git a/distributed/tests/test_resources.py b/distributed/tests/test_resources.py index 429bbc2bb56..d7102ef5301 100644 --- a/distributed/tests/test_resources.py +++ b/distributed/tests/test_resources.py @@ -19,8 +19,8 @@ def test_resources(c, s): assert not s.worker_resources assert not s.resources - a = Worker(s.ip, s.port, loop=s.loop, resources={"GPU": 2}) - b = Worker(s.ip, s.port, loop=s.loop, resources={"GPU": 1, "DB": 1}) + a = Worker(s.address, loop=s.loop, resources={"GPU": 2}) + b = Worker(s.address, loop=s.loop, resources={"GPU": 1, "DB": 1}) yield [a, b] @@ -55,7 +55,7 @@ def test_resource_submit(c, s, a, b): assert s.get_task_status(keys=[z.key]) == {z.key: "no-worker"} - d = yield Worker(s.ip, s.port, loop=s.loop, resources={"C": 10}) + d = yield Worker(s.address, loop=s.loop, resources={"C": 10}) yield wait(z) assert z.key in d.data diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 73d10ab8a55..da750dd9196 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -253,7 +253,7 @@ def test_clear_events_client_removal(c, s, a, b): @gen_cluster() def test_add_worker(s, a, b): - w = Worker(s.ip, s.port, ncores=3) + w = Worker(s.address, ncores=3) w.data["x-5"] = 6 w.data["y"] = 1 yield w @@ -533,12 +533,12 @@ def test_broadcast_nanny(s, a, b): @gen_test() def test_worker_name(): s = yield Scheduler(validate=True, port=0) - w = yield Worker(s.ip, s.port, name="alice") + w = yield Worker(s.address, name="alice") assert s.workers[w.address].name == "alice" assert s.aliases["alice"] == w.address with pytest.raises(ValueError): - w2 = yield Worker(s.ip, s.port, name="alice") + w2 = yield Worker(s.address, name="alice") yield w2.close() yield w.close() @@ -550,8 +550,8 @@ def test_coerce_address(): with dask.config.set({"distributed.comm.timeouts.connect": "100ms"}): s = yield Scheduler(validate=True, port=0) print("scheduler:", s.address, s.listen_address) - a = Worker(s.ip, s.port, name="alice") - b = Worker(s.ip, s.port, name=123) + a = Worker(s.address, name="alice") + b = Worker(s.address, name=123) c = Worker("127.0.0.1", s.port, name="charlie") yield [a, b, c] @@ -594,7 +594,7 @@ def test_file_descriptors_dont_leak(s): proc = psutil.Process() before = proc.num_fds() - w = yield Worker(s.ip, s.port) + w = yield Worker(s.address) yield w.close() during = proc.num_fds() @@ -661,7 +661,7 @@ def test_scatter_no_workers(c, s): yield c.scatter(123, timeout=0.1) assert time() < start + 1.5 - w = Worker(s.ip, s.port, ncores=3) + w = Worker(s.address, ncores=3) yield [c.scatter(data={"y": 2}, timeout=5), w._start()] assert w.data["y"] == 2 @@ -670,7 +670,7 @@ def test_scatter_no_workers(c, s): @gen_cluster(ncores=[]) def test_scheduler_sees_memory_limits(s): - w = yield Worker(s.ip, s.port, ncores=3, memory_limit=12345) + w = yield Worker(s.address, ncores=3, memory_limit=12345) assert s.workers[w.address].memory_limit == 12345 yield w.close() @@ -788,7 +788,7 @@ def test_file_descriptors(c, s): num_fds_1 = proc.num_fds() N = 20 - nannies = yield [Nanny(s.ip, s.port, loop=s.loop) for i in range(N)] + nannies = yield [Nanny(s.address, loop=s.loop) for i in range(N)] while len(s.ncores) < N: yield gen.sleep(0.1) @@ -926,7 +926,7 @@ def test_worker_arrives_with_processing_data(c, s, a, b): while not any(w.processing for w in s.workers.values()): yield gen.sleep(0.01) - w = Worker(s.ip, s.port, ncores=1) + w = Worker(s.address, ncores=1) w.put_key_in_memory(y.key, 3) yield w @@ -977,7 +977,7 @@ def test_no_workers_to_memory(c, s): while not s.tasks: yield gen.sleep(0.01) - w = Worker(s.ip, s.port, ncores=1) + w = Worker(s.address, ncores=1) w.put_key_in_memory(y.key, 3) yield w @@ -1007,7 +1007,7 @@ def test_no_worker_to_memory_restrictions(c, s, a, b): while not s.tasks: yield gen.sleep(0.01) - w = Worker(s.ip, s.port, ncores=1, name="alice") + w = Worker(s.address, ncores=1, name="alice") w.put_key_in_memory(y.key, 3) yield w diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index cb56fc0f263..8edeb8e339c 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -172,7 +172,7 @@ def test_new_worker_steals(c, s, a): while len(a.task_state) < 10: yield gen.sleep(0.01) - b = yield Worker(s.ip, s.port, loop=s.loop, ncores=1, memory_limit=TOTAL_MEMORY) + b = yield Worker(s.address, loop=s.loop, ncores=1, memory_limit=TOTAL_MEMORY) result = yield total assert result == sum(map(inc, range(100))) @@ -277,7 +277,7 @@ def test_steal_resource_restrictions(c, s, a): yield gen.sleep(0.01) assert len(a.task_state) == 101 - b = yield Worker(s.ip, s.port, loop=s.loop, ncores=1, resources={"A": 4}) + b = yield Worker(s.address, loop=s.loop, ncores=1, resources={"A": 4}) start = time() while not b.task_state or len(a.task_state) == 101: @@ -536,7 +536,7 @@ def test_steal_twice(c, s, a, b): yield gen.sleep(0.01) # Army of new workers arrives to help - workers = yield [Worker(s.ip, s.port, loop=s.loop) for _ in range(20)] + workers = yield [Worker(s.address, loop=s.loop) for _ in range(20)] yield wait(futures) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 5fad86b2665..16a685fb09b 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -437,7 +437,7 @@ def f(dask_worker=None): @gen_cluster(client=True, ncores=[]) def test_Executor(c, s): with ThreadPoolExecutor(2) as e: - w = Worker(s.ip, s.port, executor=e) + w = Worker(s.address, executor=e) assert w.executor is e w = yield w From 2dc778d34ed68660885e8be8b87cf94c021458da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Torsten=20W=C3=B6rtwein?= Date: Fri, 10 May 2019 20:21:05 -0400 Subject: [PATCH 0280/1550] Cap worker's memory limit by the hard limit of the maximum resident memory (#2665) --- distributed/tests/test_worker.py | 24 +++++++++++++++++++++- distributed/worker.py | 34 +++++++++++++++++++------------- 2 files changed, 43 insertions(+), 15 deletions(-) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 16a685fb09b..77df078fa89 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -25,7 +25,7 @@ from distributed.client import wait from distributed.scheduler import Scheduler from distributed.metrics import time -from distributed.worker import Worker, error_message, logger +from distributed.worker import Worker, error_message, logger, parse_memory_limit from distributed.utils import tmpfile, format_bytes from distributed.utils_test import ( inc, @@ -1401,3 +1401,25 @@ def test_host_address(c, s): assert "127.0.0.3" in n.address assert "127.0.0.3" in n.worker_address yield n.close() + + +def test_resource_limit(): + assert parse_memory_limit("250MiB", 1, total_cores=1) == 1024 * 1024 * 250 + + # get current limit + resource = pytest.importorskip("resource") + try: + hard_limit = resource.getrlimit(resource.RLIMIT_RSS)[1] + except OSError: + pytest.skip("resource could not get the RSS limit") + memory_limit = psutil.virtual_memory().total + if hard_limit > memory_limit or hard_limit < 0: + hard_limit = memory_limit + + # decrease memory limit by one byte + new_limit = hard_limit - 1 + try: + resource.setrlimit(resource.RLIMIT_RSS, (new_limit, new_limit)) + assert parse_memory_limit(hard_limit, 1, total_cores=1) == new_limit + except OSError: + pytest.skip("resource could not set the RSS limit") diff --git a/distributed/worker.py b/distributed/worker.py index 64eca0cc770..5302d94c6de 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -12,6 +12,7 @@ import sys import warnings import weakref +import psutil import dask from dask.core import istask @@ -76,15 +77,7 @@ no_value = "--no-value-sentinel--" -try: - import psutil - - TOTAL_MEMORY = psutil.virtual_memory().total -except ImportError: - logger.warning("Please install psutil to estimate worker memory use") - TOTAL_MEMORY = 8e9 - psutil = None - +TOTAL_MEMORY = psutil.virtual_memory().total IN_PLAY = ("waiting", "ready", "executing", "long-running") PENDING = ("waiting", "ready", "constrained") @@ -2933,17 +2926,30 @@ class Reschedule(Exception): def parse_memory_limit(memory_limit, ncores, total_cores=_ncores): if memory_limit is None: return None + if memory_limit == "auto": memory_limit = int(TOTAL_MEMORY * min(1, ncores / total_cores)) with ignoring(ValueError, TypeError): - x = float(memory_limit) - if isinstance(x, float) and x <= 1: - return int(x * TOTAL_MEMORY) + memory_limit = float(memory_limit) + if isinstance(memory_limit, float) and memory_limit <= 1: + memory_limit = int(memory_limit * TOTAL_MEMORY) if isinstance(memory_limit, (unicode, str)): - return parse_bytes(memory_limit) + memory_limit = parse_bytes(memory_limit) else: - return int(memory_limit) + memory_limit = int(memory_limit) + + # should be less than hard RSS limit + try: + import resource + + hard_limit = resource.getrlimit(resource.RLIMIT_RSS)[1] + if hard_limit > 0: + memory_limit = min(memory_limit, hard_limit) + except (ImportError, OSError): + pass + + return memory_limit @gen.coroutine From 1cbd324248db1973f1cfeb393e2d0dc7a5490d41 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 10 May 2019 21:42:44 -0500 Subject: [PATCH 0281/1550] Add WeakSet _instances attributes to all classes (#2673) These help us track leaking workers, schedulers, and clients --- distributed/client.py | 3 +++ distributed/nanny.py | 3 +++ distributed/scheduler.py | 3 +++ distributed/tests/test_client.py | 11 +++++++--- distributed/tests/test_scheduler.py | 4 +++- distributed/tests/test_worker.py | 9 ++------ distributed/utils_test.py | 24 ++++++++++----------- distributed/worker.py | 33 +++++++++++------------------ 8 files changed, 46 insertions(+), 44 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 97728929f33..e31b808abcf 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -577,6 +577,8 @@ class resembles executors in ``concurrent.futures`` but also allows distributed.deploy.local.LocalCluster: """ + _instances = weakref.WeakSet() + def __init__( self, address=None, @@ -710,6 +712,7 @@ def __init__( ext(self) self.start(timeout=timeout) + Client._instances.add(self) from distributed.recreate_exceptions import ReplayExceptionClient diff --git a/distributed/nanny.py b/distributed/nanny.py index 4bc0eeef6f9..7be630cb4e9 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -9,6 +9,7 @@ import threading import uuid import warnings +import weakref import dask from tornado import gen @@ -42,6 +43,7 @@ class Nanny(ServerNode): them as necessary. """ + _instances = weakref.WeakSet() process = None status = None @@ -149,6 +151,7 @@ def __init__( ) self._listen_address = listen_address + Nanny._instances.add(self) self.status = "init" def __repr__(self): diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 87cc4fda8a4..68a80ac664b 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -12,6 +12,7 @@ import pickle import random import six +import weakref import psutil import sortedcontainers @@ -819,6 +820,7 @@ class Scheduler(ServerNode): """ default_port = 8786 + _instances = weakref.WeakSet() def __init__( self, @@ -1113,6 +1115,7 @@ def __init__( ext(self) setproctitle("dask-scheduler [not started]") + Scheduler._instances.add(self) ################## # Administration # diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index bce1066c0d2..5144c0f0868 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -4829,9 +4829,7 @@ def f(x): def test_get_client_no_cluster(): # Clean up any global workers added by other tests. This test requires that # there are no global workers. - from distributed.worker import _global_workers - - del _global_workers[:] + Worker._instances.clear() msg = "No global client found and no address provided" with pytest.raises(ValueError, match=r"^{}$".format(msg)): @@ -5707,5 +5705,12 @@ def test_direct_to_workers(s, loop): assert "gather" not in str(resp) +@gen_cluster(client=True) +def test_instances(c, s, a, b): + assert list(Client._instances) == [c] + assert list(Scheduler._instances) == [s] + assert set(Worker._instances) == {a, b} + + if sys.version_info >= (3, 5): from distributed.tests.py3_test_client import * # noqa F401 diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index da750dd9196..ceb992d6e22 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -821,7 +821,9 @@ def test_file_descriptors(c, s): yield [n.close() for n in nannies] assert not s.rpc.open - assert not any(occ for addr, occ in c.rpc.occupied.items() if occ != s.address) + assert not any( + occ for addr, occ in c.rpc.occupied.items() if occ != s.address + ), list(c.rpc._created) assert not s.stream_comms start = time() diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 77df078fa89..d8ca4d31481 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -934,14 +934,9 @@ def f(): @gen_cluster() def test_global_workers(s, a, b): - from distributed.worker import _global_workers - - n = len(_global_workers) - w = _global_workers[-1]() + n = len(Worker._instances) + w = first(Worker._instances) assert w is a or w is b - yield a.close() - yield b.close() - assert len(_global_workers) == n - 2 @pytest.mark.skipif(WINDOWS, reason="file descriptors") diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 0a1cf447cfd..4cef981cd08 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -60,7 +60,8 @@ iscoroutinefunction, thread_state, ) -from .worker import Worker, TOTAL_MEMORY, _global_workers +from .worker import Worker, TOTAL_MEMORY +from .nanny import Nanny try: import dask.array # register config @@ -109,14 +110,13 @@ def invalid_python_script(tmpdir_factory): @gen.coroutine def cleanup_global_workers(): - for w in _global_workers: - w = w() - w.close(report=False, executor_wait=False) + for worker in Worker._instances: + worker.close(report=False, executor_wait=False) @pytest.fixture def loop(): - del _global_workers[:] + Worker._instances.clear() _global_clients.clear() with pristine_loop() as loop: # Monkey-patch IOLoop.start to wait for loop stop @@ -146,7 +146,7 @@ def start(): pass else: is_stopped.wait() - del _global_workers[:] + Worker._instances.clear() start = time() while set(_global_clients): @@ -511,8 +511,6 @@ def wait_until_closed(): def run_nanny(q, scheduler_q, **kwargs): - from distributed import Nanny - with log_errors(): with pristine_loop() as loop: scheduler_addr = scheduler_q.get() @@ -924,7 +922,10 @@ def _(func): func = gen.coroutine(func) def test_func(): - del _global_workers[:] + Client._instances.clear() + Worker._instances.clear() + Scheduler._instances.clear() + Nanny._instances.clear() _global_clients.clear() Comm._instances.clear() active_threads_start = set(threading._active) @@ -1029,12 +1030,11 @@ def coro(): pass del w.data DequeHandler.clear_all_instances() - for w in _global_workers: - w = w() + for w in Worker._instances: w.close(report=False, executor_wait=False) if w.status == "running": w.close() - del _global_workers[:] + Worker._instances.clear() if PY3 and not WINDOWS and check_new_threads: start = time() diff --git a/distributed/worker.py b/distributed/worker.py index 5302d94c6de..f5f4f2bdd50 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -19,9 +19,9 @@ from dask.compatibility import apply try: - from cytoolz import pluck, partial, merge + from cytoolz import pluck, partial, merge, first except ImportError: - from toolz import pluck, partial, merge + from toolz import pluck, partial, merge, first from tornado.gen import Return from tornado import gen from tornado.ioloop import IOLoop @@ -87,8 +87,6 @@ DEFAULT_EXTENSIONS = [PubSubWorkerExtension] -_global_workers = [] - class Worker(ServerNode): """ Worker node in a Dask distributed cluster @@ -275,6 +273,8 @@ class Worker(ServerNode): distributed.nanny.Nanny """ + _instances = weakref.WeakSet() + def __init__( self, scheduler_ip=None, @@ -630,7 +630,7 @@ def __init__( ) self.periodic_callbacks["profile-cycle"] = pc - _global_workers.append(weakref.ref(self)) + Worker._instances.add(self) ################## # Administrative # @@ -1022,23 +1022,12 @@ def close(self, report=True, timeout=10, nanny=True, executor_wait=True): self.rpc.close() self._closed.set() - self._remove_from_global_workers() self.status = "closed" yield ServerNode.close(self) setproctitle("dask-worker [closed]") - def __del__(self): - self._remove_from_global_workers() - - def _remove_from_global_workers(self): - for ref in list(_global_workers): - if ref() is self: - _global_workers.remove(ref) - if ref() is None: - _global_workers.remove(ref) - @gen.coroutine def terminate(self, comm, report=True): yield self.close(report=report) @@ -2807,11 +2796,10 @@ def get_worker(): try: return thread_state.execution_state["worker"] except AttributeError: - for ref in _global_workers[::-1]: - worker = ref() - if worker: - return worker - raise ValueError("No workers found") + try: + return first(Worker._instances) + except StopIteration: + raise ValueError("No workers found") def get_client(address=None, timeout=3, resolve_address=True): @@ -3299,3 +3287,6 @@ def run(server, comm, function, args=(), kwargs={}, is_coro=None, wait=True): else: response = {"status": "OK", "result": to_serialize(result)} raise Return(response) + + +_global_workers = Worker._instances From ee1008416eb73418c6b656a1ab62af62c2087ffc Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sun, 12 May 2019 10:03:36 -0500 Subject: [PATCH 0282/1550] Organize thread/process/instance checking in utils_test.py (#2687) This collects various state checking functionality in one place. It also makes some modifications that were previously causing Dask tests to fail. --- distributed/utils_test.py | 546 +++++++++++++++++++------------------- 1 file changed, 280 insertions(+), 266 deletions(-) diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 4cef981cd08..e40912fcb33 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -148,15 +148,7 @@ def start(): is_stopped.wait() Worker._instances.clear() - start = time() - while set(_global_clients): - sleep(0.1) - assert time() < start + 10 - _cleanup_dangling() - - assert_no_leaked_processes() - _global_clients.clear() @@ -630,158 +622,123 @@ def cluster( nworkers=2, nanny=False, worker_kwargs={}, active_rpc_timeout=1, scheduler_kwargs={} ): ws = weakref.WeakSet() - - reset_config() - Comm._instances.clear() - - for name, level in logging_levels.items(): - logging.getLogger(name).setLevel(level) - enable_proctitle_on_children() - with pristine_loop() as loop: - with check_active_rpc(loop, active_rpc_timeout): - if nanny: - _run_worker = run_nanny - else: - _run_worker = run_worker + with clean(timeout=active_rpc_timeout, threads=False) as loop: + if nanny: + _run_worker = run_nanny + else: + _run_worker = run_worker - # The scheduler queue will receive the scheduler's address - scheduler_q = mp_context.Queue() + # The scheduler queue will receive the scheduler's address + scheduler_q = mp_context.Queue() - # Launch scheduler - scheduler = mp_context.Process( - name="Dask cluster test: Scheduler", - target=run_scheduler, - args=(scheduler_q, nworkers + 1), - kwargs=scheduler_kwargs, + # Launch scheduler + scheduler = mp_context.Process( + name="Dask cluster test: Scheduler", + target=run_scheduler, + args=(scheduler_q, nworkers + 1), + kwargs=scheduler_kwargs, + ) + ws.add(scheduler) + scheduler.daemon = True + scheduler.start() + + # Launch workers + workers = [] + for i in range(nworkers): + q = mp_context.Queue() + fn = "_test_worker-%s" % uuid.uuid4() + kwargs = merge( + {"ncores": 1, "local_dir": fn, "memory_limit": TOTAL_MEMORY}, + worker_kwargs, ) - ws.add(scheduler) - scheduler.daemon = True - scheduler.start() - - # Launch workers - workers = [] - for i in range(nworkers): - q = mp_context.Queue() - fn = "_test_worker-%s" % uuid.uuid4() - kwargs = merge( - {"ncores": 1, "local_dir": fn, "memory_limit": TOTAL_MEMORY}, - worker_kwargs, - ) - proc = mp_context.Process( - name="Dask cluster test: Worker", - target=_run_worker, - args=(q, scheduler_q), - kwargs=kwargs, - ) - ws.add(proc) - workers.append({"proc": proc, "queue": q, "dir": fn}) + proc = mp_context.Process( + name="Dask cluster test: Worker", + target=_run_worker, + args=(q, scheduler_q), + kwargs=kwargs, + ) + ws.add(proc) + workers.append({"proc": proc, "queue": q, "dir": fn}) + for worker in workers: + worker["proc"].start() + try: for worker in workers: - worker["proc"].start() - try: - for worker in workers: - worker["address"] = worker["queue"].get(timeout=5) - except Empty: - raise pytest.xfail.Exception("Worker failed to start in test") + worker["address"] = worker["queue"].get(timeout=5) + except Empty: + raise pytest.xfail.Exception("Worker failed to start in test") - saddr = scheduler_q.get() + saddr = scheduler_q.get() - start = time() + start = time() + try: try: - try: - security = scheduler_kwargs["security"] - rpc_kwargs = { - "connection_args": security.get_connection_args("client") - } - except KeyError: - rpc_kwargs = {} - - with rpc(saddr, **rpc_kwargs) as s: - while True: - ncores = loop.run_sync(s.ncores) - if len(ncores) == nworkers: - break - if time() - start > 5: - raise Exception("Timeout on cluster creation") - - # avoid sending processes down to function - yield {"address": saddr}, [ - {"address": w["address"], "proc": weakref.ref(w["proc"])} - for w in workers - ] - finally: - logger.debug("Closing out test cluster") - - loop.run_sync( - lambda: disconnect_all( - [w["address"] for w in workers], - timeout=0.5, - rpc_kwargs=rpc_kwargs, - ) - ) - loop.run_sync( - lambda: disconnect(saddr, timeout=0.5, rpc_kwargs=rpc_kwargs) + security = scheduler_kwargs["security"] + rpc_kwargs = {"connection_args": security.get_connection_args("client")} + except KeyError: + rpc_kwargs = {} + + with rpc(saddr, **rpc_kwargs) as s: + while True: + ncores = loop.run_sync(s.ncores) + if len(ncores) == nworkers: + break + if time() - start > 5: + raise Exception("Timeout on cluster creation") + + # avoid sending processes down to function + yield {"address": saddr}, [ + {"address": w["address"], "proc": weakref.ref(w["proc"])} + for w in workers + ] + finally: + logger.debug("Closing out test cluster") + + loop.run_sync( + lambda: disconnect_all( + [w["address"] for w in workers], timeout=0.5, rpc_kwargs=rpc_kwargs ) + ) + loop.run_sync(lambda: disconnect(saddr, timeout=0.5, rpc_kwargs=rpc_kwargs)) - scheduler.terminate() - scheduler_q.close() - scheduler_q._reader.close() - scheduler_q._writer.close() + scheduler.terminate() + scheduler_q.close() + scheduler_q._reader.close() + scheduler_q._writer.close() - for w in workers: - w["proc"].terminate() - w["queue"].close() - w["queue"]._reader.close() - w["queue"]._writer.close() + for w in workers: + w["proc"].terminate() + w["queue"].close() + w["queue"]._reader.close() + w["queue"]._writer.close() - scheduler.join(2) - del scheduler - for proc in [w["proc"] for w in workers]: - proc.join(timeout=2) + scheduler.join(2) + del scheduler + for proc in [w["proc"] for w in workers]: + proc.join(timeout=2) - with ignoring(UnboundLocalError): - del worker, w, proc - del workers[:] + with ignoring(UnboundLocalError): + del worker, w, proc + del workers[:] - for fn in glob("_test_worker-*"): - with ignoring(OSError): - shutil.rmtree(fn) + for fn in glob("_test_worker-*"): + with ignoring(OSError): + shutil.rmtree(fn) - try: - client = default_client() - except ValueError: - pass - else: - client.close() + try: + client = default_client() + except ValueError: + pass + else: + client.close() start = time() while list(ws): sleep(0.01) assert time() < start + 1, "Workers still around after one second" - for i in range(5): - if all(c.closed() for c in Comm._instances): - break - else: - sleep(0.1) - else: - L = [c for c in Comm._instances if not c.closed()] - Comm._instances.clear() - print("Unclosed Comms", L) - # raise ValueError("Unclosed Comms", L) - - assert_no_leaked_processes() - - -def assert_no_leaked_processes(): - for i in range(20): - if mp_context.active_children(): - sleep(0.1) - else: - assert not mp_context.active_children() - @gen.coroutine def disconnect(addr, timeout=3, rpc_kwargs=None): @@ -922,147 +879,95 @@ def _(func): func = gen.coroutine(func) def test_func(): - Client._instances.clear() - Worker._instances.clear() - Scheduler._instances.clear() - Nanny._instances.clear() - _global_clients.clear() - Comm._instances.clear() - active_threads_start = set(threading._active) - - reset_config() - - dask.config.set({"distributed.comm.timeouts.connect": "5s"}) - # Restore default logging levels - # XXX use pytest hooks/fixtures instead? - for name, level in logging_levels.items(): - logging.getLogger(name).setLevel(level) - result = None workers = [] + with clean(threads=check_new_threads, timeout=active_rpc_timeout) as loop: - with pristine_loop() as loop: - with check_active_rpc(loop, active_rpc_timeout): - - @gen.coroutine - def coro(): - with dask.config.set(config): - s = False - for i in range(5): - try: - s, ws = yield start_cluster( - ncores, - scheduler, - loop, - security=security, - Worker=Worker, - scheduler_kwargs=scheduler_kwargs, - worker_kwargs=worker_kwargs, - ) - except Exception as e: - logger.error( - "Failed to start gen_cluster, retrying", - exc_info=True, - ) - else: - workers[:] = ws - args = [s] + workers - break - if s is False: - raise Exception("Could not start cluster") - if client: - c = yield Client( - s.address, - loop=loop, + @gen.coroutine + def coro(): + with dask.config.set(config): + s = False + for i in range(5): + try: + s, ws = yield start_cluster( + ncores, + scheduler, + loop, security=security, - asynchronous=True, - **client_kwargs + Worker=Worker, + scheduler_kwargs=scheduler_kwargs, + worker_kwargs=worker_kwargs, ) - args = [c] + args - try: - future = func(*args) - if timeout: - future = gen.with_timeout( - timedelta(seconds=timeout), future - ) - result = yield future - if s.validate: - s.validate_state() - finally: - if client and c.status not in ("closing", "closed"): - yield c._close(fast=s.status == "closed") - yield end_cluster(s, workers) - yield gen.with_timeout( - timedelta(seconds=1), cleanup_global_workers() + except Exception as e: + logger.error( + "Failed to start gen_cluster, retrying", + exc_info=True, ) - - try: - c = yield default_client() - except ValueError: - pass - else: - yield c._close(fast=True) - - for i in range(5): - if all(c.closed() for c in Comm._instances): - break - else: - yield gen.sleep(0.05) else: - L = [c for c in Comm._instances if not c.closed()] - Comm._instances.clear() - # raise ValueError("Unclosed Comms", L) - print("Unclosed Comms", L) - - raise gen.Return(result) - - result = loop.run_sync( - coro, timeout=timeout * 2 if timeout else timeout - ) + workers[:] = ws + args = [s] + workers + break + if s is False: + raise Exception("Could not start cluster") + if client: + c = yield Client( + s.address, + loop=loop, + security=security, + asynchronous=True, + **client_kwargs + ) + args = [c] + args + try: + future = func(*args) + if timeout: + future = gen.with_timeout( + timedelta(seconds=timeout), future + ) + result = yield future + if s.validate: + s.validate_state() + finally: + if client and c.status not in ("closing", "closed"): + yield c._close(fast=s.status == "closed") + yield end_cluster(s, workers) + yield gen.with_timeout( + timedelta(seconds=1), cleanup_global_workers() + ) - for w in workers: - if getattr(w, "data", None): try: - w.data.clear() - except EnvironmentError: - # zict backends can fail if their storage directory - # was already removed + c = yield default_client() + except ValueError: pass - del w.data - DequeHandler.clear_all_instances() - for w in Worker._instances: - w.close(report=False, executor_wait=False) - if w.status == "running": - w.close() - Worker._instances.clear() - - if PY3 and not WINDOWS and check_new_threads: - start = time() - while True: - bad = [ - t - for t, v in threading._active.items() - if t not in active_threads_start - and "Threaded" not in v.name - and "watch message" not in v.name - and "TCP-Executor" not in v.name - ] - if not bad: - break - else: - sleep(0.01) - if time() > start + 5: - from distributed import profile + else: + yield c._close(fast=True) - tid = bad[0] - thread = threading._active[tid] - call_stacks = profile.call_stack(sys._current_frames()[tid]) - assert False, (thread, call_stacks) - _cleanup_dangling() - with ignoring(AttributeError): - del thread_state.on_event_loop_thread + for i in range(5): + if all(c.closed() for c in Comm._instances): + break + else: + yield gen.sleep(0.05) + else: + L = [c for c in Comm._instances if not c.closed()] + Comm._instances.clear() + # raise ValueError("Unclosed Comms", L) + print("Unclosed Comms", L) - assert_no_leaked_processes() + raise gen.Return(result) + + result = loop.run_sync( + coro, timeout=timeout * 2 if timeout else timeout + ) + + for w in workers: + if getattr(w, "data", None): + try: + w.data.clear() + except EnvironmentError: + # zict backends can fail if their storage directory + # was already removed + pass + del w.data return result @@ -1510,3 +1415,112 @@ def gen_tls_cluster(**kwargs): return gen_cluster( scheduler="tls://127.0.0.1", security=tls_only_security(), **kwargs ) + + +@contextmanager +def check_thread_leak(): + active_threads_start = set(threading._active) + + yield + + start = time() + while True: + bad = [ + t + for t, v in threading._active.items() + if t not in active_threads_start + and "Threaded" not in v.name + and "watch message" not in v.name + and "TCP-Executor" not in v.name + ] + if not bad: + break + else: + sleep(0.01) + if time() > start + 5: + from distributed import profile + + tid = bad[0] + thread = threading._active[tid] + call_stacks = profile.call_stack(sys._current_frames()[tid]) + assert False, (thread, call_stacks) + + +@contextmanager +def check_process_leak(): + start_children = set(mp_context.active_children()) + + yield + + for i in range(50): + if not set(mp_context.active_children()) - start_children: + break + else: + sleep(0.2) + else: + assert not mp_context.active_children() + + _cleanup_dangling() + + +@contextmanager +def check_instances(): + Client._instances.clear() + Worker._instances.clear() + Scheduler._instances.clear() + Nanny._instances.clear() + _global_clients.clear() + Comm._instances.clear() + + yield + + start = time() + while set(_global_clients): + sleep(0.1) + assert time() < start + 10 + + _global_clients.clear() + + for w in Worker._instances: + w.close(report=False, executor_wait=False) + if w.status == "running": + w.close() + Worker._instances.clear() + + for i in range(5): + if all(c.closed() for c in Comm._instances): + break + else: + sleep(0.1) + else: + L = [c for c in Comm._instances if not c.closed()] + Comm._instances.clear() + print("Unclosed Comms", L) + # raise ValueError("Unclosed Comms", L) + + DequeHandler.clear_all_instances() + + +@contextmanager +def clean(threads=not WINDOWS, processes=True, instances=True, timeout=1): + @contextmanager + def null(): + yield + + with check_thread_leak() if threads else null(): + with pristine_loop() as loop: + with check_process_leak() if processes else null(): + with check_instances() if instances else null(): + with check_active_rpc(loop, timeout): + reset_config() + + dask.config.set({"distributed.comm.timeouts.connect": "5s"}) + # Restore default logging levels + # XXX use pytest hooks/fixtures instead? + for name, level in logging_levels.items(): + logging.getLogger(name).setLevel(level) + + yield loop + + with ignoring(AttributeError): + del thread_state.on_event_loop_thread From 3142dda225baaf1db0b8c73f240f7fad8941fc1e Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 13 May 2019 08:11:20 -0500 Subject: [PATCH 0283/1550] Learn bandwidth over time (#2658) In order to schedule tasks intelligently we need to know how long communications will take. To do this, we need to estimate the bandwidth of the network. This can vary by orders of magnitude depending on hardwware. Previously we asked the user to specify this in configuration. Now we learn it over time. Each worker keeps an exponentially weighted moving average for all of its data communications. It sends this information to the scheduler as part of the heartbeats (which include lots of other diagnostic information). The scheduler updates its own measurement accordingly. --- distributed/scheduler.py | 9 ++++++--- distributed/stealing.py | 3 +-- distributed/tests/test_scheduler.py | 28 +++++++++++++++++++++------- distributed/tests/test_steal.py | 8 ++++---- distributed/utils.py | 2 ++ distributed/worker.py | 11 +++++++++-- 6 files changed, 43 insertions(+), 18 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 68a80ac664b..446b479c769 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -54,6 +54,7 @@ no_default, DequeHandler, parse_timedelta, + parse_bytes, PeriodicCallback, shutting_down, ) @@ -72,7 +73,6 @@ logger = logging.getLogger(__name__) -BANDWIDTH = dask.config.get("distributed.scheduler.bandwidth") ALLOWED_FAILURES = dask.config.get("distributed.scheduler.allowed-failures") LOG_PDB = dask.config.get("distributed.admin.pdb-on-err") @@ -868,6 +868,7 @@ def __init__( else: self.idle_timeout = None self.time_started = time() + self.bandwidth = parse_bytes(dask.config.get("distributed.scheduler.bandwidth")) self.security = security or Security() assert isinstance(self.security, Security) @@ -1359,6 +1360,8 @@ def heartbeat_worker( host_info = host_info or {} self.host_info[host]["last-seen"] = local_now + frac = 1 / 20 / len(self.workers) + self.bandwidth = self.bandwidth * (1 - frac) + metrics["bandwidth"] * frac ws = self.workers.get(address) if not ws: @@ -3336,7 +3339,7 @@ def get_comm_cost(self, ts, ws): Get the estimated communication cost (in s.) to compute the task on the given worker. """ - return sum(dts.nbytes for dts in ts.dependencies - ws.has_what) / BANDWIDTH + return sum(dts.nbytes for dts in ts.dependencies - ws.has_what) / self.bandwidth def get_task_duration(self, ts, default=0.5): """ @@ -4522,7 +4525,7 @@ def worker_objective(self, ts, ws): [dts.get_nbytes() for dts in ts.dependencies if ws not in dts.who_has] ) stack_time = ws.occupancy / ws.ncores - start_time = comm_bytes / BANDWIDTH + stack_time + start_time = comm_bytes / self.bandwidth + stack_time if ts.actor: return (len(ws.actors), start_time, ws.nbytes) diff --git a/distributed/stealing.py b/distributed/stealing.py index d361305b105..dc8c989e39d 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -15,7 +15,6 @@ except ImportError: from toolz import topk -BANDWIDTH = 100e6 LATENCY = 10e-3 log_2 = log(2) @@ -134,7 +133,7 @@ def steal_time_ratio(self, ts): nbytes = sum(dep.get_nbytes() for dep in ts.dependencies) - transfer_time = nbytes / BANDWIDTH + LATENCY + transfer_time = nbytes / self.scheduler.bandwidth + LATENCY split = ts.prefix if split in fast_tasks: return None, None diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index ceb992d6e22..9f61e5e710e 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -5,7 +5,7 @@ from collections import defaultdict from datetime import timedelta import json -from operator import add, mul +import operator import sys from time import sleep @@ -18,7 +18,7 @@ from distributed import Nanny, Worker, Client, wait, fire_and_forget from distributed.core import connect, rpc -from distributed.scheduler import Scheduler, BANDWIDTH +from distributed.scheduler import Scheduler from distributed.client import wait from distributed.metrics import time from distributed.protocol.pickle import dumps @@ -63,7 +63,7 @@ def test_respect_data_in_memory(c, s, a): assert s.tasks[y.key].who_has == {s.workers[a.address]} - z = delayed(add)(x, y) + z = delayed(operator.add)(x, y) f2 = c.persist(z) while f2.key not in s.tasks or not s.tasks[f2.key]: assert s.tasks[y.key].who_has @@ -427,7 +427,10 @@ def test_filtered_communication(s, a, b): yield f.write( { "op": "update-graph", - "tasks": {"x": dumps_task((inc, 1)), "z": dumps_task((add, "x", 10))}, + "tasks": { + "x": dumps_task((inc, 1)), + "z": dumps_task((operator.add, "x", 10)), + }, "dependencies": {"x": [], "z": ["x"]}, "client": "f", "keys": ["z"], @@ -903,8 +906,8 @@ def test_learn_occupancy_multiple_workers(c, s, a, b): @gen_cluster(client=True) def test_include_communication_in_occupancy(c, s, a, b): s.task_duration["slowadd"] = 0.001 - x = c.submit(mul, b"0", int(BANDWIDTH), workers=a.address) - y = c.submit(mul, b"1", int(BANDWIDTH * 1.5), workers=b.address) + x = c.submit(operator.mul, b"0", int(s.bandwidth), workers=a.address) + y = c.submit(operator.mul, b"1", int(s.bandwidth * 1.5), workers=b.address) z = c.submit(slowadd, x, y, delay=1) while z.key not in s.tasks or not s.tasks[z.key].processing_on: @@ -1375,7 +1378,7 @@ def test_dont_recompute_if_persisted_3(c, s, a, b): x = delayed(inc)(1, dask_key_name="x") y = delayed(inc)(2, dask_key_name="y") z = delayed(inc)(y, dask_key_name="z") - w = delayed(add)(x, z, dask_key_name="w") + w = delayed(operator.add)(x, z, dask_key_name="w") ww = w.persist() yield wait(ww) @@ -1513,6 +1516,17 @@ def test_idle_timeout(c, s, a, b): assert b.status == "closed" +@gen_cluster(client=True, config={"distributed.scheduler.bandwidth": "100 GB"}) +def test_bandwidth(c, s, a, b): + start = s.bandwidth + x = c.submit(operator.mul, b"0", 20000, workers=a.address) + y = c.submit(lambda x: x, x, workers=b.address) + yield y + yield b.heartbeat() + assert s.bandwidth < start # we've learned that we're slower + assert b.latency + + @gen_cluster() def test_workerstate_clean(s, a, b): ws = s.workers[a.address].clean() diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 8edeb8e339c..7348d164c72 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -14,7 +14,7 @@ from distributed import Nanny, Worker, wait, worker_client from distributed.config import config from distributed.metrics import time -from distributed.scheduler import BANDWIDTH, key_split +from distributed.scheduler import key_split from distributed.utils_test import ( slowinc, slowadd, @@ -394,7 +394,7 @@ def assert_balanced(inp, expected, c, s, *workers): ts = s.tasks[dat.key] # Ensure scheduler state stays consistent old_nbytes = ts.nbytes - ts.nbytes = BANDWIDTH * t + ts.nbytes = s.bandwidth * t for ws in ts.who_has: ws.nbytes += ts.nbytes - old_nbytes else: @@ -499,8 +499,8 @@ def test_restart(c, s, a, b): def test_steal_communication_heavy_tasks(c, s, a, b): steal = s.extensions["stealing"] s.task_duration["slowadd"] = 0.001 - x = c.submit(mul, b"0", int(BANDWIDTH), workers=a.address) - y = c.submit(mul, b"1", int(BANDWIDTH), workers=b.address) + x = c.submit(mul, b"0", int(s.bandwidth), workers=a.address) + y = c.submit(mul, b"1", int(s.bandwidth), workers=b.address) futures = [ c.submit( diff --git a/distributed/utils.py b/distributed/utils.py index d6cc5ba62cf..765035e5c13 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -1221,6 +1221,8 @@ def parse_bytes(s): >>> parse_bytes('MB') 1000000 """ + if isinstance(s, (int, float)): + return int(s) s = s.replace(" ", "") if not s[0].isdigit(): s = "1" + s diff --git a/distributed/worker.py b/distributed/worker.py index f5f4f2bdd50..8b8c139f356 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -400,6 +400,8 @@ def __init__( self.outgoing_count = 0 self.outgoing_current_count = 0 self.repetitively_busy = 0 + self.bandwidth = parse_bytes(dask.config.get("distributed.scheduler.bandwidth")) + self.latency = 0.001 self._client = None profile_cycle_interval = kwargs.pop( @@ -673,6 +675,7 @@ def get_metrics(self): in_memory=len(self.data), ready=len(self.ready), in_flight=len(self.in_flight_tasks), + bandwidth=self.bandwidth, ) custom = {k: metric(self) for k, metric in self.metrics.items()} @@ -742,6 +745,7 @@ def _register_with_scheduler(self): response = yield future _end = time() middle = (_start + _end) / 2 + self.latency = (_end - start) * 0.05 + self.latency * 0.95 self.scheduler_delay = response["time"] - middle self.status = "running" break @@ -1837,7 +1841,8 @@ def gather_dep(self, worker, dep, deps, total_nbytes, cause=None): ) total_bytes = sum(self.nbytes.get(dep, 0) for dep in response["data"]) - duration = (stop - start) or 0.5 + duration = (stop - start) or 0.010 + bandwidth = total_bytes / duration self.incoming_transfer_log.append( { "start": start + self.scheduler_delay, @@ -1848,10 +1853,12 @@ def gather_dep(self, worker, dep, deps, total_nbytes, cause=None): dep: self.nbytes.get(dep, None) for dep in response["data"] }, "total": total_bytes, - "bandwidth": total_bytes / duration, + "bandwidth": bandwidth, "who": worker, } ) + if total_bytes > 10000: + self.bandwidth = self.bandwidth * 0.95 + bandwidth * 0.05 if self.digests is not None: self.digests["transfer-bandwidth"].add(total_bytes / duration) self.digests["transfer-duration"].add(duration) From e3ffb9d12824b5a5434c5c15a2b6edfa7bf9ea12 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Thu, 9 May 2019 08:01:35 -0500 Subject: [PATCH 0284/1550] Use config accessor method for "scheduler-address" (#2676) --- distributed/cli/dask_worker.py | 8 ++++++-- distributed/cli/tests/test_dask_worker.py | 12 ++++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/distributed/cli/dask_worker.py b/distributed/cli/dask_worker.py index a0bc801a960..1e0ebe24176 100755 --- a/distributed/cli/dask_worker.py +++ b/distributed/cli/dask_worker.py @@ -8,8 +8,8 @@ import warnings import click +import dask from distributed import Nanny, Worker -from distributed.config import config from distributed.utils import get_ip_interface, parse_timedelta from distributed.worker import _ncores from distributed.security import Security @@ -322,7 +322,11 @@ def del_pid_file(): kwargs["service_ports"] = {"nanny": nanny_port} t = Worker - if not scheduler and not scheduler_file and "scheduler-address" not in config: + if ( + not scheduler + and not scheduler_file + and dask.config.get("scheduler-address", None) is None + ): raise ValueError( "Need to provide scheduler address like\n" "dask-worker SCHEDULER_ADDRESS:8786" diff --git a/distributed/cli/tests/test_dask_worker.py b/distributed/cli/tests/test_dask_worker.py index 72084e53141..e1cfc8d5ad3 100644 --- a/distributed/cli/tests/test_dask_worker.py +++ b/distributed/cli/tests/test_dask_worker.py @@ -6,6 +6,7 @@ import requests import sys +import os from time import sleep from toolz import first @@ -142,6 +143,17 @@ def test_scheduler_file(loop, nanny): assert time() < start + 10 +def test_scheduler_address_env(loop, monkeypatch): + monkeypatch.setenv("DASK_SCHEDULER_ADDRESS", "tcp://127.0.0.1:8786") + with popen(["dask-scheduler", "--no-bokeh"]) as sched: + with popen(["dask-worker", "--no-bokeh"]): + with Client(os.environ["DASK_SCHEDULER_ADDRESS"], loop=loop) as c: + start = time() + while not c.scheduler_info()["workers"]: + sleep(0.1) + assert time() < start + 10 + + def test_nprocs_requires_nanny(loop): with popen(["dask-scheduler", "--no-bokeh"]) as sched: with popen( From 61be3a78375b730b78ce4e50aecea2fe432cdddd Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 13 May 2019 16:48:31 -0500 Subject: [PATCH 0285/1550] bump version to 1.28.1 --- docs/source/changelog.rst | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index 1b584c02219..613773a0c1c 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -1,6 +1,14 @@ Changelog ========= +1.28.1 - 2019-05-13 +------------------- + +This is a small bugfix release due to a config change upstream. + +- Use config accessor method for "scheduler-address" (#2676) `James Bourbeau`_ + + 1.28.0 - 2019-05-08 ------------------- @@ -1030,4 +1038,4 @@ significantly without many new features. .. _`Peter Andreas Entschev`: https://github.com/pentschev .. _`condoratberlin`: https://github.com/condoratberlin .. _`K.-Michael Aye`: https://github.com/michaelaye -.. _`@plbertrand`: https://github.com/plbertrand \ No newline at end of file +.. _`@plbertrand`: https://github.com/plbertrand From a622ce4ac48861b07a2738baba361b9886ee1ca3 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 13 May 2019 18:13:48 -0500 Subject: [PATCH 0286/1550] Remove support for Iterators and Queues (#2671) These add non-trivial code complexity, and don't seem to be commonly used (based on bug reports and SO questions). They're also a bit odd on our tests (there are some lingering threads as a result. This commit removes functionality for them and replaces them with informative warnings pointing people towards normal for loops. --- distributed/client.py | 156 +++++------------------ distributed/tests/test_client.py | 206 +------------------------------ distributed/tests/test_utils.py | 23 +--- distributed/utils.py | 38 +----- distributed/utils_test.py | 2 +- docs/source/queues.rst | 196 +---------------------------- 6 files changed, 40 insertions(+), 581 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index e31b808abcf..7b6e14aa4ef 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -55,7 +55,6 @@ from .cfexecutor import ClientExecutor from .compatibility import ( Queue as pyQueue, - Empty, isqueue, html_escape, StopAsyncIteration, @@ -77,7 +76,6 @@ sync, funcname, ignoring, - queue_to_iterator, tokey, log_errors, str_graph, @@ -1424,24 +1422,6 @@ def submit(self, func, *args, **kwargs): return futures[skey] - def _threaded_map(self, q_out, func, qs_in, **kwargs): - """ Internal function for mapping Queue """ - if isqueue(qs_in[0]): - get = pyQueue.get - elif isinstance(qs_in[0], Iterator): - get = next - else: - raise NotImplementedError() - - while True: - try: - args = [get(q) for q in qs_in] - except StopIteration as e: - q_out.put(e) - break - f = self.submit(func, *args, **kwargs) - q_out.put(f) - def map(self, func, *iterables, **kwargs): """ Map a function on a sequence of arguments @@ -1450,7 +1430,8 @@ def map(self, func, *iterables, **kwargs): Parameters ---------- func: callable - iterables: Iterables, Iterators, or Queues + iterables: Iterables + List-like objects to map over. They should have the same length. key: str, list Prefix for task names if string. Explicit names if list. pure: bool (defaults to True) @@ -1489,20 +1470,10 @@ def map(self, func, *iterables, **kwargs): if all(map(isqueue, iterables)) or all( isinstance(i, Iterator) for i in iterables ): - maxsize = kwargs.pop("maxsize", 0) - q_out = pyQueue(maxsize=maxsize) - t = threading.Thread( - target=self._threaded_map, - name="Threaded map()", - args=(q_out, func, iterables), - kwargs=kwargs, + raise TypeError( + "Dask no longer supports mapping over Iterators or Queues." + "Consider using a normal for loop and Client.submit" ) - t.daemon = True - t.start() - if isqueue(iterables[0]): - return q_out - else: - return queue_to_iterator(q_out) key = kwargs.pop("key", None) key = key or funcname(func) @@ -1738,22 +1709,7 @@ def _gather_remote(self, direct, local_worker): raise gen.Return(response) - def _threaded_gather(self, qin, qout, **kwargs): - """ Internal function for gathering Queue """ - while True: - L = [qin.get()] - while qin.empty(): - try: - L.append(qin.get_nowait()) - except Empty: - break - results = self.gather(L, **kwargs) - for item in results: - qout.put(item) - - def gather( - self, futures, errors="raise", maxsize=0, direct=None, asynchronous=None - ): + def gather(self, futures, errors="raise", direct=None, asynchronous=None): """ Gather futures from distributed memory Accepts a future, nested container of futures, iterator, or queue. @@ -1763,7 +1719,7 @@ def gather( ---------- futures: Collection of futures This can be a possibly nested collection of Future objects. - Collections can be lists, sets, iterators, queues or dictionaries + Collections can be lists, sets, or dictionaries errors: string Either 'raise' or 'skip' if we should raise if a future has erred or skip its inclusion in the output collection @@ -1771,9 +1727,6 @@ def gather( Whether or not to connect directly to the workers, or to ask the scheduler to serve as intermediary. This can also be set when creating the Client. - maxsize: int - If the input is a queue then this produces an output queue with a - maximum size. Returns ------- @@ -1790,25 +1743,16 @@ def gather( >>> c.gather([x, [x], x]) # support lists and dicts # doctest: +SKIP [3, [3], 3] - >>> seq = c.gather(iter([x, x])) # support iterators # doctest: +SKIP - >>> next(seq) # doctest: +SKIP - 3 - See Also -------- Client.scatter: Send data out to cluster """ if isqueue(futures): - qout = pyQueue(maxsize=maxsize) - t = threading.Thread( - target=self._threaded_gather, - name="Threaded gather()", - args=(futures, qout), - kwargs={"errors": errors, "direct": direct}, + raise TypeError( + "Dask no longer supports gathering over Iterators and Queues. " + "Consider using a normal for loop and Client.submit/gather" ) - t.daemon = True - t.start() - return qout + elif isinstance(futures, Iterator): return (self.gather(f, errors=errors, direct=direct) for f in futures) else: @@ -1935,27 +1879,6 @@ def _scatter( out = list(out.values())[0] raise gen.Return(out) - def _threaded_scatter(self, q_or_i, qout, **kwargs): - """ Internal function for scattering Iterable/Queue data """ - while True: - if isqueue(q_or_i): - L = [q_or_i.get()] - while not q_or_i.empty(): - try: - L.append(q_or_i.get_nowait()) - except Empty: - break - else: - try: - L = [next(q_or_i)] - except StopIteration as e: - qout.put(e) - break - - futures = self.scatter(L, **kwargs) - for future in futures: - qout.put(future) - def scatter( self, data, @@ -1963,7 +1886,6 @@ def scatter( broadcast=False, direct=None, hash=True, - maxsize=0, timeout=no_default, asynchronous=None, ): @@ -1976,7 +1898,7 @@ def scatter( Parameters ---------- - data: list, iterator, dict, Queue, or object + data: list, dict, or object Data to scatter out to workers. Output type matches input type. workers: list of tuples (optional) Optionally constrain locations of data. @@ -1988,8 +1910,6 @@ def scatter( Whether or not to connect directly to the workers, or to ask the scheduler to serve as intermediary. This can also be set when creating the Client. - maxsize: int (optional) - Maximum size of queue if using queues, 0 implies infinite hash: bool (optional) Whether or not to hash data to determine key. If False then this uses a random key @@ -2018,12 +1938,6 @@ def scatter( >>> c.scatter([1, 2, 3], workers=[('hostname', 8788)]) # doctest: +SKIP - Handle streaming sequences of data with iterators or queues - - >>> seq = c.scatter(iter([1, 2, 3])) # doctest: +SKIP - >>> next(seq) # doctest: +SKIP - , - Broadcast data to all workers >>> [future] = c.scatter([element], broadcast=True) # doctest: +SKIP @@ -2041,38 +1955,26 @@ def scatter( if timeout == no_default: timeout = self._timeout if isqueue(data) or isinstance(data, Iterator): - logger.debug("Starting thread for streaming data") - qout = pyQueue(maxsize=maxsize) - - t = threading.Thread( - target=self._threaded_scatter, - name="Threaded scatter()", - args=(data, qout), - kwargs={"workers": workers, "broadcast": broadcast}, + raise TypeError( + "Dask no longer supports mapping over Iterators or Queues." + "Consider using a normal for loop and Client.submit" ) - t.daemon = True - t.start() - if isqueue(data): - return qout - else: - return queue_to_iterator(qout) + if hasattr(thread_state, "execution_state"): # within worker task + local_worker = thread_state.execution_state["worker"] else: - if hasattr(thread_state, "execution_state"): # within worker task - local_worker = thread_state.execution_state["worker"] - else: - local_worker = None - return self.sync( - self._scatter, - data, - workers=workers, - broadcast=broadcast, - direct=direct, - local_worker=local_worker, - timeout=timeout, - asynchronous=asynchronous, - hash=hash, - ) + local_worker = None + return self.sync( + self._scatter, + data, + workers=workers, + broadcast=broadcast, + direct=direct, + local_worker=local_worker, + timeout=timeout, + asynchronous=asynchronous, + hash=hash, + ) @gen.coroutine def _cancel(self, futures, force=False): diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 5144c0f0868..449d207a91a 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -5,7 +5,6 @@ from collections import deque from concurrent.futures import CancelledError import gc -import itertools import logging import os import pickle @@ -52,7 +51,7 @@ futures_of, temp_default_client, ) -from distributed.compatibility import PY3, Iterator +from distributed.compatibility import PY3 from distributed.metrics import time from distributed.scheduler import Scheduler, KilledWorker @@ -1321,104 +1320,6 @@ def test_directed_scatter_sync(c, s, a, b, loop): assert len(has_what[a["address"]]) == 0 -def test_iterator_scatter(c): - aa = c.scatter([1, 2, 3]) - assert [1, 2, 3] == c.gather(aa) - - g = (i for i in range(10)) - futures = c.scatter(g) - assert isinstance(futures, Iterator) - - a = next(futures) - assert c.gather(a) == 0 - - futures = list(futures) - assert len(futures) == 9 - assert c.gather(futures) == [1, 2, 3, 4, 5, 6, 7, 8, 9] - - -def test_queue_scatter(c): - from distributed.compatibility import Queue - - q = Queue() - for d in range(10): - q.put(d) - - futures = c.scatter(q) - assert isinstance(futures, Queue) - a = futures.get() - assert c.gather(a) == 0 - - -def test_queue_scatter_gather_maxsize(c): - from distributed.compatibility import Queue - - q = Queue(maxsize=3) - out = c.scatter(q, maxsize=10) - assert out.maxsize == 10 - local = c.gather(q) - assert not local.maxsize - - q = Queue() - out = c.scatter(q) - assert not out.maxsize - local = c.gather(out, maxsize=10) - assert local.maxsize == 10 - - q = Queue(maxsize=3) - out = c.scatter(q) - assert not out.maxsize - - -def test_queue_gather(c): - from distributed.compatibility import Queue - - q = Queue() - - qin = list(range(10)) - for d in qin: - q.put(d) - - futures = c.scatter(q) - assert isinstance(futures, Queue) - - ff = c.gather(futures) - assert isinstance(ff, Queue) - - qout = [] - for f in range(10): - qout.append(ff.get()) - assert qout == qin - - -@pytest.mark.skip(reason="intermittent blocking failures") -def test_iterator_gather(c, c2): - i_in = list(range(10)) - - g = (d for d in i_in) - futures = c.scatter(g) - assert isinstance(futures, Iterator) - - ff = c.gather(futures) - assert isinstance(ff, Iterator) - - i_out = list(ff) - assert i_out == i_in - - i_in = ["a", "b", "c", StopIteration("f"), StopIteration, "d", "c"] - - g = (d for d in i_in) - futures = c.scatter(g) - - ff = c.gather(futures) - i_out = list(ff) - assert i_out[:3] == i_in[:3] - # This is because StopIteration('f') != StopIteration('f') - assert isinstance(i_out[3], StopIteration) - assert i_out[3].args == i_in[3].args - assert i_out[4:] == i_in[4:] - - @gen_cluster(client=True) def test_scatter_direct(c, s, a, b): future = yield c.scatter(123, direct=True) @@ -2373,109 +2274,6 @@ def test_traceback_clean(c, s, a, b): tb = tb.tb_next -@gen_cluster(client=True) -def test_map_queue(c, s, a, b): - from distributed.compatibility import Queue, isqueue - - q_1 = Queue(maxsize=2) - q_2 = c.map(inc, q_1) - assert isqueue(q_2) - assert not q_2.maxsize - q_3 = c.map(double, q_2, maxsize=3) - assert isqueue(q_3) - assert q_3.maxsize == 3 - q_4 = yield c._gather(q_3) - assert isqueue(q_4) - - q_1.put(1) - - f = q_4.get() - assert isinstance(f, Future) - result = yield f - assert result == (1 + 1) * 2 - - -@pytest.mark.skipif( - sys.version_info >= (3, 7), reason="replace StopIteration with return" -) -@gen_cluster(client=True) -def test_map_iterator_with_return(c, s, a, b): - def g(): - yield 1 - yield 2 - raise StopIteration(3) # py2.7 compat. - - f1 = c.map(lambda x: x, g()) - assert isinstance(f1, Iterator) - - start = time() # ensure that we compute eagerly - while not s.tasks: - yield gen.sleep(0.01) - assert time() < start + 5 - - g1 = g() - try: - while True: - f = next(f1) - n = yield f - assert n == next(g1) - except StopIteration as e: - with pytest.raises(StopIteration) as exc_info: - next(g1) - assert e.args == exc_info.value.args - - -@gen_cluster(client=True) -def test_map_iterator(c, s, a, b): - x = iter([1, 2, 3]) - y = iter([10, 20, 30]) - f1 = c.map(add, x, y) - assert isinstance(f1, Iterator) - - start = time() # ensure that we compute eagerly - while not s.tasks: - yield gen.sleep(0.01) - assert time() < start + 5 - - f2 = c.map(double, f1) - assert isinstance(f2, Iterator) - - future = next(f2) - result = yield future - assert result == (1 + 10) * 2 - futures = list(f2) - results = [] - for f in futures: - r = yield f - results.append(r) - assert results == [(2 + 20) * 2, (3 + 30) * 2] - - items = enumerate(range(10)) - futures = c.map(lambda x: x, items) - assert isinstance(futures, Iterator) - - result = yield next(futures) - assert result == (0, 0) - futures_l = list(futures) - results = [] - for f in futures_l: - r = yield f - results.append(r) - assert results == [(i, i) for i in range(1, 10)] - - -@gen_cluster(client=True) -def test_map_infinite_iterators(c, s, a, b): - futures = c.map(add, [1, 2], itertools.repeat(10)) - assert len(futures) == 2 - - -def test_map_iterator_sync(c): - items = enumerate(range(10)) - futures = c.map(lambda x: x, items) - next(futures).result() == (0, 0) - - @gen_cluster(client=True) def test_map_differnet_lengths(c, s, a, b): assert len(c.map(add, [1, 2], [1, 2, 3])) == 2 @@ -3557,7 +3355,7 @@ def test_get_stops_work_after_error(c): def test_as_completed_list(c): - seq = c.map(inc, iter(range(5))) + seq = c.map(inc, range(5)) seq2 = list(as_completed(seq)) assert set(c.gather(seq2)) == {1, 2, 3, 4, 5} diff --git a/distributed/tests/test_utils.py b/distributed/tests/test_utils.py index f4423d26e4a..b82dce4e7d9 100644 --- a/distributed/tests/test_utils.py +++ b/distributed/tests/test_utils.py @@ -14,7 +14,7 @@ from tornado.ioloop import IOLoop import dask -from distributed.compatibility import Queue, Empty, isqueue, PY2, Iterator +from distributed.compatibility import Queue, Empty, PY2 from distributed.metrics import time from distributed.utils import ( All, @@ -24,8 +24,6 @@ str_graph, truncate_exception, get_traceback, - queue_to_iterator, - iterator_to_queue, _maybe_complex, read_block, seek_delimiter, @@ -183,25 +181,6 @@ def c(x): assert type(tb).__name__ == "traceback" -def test_queue_to_iterator(): - q = Queue() - q.put(1) - q.put(2) - - seq = queue_to_iterator(q) - assert isinstance(seq, Iterator) - assert next(seq) == 1 - assert next(seq) == 2 - - -def test_iterator_to_queue(): - seq = iter([1, 2, 3]) - - q = iterator_to_queue(seq) - assert isqueue(q) - assert q.get() == 1 - - def test_str_graph(): dsk = {"x": 1} assert str_graph(dsk) == dsk diff --git a/distributed/utils.py b/distributed/utils.py index 765035e5c13..6debdedd24e 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -46,7 +46,7 @@ except ImportError: PollIOLoop = None # dropped in tornado 6.0 -from .compatibility import Queue, PY3, PY2, get_thread_identity, unicode +from .compatibility import PY3, PY2, get_thread_identity, unicode from .metrics import time @@ -797,42 +797,6 @@ def truncate_exception(e, n=10000): return e -if sys.version_info >= (3,): - # (re-)raising StopIteration is deprecated in 3.6+ - exec( - """def queue_to_iterator(q): - while True: - result = q.get() - if isinstance(result, StopIteration): - return result.value - yield result - """ - ) -else: - # Returning non-None from generator is a syntax error in 2.x - def queue_to_iterator(q): - while True: - result = q.get() - if isinstance(result, StopIteration): - raise result - yield result - - -def _dump_to_queue(seq, q): - for item in seq: - q.put(item) - - -def iterator_to_queue(seq, maxsize=0): - q = Queue(maxsize=maxsize) - - t = threading.Thread(target=_dump_to_queue, args=(seq, q)) - t.daemon = True - t.start() - - return q - - def tokey(o): """ Convert an object to a string. diff --git a/distributed/utils_test.py b/distributed/utils_test.py index e40912fcb33..7b90745ac6d 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -399,7 +399,7 @@ def map_varying(itemslists): def apply(func, *args, **kwargs): return func(*args, **kwargs) - return apply, map(varying, itemslists) + return apply, list(map(varying, itemslists)) @gen.coroutine diff --git a/docs/source/queues.rst b/docs/source/queues.rst index ea1bc76a4f7..34a4cae538a 100644 --- a/docs/source/queues.rst +++ b/docs/source/queues.rst @@ -1,194 +1,10 @@ Data Streams with Queues ======================== -The ``Client`` methods ``scatter``, ``map``, and ``gather`` can consume and -produce standard Python ``Queue`` objects. This is useful for processing -continuous streams of data. However, it does not constitute a full streaming -data processing pipeline like Storm. +This feature is no longer supported. +Instead people may want to look at the following options: -.. raw:: html - - - -Example -------- - -We connect to a local Client. - -.. code-block:: python - - >>> from distributed import Client - >>> client = Client('127.0.0.1:8786') - >>> client - - -We build a couple of toy data processing functions: - -.. code-block:: python - - from time import sleep - from random import random - - def inc(x): - from random import random - sleep(random() * 2) - return x + 1 - - def double(x): - from random import random - sleep(random()) - return 2 * x - -And we set up an input Queue and map our functions across it. - -.. code-block:: python - - >>> from queue import Queue - >>> input_q = Queue() - >>> remote_q = client.scatter(input_q) - >>> inc_q = client.map(inc, remote_q) - >>> double_q = client.map(double, inc_q) - -We will fill the ``input_q`` with local data from some stream, and then -``remote_q``, ``inc_q`` and ``double_q`` will fill with ``Future`` objects as -data gets moved around. - -We gather the futures from the ``double_q`` back to a queue holding local -data in the local process. - -.. code-block:: python - - >>> result_q = client.gather(double_q) - -Insert Data Manually -~~~~~~~~~~~~~~~~~~~~ - -Because we haven't placed any data into any of the queues everything is empty, -including the final output, ``result_q``. - -.. code-block:: python - - >>> result_q.qsize() - 0 - -But when we insert an entry into the ``input_q``, it starts to make its way -through the pipeline and ends up in the ``result_q``. - -.. code-block:: python - - >>> input_q.put(10) - >>> result_q.get() - 22 - -Insert data in a separate thread -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -We simulate a slightly more realistic situation by dumping data into the -``input_q`` in a separate thread. This simulates what you might get if you -were to read from an active data source. - -.. code-block:: python - - def load_data(q): - i = 0 - while True: - q.put(i) - sleep(random()) - i += 1 - - >>> from threading import Thread - >>> load_thread = Thread(target=load_data, args=(input_q,)) - >>> load_thread.start() - - >>> result_q.qsize() - 4 - >>> result_q.qsize() - 9 - -We consume data from the ``result_q`` and print results to the screen. - -.. code-block:: python - - >>> while True: - ... item = result_q.get() - ... print(item) - 2 - 4 - 6 - 8 - 10 - 12 - ... - -Limitations ------------ - -* This doesn't do any sort of auto-batching of computations, so ideally you - batch your data to take significantly longer than 1ms to run. -* This isn't a proper streaming system. There is no support outside of what - you see here. In particular there are no policies for dropping data, joining - over time windows, etc.. - -Extensions ----------- - -We can extend this small example to more complex systems that have buffers, -split queues, merge queues, etc. all by manipulating normal Python Queues. - -Here are a couple of useful function to multiplex and merge queues: - -.. code-block:: python - - from queue import Queue - from threading import Thread - - def multiplex(n, q, **kwargs): - """ Convert one queue into several equivalent Queues - - >>> q1, q2, q3 = multiplex(3, in_q) - """ - out_queues = [Queue(**kwargs) for i in range(n)] - def f(): - while True: - x = q.get() - for out_q in out_queues: - out_q.put(x) - t = Thread(target=f) - t.daemon = True - t.start() - return out_queues - - def push(in_q, out_q): - while True: - x = in_q.get() - out_q.put(x) - - def merge(*in_qs, **kwargs): - """ Merge multiple queues together - - >>> out_q = merge(q1, q2, q3) - """ - out_q = Queue(**kwargs) - threads = [Thread(target=push, args=(q, out_q)) for q in in_qs] - for t in threads: - t.daemon = True - t.start() - return out_q - -With useful functions like these we can build out more sophisticated data -processing pipelines that split off and join back together. By creating queues -with ``maxsize=`` we can control buffering and apply back pressure. - -.. raw:: html - - +1. Use normal for loops with Client.submit/gather and as_completed +2. Use :doc:`asynchronous async/await ` code and a few coroutines +3. Try out the `Streamz `_ project, + which has Dask support From bb80c5e997d020c5b472819fd0751c8234d2b8cc Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 13 May 2019 18:19:59 -0500 Subject: [PATCH 0287/1550] Use 'temporary-directory' from dask.config for Worker's directory (#2654) --- distributed/tests/test_worker.py | 9 +++++++++ distributed/worker.py | 8 +++++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index d8ca4d31481..4541e183e46 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1383,6 +1383,15 @@ def __init__(self, x, y): yield w.close() +@gen_cluster(ncores=[]) +def test_local_dir(s): + with tmpfile() as fn: + with dask.config.set(temporary_directory=fn): + w = yield Worker(s.address) + assert w.local_dir.startswith(fn) + assert "dask-worker-space" in w.local_dir + + @pytest.mark.skipif( not sys.platform.startswith("linux"), reason="Need 127.0.0.2 to mean localhost" ) diff --git a/distributed/worker.py b/distributed/worker.py index 8b8c139f356..3b5e1cb233a 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -282,7 +282,7 @@ def __init__( scheduler_file=None, ncores=None, loop=None, - local_dir="dask-worker-space", + local_dir=None, services=None, service_ports=None, service_kwargs=None, @@ -448,6 +448,12 @@ def __init__( if silence_logs: silence_logging(level=silence_logs) + if local_dir is None: + local_dir = dask.config.get("temporary-directory") or os.getcwd() + if not os.path.exists(local_dir): + os.mkdir(local_dir) + local_dir = os.path.join(local_dir, "dask-worker-space") + with warn_on_duration( "1s", "Creating scratch directories is taking a surprisingly long time. " From fd31ecca8017bae845a73d468de0376c02363fab Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 14 May 2019 12:49:49 -0500 Subject: [PATCH 0288/1550] Cleanup localcluster (#2693) * Remove address handling (handled in scheduler) * Move ip= keyword to host= --- distributed/deploy/local.py | 57 +++++++++++--------------- distributed/deploy/tests/test_local.py | 14 ++++--- 2 files changed, 33 insertions(+), 38 deletions(-) diff --git a/distributed/deploy/local.py b/distributed/deploy/local.py index fb8793d0840..832e8f3e051 100644 --- a/distributed/deploy/local.py +++ b/distributed/deploy/local.py @@ -15,7 +15,6 @@ from ..compatibility import get_thread_identity from ..core import CommClosedError from ..utils import ( - get_ip_interface, sync, ignoring, All, @@ -51,8 +50,10 @@ class LocalCluster(Cluster): silence_logs: logging level Level of logs to print out to stdout. ``logging.WARN`` by default. Use a falsey value like False or None for no change. + host: string + Host address on which the scheduler will listen, defaults to only localhost ip: string - IP address on which the scheduler will listen, defaults to only localhost + Deprecated. See ``host`` above. dashboard_address: str Address on which to listen for the Bokeh diagnostics server like 'localhost:8787' or '0.0.0.0:8787'. Defaults to ':8787'. @@ -88,13 +89,9 @@ class LocalCluster(Cluster): >>> c = Client(cluster) # connect to local cluster # doctest: +SKIP - Add a new worker to the cluster + Scale the cluster to three workers - >>> w = cluster.start_worker(ncores=2) # doctest: +SKIP - - Shut down the extra worker - - >>> cluster.stop_worker(w) # doctest: +SKIP + >>> cluster.scale(3) # doctest: +SKIP Pass extra keyword arguments to Bokeh @@ -109,6 +106,7 @@ def __init__( loop=None, start=None, ip=None, + host=None, scheduler_port=0, silence_logs=logging.WARN, dashboard_address=":8787", @@ -125,6 +123,10 @@ def __init__( worker_class=None, **worker_kwargs ): + if ip is not None: + warnings.warn("The ip keyword has been moved to host") + host = ip + if start is not None: msg = ( "The start= parameter is deprecated. " @@ -145,8 +147,8 @@ def __init__( self.processes = processes if protocol is None: - if ip and "://" in ip: - protocol = ip.split("://")[0] + if host and "://" in host: + protocol = host.split("://")[0] elif security: protocol = "tls://" elif not self.processes and not scheduler_port: @@ -155,12 +157,12 @@ def __init__( protocol = "tcp://" if not protocol.endswith("://"): protocol = protocol + "://" - self.protocol = protocol + + if host is None and not protocol.startswith("inproc") and not interface: + host = "127.0.0.1" self.silence_logs = silence_logs self._asynchronous = asynchronous - self.security = security - self.interface = interface services = services or {} worker_services = worker_services or {} if silence_logs: @@ -184,6 +186,8 @@ def __init__( "ncores": threads_per_worker, "services": worker_services, "dashboard_address": worker_dashboard_address, + "interface": interface, + "protocol": protocol, } ) @@ -192,14 +196,16 @@ def __init__( self.scheduler = Scheduler( loop=self.loop, + host=host, services=services, service_kwargs=service_kwargs, security=security, + port=scheduler_port, interface=interface, + protocol=protocol, dashboard_address=dashboard_address, blocked_handlers=blocked_handlers, ) - self.scheduler_port = scheduler_port self.workers = [] self.worker_kwargs = worker_kwargs @@ -210,7 +216,7 @@ def __init__( worker_class = Worker if not processes else Nanny self.worker_class = worker_class - self.start(ip=ip, n_workers=n_workers) + self.start(n_workers=n_workers) clusters_to_close.add(self) @@ -251,32 +257,17 @@ def start(self, **kwargs): self.sync(self._start, **kwargs) @gen.coroutine - def _start(self, ip=None, n_workers=0): + def _start(self, n_workers=0): """ Start all cluster services. """ if self.status == "running": return - if self.protocol == "inproc://": - address = self.protocol - else: - if ip is None: - if self.interface: - ip = get_ip_interface(self.interface) - else: - ip = "127.0.0.1" - - if "://" in ip: - address = ip - else: - address = self.protocol + ip - if self.scheduler_port: - address += ":" + str(self.scheduler_port) - - self.scheduler.start(address) + self.scheduler.start() yield [self._start_worker(**self.worker_kwargs) for i in range(n_workers)] + yield self.scheduler self.status = "running" diff --git a/distributed/deploy/tests/test_local.py b/distributed/deploy/tests/test_local.py index 6e1e71e83b2..ed9e3bb2dbe 100644 --- a/distributed/deploy/tests/test_local.py +++ b/distributed/deploy/tests/test_local.py @@ -128,7 +128,7 @@ def test_move_unserializable_data(): assert y.result() is lock -def test_transports(): +def test_transports_inproc(): """ Test the transport chosen by LocalCluster depending on arguments. """ @@ -140,6 +140,8 @@ def test_transports(): with Client(c.scheduler.address) as e: assert e.submit(inc, 4).result() == 5 + +def test_transports_tcp(): # Have nannies => need TCP with LocalCluster( 1, processes=True, silence_logs=False, dashboard_address=None @@ -149,6 +151,8 @@ def test_transports(): with Client(c.scheduler.address) as e: assert e.submit(inc, 4).result() == 5 + +def test_transports_tcp_port(): # Scheduler port specified => need TCP with LocalCluster( 1, @@ -417,7 +421,7 @@ def test_remote_access(loop): scheduler_port=0, silence_logs=False, dashboard_address=None, - ip="", + host="", loop=loop, ) as c: sync(loop, assert_can_connect_from_everywhere_4_6, c.scheduler.port) @@ -620,7 +624,7 @@ def test_local_tls(loop): silence_logs=False, security=security, dashboard_address=False, - ip="tls://0.0.0.0", + host="tls://0.0.0.0", loop=loop, ) as c: sync( @@ -690,7 +694,7 @@ def test_local_tls_restart(loop): silence_logs=False, security=security, dashboard_address=False, - ip="tls://0.0.0.0", + host="tls://0.0.0.0", loop=loop, ) as c: with Client(c.scheduler.address, loop=loop, security=security) as client: @@ -750,7 +754,7 @@ def test_protocol_tcp(loop): ) def test_protocol_ip(loop): with LocalCluster( - ip="tcp://127.0.0.2", loop=loop, n_workers=0, processes=False + host="tcp://127.0.0.2", loop=loop, n_workers=0, processes=False ) as cluster: assert cluster.scheduler.address.startswith("tcp://127.0.0.2") From 8e449d392e91eff0a3454ee98ef362de8f78cc4f Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 15 May 2019 09:35:43 -0500 Subject: [PATCH 0289/1550] Support computation on delayed(None) (#2697) Previously this conflicted with our sentinel value Fixes #2696 --- distributed/tests/test_collections.py | 10 ++++++++++ distributed/worker.py | 6 +++--- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/distributed/tests/test_collections.py b/distributed/tests/test_collections.py index 985b6f78fe9..dea4296769d 100644 --- a/distributed/tests/test_collections.py +++ b/distributed/tests/test_collections.py @@ -6,6 +6,7 @@ pytest.importorskip("numpy") pytest.importorskip("pandas") +import dask import dask.dataframe as dd import dask.bag as db from distributed.client import wait @@ -185,3 +186,12 @@ def test_sparse_arrays(c, s, a, b): future = c.compute(s.sum(axis=0)[:10]) yield future + + +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)]) +def test_delayed_none(c, s, w): + x = dask.delayed(None) + y = dask.delayed(123) + [xx, yy] = c.compute([x, y]) + assert (yield xx) is None + assert (yield yy) == 123 diff --git a/distributed/worker.py b/distributed/worker.py index 3b5e1cb233a..f4189393091 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1216,7 +1216,7 @@ def add_task( function=None, args=None, kwargs=None, - task=None, + task=no_value, who_has=None, nbytes=None, priority=None, @@ -3007,7 +3007,7 @@ def get_data_from_worker( job_counter = [0] -def _deserialize(function=None, args=None, kwargs=None, task=None): +def _deserialize(function=None, args=None, kwargs=None, task=no_value): """ Deserialize task inputs and regularize to func, args, kwargs """ if function is not None: function = pickle.loads(function) @@ -3016,7 +3016,7 @@ def _deserialize(function=None, args=None, kwargs=None, task=None): if kwargs: kwargs = pickle.loads(kwargs) - if task is not None: + if task is not no_value: assert not function and not args and not kwargs function = execute_task args = (task,) From 73362eaf4bb1f941284658ae87f51dbaae753a4c Mon Sep 17 00:00:00 2001 From: Daniel Farrell Date: Wed, 15 May 2019 10:04:30 -0700 Subject: [PATCH 0290/1550] Add method to wait for n workers before continuing (#2688) --- distributed/client.py | 11 +++++++++++ distributed/tests/test_client.py | 13 +++++++++++++ 2 files changed, 24 insertions(+) diff --git a/distributed/client.py b/distributed/client.py index 7b6e14aa4ef..8dd85d7795e 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -1051,6 +1051,17 @@ def _update_scheduler_info(self): except EnvironmentError: logger.debug("Not able to query scheduler for identity") + @gen.coroutine + def _wait_for_workers(self, n_workers=0): + info = yield self.scheduler.identity() + while n_workers and len(info["workers"]) < n_workers: + yield gen.sleep(0.1) + info = yield self.scheduler.identity() + + def wait_for_workers(self, n_workers=0): + """Blocking call to wait for n workers before continuing""" + return self.sync(self._wait_for_workers, n_workers) + def _heartbeat(self): if self.scheduler_comm: self.scheduler_comm.send({"op": "heartbeat-client"}) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 449d207a91a..dfe677ddd1e 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -5510,5 +5510,18 @@ def test_instances(c, s, a, b): assert set(Worker._instances) == {a, b} +@gen_cluster(client=True) +def test_wait_for_workers(c, s, a, b): + future = c.wait_for_workers(n_workers=3) + yield gen.sleep(0.22) # 2 chances + assert not future.done() + + w = yield Worker(s.address) + start = time() + yield future + assert time() < start + 1 + yield w.close() + + if sys.version_info >= (3, 5): from distributed.tests.py3_test_client import * # noqa F401 From 0ce8f2bcb84d306e9a095d75497857dce30145b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Dugr=C3=A9?= Date: Wed, 15 May 2019 15:56:44 -0400 Subject: [PATCH 0291/1550] Modify styling of histograms for many-worker dashboard plots (#2695) Fixes #2691 --- distributed/bokeh/scheduler.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/distributed/bokeh/scheduler.py b/distributed/bokeh/scheduler.py index 2dd60f0690f..c078b612dd9 100644 --- a/distributed/bokeh/scheduler.py +++ b/distributed/bokeh/scheduler.py @@ -219,9 +219,10 @@ def __init__(self, scheduler, **kwargs): ) self.root = figure( - title="Tasks Processing", + title="Tasks Processing (Histogram)", id="bk-nprocessing-histogram-plot", name="processing_hist", + y_axis_label="frequency", **kwargs ) @@ -237,7 +238,8 @@ def __init__(self, scheduler, **kwargs): right="right", bottom=0, top="top", - color="blue", + color="deepskyblue", + fill_alpha=0.5, ) @without_property_validation @@ -259,11 +261,13 @@ def __init__(self, scheduler, **kwargs): ) self.root = figure( - title="Bytes Stored", + title="Bytes Stored (Histogram)", name="nbytes_hist", id="bk-nbytes-histogram-plot", + y_axis_label="frequency", **kwargs ) + self.root.xaxis[0].formatter = NumeralTickFormatter(format="0.0 b") self.root.xaxis.major_label_orientation = -math.pi / 12 @@ -279,7 +283,8 @@ def __init__(self, scheduler, **kwargs): right="right", bottom=0, top="top", - color="blue", + color="deepskyblue", + fill_alpha=0.5, ) @without_property_validation @@ -289,7 +294,7 @@ def update(self): d = {"left": x[:-1], "right": x[1:], "top": counts} self.source.data.update(d) - self.root.title.text = "Bytes stored: " + format_bytes(nbytes.sum()) + self.root.title.text = "Bytes stored (Histogram): " + format_bytes(nbytes.sum()) class CurrentLoad(DashboardComponent): From d4f478672137481ad43abc6a44383a595e8485b5 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 16 May 2019 08:23:42 -0500 Subject: [PATCH 0292/1550] Handle heartbeat when worker has just left (#2702) --- distributed/scheduler.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 446b479c769..1d6a41a9acc 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1353,6 +1353,9 @@ def heartbeat_worker( address = self.coerce_address(address, resolve_address) address = normalize_address(address) host = get_address_host(address) + if address not in self.workers: + logger.info("Received heartbeat from removed worker: %s", address) + return local_now = time() now = now or time() From 4feb90d7ddd4860648c161c1497bb55da0fc1b2a Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 16 May 2019 08:23:54 -0500 Subject: [PATCH 0293/1550] Except errors in Nanny's memory monitor if process no longer exists (#2701) --- distributed/nanny.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/distributed/nanny.py b/distributed/nanny.py index 7be630cb4e9..842ec765d7f 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -326,9 +326,9 @@ def memory_monitor(self): return try: proc = psutil.Process(process.pid) - except psutil.NoSuchProcess: + memory = proc.memory_info().rss + except (ProcessLookupError, psutil.NoSuchProcess, psutil.AccessDenied): return - memory = proc.memory_info().rss frac = memory / self.memory_limit if self.memory_terminate_fraction and frac > self.memory_terminate_fraction: logger.warning( From d9a0897cd3abc6f0c921f50b5e12dbb10fc5aac5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Dugr=C3=A9?= Date: Thu, 16 May 2019 16:25:15 -0400 Subject: [PATCH 0294/1550] Disable pan tool for the Progress, Byte Stored and Tasks Processing plot (#2703) --- distributed/bokeh/scheduler.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/distributed/bokeh/scheduler.py b/distributed/bokeh/scheduler.py index c078b612dd9..471f93dd4b5 100644 --- a/distributed/bokeh/scheduler.py +++ b/distributed/bokeh/scheduler.py @@ -223,6 +223,7 @@ def __init__(self, scheduler, **kwargs): id="bk-nprocessing-histogram-plot", name="processing_hist", y_axis_label="frequency", + tools="", **kwargs ) @@ -265,6 +266,7 @@ def __init__(self, scheduler, **kwargs): name="nbytes_hist", id="bk-nbytes-histogram-plot", y_axis_label="frequency", + tools="", **kwargs ) @@ -943,6 +945,7 @@ def __init__(self, scheduler, **kwargs): x_range=x_range, y_range=y_range, toolbar_location=None, + tools="", **kwargs ) self.root.line( # just to define early ranges From a42721656418ff4848c44e230b8f033d2db58a63 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 16 May 2019 16:36:07 -0500 Subject: [PATCH 0295/1550] Cleanly stop periodic callbacks in Client (#2705) Previously we did this only in the asynchronous code, which left a gap during which a heartbeat could sneak out. Now we call it explicitly at the beginning of the synchronous close command. --- distributed/client.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 8dd85d7795e..4aab8102d0c 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -1204,10 +1204,11 @@ def _close(self, fast=False): """ Send close signal and wait until scheduler completes """ self.status = "closing" + for pc in self._periodic_callbacks.values(): + pc.stop() + with log_errors(): _del_global_client(self) - for pc in self._periodic_callbacks.values(): - pc.stop() self._scheduler_identity = {} with ignoring(AttributeError): # clear the dask.config set keys @@ -1289,6 +1290,9 @@ def close(self, timeout=no_default): return self.status = "closing" + for pc in self._periodic_callbacks.values(): + pc.stop() + if self.asynchronous: future = self._close() if timeout: From 7ebe65980e7fef90fd25cc0d35e2fcfc0c266881 Mon Sep 17 00:00:00 2001 From: Benjamin Zaitlen Date: Thu, 16 May 2019 17:36:36 -0400 Subject: [PATCH 0296/1550] Change the main workers bokeh page to /status (#2689) This matches the behavior on the scheduler --- distributed/bokeh/scheduler.py | 4 ++-- distributed/bokeh/worker.py | 16 ++++++++-------- distributed/bokeh/worker_html.py | 11 ++++++++++- distributed/cli/tests/test_dask_worker.py | 4 +++- docs/source/diagnosing-performance.rst | 2 +- 5 files changed, 24 insertions(+), 13 deletions(-) diff --git a/distributed/bokeh/scheduler.py b/distributed/bokeh/scheduler.py index 471f93dd4b5..cce94b356ae 100644 --- a/distributed/bokeh/scheduler.py +++ b/distributed/bokeh/scheduler.py @@ -152,7 +152,7 @@ def __init__(self, scheduler, **kwargs): # fig.xaxis[0].formatter = NumeralTickFormatter(format='0.0s') fig.x_range.start = 0 - tap = TapTool(callback=OpenURL(url="http://@bokeh_address/main")) + tap = TapTool(callback=OpenURL(url="http://@bokeh_address/")) hover = HoverTool() hover.tooltips = "@worker : @occupancy s." @@ -368,7 +368,7 @@ def __init__(self, scheduler, width=600, **kwargs): fig.yaxis.visible = False fig.ygrid.visible = False - tap = TapTool(callback=OpenURL(url="http://@bokeh_address/main")) + tap = TapTool(callback=OpenURL(url="http://@bokeh_address/")) fig.add_tools(tap) fig.toolbar.logo = None diff --git a/distributed/bokeh/worker.py b/distributed/bokeh/worker.py index c7ced4d90fc..ed7b68b76b4 100644 --- a/distributed/bokeh/worker.py +++ b/distributed/bokeh/worker.py @@ -51,7 +51,7 @@ BOKEH_THEME = Theme(os.path.join(os.path.dirname(__file__), "theme.yaml")) -template_variables = {"pages": ["main", "system", "profile", "crossfilter"]} +template_variables = {"pages": ["status", "system", "profile", "crossfilter"]} class StateTable(DashboardComponent): @@ -410,9 +410,9 @@ def process_msg(self, msg): def func(k): return msg["keys"].get(k, 0) - main_key = max(msg["keys"], key=func) - typ = self.worker.types.get(main_key, object).__name__ - keyname = key_split(main_key) + status_key = max(msg["keys"], key=func) + typ = self.worker.types.get(status_key, object).__name__ + keyname = key_split(status_key) d = { "nbytes": msg["total"], "duration": msg["duration"], @@ -659,7 +659,7 @@ def update(self): from bokeh.application import Application -def main_doc(worker, extra, doc): +def status_doc(worker, extra, doc): with log_errors(): statetable = StateTable(worker) executing_ts = ExecutingTimeSeries(worker, sizing_mode="scale_width") @@ -685,7 +685,7 @@ def main_doc(worker, extra, doc): ) ) doc.template = env.get_template("simple.html") - doc.template_variables["active_page"] = "main" + doc.template_variables["active_page"] = "status" doc.template_variables.update(extra) doc.theme = BOKEH_THEME @@ -773,7 +773,7 @@ def __init__(self, worker, io_loop=None, prefix="", **kwargs): extra.update(template_variables) - main = Application(FunctionHandler(partial(main_doc, worker, extra))) + status = Application(FunctionHandler(partial(status_doc, worker, extra))) crossfilter = Application( FunctionHandler(partial(crossfilter_doc, worker, extra)) ) @@ -787,7 +787,7 @@ def __init__(self, worker, io_loop=None, prefix="", **kwargs): ) self.apps = { - "/main": main, + "/status": status, "/counters": counters, "/crossfilter": crossfilter, "/system": systemmonitor, diff --git a/distributed/bokeh/worker_html.py b/distributed/bokeh/worker_html.py index 3ddf9490c4d..c818c8fb1e6 100644 --- a/distributed/bokeh/worker_html.py +++ b/distributed/bokeh/worker_html.py @@ -67,7 +67,16 @@ def get(self): self.set_header("Content-Type", "text/plain") -routes = [(r"metrics", PrometheusHandler), (r"health", HealthHandler)] +class OldRoute(RequestHandler): + def get(self): + self.redirect("/status") + + +routes = [ + (r"metrics", PrometheusHandler), + (r"health", HealthHandler), + (r"main", OldRoute), +] def get_handlers(server): diff --git a/distributed/cli/tests/test_dask_worker.py b/distributed/cli/tests/test_dask_worker.py index c26c99f2350..5ed668e758a 100644 --- a/distributed/cli/tests/test_dask_worker.py +++ b/distributed/cli/tests/test_dask_worker.py @@ -260,8 +260,10 @@ def test_bokeh_non_standard_ports(loop): start = time() while True: try: - response = requests.get("http://127.0.0.1:4833/main") + response = requests.get("http://127.0.0.1:4833/status") assert response.ok + redirect_resp = requests.get("http://127.0.0.1:4833/main") + redirect_resp.ok break except Exception: sleep(0.5) diff --git a/docs/source/diagnosing-performance.rst b/docs/source/diagnosing-performance.rst index 28d7d9aba44..773a5d2316b 100644 --- a/docs/source/diagnosing-performance.rst +++ b/docs/source/diagnosing-performance.rst @@ -105,7 +105,7 @@ attributes including 4. Keys moved 5. Peer -These are made available to users through the ``/main`` page of the Worker's +These are made available to users through the ``/status`` page of the Worker's diagnostic dashboard. You can capture their state explicitly by running a command on the workers: From fc48c435f8f366c335c92dd6fc58af38065edcec Mon Sep 17 00:00:00 2001 From: Sam Grayson Date: Fri, 17 May 2019 17:53:12 -0500 Subject: [PATCH 0297/1550] Support uploading files with multiple modules (#2587) --- distributed/tests/test_client.py | 96 ++++++++++++++++++++++++-------- distributed/utils.py | 15 ++--- distributed/utils_test.py | 15 +++++ 3 files changed, 94 insertions(+), 32 deletions(-) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index dfe677ddd1e..4cd196fa2f4 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -9,6 +9,7 @@ import os import pickle import random +import subprocess import sys import threading from threading import Semaphore @@ -79,6 +80,7 @@ wait_for, async_wait_for, pristine_loop, + save_sys_modules, ) from distributed.utils_test import ( # noqa: F401 client as c, @@ -1484,7 +1486,7 @@ def g(): return myfile.f() - try: + with save_sys_modules(): for value in [123, 456]: with tmp_text("myfile.py", "def f():\n return {}".format(value)) as fn: yield c.upload_file(fn) @@ -1492,10 +1494,6 @@ def g(): x = c.submit(g, pure=False) result = yield x assert result == value - finally: - # Ensure that this test won't impact the others - if "myfile" in sys.modules: - del sys.modules["myfile"] @gen_cluster(client=True) @@ -1511,28 +1509,80 @@ def g(): return myfile.f() - try: + with save_sys_modules(): + try: + for value in [123, 456]: + with tmp_text( + "myfile.py", "def f():\n return {}".format(value) + ) as fn_my_file: + with zipfile.ZipFile("myfile.zip", "w") as z: + z.write(fn_my_file, arcname=os.path.basename(fn_my_file)) + yield c.upload_file("myfile.zip") + + x = c.submit(g, pure=False) + result = yield x + assert result == value + finally: + if os.path.exists("myfile.zip"): + os.remove("myfile.zip") + + +@gen_cluster(client=True) +def test_upload_file_egg(c, s, a, b): + def g(): + import package_1, package_2 + + return package_1.a, package_2.b + + # c.upload_file tells each worker to + # - put this file in their local_dir + # - modify their sys.path to include it + # we don't care about the local_dir + # but we do care about restoring the path + + with save_sys_modules(): for value in [123, 456]: - with tmp_text( - "myfile.py", "def f():\n return {}".format(value) - ) as fn_my_file: - with zipfile.ZipFile("myfile.zip", "w") as z: - z.write(fn_my_file, arcname=os.path.basename(fn_my_file)) - yield c.upload_file("myfile.zip") + with tmpfile() as dirname: + os.mkdir(dirname) + + with open(os.path.join(dirname, "setup.py"), "w") as f: + f.write("from setuptools import setup, find_packages\n") + f.write( + 'setup(name="my_package", packages=find_packages(), version="{}")\n'.format( + value + ) + ) + + # test a package with an underscore in the name + package_1 = os.path.join(dirname, "package_1") + os.mkdir(package_1) + with open(os.path.join(package_1, "__init__.py"), "w") as f: + f.write("a = {}\n".format(value)) + + # test multiple top-level packages + package_2 = os.path.join(dirname, "package_2") + os.mkdir(package_2) + with open(os.path.join(package_2, "__init__.py"), "w") as f: + f.write("b = {}\n".format(value)) + + # compile these into an egg + subprocess.check_call( + [sys.executable, "setup.py", "bdist_egg"], cwd=dirname + ) + + egg_root = os.path.join(dirname, "dist") + # first file ending with '.egg' + egg_name = [ + fname for fname in os.listdir(egg_root) if fname.endswith(".egg") + ][0] + egg_path = os.path.join(egg_root, egg_name) + + yield c.upload_file(egg_path) + os.remove(egg_path) x = c.submit(g, pure=False) result = yield x - assert result == value - finally: - # Ensure that this test won't impact the others - if os.path.exists("myfile.zip"): - os.remove("myfile.zip") - if "myfile" in sys.modules: - del sys.modules["myfile"] - for path in sys.path: - if os.path.basename(path) == "myfile.zip": - sys.path.remove(path) - break + assert result == (value, value) @gen_cluster(client=True) diff --git a/distributed/utils.py b/distributed/utils.py index 6debdedd24e..466a96fdfd5 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -23,7 +23,7 @@ import threading import warnings import weakref - +import pkgutil import six import tblib.pickling_support @@ -1066,14 +1066,11 @@ def import_file(path): if ext in (".egg", ".zip", ".pyz"): if path not in sys.path: sys.path.insert(0, path) - if ext == ".egg": - import pkg_resources - - pkgs = pkg_resources.find_distributions(path) - for pkg in pkgs: - names_to_import.append(pkg.project_name) - elif ext in (".zip", ".pyz"): - names_to_import.append(name) + if sys.version_info >= (3, 6): + names = (mod_info.name for mod_info in pkgutil.iter_modules([path])) + else: + names = (mod_info[1] for mod_info in pkgutil.iter_modules([path])) + names_to_import.extend(names) loaded = [] if not names_to_import: diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 7b90745ac6d..c44f4177472 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -1417,6 +1417,21 @@ def gen_tls_cluster(**kwargs): ) +@contextmanager +def save_sys_modules(): + old_modules = sys.modules + old_path = sys.path + try: + yield + finally: + for i, elem in enumerate(sys.path): + if elem not in old_path: + del sys.path[i] + for elem in sys.modules.keys(): + if elem not in old_modules: + del sys.modules[elem] + + @contextmanager def check_thread_leak(): active_threads_start = set(threading._active) From 138842c9769c273b6edc6c086c80009102596198 Mon Sep 17 00:00:00 2001 From: Magnus Nord Date: Mon, 20 May 2019 17:54:37 +0200 Subject: [PATCH 0298/1550] Fix two typos in Pub class docstring (#2714) --- distributed/pubsub.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/distributed/pubsub.py b/distributed/pubsub.py index 5e086492923..f9cf1f6f7c3 100644 --- a/distributed/pubsub.py +++ b/distributed/pubsub.py @@ -234,8 +234,8 @@ class Pub(object): disappear without notice. When using a Pub or Sub from a Client all communications will be routed - through the scheduler. This can cause some performance degredation. Pubs - an Subs only operate at top-speed when they are both on workers. + through the scheduler. This can cause some performance degradation. Pubs + and Subs only operate at top-speed when they are both on workers. Parameters ---------- From 1a96f70d9c2d23d85a8550f5bfdf60c26bb4ed4f Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 20 May 2019 12:12:52 -0500 Subject: [PATCH 0299/1550] Remove special casing of Scikit-Learn BaseEstimator serialization (#2713) Fixes https://github.com/dask/dask/issues/4769 --- distributed/protocol/__init__.py | 8 -------- distributed/protocol/tests/test_sklearn.py | 4 ++++ 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/distributed/protocol/__init__.py b/distributed/protocol/__init__.py index cf1a3df8994..04691ce605d 100644 --- a/distributed/protocol/__init__.py +++ b/distributed/protocol/__init__.py @@ -60,14 +60,6 @@ def _register_arrow(): from . import arrow -@dask_serialize.register_lazy("sklearn") -@dask_deserialize.register_lazy("sklearn") -def _register_sklearn(): - import sklearn.base - - register_generic(sklearn.base.BaseEstimator) - - @dask_serialize.register_lazy("torch") @dask_deserialize.register_lazy("torch") @dask_serialize.register_lazy("torchvision") diff --git a/distributed/protocol/tests/test_sklearn.py b/distributed/protocol/tests/test_sklearn.py index 051a0440f3a..2a3835168ee 100644 --- a/distributed/protocol/tests/test_sklearn.py +++ b/distributed/protocol/tests/test_sklearn.py @@ -7,6 +7,10 @@ from distributed.protocol import serialize, deserialize +@pytest.mark.xfail( + reason="We no longer special-case the BaseEstimator " + "super class. It's hard to guarantee support for all subclasseses" +) def test_basic(): est = sklearn.linear_model.LinearRegression() est.fit([[0, 0], [1, 1], [2, 2]], [0, 1, 2]) From f47ed2e610590c644da052af64a34bffa1552a92 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 21 May 2019 17:45:13 -0500 Subject: [PATCH 0300/1550] Refer to LocalCluster in Client docstring (#2719) --- distributed/client.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 4aab8102d0c..6b5bc2811bb 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -504,12 +504,19 @@ class AllExit(Exception): class Client(Node): - """ Connect to and drive computation on a distributed Dask cluster - - The Client connects users to a dask.distributed compute cluster. It - provides an asynchronous user interface around functions and futures. This - class resembles executors in ``concurrent.futures`` but also allows - ``Future`` objects within ``submit/map`` calls. + """ Connect to and submit computation to a Dask cluster + + The Client connects users to a Dask cluster. It provides an asynchronous + user interface around functions and futures. This class resembles + executors in ``concurrent.futures`` but also allows ``Future`` objects + within ``submit/map`` calls. When a Client is instantiated it takes over + all ``dask.compute`` and ``dask.persist`` calls by default. + + It is also common to create a Client without specifying the scheduler + address , like ``Client()``. In this case the Client creates a + ``LocalCluster`` in the background and connects to that. Any extra + keywords are passed from Client to LocalCluster in this case. See the + LocalCluster documentation for more information. Parameters ---------- From 62f604e7e567a1cc7806226adc5d7f288dc2fbad Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 22 May 2019 15:05:39 -0500 Subject: [PATCH 0301/1550] Add docstring to Scheduler.check_idle_saturated (#2721) --- distributed/scheduler.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 1d6a41a9acc..801c0c849d7 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4357,6 +4357,19 @@ def reschedule(self, key=None, worker=None): ############################## def check_idle_saturated(self, ws, occ=None): + """ Update the status of the idle and saturated state + + The scheduler keeps track of workers that are .. + + - Saturated: have enough work to stay busy + - Idle: do not have enough work to stay busy + + They are considered saturated if they both have enough tasks to occupy + all of their cores, and if the expected runtime of those tasks is large + enough. + + This is useful for load balancing and adaptivity. + """ if self.total_ncores == 0 or ws.status == "closed": return if occ is None: From 28ce1eda0f6ab4940ce4daa1f309b29a496e6834 Mon Sep 17 00:00:00 2001 From: Benjamin Zaitlen Date: Wed, 22 May 2019 16:06:50 -0400 Subject: [PATCH 0302/1550] Proxy worker dashboards from scheduler dashboard (#2715) --- dev-requirements.txt | 1 + distributed/bokeh/proxy.py | 130 ++++++++++++++++++ distributed/bokeh/scheduler.py | 35 +++-- distributed/bokeh/scheduler_html.py | 5 +- distributed/bokeh/templates/task.html | 2 +- distributed/bokeh/templates/worker-table.html | 4 +- distributed/bokeh/templates/workers.html | 1 - .../bokeh/tests/test_scheduler_bokeh.py | 44 +++++- .../bokeh/tests/test_scheduler_bokeh_html.py | 7 +- distributed/cli/tests/test_dask_worker.py | 14 +- distributed/scheduler.py | 1 + 11 files changed, 222 insertions(+), 22 deletions(-) create mode 100644 distributed/bokeh/proxy.py diff --git a/dev-requirements.txt b/dev-requirements.txt index 7d684343ca7..8cc8f7d256d 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -10,3 +10,4 @@ jupyter_client >= 4.4.0 ipykernel >= 4.5.2 pytest >= 3.0.5 prometheus_client >= 0.6.0 +jupyter-server-proxy >= 1.1.0 diff --git a/distributed/bokeh/proxy.py b/distributed/bokeh/proxy.py new file mode 100644 index 00000000000..9353e383112 --- /dev/null +++ b/distributed/bokeh/proxy.py @@ -0,0 +1,130 @@ +import logging + +from tornado import web + +logger = logging.getLogger(__name__) + +try: + from jupyter_server_proxy.handlers import ProxyHandler + + class GlobalProxyHandler(ProxyHandler): + """ + A tornado request handler that proxies HTTP and websockets + from a port to any valid endpoint'. + """ + + def initialize(self, server=None, extra=None): + self.scheduler = server + self.extra = extra or {} + + async def http_get(self, port, host, proxied_path): + # route here first + # incoming URI /proxy/{port}/{host}/{proxied_path} + + self.host = host + + # rewrite uri for jupyter-server-proxy handling + uri = "/proxy/%s/%s" % (str(port), proxied_path) + self.request.uri = uri + + # slash is removed during regex in handler + proxied_path = "/%s" % proxied_path + + worker = "%s:%s" % (self.host, str(port)) + if not check_worker_dashboard_exits(self.scheduler, worker): + msg = "Worker <%s> does not exist" % worker + self.set_status(400) + self.finish(msg) + return + return await self.proxy(port, proxied_path) + + async def open(self, port, host, proxied_path): + # finally, proxy to other address/port + return await self.proxy_open(host, port, proxied_path) + + def post(self, port, proxied_path): + return self.proxy(port, proxied_path) + + def put(self, port, proxied_path): + return self.proxy(port, proxied_path) + + def delete(self, port, proxied_path): + return self.proxy(port, proxied_path) + + def head(self, port, proxied_path): + return self.proxy(port, proxied_path) + + def patch(self, port, proxied_path): + return self.proxy(port, proxied_path) + + def options(self, port, proxied_path): + return self.proxy(port, proxied_path) + + def proxy(self, port, proxied_path): + # router here second + # returns ProxyHandler coroutine + return super().proxy(self.host, port, proxied_path) + + +except ImportError: + logger.info( + "To route to workers diagnostics web server " + "please install jupyter-server-proxy: " + "pip install jupyter-server-proxy" + ) + + class GlobalProxyHandler(web.RequestHandler): + """Minimal Proxy handler when jupyter-server-proxy is not installed + """ + + def initialize(self, server=None, extra=None): + self.server = server + self.extra = extra or {} + + def get(self, port, host, proxied_path): + worker_url = "%s:%s/%s" % (host, str(port), proxied_path) + msg = """ +

        Try navigating to %s for your worker dashboard

        + +

        + Dask tried to proxy you to that page through your + Scheduler's dashboard connection, but you don't have + jupyter-server-proxy installed. You may want to install it + with either conda or pip, and then restart your scheduler. +

        + +

         conda install jupyter-server-proxy -c conda-forge 

        +

         pip install jupyter-server-proxy

        + +

        + The link above should work though if your workers are on a + sufficiently open network. This is common on single machines, + but less common in production clusters. Your IT administrators + will know more +

        + """ % ( + worker_url, + worker_url, + ) + self.write(msg) + + +def check_worker_dashboard_exits(scheduler, worker): + """Check addr:port exists as a worker in scheduler list + + Parameters + ---------- + worker : str + addr:port + + Returns + ------- + bool + """ + addr, port = worker.split(":") + workers = list(scheduler.workers.values()) + for w in workers: + bokeh_port = w.services.get("bokeh", "") + if addr == w.host and port == str(bokeh_port): + return True + return False diff --git a/distributed/bokeh/scheduler.py b/distributed/bokeh/scheduler.py index cce94b356ae..e0f5bfffab6 100644 --- a/distributed/bokeh/scheduler.py +++ b/distributed/bokeh/scheduler.py @@ -130,7 +130,8 @@ def __init__(self, scheduler, **kwargs): "y": [1, 2], "ms": [1, 2], "color": ["red", "blue"], - "bokeh_address": ["", ""], + "dashboard_port": ["", ""], + "dashboard_host": ["", ""], } ) @@ -152,7 +153,9 @@ def __init__(self, scheduler, **kwargs): # fig.xaxis[0].formatter = NumeralTickFormatter(format='0.0s') fig.x_range.start = 0 - tap = TapTool(callback=OpenURL(url="http://@bokeh_address/")) + tap = TapTool( + callback=OpenURL(url="./proxy/@dashboard_port/@dashboard_host/status") + ) hover = HoverTool() hover.tooltips = "@worker : @occupancy s." @@ -166,10 +169,8 @@ def update(self): with log_errors(): workers = list(self.scheduler.workers.values()) - bokeh_addresses = [] - for ws in workers: - addr = self.scheduler.get_worker_service_addr(ws.address, "bokeh") - bokeh_addresses.append("%s:%d" % addr if addr is not None else "") + dashboard_host = [ws.host for ws in workers] + dashboard_port = [ws.services.get("bokeh", "") for ws in workers] y = list(range(len(workers))) occupancy = [ws.occupancy for ws in workers] @@ -199,7 +200,8 @@ def update(self): "worker": [ws.address for ws in workers], "ms": ms, "color": color, - "bokeh_address": bokeh_addresses, + "dashboard_host": dashboard_host, + "dashboard_port": dashboard_port, "x": x, "y": y, } @@ -317,7 +319,8 @@ def __init__(self, scheduler, width=600, **kwargs): "worker": ["a", "b"], "y": [1, 2], "nbytes-color": ["blue", "blue"], - "bokeh_address": ["", ""], + "dashboard_port": ["", ""], + "dashboard_host": ["", ""], } ) @@ -368,7 +371,11 @@ def __init__(self, scheduler, width=600, **kwargs): fig.yaxis.visible = False fig.ygrid.visible = False - tap = TapTool(callback=OpenURL(url="http://@bokeh_address/")) + tap = TapTool( + callback=OpenURL( + url="./proxy/@dashboard_port/@dashboard_host/status" + ) + ) fig.add_tools(tap) fig.toolbar.logo = None @@ -395,10 +402,8 @@ def update(self): with log_errors(): workers = list(self.scheduler.workers.values()) - bokeh_addresses = [] - for ws in workers: - addr = self.scheduler.get_worker_service_addr(ws.address, "bokeh") - bokeh_addresses.append("%s:%d" % addr if addr is not None else "") + dashboard_host = [ws.host for ws in workers] + dashboard_port = [ws.services.get("bokeh", "") for ws in workers] y = list(range(len(workers))) nprocessing = [len(ws.processing) for ws in workers] @@ -442,7 +447,8 @@ def update(self): "nbytes-half": [nb / 2 for nb in nbytes], "nbytes-color": nbytes_color, "nbytes_text": nbytes_text, - "bokeh_address": bokeh_addresses, + "dashboard_host": dashboard_host, + "dashboard_port": dashboard_port, "worker": [ws.address for ws in workers], "y": y, } @@ -1579,6 +1585,7 @@ def __init__(self, scheduler, io_loop=None, prefix="", **kwargs): self.prefix = prefix self.server_kwargs = kwargs + self.server_kwargs["prefix"] = prefix or None self.apps = { diff --git a/distributed/bokeh/scheduler_html.py b/distributed/bokeh/scheduler_html.py index d1ba2646ed6..1d3635c37c5 100644 --- a/distributed/bokeh/scheduler_html.py +++ b/distributed/bokeh/scheduler_html.py @@ -7,6 +7,7 @@ from tornado import web from ..utils import log_errors, format_bytes, format_time +from .proxy import GlobalProxyHandler dirname = os.path.dirname(__file__) @@ -42,6 +43,7 @@ def get(self, worker): self.render( "worker.html", title="Worker: " + worker, + scheduler=self.server, Worker=worker, **toolz.merge(self.server.__dict__, ns, self.extra) ) @@ -55,7 +57,7 @@ def get(self, task): "task.html", title="Task: " + task, Task=task, - server=self.server, + scheduler=self.server, **toolz.merge(self.server.__dict__, ns, self.extra) ) @@ -249,6 +251,7 @@ def get(self): (r"individual-plots.json", IndividualPlots), (r"metrics", PrometheusHandler), (r"health", HealthHandler), + (r"proxy/(\d+)/(.*?)/(.*)", GlobalProxyHandler), ] diff --git a/distributed/bokeh/templates/task.html b/distributed/bokeh/templates/task.html index f396a4cba8f..8c292da4e43 100644 --- a/distributed/bokeh/templates/task.html +++ b/distributed/bokeh/templates/task.html @@ -122,7 +122,7 @@

        Transition Log

        Recommended Action - {% for key, start, finish, recommendations, time in server.story(Task) %} + {% for key, start, finish, recommendations, time in scheduler.story(Task) %} {{ fromtimestamp(time) }} {{key}} diff --git a/distributed/bokeh/templates/worker-table.html b/distributed/bokeh/templates/worker-table.html index 90b59c08c54..8a86f8debd1 100644 --- a/distributed/bokeh/templates/worker-table.html +++ b/distributed/bokeh/templates/worker-table.html @@ -1,4 +1,4 @@ - +
        @@ -20,7 +20,7 @@ {% if 'bokeh' in ws.services %} - + {% else %} {% end %} diff --git a/distributed/bokeh/templates/workers.html b/distributed/bokeh/templates/workers.html index 6a2b7fc9345..f300855ac98 100644 --- a/distributed/bokeh/templates/workers.html +++ b/distributed/bokeh/templates/workers.html @@ -5,7 +5,6 @@

        Scheduler {{scheduler.address}}

        LogsBokeh - {% set worker_list = list(workers.values()) %} {% include "worker-table.html" %} diff --git a/distributed/bokeh/tests/test_scheduler_bokeh.py b/distributed/bokeh/tests/test_scheduler_bokeh.py index f3a57586c72..057aa679655 100644 --- a/distributed/bokeh/tests/test_scheduler_bokeh.py +++ b/distributed/bokeh/tests/test_scheduler_bokeh.py @@ -89,7 +89,7 @@ def test_basic(c, s, a, b): data = ss.source.data assert len(first(data.values())) if component is Occupancy: - assert all(addr.startswith("127.0.0.1:") for addr in data["bokeh_address"]) + assert all(addr == "127.0.0.1" for addr in data["dashboard_host"]) @gen_cluster(client=True) @@ -581,3 +581,45 @@ def test_root_redirect(c, s, a, b): ) assert response.code == 200 assert "/status" in response.effective_url + + +@gen_cluster( + client=True, + scheduler_kwargs={"services": {("bokeh", 0): BokehScheduler}}, + worker_kwargs={"services": {"bokeh": BokehWorker}}, + timeout=180, +) +def test_proxy_to_workers(c, s, a, b): + try: + import jupyter_server_proxy # noqa: F401 + + proxy_exists = True + except ImportError: + proxy_exists = False + + dashboard_port = s.services["bokeh"].port + http_client = AsyncHTTPClient() + response = yield http_client.fetch("http://localhost:%d/" % dashboard_port) + assert response.code == 200 + assert "/status" in response.effective_url + + for w in [a, b]: + host = w.ip + port = w.service_ports["bokeh"] + proxy_url = "http://localhost:%d/proxy/%s/%s/status" % ( + dashboard_port, + port, + host, + ) + direct_url = "http://localhost:%s/status" % port + http_client = AsyncHTTPClient() + response_proxy = yield http_client.fetch(proxy_url) + response_direct = yield http_client.fetch(direct_url) + + assert response_proxy.code == 200 + if proxy_exists: + assert b"Crossfilter" in response_proxy.body + else: + assert b"pip install jupyter-server-proxy" in response_proxy.body + assert response_direct.code == 200 + assert b"Crossfilter" in response_direct.body diff --git a/distributed/bokeh/tests/test_scheduler_bokeh_html.py b/distributed/bokeh/tests/test_scheduler_bokeh_html.py index 96fe3c2f5d2..691121f7514 100644 --- a/distributed/bokeh/tests/test_scheduler_bokeh_html.py +++ b/distributed/bokeh/tests/test_scheduler_bokeh_html.py @@ -14,9 +14,14 @@ from dask.sizeof import sizeof from distributed.utils_test import gen_cluster, slowinc, inc from distributed.bokeh.scheduler import BokehScheduler +from distributed.bokeh.worker import BokehWorker -@gen_cluster(client=True, scheduler_kwargs={"services": {("bokeh", 0): BokehScheduler}}) +@gen_cluster( + client=True, + scheduler_kwargs={"services": {("bokeh", 0): BokehScheduler}}, + worker_kwargs={"services": {"bokeh": BokehWorker}}, +) def test_connect(c, s, a, b): future = c.submit(lambda x: x + 1, 1) x = c.submit(slowinc, 1, delay=1, retries=5) diff --git a/distributed/cli/tests/test_dask_worker.py b/distributed/cli/tests/test_dask_worker.py index 5ed668e758a..aac27061b21 100644 --- a/distributed/cli/tests/test_dask_worker.py +++ b/distributed/cli/tests/test_dask_worker.py @@ -249,8 +249,14 @@ def func(dask_worker): def test_bokeh_non_standard_ports(loop): pytest.importorskip("bokeh") + try: + import jupyter_server_proxy # noqa: F401 - with popen(["dask-scheduler", "--port", "3449", "--no-bokeh"]): + proxy_exists = True + except ImportError: + proxy_exists = False + + with popen(["dask-scheduler", "--port", "3449"]): with popen( ["dask-worker", "tcp://127.0.0.1:3449", "--dashboard-address", ":4833"] ) as proc: @@ -264,9 +270,15 @@ def test_bokeh_non_standard_ports(loop): assert response.ok redirect_resp = requests.get("http://127.0.0.1:4833/main") redirect_resp.ok + # TEST PROXYING WORKS + if proxy_exists: + url = "http://127.0.0.1:8787/proxy/4833/127.0.0.1/status" + response = requests.get(url) + assert response.ok break except Exception: sleep(0.5) assert time() < start + 20 + with pytest.raises(Exception): requests.get("http://localhost:4833/status/") diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 801c0c849d7..cce66fb1767 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -856,6 +856,7 @@ def __init__( ) self.digests = None self.service_specs = services or {} + self.service_kwargs = service_kwargs or {} self.services = {} self.scheduler_file = scheduler_file worker_ttl = worker_ttl or dask.config.get("distributed.scheduler.worker-ttl") From 6134c754b08b35fd3e98d6128b9cdb2f28bb5300 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 22 May 2019 15:08:48 -0500 Subject: [PATCH 0303/1550] Replace register_worker_callbacks with worker plugins (#2453) * Add worker plugins * add docstring * Replace legacy worker_callbacks with worker_plugins * add and test name keyword * fix missing import * black * respond to feedback * Handle errors again * Expand docstring --- distributed/client.py | 92 ++++++++++++++++++++---- distributed/scheduler.py | 17 +++-- distributed/tests/test_worker.py | 7 +- distributed/tests/test_worker_plugins.py | 68 ++++++++++++++++++ distributed/worker.py | 57 ++++++++++++--- 5 files changed, 206 insertions(+), 35 deletions(-) create mode 100644 distributed/tests/test_worker_plugins.py diff --git a/distributed/client.py b/distributed/client.py index 6b5bc2811bb..6b22f5b8ee5 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -89,6 +89,7 @@ parse_timedelta, shutting_down, Any, + has_keyword, ) from .versions import get_versions @@ -3854,17 +3855,6 @@ def _get_task_stream( else: raise gen.Return(msgs) - @gen.coroutine - def _register_worker_callbacks(self, setup=None): - responses = yield self.scheduler.register_worker_callbacks(setup=dumps(setup)) - results = {} - for key, resp in responses.items(): - if resp["status"] == "OK": - results[key] = resp["result"] - elif resp["status"] == "error": - six.reraise(*clean_exception(**resp)) - raise gen.Return(results) - def register_worker_callbacks(self, setup=None): """ Registers a setup callback function for all current and future workers. @@ -3883,7 +3873,85 @@ def register_worker_callbacks(self, setup=None): setup : callable(dask_worker: Worker) -> None Function to register and run on all workers """ - return self.sync(self._register_worker_callbacks, setup=setup) + return self.register_worker_plugin(_WorkerSetupPlugin(setup)) + + @gen.coroutine + def _register_worker_plugin(self, plugin=None, name=None): + responses = yield self.scheduler.register_worker_plugin( + plugin=dumps(plugin), name=name + ) + for response in responses.values(): + if response["status"] == "error": + exc = response["exception"] + typ = type(exc) + tb = response["traceback"] + six.reraise(typ, exc, tb) + raise gen.Return(responses) + + def register_worker_plugin(self, plugin=None, name=None): + """ + Registers a lifecycle worker plugin for all current and future workers. + + This registers a new object to handle setup and teardown for workers in + this cluster. The plugin will instantiate itself on all currently + connected workers. It will also be run on any worker that connects in + the future. + + The plugin should be an object with ``setup`` and ``teardown`` methods. + It must be serializable with the pickle or cloudpickle modules. + + If the plugin has a ``name`` attribute, or if the ``name=`` keyword is + used then that will control idempotency. A a plugin with that name has + already registered then any future plugins will not run. + + For alternatives to plugins, you may also wish to look into preload + scripts. + + Parameters + ---------- + plugin: object + The plugin object to pass to the workers + name: str, optional + A name for the plugin. + Registering a plugin with the same name will have no effect. + + Examples + -------- + >>> class MyPlugin: + ... def __init__(self, *args, **kwargs): + ... pass # the constructor is up to you + ... def setup(self, worker: dask.distributed.Worker): + ... pass + ... def teardown(self, worker: dask.distributed.Worker): + ... pass + + >>> plugin = MyPlugin(1, 2, 3) + >>> client.register_worker_plugin(plugin) + + You can get access to the plugin with the ``get_worker`` function + + >>> client.register_worker_plugin(other_plugin, name='my-plugin') + >>> def f(): + ... worker = get_worker() + ... plugin = worker.plugins['my-plugin'] + ... return plugin.my_state + + >>> future = client.run(f) + """ + return self.sync(self._register_worker_plugin, plugin=plugin, name=name) + + +class _WorkerSetupPlugin(object): + """ This is used to support older setup functions as callbacks """ + + def __init__(self, setup): + self._setup = setup + + def setup(self, worker): + if has_keyword(self._setup, "dask_worker"): + return self._setup(dask_worker=worker) + else: + return self._setup() class Executor(Client): diff --git a/distributed/scheduler.py b/distributed/scheduler.py index cce66fb1767..57d768f95f6 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1003,7 +1003,7 @@ def __init__( self.log = deque( maxlen=dask.config.get("distributed.scheduler.transition-log-length") ) - self.worker_setups = [] + self.worker_plugins = [] worker_handlers = { "task-finished": self.handle_task_finished, @@ -1062,7 +1062,7 @@ def __init__( "heartbeat_worker": self.heartbeat_worker, "get_task_status": self.get_task_status, "get_task_stream": self.get_task_stream, - "register_worker_callbacks": self.register_worker_callbacks, + "register_worker_plugin": self.register_worker_plugin, } self._transitions = { @@ -1510,7 +1510,7 @@ def add_worker( "status": "OK", "time": time(), "heartbeat-interval": heartbeat_interval(len(self.workers)), - "worker-setups": self.worker_setups, + "worker-plugins": self.worker_plugins, } ) yield self.handle_worker(comm=comm, worker=address) @@ -3407,14 +3407,13 @@ def get_task_stream(self, comm=None, start=None, stop=None, count=None): return ts.collect(start=start, stop=stop, count=count) @gen.coroutine - def register_worker_callbacks(self, comm, setup=None): + def register_worker_plugin(self, comm, plugin, name=None): """ Registers a setup function, and call it on every worker """ - if setup is None: - raise gen.Return({}) - - self.worker_setups.append(setup) + self.worker_plugins.append(plugin) - responses = yield self.broadcast(msg=dict(op="run", function=setup)) + responses = yield self.broadcast( + msg=dict(op="plugin-add", plugin=plugin, name=name) + ) raise gen.Return(responses) ##################### diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 4541e183e46..9fc967eef5a 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1312,7 +1312,6 @@ def test_startup2(): return os.getenv("MY_ENV_VALUE", None) == "WORKER_ENV_VALUE" # Nothing has been run yet - assert len(s.worker_setups) == 0 result = yield c.run(test_import) assert list(result.values()) == [False] * 2 result = yield c.run(test_startup2) @@ -1327,7 +1326,6 @@ def test_startup2(): # Add a preload function response = yield c.register_worker_callbacks(setup=mystartup) assert len(response) == 2 - assert len(s.worker_setups) == 1 # Check it has been ran on existing worker result = yield c.run(test_import) @@ -1342,7 +1340,6 @@ def test_startup2(): # Register another preload function response = yield c.register_worker_callbacks(setup=mystartup2) assert len(response) == 2 - assert len(s.worker_setups) == 2 # Check it has been run result = yield c.run(test_startup2) @@ -1356,7 +1353,9 @@ def test_startup2(): assert list(result.values()) == [True] yield worker.close() - # Final exception test + +@gen_cluster(client=True) +def test_register_worker_callbacks_err(c, s, a, b): with pytest.raises(ZeroDivisionError): yield c.register_worker_callbacks(setup=lambda: 1 / 0) diff --git a/distributed/tests/test_worker_plugins.py b/distributed/tests/test_worker_plugins.py new file mode 100644 index 00000000000..25388459788 --- /dev/null +++ b/distributed/tests/test_worker_plugins.py @@ -0,0 +1,68 @@ +from distributed.utils_test import gen_cluster +from distributed import Worker + + +class MyPlugin: + name = "MyPlugin" + + def __init__(self, data): + self.data = data + + def setup(self, worker): + assert isinstance(worker, Worker) + self.worker = worker + self.worker._my_plugin_status = "setup" + self.worker._my_plugin_data = self.data + + def teardown(self, worker): + assert isinstance(worker, Worker) + self.worker._my_plugin_status = "teardown" + + +@gen_cluster(client=True, ncores=[]) +def test_create_with_client(c, s): + yield c.register_worker_plugin(MyPlugin(123)) + + worker = Worker(s.address, loop=s.loop) + yield worker._start() + assert worker._my_plugin_status == "setup" + assert worker._my_plugin_data == 123 + + yield worker._close() + assert worker._my_plugin_status == "teardown" + + +@gen_cluster(client=True, worker_kwargs={"plugins": [MyPlugin(5)]}) +def test_create_on_construction(c, s, a, b): + assert len(a.plugins) == len(b.plugins) == 1 + assert a._my_plugin_status == "setup" + assert a._my_plugin_data == 5 + + +@gen_cluster(client=True, worker_kwargs={"plugins": [MyPlugin(5)]}) +def test_idempotence_with_name(c, s, a, b): + a._my_plugin_data = 100 + + yield c.register_worker_plugin(MyPlugin(5)) + + assert a._my_plugin_data == 100 # call above has no effect + + +@gen_cluster(client=True, worker_kwargs={"plugins": [MyPlugin(5)]}) +def test_duplicate_with_no_name(c, s, a, b): + assert len(a.plugins) == len(b.plugins) == 1 + + plugin = MyPlugin(10) + plugin.name = "other-name" + + yield c.register_worker_plugin(plugin) + + assert len(a.plugins) == len(b.plugins) == 2 + + assert a._my_plugin_data == 10 + + yield c.register_worker_plugin(plugin) + assert len(a.plugins) == len(b.plugins) == 2 + + yield c.register_worker_plugin(plugin, name="foo") + assert len(a.plugins) == len(b.plugins) == 3 diff --git a/distributed/worker.py b/distributed/worker.py index f4189393091..3fcc477bf48 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -10,6 +10,7 @@ import random import threading import sys +import uuid import warnings import weakref import psutil @@ -307,6 +308,7 @@ def __init__( protocol=None, dashboard_address=None, nanny=None, + plugins=(), low_level_profiler=dask.config.get("distributed.worker.profile.low-level"), **kwargs ): @@ -576,6 +578,7 @@ def __init__( "versions": self.versions, "actor_execute": self.actor_execute, "actor_attribute": self.actor_attribute, + "plugin-add": self.plugin_add, } stream_handlers = { @@ -638,6 +641,9 @@ def __init__( ) self.periodic_callbacks["profile-cycle"] = pc + self.plugins = {} + self._pending_plugins = plugins + Worker._instances.add(self) ################## @@ -763,16 +769,9 @@ def _register_with_scheduler(self): if response["status"] != "OK": raise ValueError("Unexpected response from register: %r" % (response,)) else: - # Retrieve eventual init functions and run them - for function_bytes in response["worker-setups"]: - setup_function = pickle.loads(function_bytes) - if has_arg(setup_function, "dask_worker"): - result = setup_function(dask_worker=self) - else: - result = setup_function() - logger.info( - "Init function %s ran: output=%s" % (setup_function, result) - ) + yield [ + self.plugin_add(plugin=plugin) for plugin in response["worker-plugins"] + ] logger.info(" Registered to: %26s", self.scheduler.address) logger.info("-" * 49) @@ -968,6 +967,9 @@ def _start(self, addr_or_port=0): setproctitle("dask-worker [%s]" % self.address) + yield [self.plugin_add(plugin=plugin) for plugin in self._pending_plugins] + self._pending_plugins = () + yield self._register_with_scheduler() self.start_periodic_callbacks() @@ -998,6 +1000,12 @@ def close(self, report=True, timeout=10, nanny=True, executor_wait=True): self.status = "closing" setproctitle("dask-worker [closing]") + yield [ + plugin.teardown(self) + for plugin in self.plugins.values() + if hasattr(plugin, "teardown") + ] + self.stop() for pc in self.periodic_callbacks.values(): pc.stop() @@ -2206,6 +2214,35 @@ def run(self, comm, function, args=(), wait=True, kwargs=None): def run_coroutine(self, comm, function, args=(), kwargs=None, wait=True): return run(self, comm, function=function, args=args, kwargs=kwargs, wait=wait) + @gen.coroutine + def plugin_add(self, comm=None, plugin=None, name=None): + with log_errors(pdb=False): + if isinstance(plugin, bytes): + plugin = pickle.loads(plugin) + if not name: + if hasattr(plugin, "name"): + name = plugin.name + else: + name = funcname(plugin) + "-" + str(uuid.uuid4()) + + assert name + + if name in self.plugins: + return {"status": "repeat"} + else: + self.plugins[name] = plugin + + logger.info("Starting Worker plugin %s" % name) + try: + result = plugin.setup(worker=self) + if isinstance(result, gen.Future): + result = yield result + except Exception as e: + msg = error_message(e) + return msg + else: + return {"status": "OK"} + @gen.coroutine def actor_execute(self, comm=None, actor=None, function=None, args=(), kwargs={}): separate_thread = kwargs.pop("separate_thread", True) From 6e0c0a6b90b1d3c3f686f0c968e9cf3d0c354413 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 22 May 2019 16:45:53 -0500 Subject: [PATCH 0304/1550] Add SpecificationCluster (#2675) This is intended to be a base for LocalCluster (and others) that want to specify more heterogeneous information about workers. This forces the use of Python 3 and introduces more asyncio and async def handling. This cleans up a number of intermittent testing failures and improves our testing harness hygeine. --- .../setup_conda_environment.cmd | 2 +- continuous_integration/travis/install.sh | 2 +- distributed/__init__.py | 2 +- distributed/cli/dask_worker.py | 2 +- distributed/client.py | 24 +- distributed/comm/tcp.py | 1 + distributed/core.py | 8 +- distributed/deploy/__init__.py | 1 + distributed/deploy/adaptive.py | 4 +- distributed/deploy/cluster.py | 32 +- distributed/deploy/local.py | 326 ++---------------- distributed/deploy/spec.py | 297 ++++++++++++++++ distributed/deploy/tests/py3_test_deploy.py | 15 +- distributed/deploy/tests/test_adaptive.py | 56 ++- distributed/deploy/tests/test_local.py | 120 ++++--- distributed/deploy/tests/test_spec_cluster.py | 115 ++++++ distributed/deploy/utils_test.py | 17 +- distributed/nanny.py | 34 +- distributed/scheduler.py | 12 +- distributed/tests/test_as_completed.py | 4 +- distributed/tests/test_asyncprocess.py | 1 + distributed/tests/test_client.py | 39 ++- distributed/tests/test_nanny.py | 9 +- distributed/tests/test_scheduler.py | 16 +- distributed/tests/test_worker.py | 31 +- distributed/tests/test_worker_client.py | 4 +- distributed/utils.py | 2 + distributed/utils_test.py | 84 ++--- distributed/worker.py | 15 +- 29 files changed, 774 insertions(+), 501 deletions(-) create mode 100644 distributed/deploy/spec.py create mode 100644 distributed/deploy/tests/test_spec_cluster.py diff --git a/continuous_integration/setup_conda_environment.cmd b/continuous_integration/setup_conda_environment.cmd index cd201ff46d5..5748a8cf20c 100644 --- a/continuous_integration/setup_conda_environment.cmd +++ b/continuous_integration/setup_conda_environment.cmd @@ -50,7 +50,7 @@ call activate %CONDA_ENV% %PIP_INSTALL% git+https://github.com/joblib/joblib.git --upgrade %PIP_INSTALL% git+https://github.com/dask/zict --upgrade -%PIP_INSTALL% pytest-repeat pytest-timeout pytest-faulthandler sortedcollections +%PIP_INSTALL% pytest-repeat pytest-timeout pytest-faulthandler sortedcollections pytest-asyncio @rem Display final environment (for reproducing) %CONDA% list diff --git a/continuous_integration/travis/install.sh b/continuous_integration/travis/install.sh index bba69dd3ac8..f1ff25a9bfa 100644 --- a/continuous_integration/travis/install.sh +++ b/continuous_integration/travis/install.sh @@ -59,7 +59,7 @@ conda install -q \ conda install -c defaults -c conda-forge libunwind conda install --no-deps -c defaults -c numba -c conda-forge stacktrace -pip install -q pytest-repeat pytest-faulthandler +pip install -q pytest-repeat pytest-faulthandler pytest-asyncio pip install -q git+https://github.com/dask/dask.git --upgrade --no-deps pip install -q git+https://github.com/joblib/joblib.git --upgrade --no-deps diff --git a/distributed/__init__.py b/distributed/__init__.py index 7b2bc4ab082..2a632607cf9 100644 --- a/distributed/__init__.py +++ b/distributed/__init__.py @@ -4,7 +4,7 @@ from dask.config import config from .actor import Actor, ActorFuture from .core import connect, rpc -from .deploy import LocalCluster, Adaptive +from .deploy import LocalCluster, Adaptive, SpecCluster from .diagnostics import progress from .client import ( Client, diff --git a/distributed/cli/dask_worker.py b/distributed/cli/dask_worker.py index e383095b382..439bdaf4a62 100755 --- a/distributed/cli/dask_worker.py +++ b/distributed/cli/dask_worker.py @@ -114,7 +114,7 @@ @click.option( "--name", type=str, - default="", + default=None, help="A unique name for this worker like 'worker-1'. " "If used with --nprocs then the process number " "will be appended like name-0, name-1, name-2, ...", diff --git a/distributed/client.py b/distributed/client.py index 6b22f5b8ee5..afe6f6ef39f 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -2,8 +2,8 @@ import atexit from collections import defaultdict -from concurrent.futures import ThreadPoolExecutor -from concurrent.futures._base import DoneAndNotDoneFutures, CancelledError +from concurrent.futures import ThreadPoolExecutor, CancelledError +from concurrent.futures._base import DoneAndNotDoneFutures from contextlib import contextmanager import copy from datetime import timedelta @@ -44,6 +44,8 @@ from tornado.ioloop import IOLoop from tornado.queues import Queue +from asyncio import iscoroutine + from .batched import BatchedSend from .utils_comm import ( WrappedKey, @@ -1309,7 +1311,13 @@ def close(self, timeout=no_default): if self._start_arg is None: with ignoring(AttributeError): - self.cluster.close() + f = self.cluster.close() + if iscoroutine(f): + + async def _(): + await f + + self.sync(_) sync(self.loop, self._close, fast=True) @@ -1644,10 +1652,11 @@ def wait(k): st = self.futures[key] exception = st.exception traceback = st.traceback - except (AttributeError, KeyError): - six.reraise(CancelledError, CancelledError(key), None) + except (KeyError, AttributeError): + exc = CancelledError(key) else: six.reraise(type(exception), exception, traceback) + raise exc if errors == "skip": bad_keys.add(key) bad_data[key] = None @@ -4134,7 +4143,10 @@ def _track_future(self, future): except CancelledError: pass if self.with_results: - result = yield future._result(raiseit=False) + try: + result = yield future._result(raiseit=False) + except CancelledError as exc: + result = exc with self.lock: self.futures[future] -= 1 if not self.futures[future]: diff --git a/distributed/comm/tcp.py b/distributed/comm/tcp.py index 85dbe2ce278..d5351c7d565 100644 --- a/distributed/comm/tcp.py +++ b/distributed/comm/tcp.py @@ -431,6 +431,7 @@ def start(self): break else: raise exc + self.get_host_port() # trigger assignment to self.bound_address def stop(self): tcp_server, self.tcp_server = self.tcp_server, None diff --git a/distributed/core.py b/distributed/core.py index 9b1d408a038..17685c9d2d5 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -489,14 +489,16 @@ def handle_stream(self, comm, extra=None, every_cycle=[]): @gen.coroutine def close(self): - self.listener.stop() + for pc in self.periodic_callbacks.values(): + pc.stop() + if self.listener: + self.listener.stop() for i in range(20): # let comms close naturally for a second if not self._comms: break else: yield gen.sleep(0.05) - for comm in self._comms: - comm.close() + yield [comm.close() for comm in self._comms] for cb in self._ongoing_coroutines: cb.cancel() for i in range(10): diff --git a/distributed/deploy/__init__.py b/distributed/deploy/__init__.py index 35abf0a6439..9b5e478c303 100644 --- a/distributed/deploy/__init__.py +++ b/distributed/deploy/__init__.py @@ -4,6 +4,7 @@ from .cluster import Cluster from .local import LocalCluster +from .spec import SpecCluster from .adaptive import Adaptive with ignoring(ImportError): diff --git a/distributed/deploy/adaptive.py b/distributed/deploy/adaptive.py index 8c260609638..793e80d984c 100644 --- a/distributed/deploy/adaptive.py +++ b/distributed/deploy/adaptive.py @@ -272,7 +272,7 @@ def _retire_workers(self, workers=None): logger.info("Retiring workers %s", workers) f = self.cluster.scale_down(workers) - if gen.is_future(f): + if hasattr(f, "__await__"): yield f raise gen.Return(workers) @@ -354,7 +354,7 @@ def _adapt(self): if status == "up": f = self.cluster.scale_up(**recommendations) self.log.append((time(), "up", recommendations)) - if gen.is_future(f): + if hasattr(f, "__await__"): yield f elif status == "down": diff --git a/distributed/deploy/cluster.py b/distributed/deploy/cluster.py index f170d4ea5ad..8425b836a4d 100644 --- a/distributed/deploy/cluster.py +++ b/distributed/deploy/cluster.py @@ -1,12 +1,23 @@ +from datetime import timedelta import logging import os from weakref import ref import dask +from tornado import gen from .adaptive import Adaptive -from ..utils import format_bytes, PeriodicCallback, log_errors, ignoring +from ..compatibility import get_thread_identity +from ..utils import ( + format_bytes, + PeriodicCallback, + log_errors, + ignoring, + sync, + thread_state, +) + logger = logging.getLogger(__name__) @@ -215,3 +226,22 @@ def update(): def _ipython_display_(self, **kwargs): return self._widget()._ipython_display_(**kwargs) + + @property + def asynchronous(self): + return ( + self._asynchronous + or getattr(thread_state, "asynchronous", False) + or hasattr(self.loop, "_thread_identity") + and self.loop._thread_identity == get_thread_identity() + ) + + def sync(self, func, *args, **kwargs): + if kwargs.pop("asynchronous", None) or self.asynchronous: + callback_timeout = kwargs.pop("callback_timeout", None) + future = func(*args, **kwargs) + if callback_timeout is not None: + future = gen.with_timeout(timedelta(seconds=callback_timeout), future) + return future + else: + return sync(self.loop, func, *args, **kwargs) diff --git a/distributed/deploy/local.py b/distributed/deploy/local.py index 832e8f3e051..17150fdf70f 100644 --- a/distributed/deploy/local.py +++ b/distributed/deploy/local.py @@ -1,29 +1,14 @@ from __future__ import print_function, division, absolute_import import atexit -from datetime import timedelta import logging import math import warnings import weakref -import toolz from dask.utils import factors -from tornado import gen - -from .cluster import Cluster -from ..compatibility import get_thread_identity -from ..core import CommClosedError -from ..utils import ( - sync, - ignoring, - All, - silence_logging, - LoopRunner, - log_errors, - thread_state, - parse_timedelta, -) + +from .spec import SpecCluster from ..nanny import Nanny from ..scheduler import Scheduler from ..worker import Worker, parse_memory_limit, _ncores @@ -31,7 +16,7 @@ logger = logging.getLogger(__name__) -class LocalCluster(Cluster): +class LocalCluster(SpecCluster): """ Create local Scheduler and Workers This creates a "cluster" of a scheduler and workers running on the local @@ -105,8 +90,8 @@ def __init__( processes=True, loop=None, start=None, - ip=None, host=None, + ip=None, scheduler_port=0, silence_logs=logging.WARN, dashboard_address=":8787", @@ -127,15 +112,6 @@ def __init__( warnings.warn("The ip keyword has been moved to host") host = ip - if start is not None: - msg = ( - "The start= parameter is deprecated. " - "LocalCluster always starts. " - "For asynchronous operation use the following: \n\n" - " cluster = yield LocalCluster(asynchronous=True)" - ) - raise ValueError(msg) - if diagnostics_port is not None: warnings.warn( "diagnostics_port has been deprecated. " @@ -161,12 +137,8 @@ def __init__( if host is None and not protocol.startswith("inproc") and not interface: host = "127.0.0.1" - self.silence_logs = silence_logs - self._asynchronous = asynchronous services = services or {} worker_services = worker_services or {} - if silence_logs: - self._old_logging_level = silence_logging(level=silence_logs) if n_workers is None and threads_per_worker is None: if processes: n_workers, threads_per_worker = nprocesses_nthreads(_ncores) @@ -188,268 +160,42 @@ def __init__( "dashboard_address": worker_dashboard_address, "interface": interface, "protocol": protocol, + "security": security, + "silence_logs": silence_logs, } ) - self._loop_runner = LoopRunner(loop=loop, asynchronous=asynchronous) - self.loop = self._loop_runner.loop - - self.scheduler = Scheduler( - loop=self.loop, - host=host, - services=services, - service_kwargs=service_kwargs, - security=security, - port=scheduler_port, - interface=interface, - protocol=protocol, - dashboard_address=dashboard_address, - blocked_handlers=blocked_handlers, + scheduler = { + "cls": Scheduler, + "options": dict( + host=host, + services=services, + service_kwargs=service_kwargs, + security=security, + port=scheduler_port, + interface=interface, + protocol=protocol, + dashboard_address=dashboard_address, + blocked_handlers=blocked_handlers, + ), + } + + worker = { + "cls": worker_class or (Worker if not processes else Nanny), + "options": worker_kwargs, + } + + workers = {i: worker for i in range(n_workers)} + + super(LocalCluster, self).__init__( + scheduler=scheduler, + workers=workers, + worker=worker, + loop=loop, + asynchronous=asynchronous, + silence_logs=silence_logs, ) - - self.workers = [] - self.worker_kwargs = worker_kwargs - if security: - self.worker_kwargs["security"] = security - - if not worker_class: - worker_class = Worker if not processes else Nanny - self.worker_class = worker_class - - self.start(n_workers=n_workers) - - clusters_to_close.add(self) - - def __repr__(self): - return "LocalCluster(%r, workers=%d, ncores=%d)" % ( - self.scheduler_address, - len(self.workers), - sum(w.ncores for w in self.workers), - ) - - def __await__(self): - return self._started.__await__() - - @property - def asynchronous(self): - return ( - self._asynchronous - or getattr(thread_state, "asynchronous", False) - or hasattr(self.loop, "_thread_identity") - and self.loop._thread_identity == get_thread_identity() - ) - - def sync(self, func, *args, **kwargs): - if kwargs.pop("asynchronous", None) or self.asynchronous: - callback_timeout = kwargs.pop("callback_timeout", None) - future = func(*args, **kwargs) - if callback_timeout is not None: - future = gen.with_timeout(timedelta(seconds=callback_timeout), future) - return future - else: - return sync(self.loop, func, *args, **kwargs) - - def start(self, **kwargs): - self._loop_runner.start() - if self._asynchronous: - self._started = self._start(**kwargs) - else: - self.sync(self._start, **kwargs) - - @gen.coroutine - def _start(self, n_workers=0): - """ - Start all cluster services. - """ - if self.status == "running": - return - - self.scheduler.start() - - yield [self._start_worker(**self.worker_kwargs) for i in range(n_workers)] - yield self.scheduler - - self.status = "running" - - raise gen.Return(self) - - @gen.coroutine - def _start_worker(self, death_timeout=60, **kwargs): - if self.status and self.status.startswith("clos"): - warnings.warn( - "Tried to start a worker while status=='%s'" % self.status, stacklevel=2 - ) - return - - if self.processes: - kwargs["quiet"] = True - - w = yield self.worker_class( - self.scheduler.address, - loop=self.loop, - death_timeout=death_timeout, - silence_logs=self.silence_logs, - **kwargs - ) - - self.workers.append(w) - - while w.status != "closed" and w.worker_address not in self.scheduler.workers: - yield gen.sleep(0.01) - - if w.status == "closed" and self.scheduler.status == "running": - self.workers.remove(w) - raise gen.TimeoutError("Worker failed to start") - - raise gen.Return(w) - - def start_worker(self, **kwargs): - """ Add a new worker to the running cluster - - Parameters - ---------- - port: int (optional) - Port on which to serve the worker, defaults to 0 or random - ncores: int (optional) - Number of threads to use. Defaults to number of logical cores - - Examples - -------- - >>> c = LocalCluster() # doctest: +SKIP - >>> c.start_worker(ncores=2) # doctest: +SKIP - - Returns - ------- - The created Worker or Nanny object. Can be discarded. - """ - return self.sync(self._start_worker, **kwargs) - - @gen.coroutine - def _stop_worker(self, w): - yield w.close() - if w in self.workers: - self.workers.remove(w) - - def stop_worker(self, w): - """ Stop a running worker - - Examples - -------- - >>> c = LocalCluster() # doctest: +SKIP - >>> w = c.start_worker(ncores=2) # doctest: +SKIP - >>> c.stop_worker(w) # doctest: +SKIP - """ - self.sync(self._stop_worker, w) - - @gen.coroutine - def _close(self, timeout="2s"): - # Can be 'closing' as we're called by close() below - if self.status == "closed": - return - self.status = "closing" - - with ignoring(gen.TimeoutError, CommClosedError, OSError): - yield gen.with_timeout( - timedelta(seconds=parse_timedelta(timeout)), - self.scheduler.close(close_workers=True), - ) - - with ignoring(gen.TimeoutError): - yield gen.with_timeout( - timedelta(seconds=parse_timedelta(timeout)), - All([self._stop_worker(w) for w in self.workers]), - ) - del self.workers[:] - self.status = "closed" - - def close(self, timeout=20): - """ Close the cluster """ - if self.status == "closed": - return - - try: - result = self.sync(self._close, callback_timeout=timeout) - except RuntimeError: # IOLoop is closed - result = None - - if hasattr(self, "_old_logging_level"): - if self.asynchronous: - result.add_done_callback( - lambda _: silence_logging(self._old_logging_level) - ) - else: - silence_logging(self._old_logging_level) - - if not self.asynchronous: - self._loop_runner.stop() - - return result - - @gen.coroutine - def scale_up(self, n, **kwargs): - """ Bring the total count of workers up to ``n`` - - This function/coroutine should bring the total number of workers up to - the number ``n``. - - This can be implemented either as a function or as a Tornado coroutine. - """ - with log_errors(): - kwargs2 = toolz.merge(self.worker_kwargs, kwargs) - yield [ - self._start_worker(**kwargs2) - for i in range(n - len(self.scheduler.workers)) - ] - - # clean up any closed worker - self.workers = [w for w in self.workers if w.status != "closed"] - - @gen.coroutine - def scale_down(self, workers): - """ Remove ``workers`` from the cluster - - Given a list of worker addresses this function should remove those - workers from the cluster. This may require tracking which jobs are - associated to which worker address. - - This can be implemented either as a function or as a Tornado coroutine. - """ - with log_errors(): - # clean up any closed worker - self.workers = [w for w in self.workers if w.status != "closed"] - workers = set(workers) - - # we might be given addresses - if all(isinstance(w, str) for w in workers): - workers = {w for w in self.workers if w.worker_address in workers} - - # stop the provided workers - yield [self._stop_worker(w) for w in workers] - - def __del__(self): - self.close() - - def __enter__(self): - return self - - def __exit__(self, *args): - self.close() - - @gen.coroutine - def __aenter__(self): - yield self._started - raise gen.Return(self) - - @gen.coroutine - def __aexit__(self, typ, value, traceback): - yield self._close() - - @property - def scheduler_address(self): - try: - return self.scheduler.address - except ValueError: - return "" + self.scale(n_workers) def nprocesses_nthreads(n): diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py new file mode 100644 index 00000000000..9a4385e5054 --- /dev/null +++ b/distributed/deploy/spec.py @@ -0,0 +1,297 @@ +import asyncio +import weakref + +from tornado import gen + +from .cluster import Cluster +from ..utils import LoopRunner, silence_logging, ignoring +from ..scheduler import Scheduler + + +class SpecCluster(Cluster): + """ Cluster that requires a full specification of workers + + The SpecCluster class expects a full specification of the Scheduler and + Workers to use. It removes any handling of user inputs (like threads vs + processes, number of cores, and so on) and any handling of cluster resource + managers (like pods, jobs, and so on). Instead, it expects this + information to be passed in scheduler and worker specifications. This + class does handle all of the logic around asynchronously cleanly setting up + and tearing things down at the right times. Hopefully it can form a base + for other more user-centric classes. + + Parameters + ---------- + workers: dict + A dictionary mapping names to worker classes and their specifications + See example below + scheduler: dict, optional + A similar mapping for a scheduler + worker: dict + A specification of a single worker. + This is used for any new workers that are created. + asynchronous: bool + If this is intended to be used directly within an event loop with + async/await + silence_logs: bool + Whether or not we should silence logging when setting up the cluster. + + Examples + -------- + To create a SpecCluster you specify how to set up a Scheduler and Workers + + >>> from dask.distributed import Scheduler, Worker, Nanny + >>> scheduler = {'cls': Scheduler, 'options': {"dashboard_address": ':8787'}} + >>> workers = { + ... 'my-worker': {"cls": Worker, "options": {"ncores": 1}}, + ... 'my-nanny': {"cls": Nanny, "options": {"ncores": 2}}, + ... } + >>> cluster = SpecCluster(scheduler=scheduler, workers=workers) + + The worker spec is stored as the ``.worker_spec`` attribute + + >>> cluster.worker_spec + { + 'my-worker': {"cls": Worker, "options": {"ncores": 1}}, + 'my-nanny': {"cls": Nanny, "options": {"ncores": 2}}, + } + + While the instantiation of this spec is stored in the ``.workers`` + attribute + + >>> cluster.workers + { + 'my-worker': + 'my-nanny': + } + + Should the spec change, we can await the cluster or call the + ``._correct_state`` method to align the actual state to the specified + state. + + We can also ``.scale(...)`` the cluster, which adds new workers of a given + form. + + >>> worker = {'cls': Worker, 'options': {}} + >>> cluster = SpecCluster(scheduler=scheduler, worker=worker) + >>> cluster.worker_spec + {} + + >>> cluster.scale(3) + >>> cluster.worker_spec + { + 0: {'cls': Worker, 'options': {}}, + 1: {'cls': Worker, 'options': {}}, + 2: {'cls': Worker, 'options': {}}, + } + + Note that above we are using the standard ``Worker`` and ``Nanny`` classes, + however in practice other classes could be used that handle resource + management like ``KubernetesPod`` or ``SLURMJob``. The spec does not need + to conform to the expectations of the standard Dask Worker class. It just + needs to be called with the provided options, support ``__await__`` and + ``close`` methods and the ``worker_address`` property.. + + Also note that uniformity of the specification is not required. Other API + could be added externally (in subclasses) that adds workers of different + specifications into the same dictionary. + """ + + def __init__( + self, + workers=None, + scheduler=None, + worker=None, + asynchronous=False, + loop=None, + silence_logs=False, + ): + self._created = weakref.WeakSet() + if scheduler is None: + try: + from distributed.bokeh.scheduler import BokehScheduler + except ImportError: + services = {} + else: + services = {("bokeh", 8787): BokehScheduler} + scheduler = {"cls": Scheduler, "options": {"services": services}} + + self.scheduler_spec = scheduler + self.worker_spec = workers or {} + self.new_spec = worker + self.workers = {} + self._i = 0 + self._asynchronous = asynchronous + + if silence_logs: + self._old_logging_level = silence_logging(level=silence_logs) + + self._loop_runner = LoopRunner(loop=loop, asynchronous=asynchronous) + self.loop = self._loop_runner.loop + + self.scheduler = self.scheduler_spec["cls"]( + loop=self.loop, **self.scheduler_spec["options"] + ) + self.status = "created" + self._correct_state_waiting = None + + if not self.asynchronous: + self._loop_runner.start() + self.sync(self._start) + self.sync(self._correct_state) + self.sync(self._wait_for_workers) + + async def _start(self): + while self.status == "starting": + await asyncio.sleep(0.01) + if self.status == "running": + return + if self.status == "closed": + raise ValueError("Cluster is closed") + + self._lock = asyncio.Lock() + self.status = "starting" + self.scheduler = await self.scheduler + self.status = "running" + + def _correct_state(self): + if self._correct_state_waiting: + # If people call this frequently, we only want to run it once + return self._correct_state_waiting + else: + task = asyncio.ensure_future(self._correct_state_internal()) + self._correct_state_waiting = task + return task + + async def _correct_state_internal(self): + async with self._lock: + self._correct_state_waiting = None + + pre = list(set(self.workers)) + to_close = set(self.workers) - set(self.worker_spec) + if to_close: + await self.scheduler.retire_workers(workers=list(to_close)) + tasks = [self.workers[w].close() for w in to_close] + await asyncio.wait(tasks) + for task in tasks: # for tornado gen.coroutine support + await task + for name in to_close: + del self.workers[name] + + to_open = set(self.worker_spec) - set(self.workers) + workers = [] + for name in to_open: + d = self.worker_spec[name] + cls, opts = d["cls"], d.get("options", {}) + if "name" not in opts: + opts = opts.copy() + opts["name"] = name + worker = cls(self.scheduler.address, **opts) + self._created.add(worker) + workers.append(worker) + if workers: + await asyncio.wait(workers) + for w in workers: + w._cluster = weakref.ref(self) + await w # for tornado gen.coroutine support + self.workers.update(dict(zip(to_open, workers))) + + def __await__(self): + async def _(): + if self.status == "created": + await self._start() + await self.scheduler + await self._correct_state() + if self.workers: + await asyncio.wait(list(self.workers.values())) # maybe there are more + await self._wait_for_workers() + return self + + return _().__await__() + + async def _wait_for_workers(self): + # TODO: this function needs to query scheduler and worker state + # remotely without assuming that they are local + while {d["name"] for d in self.scheduler.identity()["workers"].values()} != set( + self.workers + ): + if ( + any(w.status == "closed" for w in self.workers.values()) + and self.scheduler.status == "running" + ): + raise gen.TimeoutError("Worker unexpectedly closed") + await asyncio.sleep(0.1) + + async def __aenter__(self): + await self + return self + + async def __aexit__(self, typ, value, traceback): + await self.close() + + async def _close(self): + while self.status == "closing": + await asyncio.sleep(0.1) + if self.status == "closed": + return + self.status = "closing" + + async with self._lock: + await self.scheduler.close(close_workers=True) + self.scale(0) + await self._correct_state() + for w in self._created: + assert w.status == "closed" + + if hasattr(self, "_old_logging_level"): + silence_logging(self._old_logging_level) + + self.status = "closed" + + def close(self): + with ignoring(RuntimeError): # loop closed during process shutdown + return self.sync(self._close) + + def __del__(self): + if self.status != "closed": + self.close() + + def __enter__(self): + self.sync(self._correct_state) + self.sync(self._wait_for_workers) + assert self.status == "running" + return self + + def __exit__(self, typ, value, traceback): + self.close() + self._loop_runner.stop() + + def scale(self, n): + while len(self.worker_spec) > n: + self.worker_spec.popitem() + + while len(self.worker_spec) < n: + while self._i in self.worker_spec: + self._i += 1 + self.worker_spec[self._i] = self.new_spec + + self.loop.add_callback(self._correct_state) + + async def scale_down(self, workers): + workers = set(workers) + + # TODO: this is linear cost. We should be indexing by name or something + to_close = [w for w in self.workers.values() if w.address in workers] + for k, v in self.workers.items(): + if v.worker_address in workers: + del self.worker_spec[k] + + await self + + scale_up = scale # backwards compatibility + + def __repr__(self): + return "SpecCluster(%r, workers=%d)" % ( + self.scheduler_address, + len(self.workers), + ) diff --git a/distributed/deploy/tests/py3_test_deploy.py b/distributed/deploy/tests/py3_test_deploy.py index 4c8fb2f86de..7a66ecf942c 100644 --- a/distributed/deploy/tests/py3_test_deploy.py +++ b/distributed/deploy/tests/py3_test_deploy.py @@ -1,14 +1,13 @@ from distributed import LocalCluster from distributed.utils_test import loop # noqa: F401 +import pytest -def test_async_with(loop): - async def f(): - async with LocalCluster(processes=False, asynchronous=True) as cluster: - w = cluster.workers - assert w +@pytest.mark.asyncio +async def test_async_with(): + async with LocalCluster(processes=False, asynchronous=True) as cluster: + w = cluster.workers + assert w - assert not w - - loop.run_sync(f) + assert not w diff --git a/distributed/deploy/tests/test_adaptive.py b/distributed/deploy/tests/test_adaptive.py index 50c4f0a45a3..8915c721353 100644 --- a/distributed/deploy/tests/test_adaptive.py +++ b/distributed/deploy/tests/test_adaptive.py @@ -2,12 +2,13 @@ from time import sleep +import pytest from toolz import frequencies, pluck from tornado import gen from tornado.ioloop import IOLoop -from distributed import Client, wait, Adaptive, LocalCluster -from distributed.utils_test import gen_cluster, gen_test, slowinc, inc +from distributed import Client, wait, Adaptive, LocalCluster, SpecCluster, Worker +from distributed.utils_test import gen_cluster, gen_test, slowinc, inc, clean from distributed.utils_test import loop, nodebug # noqa: F401 from distributed.metrics import time @@ -162,19 +163,17 @@ def scale_down(self, workers): assert len(s.workers) == 2 +@pytest.mark.xfail(reason="need to rework adaptive") @gen_test(timeout=30) def test_min_max(): - loop = IOLoop.current() cluster = yield LocalCluster( 0, scheduler_port=0, silence_logs=False, processes=False, dashboard_address=None, - loop=loop, asynchronous=True, ) - yield cluster._start() try: adapt = Adaptive( cluster.scheduler, @@ -184,7 +183,7 @@ def test_min_max(): interval="20 ms", wait_count=10, ) - c = yield Client(cluster, asynchronous=True, loop=loop) + c = yield Client(cluster, asynchronous=True) start = time() while not cluster.scheduler.workers: @@ -359,17 +358,18 @@ def test_no_more_workers_than_tasks(): def test_basic_no_loop(): - try: - with LocalCluster( - 0, scheduler_port=0, silence_logs=False, dashboard_address=None - ) as cluster: - with Client(cluster) as client: - cluster.adapt() - future = client.submit(lambda x: x + 1, 1) - assert future.result() == 2 - loop = cluster.loop - finally: - loop.add_callback(loop.stop) + with clean(threads=False): + try: + with LocalCluster( + 0, scheduler_port=0, silence_logs=False, dashboard_address=None + ) as cluster: + with Client(cluster) as client: + cluster.adapt() + future = client.submit(lambda x: x + 1, 1) + assert future.result() == 2 + loop = cluster.loop + finally: + loop.add_callback(loop.stop) @gen_test(timeout=None) @@ -408,25 +408,17 @@ def test_target_duration(): @gen_test(timeout=None) def test_worker_keys(): """ Ensure that redefining adapt with a lower maximum removes workers """ - cluster = yield LocalCluster( - 0, + cluster = yield SpecCluster( + workers={ + "a-1": {"cls": Worker}, + "a-2": {"cls": Worker}, + "b-1": {"cls": Worker}, + "b-2": {"cls": Worker}, + }, asynchronous=True, - processes=False, - scheduler_port=0, - silence_logs=False, - dashboard_address=None, ) try: - yield [ - cluster.start_worker(name="a-1"), - cluster.start_worker(name="a-2"), - cluster.start_worker(name="b-1"), - cluster.start_worker(name="b-2"), - ] - - while len(cluster.scheduler.workers) != 4: - yield gen.sleep(0.01) def key(ws): return ws.name.split("-")[0] diff --git a/distributed/deploy/tests/test_local.py b/distributed/deploy/tests/test_local.py index ed9e3bb2dbe..4498611d7e8 100644 --- a/distributed/deploy/tests/test_local.py +++ b/distributed/deploy/tests/test_local.py @@ -17,6 +17,7 @@ from distributed.deploy.local import LocalCluster, nprocesses_nthreads from distributed.metrics import time from distributed.utils_test import ( + clean, inc, gen_test, slowinc, @@ -46,7 +47,7 @@ def test_simple(loop): x = e.submit(inc, 1) x.result() assert x.key in c.scheduler.tasks - assert any(w.data == {x.key: 2} for w in c.workers) + assert any(w.data == {x.key: 2} for w in c.workers.values()) assert e.loop is c.loop @@ -87,10 +88,10 @@ def test_procs(): silence_logs=False, ) as c: assert len(c.workers) == 2 - assert all(isinstance(w, Worker) for w in c.workers) + assert all(isinstance(w, Worker) for w in c.workers.values()) with Client(c.scheduler.address) as e: - assert all(w.ncores == 3 for w in c.workers) - assert all(isinstance(w, Worker) for w in c.workers) + assert all(w.ncores == 3 for w in c.workers.values()) + assert all(isinstance(w, Worker) for w in c.workers.values()) repr(c) with LocalCluster( @@ -102,12 +103,12 @@ def test_procs(): silence_logs=False, ) as c: assert len(c.workers) == 2 - assert all(isinstance(w, Nanny) for w in c.workers) + assert all(isinstance(w, Nanny) for w in c.workers.values()) with Client(c.scheduler.address) as e: assert all(v == 3 for v in e.ncores().values()) - c.start_worker() - assert all(isinstance(w, Nanny) for w in c.workers) + c.scale(3) + assert all(isinstance(w, Nanny) for w in c.workers.values()) repr(c) @@ -171,7 +172,7 @@ def test_transports_tcp_port(): @pytest.mark.skipif("sys.version_info[0] == 2", reason="") class LocalTest(ClusterTest, unittest.TestCase): Cluster = partial(LocalCluster, silence_logs=False, dashboard_address=None) - kwargs = {"dashboard_address": None} + kwargs = {"dashboard_address": None, "processes": False} @pytest.mark.skipif("sys.version_info[0] == 2", reason="") @@ -208,12 +209,13 @@ def test_duplicate_clients(): for msg in info.list ) yield c1.close() + yield c2.close() def test_Client_kwargs(loop): with Client(loop=loop, processes=False, n_workers=2, silence_logs=False) as c: assert len(c.cluster.workers) == 2 - assert all(isinstance(w, Worker) for w in c.cluster.workers) + assert all(isinstance(w, Worker) for w in c.cluster.workers.values()) assert c.cluster.status == "closed" @@ -230,14 +232,14 @@ def test_defaults(): with LocalCluster( scheduler_port=0, silence_logs=False, dashboard_address=None ) as c: - assert sum(w.ncores for w in c.workers) == _ncores - assert all(isinstance(w, Nanny) for w in c.workers) + assert sum(w.ncores for w in c.workers.values()) == _ncores + assert all(isinstance(w, Nanny) for w in c.workers.values()) with LocalCluster( processes=False, scheduler_port=0, silence_logs=False, dashboard_address=None ) as c: - assert sum(w.ncores for w in c.workers) == _ncores - assert all(isinstance(w, Worker) for w in c.workers) + assert sum(w.ncores for w in c.workers.values()) == _ncores + assert all(isinstance(w, Worker) for w in c.workers.values()) assert len(c.workers) == 1 with LocalCluster( @@ -248,7 +250,7 @@ def test_defaults(): else: # n_workers not a divisor of _ncores => threads are overcommitted expected_total_threads = max(2, _ncores + 1) - assert sum(w.ncores for w in c.workers) == expected_total_threads + assert sum(w.ncores for w in c.workers.values()) == expected_total_threads with LocalCluster( threads_per_worker=_ncores * 2, @@ -264,7 +266,7 @@ def test_defaults(): silence_logs=False, dashboard_address=None, ) as c: - assert all(w.ncores == 1 for w in c.workers) + assert all(w.ncores == 1 for w in c.workers.values()) with LocalCluster( threads_per_worker=2, n_workers=3, @@ -273,18 +275,19 @@ def test_defaults(): dashboard_address=None, ) as c: assert len(c.workers) == 3 - assert all(w.ncores == 2 for w in c.workers) + assert all(w.ncores == 2 for w in c.workers.values()) def test_worker_params(): with LocalCluster( + processes=False, n_workers=2, scheduler_port=0, silence_logs=False, dashboard_address=None, memory_limit=500, ) as c: - assert [w.memory_limit for w in c.workers] == [500] * 2 + assert [w.memory_limit for w in c.workers.values()] == [500] * 2 def test_memory_limit_none(): @@ -302,24 +305,28 @@ def test_memory_limit_none(): def test_cleanup(): - c = LocalCluster(2, scheduler_port=0, silence_logs=False, dashboard_address=None) - port = c.scheduler.port - c.close() - c2 = LocalCluster( - 2, scheduler_port=port, silence_logs=False, dashboard_address=None - ) - c.close() + with clean(threads=False): + c = LocalCluster( + 2, scheduler_port=0, silence_logs=False, dashboard_address=None + ) + port = c.scheduler.port + c.close() + c2 = LocalCluster( + 2, scheduler_port=port, silence_logs=False, dashboard_address=None + ) + c2.close() def test_repeated(): - with LocalCluster( - 0, scheduler_port=8448, silence_logs=False, dashboard_address=None - ) as c: - pass - with LocalCluster( - 0, scheduler_port=8448, silence_logs=False, dashboard_address=None - ) as c: - pass + with clean(threads=False): + with LocalCluster( + 0, scheduler_port=8448, silence_logs=False, dashboard_address=None + ) as c: + pass + with LocalCluster( + 0, scheduler_port=8448, silence_logs=False, dashboard_address=None + ) as c: + pass @pytest.mark.parametrize("processes", [True, False]) @@ -373,15 +380,15 @@ def test_scale_up_and_down(): assert not cluster.workers - yield cluster.scale_up(2) + cluster.scale(2) + yield cluster assert len(cluster.workers) == 2 assert len(cluster.scheduler.ncores) == 2 - addr = cluster.workers[0].address - yield cluster.scale_down([addr]) + cluster.scale(1) + yield cluster assert len(cluster.workers) == 1 - assert addr not in cluster.scheduler.ncores yield c.close() yield cluster.close() @@ -437,7 +444,7 @@ def test_memory(loop, n_workers): dashboard_address=None, loop=loop, ) as cluster: - assert sum(w.memory_limit for w in cluster.workers) <= TOTAL_MEMORY + assert sum(w.memory_limit for w in cluster.workers.values()) <= TOTAL_MEMORY @pytest.mark.parametrize("n_workers", [None, 3]) @@ -486,11 +493,13 @@ def test_bokeh_kwargs(loop): def test_io_loop_periodic_callbacks(loop): - with LocalCluster(loop=loop, silence_logs=False) as cluster: + with LocalCluster( + loop=loop, port=0, dashboard_address=None, silence_logs=False + ) as cluster: assert cluster.scheduler.loop is loop for pc in cluster.scheduler.periodic_callbacks.values(): assert pc.io_loop is loop - for worker in cluster.workers: + for worker in cluster.workers.values(): for pc in worker.periodic_callbacks.values(): assert pc.io_loop is loop @@ -772,7 +781,7 @@ def test_worker_class_worker(loop): scheduler_port=0, dashboard_address=None, ) as cluster: - assert all(isinstance(w, MyWorker) for w in cluster.workers) + assert all(isinstance(w, MyWorker) for w in cluster.workers.values()) def test_worker_class_nanny(loop): @@ -786,8 +795,37 @@ class MyNanny(Nanny): scheduler_port=0, dashboard_address=None, ) as cluster: - assert all(isinstance(w, MyNanny) for w in cluster.workers) + assert all(isinstance(w, MyNanny) for w in cluster.workers.values()) + + +@pytest.mark.asyncio +async def test_worker_class_nanny_async(): + class MyNanny(Nanny): + pass + + async with LocalCluster( + n_workers=2, + worker_class=MyNanny, + scheduler_port=0, + dashboard_address=None, + asynchronous=True, + ) as cluster: + assert all(isinstance(w, MyNanny) for w in cluster.workers.values()) if sys.version_info >= (3, 5): from distributed.deploy.tests.py3_test_deploy import * # noqa F401 + + +def test_starts_up_sync(loop): + cluster = LocalCluster( + n_workers=2, + loop=loop, + processes=False, + scheduler_port=0, + dashboard_address=None, + ) + try: + assert len(cluster.scheduler.workers) == 2 + finally: + cluster.close() diff --git a/distributed/deploy/tests/test_spec_cluster.py b/distributed/deploy/tests/test_spec_cluster.py new file mode 100644 index 00000000000..cfc12427274 --- /dev/null +++ b/distributed/deploy/tests/test_spec_cluster.py @@ -0,0 +1,115 @@ +from dask.distributed import SpecCluster, Worker, Client, Scheduler +from distributed.utils_test import loop # noqa: F401 +import pytest + + +class MyWorker(Worker): + pass + + +class BrokenWorker(Worker): + def __await__(self): + async def _(): + raise Exception("Worker Broken") + + return _().__await__() + + +worker_spec = { + 0: {"cls": Worker, "options": {"ncores": 1}}, + 1: {"cls": Worker, "options": {"ncores": 2}}, + "my-worker": {"cls": MyWorker, "options": {"ncores": 3}}, +} +scheduler = {"cls": Scheduler, "options": {"port": 0}} + + +@pytest.mark.asyncio +async def test_specification(): + async with SpecCluster( + workers=worker_spec, scheduler=scheduler, asynchronous=True + ) as cluster: + assert cluster.worker_spec is worker_spec + + assert len(cluster.workers) == 3 + assert set(cluster.workers) == set(worker_spec) + assert isinstance(cluster.workers[0], Worker) + assert isinstance(cluster.workers[1], Worker) + assert isinstance(cluster.workers["my-worker"], MyWorker) + + assert cluster.workers[0].ncores == 1 + assert cluster.workers[1].ncores == 2 + assert cluster.workers["my-worker"].ncores == 3 + + async with Client(cluster, asynchronous=True) as client: + result = await client.submit(lambda x: x + 1, 10) + assert result == 11 + + for name in cluster.workers: + assert cluster.workers[name].name == name + + +def test_spec_sync(loop): + worker_spec = { + 0: {"cls": Worker, "options": {"ncores": 1}}, + 1: {"cls": Worker, "options": {"ncores": 2}}, + "my-worker": {"cls": MyWorker, "options": {"ncores": 3}}, + } + with SpecCluster(workers=worker_spec, scheduler=scheduler, loop=loop) as cluster: + assert cluster.worker_spec is worker_spec + + assert len(cluster.workers) == 3 + assert set(cluster.workers) == set(worker_spec) + assert isinstance(cluster.workers[0], Worker) + assert isinstance(cluster.workers[1], Worker) + assert isinstance(cluster.workers["my-worker"], MyWorker) + + assert cluster.workers[0].ncores == 1 + assert cluster.workers[1].ncores == 2 + assert cluster.workers["my-worker"].ncores == 3 + + with Client(cluster, loop=loop) as client: + assert cluster.loop is cluster.scheduler.loop + assert cluster.loop is client.loop + result = client.submit(lambda x: x + 1, 10).result() + assert result == 11 + + +def test_loop_started(): + cluster = SpecCluster(worker_spec) + + +@pytest.mark.asyncio +async def test_scale(): + worker = {"cls": Worker, "options": {"ncores": 1}} + async with SpecCluster( + asynchronous=True, scheduler=scheduler, worker=worker + ) as cluster: + assert not cluster.workers + assert not cluster.worker_spec + + # Scale up + cluster.scale(2) + assert not cluster.workers + assert cluster.worker_spec + + await cluster + assert len(cluster.workers) == 2 + + # Scale down + cluster.scale(1) + assert len(cluster.workers) == 2 + + await cluster + assert len(cluster.workers) == 1 + + +@pytest.mark.asyncio +async def test_broken_worker(): + with pytest.raises(Exception) as info: + async with SpecCluster( + asynchronous=True, + workers={"good": {"cls": Worker}, "bad": {"cls": BrokenWorker}}, + ) as cluster: + pass + + assert "Broken" in str(info.value) diff --git a/distributed/deploy/utils_test.py b/distributed/deploy/utils_test.py index 9bc8cacccad..9da8d64cd50 100644 --- a/distributed/deploy/utils_test.py +++ b/distributed/deploy/utils_test.py @@ -1,5 +1,7 @@ from ..client import Client +import pytest + class ClusterTest(object): Cluster = None @@ -13,26 +15,15 @@ def tearDown(self): self.client.close() self.cluster.close() + @pytest.mark.xfail() def test_cores(self): + info = self.client.scheduler_info() assert len(self.client.ncores()) == 2 def test_submit(self): future = self.client.submit(lambda x: x + 1, 1) assert future.result() == 2 - def test_start_worker(self): - a = self.client.ncores() - w = self.cluster.start_worker(ncores=3) - b = self.client.ncores() - - assert len(b) == 1 + len(a) - assert any(v == 3 for v in b.values()) - - self.cluster.stop_worker(w) - - c = self.client.ncores() - assert c == a - def test_context_manager(self): with self.Cluster(**self.kwargs) as c: with Client(c) as e: diff --git a/distributed/nanny.py b/distributed/nanny.py index 842ec765d7f..a27f713ea6b 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -18,7 +18,7 @@ from .comm import get_address_host, get_local_address_for, unparse_host_port from .comm.addressing import address_from_user_args -from .core import rpc, RPCClosed, CommClosedError, coerce_to_address +from .core import RPCClosed, CommClosedError, coerce_to_address from .metrics import time from .node import ServerNode from .process import AsyncProcess @@ -30,6 +30,7 @@ silence_logging, json_load_robust, PeriodicCallback, + parse_timedelta, ) from .worker import _ncores, run, parse_memory_limit, Worker @@ -78,6 +79,11 @@ def __init__( protocol=None, **worker_kwargs ): + self.loop = loop or IOLoop.current() + self.security = security or Security() + assert isinstance(self.security, Security) + self.connection_args = self.security.get_connection_args("worker") + self.listen_args = self.security.get_listen_args("worker") if scheduler_file: cfg = json_load_robust(scheduler_file) @@ -88,12 +94,13 @@ def __init__( self.scheduler_addr = coerce_to_address(scheduler_ip) else: self.scheduler_addr = coerce_to_address((scheduler_ip, scheduler_port)) + self._given_worker_port = worker_port self.ncores = ncores or _ncores self.reconnect = reconnect self.validate = validate self.resources = resources - self.death_timeout = death_timeout + self.death_timeout = parse_timedelta(death_timeout) self.preload = preload self.preload_argv = preload_argv self.Worker = Worker if worker_class is None else worker_class @@ -105,15 +112,8 @@ def __init__( "distributed.worker.memory.terminate" ) - self.security = security or Security() - assert isinstance(self.security, Security) - self.connection_args = self.security.get_connection_args("worker") - self.listen_args = self.security.get_listen_args("worker") - self.local_dir = local_dir - self.loop = loop or IOLoop.current() - self.scheduler = rpc(self.scheduler_addr, connection_args=self.connection_args) self.services = services self.name = name self.quiet = quiet @@ -135,9 +135,11 @@ def __init__( } super(Nanny, self).__init__( - handlers, io_loop=self.loop, connection_args=self.connection_args + handlers=handlers, io_loop=self.loop, connection_args=self.connection_args ) + self.scheduler = self.rpc(self.scheduler_addr) + if self.memory_limit: pc = PeriodicCallback(self.memory_monitor, 100, io_loop=self.loop) self.periodic_callbacks["memory"] = pc @@ -240,7 +242,6 @@ def kill(self, comm=None, timeout=2): deadline = self.loop.time() + timeout yield self.process.kill(timeout=0.8 * (deadline - self.loop.time())) - yield self._unregister(deadline - self.loop.time()) @gen.coroutine def instantiate(self, comm=None): @@ -376,8 +377,12 @@ def close(self, comm=None, timeout=5, report=None): """ Close the worker process, stop all comms. """ - if self.status in ("closing", "closed"): + while self.status == "closing": + yield gen.sleep(0.01) + + if self.status == "closed": raise gen.Return("OK") + self.status = "closing" logger.info("Closing Nanny at %r", self.address) self.stop() @@ -388,9 +393,10 @@ def close(self, comm=None, timeout=5, report=None): pass self.process = None self.rpc.close() - self.scheduler.close_rpc() self.status = "closed" - raise gen.Return("OK") + if comm: + yield comm.write("OK") + yield ServerNode.close(self) class WorkerProcess(object): diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 57d768f95f6..991ff1a2108 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1365,7 +1365,10 @@ def heartbeat_worker( self.host_info[host]["last-seen"] = local_now frac = 1 / 20 / len(self.workers) - self.bandwidth = self.bandwidth * (1 - frac) + metrics["bandwidth"] * frac + try: + self.bandwidth = self.bandwidth * (1 - frac) + metrics["bandwidth"] * frac + except KeyError: + pass ws = self.workers.get(address) if not ws: @@ -1990,7 +1993,10 @@ def cancel_key(self, key, client, retries=5, force=False): """ Cancel a particular key and all dependents """ # TODO: this should be converted to use the transition mechanism ts = self.tasks.get(key) - cs = self.clients[client] + try: + cs = self.clients[client] + except KeyError: + return if ts is None or not ts.who_wants: # no key yet, lets try again in a moment if retries: self.loop.add_future( @@ -3085,7 +3091,7 @@ def retire_workers( except KeyError: # keys left during replicate pass - workers = {self.workers[w] for w in workers} + workers = {self.workers[w] for w in workers if w in self.workers} if len(workers) > 0: # Keys orphaned by retiring those workers keys = set.union(*[w.has_what for w in workers]) diff --git a/distributed/tests/test_as_completed.py b/distributed/tests/test_as_completed.py index 8e66b58dd4e..a584025ad03 100644 --- a/distributed/tests/test_as_completed.py +++ b/distributed/tests/test_as_completed.py @@ -1,4 +1,4 @@ -from concurrent.futures._base import CancelledError +from concurrent.futures import CancelledError from operator import add import random from time import sleep @@ -226,7 +226,7 @@ def test_as_completed_with_results_no_raise(client): assert y.status == "cancelled" assert z.status == "finished" - assert isinstance(dd[y][0], CancelledError) + assert isinstance(dd[y][0], CancelledError) or dd[y][0] == 6 assert isinstance(dd[x][0][1], RuntimeError) assert dd[z][0] == 2 diff --git a/distributed/tests/test_asyncprocess.py b/distributed/tests/test_asyncprocess.py index 1e7a5d2804f..3cb3eee14d4 100644 --- a/distributed/tests/test_asyncprocess.py +++ b/distributed/tests/test_asyncprocess.py @@ -49,6 +49,7 @@ def threads_info(q): q.put(threading.current_thread().name) +@pytest.mark.xfail(reason="Intermittent failure") @nodebug @gen_test() def test_simple(): diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 4cd196fa2f4..28c2f939eb7 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -22,7 +22,6 @@ import pytest from toolz import identity, isdistinct, concat, pluck, valmap, partial, first, merge from tornado import gen -from tornado.ioloop import IOLoop import dask from dask import delayed @@ -3321,7 +3320,12 @@ def test_get_foo_lost_keys(c, s, u, v, w): @pytest.mark.slow -@gen_cluster(client=True, Worker=Nanny, check_new_threads=False) +@gen_cluster( + client=True, + Worker=Nanny, + check_new_threads=False, + worker_kwargs={"death_timeout": "500ms"}, +) def test_bad_tasks_fail(c, s, a, b): f = c.submit(sys.exit, 1) with pytest.raises(KilledWorker) as info: @@ -3575,24 +3579,29 @@ def test_reconnect_timeout(c, s): @pytest.mark.skipif( sys.version_info[0] == 2, reason="Semaphore.acquire doesn't support timeout option" ) -@pytest.mark.xfail(reason="TODO: intermittent failures") +# @pytest.mark.xfail(reason="TODO: intermittent failures") @pytest.mark.parametrize("worker,count,repeat", [(Worker, 100, 5), (Nanny, 10, 20)]) def test_open_close_many_workers(loop, worker, count, repeat): psutil = pytest.importorskip("psutil") proc = psutil.Process() - with cluster(nworkers=0, active_rpc_timeout=20) as (s, _): + with cluster(nworkers=0, active_rpc_timeout=2) as (s, _): gc.collect() before = proc.num_fds() done = Semaphore(0) running = weakref.WeakKeyDictionary() + workers = set() + status = True @gen.coroutine def start_worker(sleep, duration, repeat=1): for i in range(repeat): yield gen.sleep(sleep) + if not status: + return w = worker(s["address"], loop=loop) running[w] = None + workers.add(w) yield w addr = w.worker_address running[w] = addr @@ -3621,6 +3630,12 @@ def start_worker(sleep, duration, repeat=1): sleep(0.2) assert time() < start + 10 + status = False + + [c.sync(w.close) for w in list(workers)] + for w in workers: + assert w.status == "closed" + start = time() while proc.num_fds() > before: print("fds:", before, proc.num_fds()) @@ -4232,23 +4247,23 @@ def test_scatter_dict_workers(c, s, a, b): @pytest.mark.slow @gen_test() def test_client_timeout(): - loop = IOLoop.current() c = Client("127.0.0.1:57484", asynchronous=True) - s = Scheduler(loop=loop) + s = Scheduler(loop=c.loop, port=57484) yield gen.sleep(4) try: - s.start(("127.0.0.1", 57484)) + yield s except EnvironmentError: # port in use + yield c.close() return start = time() - while not c.scheduler_comm: - yield gen.sleep(0.1) + yield c + try: assert time() < start + 2 - - yield c.close() - yield s.close() + finally: + yield c.close() + yield s.close() @gen_cluster(client=True) diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index 4c18b5242a3..be0a05afc20 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -110,10 +110,6 @@ def test_nanny_process_failure(c, s): s.stop() -def test_nanny_no_port(): - _ = str(Nanny("127.0.0.1", 8786)) - - @gen_cluster(ncores=[]) def test_run(s): pytest.importorskip("psutil") @@ -319,12 +315,13 @@ def test_scheduler_address_config(c, s): @pytest.mark.slow -@gen_test() +@gen_test(timeout=20) def test_wait_for_scheduler(): with captured_logger("distributed") as log: w = Nanny("127.0.0.1:44737") - w._start() + w.start() yield gen.sleep(6) + yield w.close() log = log.getvalue() assert "error" not in log.lower(), log diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 9f61e5e710e..f5ce276b8d2 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -822,11 +822,12 @@ def test_file_descriptors(c, s): assert num_fds_6 < num_fds_5 + N yield [n.close() for n in nannies] + yield c.close() assert not s.rpc.open - assert not any( - occ for addr, occ in c.rpc.occupied.items() if occ != s.address - ), list(c.rpc._created) + for addr, occ in c.rpc.occupied.items(): + for comm in occ: + assert comm.closed() or comm.peer_address != s.address, comm assert not s.stream_comms start = time() @@ -1141,7 +1142,8 @@ def test_scheduler_file(): assert data["address"] == s.address c = yield Client(scheduler_file=fn, loop=s.loop, asynchronous=True) - yield s.close() + yield c.close() + yield s.close() @pytest.mark.xfail(reason="") @@ -1555,7 +1557,7 @@ def test_close_workers(s, a, b): ) @gen_test() def test_host_address(): - s = yield Scheduler(host="127.0.0.2") + s = yield Scheduler(host="127.0.0.2", port=0) assert "127.0.0.2" in s.address yield s.close() @@ -1563,10 +1565,10 @@ def test_host_address(): @gen_test() def test_dashboard_address(): pytest.importorskip("bokeh") - s = yield Scheduler(dashboard_address="127.0.0.1:8901") + s = yield Scheduler(dashboard_address="127.0.0.1:8901", port=0) assert s.services["bokeh"].port == 8901 yield s.close() - s = yield Scheduler(dashboard_address="127.0.0.1") + s = yield Scheduler(dashboard_address="127.0.0.1", port=0) assert s.services["bokeh"].port yield s.close() diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 9fc967eef5a..bf4e483f441 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1,6 +1,7 @@ from __future__ import print_function, division, absolute_import from concurrent.futures import ThreadPoolExecutor +from datetime import timedelta import logging from numbers import Number from operator import add @@ -312,15 +313,20 @@ def test_worker_with_port_zero(): @pytest.mark.slow -def test_worker_waits_for_center_to_come_up(loop): +def test_worker_waits_for_scheduler(loop): @gen.coroutine def f(): - w = yield Worker("127.0.0.1", 8007) + w = Worker("127.0.0.1", 8007) + try: + yield gen.with_timeout(timedelta(seconds=3), w) + except TimeoutError: + pass + else: + assert False + assert w.status not in ("closed", "running") + yield w.close(timeout=0.1) - try: - loop.run_sync(f, timeout=4) - except TimeoutError: - pass + loop.run_sync(f) @gen_cluster(client=True, ncores=[("127.0.0.1", 1)]) @@ -355,12 +361,13 @@ def test_gather(s, a, b): assert a.data["y"] == b.data["y"] -def test_io_loop(loop): - s = Scheduler(loop=loop) - s.listen(0) - assert s.io_loop is loop - w = Worker(s.address, loop=loop) - assert w.io_loop is loop +@pytest.mark.asyncio +async def test_io_loop(): + s = await Scheduler(port=0) + w = await Worker(s.address, loop=s.loop) + assert w.io_loop is s.loop + await s.close() + await w.close() @gen_cluster(client=True, ncores=[]) diff --git a/distributed/tests/test_worker_client.py b/distributed/tests/test_worker_client.py index 2d4632b0b54..9c4616e9d26 100644 --- a/distributed/tests/test_worker_client.py +++ b/distributed/tests/test_worker_client.py @@ -202,7 +202,9 @@ def f(x): b = db.from_sequence([1, 2]) b2 = b.map(f) - with Client(loop=loop, processes=False, set_as_default=True) as c: + with Client( + loop=loop, processes=False, set_as_default=True, dashboard_address=None + ) as c: assert dask.base.get_scheduler() == c.get for i in range(2): b2.compute() diff --git a/distributed/utils.py b/distributed/utils.py index 466a96fdfd5..55508a4c574 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -1242,6 +1242,8 @@ def parse_timedelta(s, default="seconds"): >>> parse_timedelta(timedelta(seconds=3)) # also supports timedeltas 3 """ + if s is None: + return None if isinstance(s, timedelta): return s.total_seconds() if isinstance(s, Number): diff --git a/distributed/utils_test.py b/distributed/utils_test.py index c44f4177472..d61046f2a48 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -116,40 +116,35 @@ def cleanup_global_workers(): @pytest.fixture def loop(): - Worker._instances.clear() - _global_clients.clear() - with pristine_loop() as loop: - # Monkey-patch IOLoop.start to wait for loop stop - orig_start = loop.start - is_stopped = threading.Event() - is_stopped.set() + with check_instances(): + with pristine_loop() as loop: + # Monkey-patch IOLoop.start to wait for loop stop + orig_start = loop.start + is_stopped = threading.Event() + is_stopped.set() - def start(): - is_stopped.clear() - try: - orig_start() - finally: - is_stopped.set() + def start(): + is_stopped.clear() + try: + orig_start() + finally: + is_stopped.set() - loop.start = start + loop.start = start - yield loop + yield loop - # Stop the loop in case it's still running - try: - sync(loop, cleanup_global_workers, callback_timeout=0.500) - loop.add_callback(loop.stop) - except RuntimeError as e: - if not re.match("IOLoop is clos(ed|ing)", str(e)): - raise - except gen.TimeoutError: - pass - else: - is_stopped.wait() - Worker._instances.clear() - - _cleanup_dangling() - _global_clients.clear() + # Stop the loop in case it's still running + try: + sync(loop, cleanup_global_workers, callback_timeout=0.500) + loop.add_callback(loop.stop) + except RuntimeError as e: + if not re.match("IOLoop is clos(ed|ing)", str(e)): + raise + except gen.TimeoutError: + pass + else: + is_stopped.wait() @pytest.fixture @@ -464,13 +459,13 @@ def background_read(): raise gen.Return(msg) -def run_scheduler(q, nputs, **kwargs): +def run_scheduler(q, nputs, port=0, **kwargs): from distributed import Scheduler # On Python 2.7 and Unix, fork() is used to spawn child processes, # so avoid inheriting the parent's IO loop. with pristine_loop() as loop: - scheduler = Scheduler(validate=True, host="127.0.0.1", **kwargs) + scheduler = Scheduler(validate=True, host="127.0.0.1", port=port, **kwargs) done = scheduler.start() for i in range(nputs): @@ -735,9 +730,9 @@ def cluster( client.close() start = time() - while list(ws): - sleep(0.01) - assert time() < start + 1, "Workers still around after one second" + while len(ws): + sleep(0.1) + assert time() < start + 3, ("Workers still around after two seconds", list(ws)) @gen.coroutine @@ -769,15 +764,12 @@ def test_foo(): def _(func): def test_func(): - with pristine_loop() as loop: + with clean() as loop: if iscoroutinefunction(func): cor = func else: cor = gen.coroutine(func) - try: - loop.run_sync(cor, timeout=timeout) - finally: - loop.stop() + loop.run_sync(cor, timeout=timeout) return test_func @@ -798,7 +790,9 @@ def start_cluster( scheduler_kwargs={}, worker_kwargs={}, ): - s = Scheduler(loop=loop, validate=True, security=security, **scheduler_kwargs) + s = Scheduler( + loop=loop, validate=True, security=security, port=0, **scheduler_kwargs + ) done = s.start(scheduler_addr) workers = [ Worker( @@ -1483,6 +1477,9 @@ def check_instances(): Client._instances.clear() Worker._instances.clear() Scheduler._instances.clear() + # assert all(n.status == "closed" for n in Nanny._instances), { + # n: n.status for n in Nanny._instances + # } Nanny._instances.clear() _global_clients.clear() Comm._instances.clear() @@ -1513,6 +1510,11 @@ def check_instances(): print("Unclosed Comms", L) # raise ValueError("Unclosed Comms", L) + assert all(n.status == "closed" or n.status == "init" for n in Nanny._instances), { + n: n.status for n in Nanny._instances + } + + Nanny._instances.clear() DequeHandler.clear_all_instances() diff --git a/distributed/worker.py b/distributed/worker.py index 3fcc477bf48..667bd83490f 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -436,7 +436,7 @@ def __init__( self.ncores = ncores or _ncores self.total_resources = resources or {} self.available_resources = (resources or {}).copy() - self.death_timeout = death_timeout + self.death_timeout = parse_timedelta(death_timeout) self.preload = preload if self.preload is None: self.preload = dask.config.get("distributed.worker.preload") @@ -933,7 +933,8 @@ def _start(self, addr_or_port=0): if "://" in listen_host: protocol, listen_host = listen_host.split("://") - self.name = self.name or self.address + if self.name is None: + self.name = self.address preload_modules( self.preload, parameter=self, @@ -976,7 +977,15 @@ def _start(self, addr_or_port=0): raise gen.Return(self) def __await__(self): - return self._start().__await__() + if self.status is not None: + + @gen.coroutine # idempotent + def _(): + raise gen.Return(self) + + return _().__await__() + else: + return self._start().__await__() def start(self, port=0): self.loop.add_callback(self._start, port) From 6339d81e8de97b551c8cc908308c19ab89037df2 Mon Sep 17 00:00:00 2001 From: Matt Nicolls <2540582+nicolls1@users.noreply.github.com> Date: Fri, 24 May 2019 14:36:05 -0500 Subject: [PATCH 0305/1550] Add SchedulerPlugin TaskState example (#2622) --- docs/source/plugins.rst | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/docs/source/plugins.rst b/docs/source/plugins.rst index edf64362c56..b5f52f8843e 100644 --- a/docs/source/plugins.rst +++ b/docs/source/plugins.rst @@ -49,3 +49,27 @@ for more information on RabbitMQ and how to consume the messages. scheduler.add_plugin(plugin) Run with: ``dask-scheduler --preload `` + +Accessing Full Task State +------------------------- + +If you would like to access the full :class:`distributed.scheduler.TaskState` +stored in the scheduler you can do this by passing and storing a reference to +the scheduler as so: + +.. code-block:: python + + from distributed.diagnostics.plugin import SchedulerPlugin + + class MyPlugin(SchedulerPlugin): + def __init__(self, scheduler): + self.scheduler = scheduler + + def transition(self, key, start, finish, *args, **kwargs): + # Get full TaskState + ts = self.scheduler.tasks[key] + + @click.command() + def dask_setup(scheduler): + plugin = MyPlugin(scheduler) + scheduler.add_plugin(plugin) From a818711f97fdf501823c35246b16f24fddd4035a Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 27 May 2019 14:47:10 -0500 Subject: [PATCH 0306/1550] Close clusters at exit (#2730) --- distributed/deploy/spec.py | 15 +++++++++++++-- distributed/deploy/tests/test_spec_cluster.py | 11 +++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index 9a4385e5054..ad0aea25f6c 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -1,4 +1,5 @@ import asyncio +import atexit import weakref from tornado import gen @@ -97,6 +98,8 @@ class does handle all of the logic around asynchronously cleanly setting up specifications into the same dictionary. """ + _instances = weakref.WeakSet() + def __init__( self, workers=None, @@ -133,6 +136,7 @@ def __init__( loop=self.loop, **self.scheduler_spec["options"] ) self.status = "created" + self._instances.add(self) self._correct_state_waiting = None if not self.asynchronous: @@ -248,9 +252,9 @@ async def _close(self): self.status = "closed" - def close(self): + def close(self, timeout=None): with ignoring(RuntimeError): # loop closed during process shutdown - return self.sync(self._close) + return self.sync(self._close, callback_timeout=timeout) def __del__(self): if self.status != "closed": @@ -295,3 +299,10 @@ def __repr__(self): self.scheduler_address, len(self.workers), ) + + +@atexit.register +def close_clusters(): + for cluster in list(SpecCluster._instances): + with ignoring(gen.TimeoutError): + cluster.close(timeout=10) diff --git a/distributed/deploy/tests/test_spec_cluster.py b/distributed/deploy/tests/test_spec_cluster.py index cfc12427274..ac5706afe1c 100644 --- a/distributed/deploy/tests/test_spec_cluster.py +++ b/distributed/deploy/tests/test_spec_cluster.py @@ -1,4 +1,5 @@ from dask.distributed import SpecCluster, Worker, Client, Scheduler +from distributed.deploy.spec import close_clusters from distributed.utils_test import loop # noqa: F401 import pytest @@ -113,3 +114,13 @@ async def test_broken_worker(): pass assert "Broken" in str(info.value) + + +@pytest.mark.slow +def test_spec_close_clusters(loop): + workers = {0: {"cls": Worker}} + scheduler = {"cls": Scheduler, "options": {"port": 0}} + cluster = SpecCluster(workers=workers, scheduler=scheduler, loop=loop) + assert cluster in SpecCluster._instances + close_clusters() + assert cluster.status == "closed" From d202e6253ed8ddc7919d0d4f128d88954e9859b8 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 28 May 2019 12:25:44 -0500 Subject: [PATCH 0307/1550] Move bokeh module to dashboard (#2724) --- distributed/bokeh.py | 1 + distributed/bokeh/__init__.py | 37 ---------- distributed/cli/dask_mpi.py | 8 +-- distributed/cli/dask_scheduler.py | 23 +++--- distributed/cli/dask_worker.py | 18 ++--- distributed/cli/tests/test_dask_scheduler.py | 46 +++++------- distributed/cli/tests/test_dask_worker.py | 56 +++++++++------ distributed/cli/tests/test_tls_cli.py | 17 ++--- distributed/client.py | 6 +- distributed/dashboard/__init__.py | 2 + .../{bokeh => dashboard}/components.py | 0 distributed/{bokeh => dashboard}/core.py | 0 .../{bokeh => dashboard}/export_tool.coffee | 0 .../{bokeh => dashboard}/export_tool.js | 0 .../{bokeh => dashboard}/export_tool.py | 0 distributed/{bokeh => dashboard}/proxy.py | 2 +- distributed/{bokeh => dashboard}/scheduler.py | 4 +- .../{bokeh => dashboard}/scheduler_html.py | 2 +- .../{bokeh => dashboard}/static/css/base.css | 0 .../static/css/status.css | 0 .../static/css/system.css | 0 .../static/images/dask-logo.svg | 0 .../static/images/fa-bars.svg | 0 .../{bokeh => dashboard}/templates/base.html | 0 .../templates/call-stack.html | 0 .../templates/json-index.html | 0 .../{bokeh => dashboard}/templates/logs.html | 0 .../{bokeh => dashboard}/templates/main.html | 0 .../templates/simple.html | 0 .../templates/status.html | 0 .../templates/system.html | 0 .../{bokeh => dashboard}/templates/task.html | 0 .../templates/worker-table.html | 0 .../templates/worker.html | 0 .../templates/workers.html | 0 distributed/dashboard/tests/test_bokeh.py | 5 ++ .../tests/test_components.py | 4 +- .../tests/test_scheduler_bokeh.py | 30 ++++---- .../tests/test_scheduler_bokeh_html.py | 29 ++++---- .../tests/test_worker_bokeh.py | 24 ++++--- .../tests/test_worker_bokeh_html.py | 10 +-- distributed/{bokeh => dashboard}/theme.yaml | 0 distributed/{bokeh => dashboard}/utils.py | 0 distributed/{bokeh => dashboard}/worker.py | 0 .../{bokeh => dashboard}/worker_html.py | 0 distributed/deploy/cluster.py | 4 +- distributed/deploy/spec.py | 4 +- distributed/deploy/tests/test_local.py | 10 +-- .../diagnostics/tests/test_eventstream.py | 8 ++- distributed/scheduler.py | 6 +- distributed/tests/test_client.py | 11 ++- distributed/tests/test_core.py | 5 +- distributed/tests/test_scheduler.py | 14 ++-- distributed/tests/test_worker.py | 22 +++--- distributed/worker.py | 6 +- setup.py | 70 ++++++++++--------- 56 files changed, 236 insertions(+), 248 deletions(-) create mode 100644 distributed/bokeh.py delete mode 100644 distributed/bokeh/__init__.py create mode 100644 distributed/dashboard/__init__.py rename distributed/{bokeh => dashboard}/components.py (100%) rename distributed/{bokeh => dashboard}/core.py (100%) rename distributed/{bokeh => dashboard}/export_tool.coffee (100%) rename distributed/{bokeh => dashboard}/export_tool.js (100%) rename distributed/{bokeh => dashboard}/export_tool.py (100%) rename distributed/{bokeh => dashboard}/proxy.py (98%) rename distributed/{bokeh => dashboard}/scheduler.py (99%) rename distributed/{bokeh => dashboard}/scheduler_html.py (99%) rename distributed/{bokeh => dashboard}/static/css/base.css (100%) rename distributed/{bokeh => dashboard}/static/css/status.css (100%) rename distributed/{bokeh => dashboard}/static/css/system.css (100%) rename distributed/{bokeh => dashboard}/static/images/dask-logo.svg (100%) rename distributed/{bokeh => dashboard}/static/images/fa-bars.svg (100%) rename distributed/{bokeh => dashboard}/templates/base.html (100%) rename distributed/{bokeh => dashboard}/templates/call-stack.html (100%) rename distributed/{bokeh => dashboard}/templates/json-index.html (100%) rename distributed/{bokeh => dashboard}/templates/logs.html (100%) rename distributed/{bokeh => dashboard}/templates/main.html (100%) rename distributed/{bokeh => dashboard}/templates/simple.html (100%) rename distributed/{bokeh => dashboard}/templates/status.html (100%) rename distributed/{bokeh => dashboard}/templates/system.html (100%) rename distributed/{bokeh => dashboard}/templates/task.html (100%) rename distributed/{bokeh => dashboard}/templates/worker-table.html (100%) rename distributed/{bokeh => dashboard}/templates/worker.html (100%) rename distributed/{bokeh => dashboard}/templates/workers.html (100%) create mode 100644 distributed/dashboard/tests/test_bokeh.py rename distributed/{bokeh => dashboard}/tests/test_components.py (92%) rename distributed/{bokeh => dashboard}/tests/test_scheduler_bokeh.py (95%) rename distributed/{bokeh => dashboard}/tests/test_scheduler_bokeh_html.py (79%) rename distributed/{bokeh => dashboard}/tests/test_worker_bokeh.py (79%) rename distributed/{bokeh => dashboard}/tests/test_worker_bokeh_html.py (75%) rename distributed/{bokeh => dashboard}/theme.yaml (100%) rename distributed/{bokeh => dashboard}/utils.py (100%) rename distributed/{bokeh => dashboard}/worker.py (100%) rename distributed/{bokeh => dashboard}/worker_html.py (100%) diff --git a/distributed/bokeh.py b/distributed/bokeh.py new file mode 100644 index 00000000000..e27bdffa33e --- /dev/null +++ b/distributed/bokeh.py @@ -0,0 +1 @@ +raise ImportError("The distributed.bokeh module has moved to distributed.dashboard") diff --git a/distributed/bokeh/__init__.py b/distributed/bokeh/__init__.py deleted file mode 100644 index 24e082fa8e0..00000000000 --- a/distributed/bokeh/__init__.py +++ /dev/null @@ -1,37 +0,0 @@ -from collections import deque - -from ..metrics import time - -n = 60 -m = 100000 - -messages = { - "workers": { - "interval": 1000, - "deque": deque(maxlen=n), - "times": deque(maxlen=n), - "index": deque(maxlen=n), - "plot-data": { - "time": deque(maxlen=n), - "cpu": deque(maxlen=n), - "memory_percent": deque(maxlen=n), - "network-send": deque(maxlen=n), - "network-recv": deque(maxlen=n), - }, - }, - "tasks": {"interval": 150, "deque": deque(maxlen=100), "times": deque(maxlen=100)}, - "progress": {}, - "processing": {"processing": {}, "memory": 0, "waiting": 0}, - "task-events": { - "interval": 200, - "deque": deque(maxlen=m), - "times": deque(maxlen=m), - "index": deque(maxlen=m), - "rectangles": { - name: deque(maxlen=m) - for name in "start duration key name color worker worker_thread y alpha".split() - }, - "workers": dict(), - "last_seen": [time()], - }, -} diff --git a/distributed/cli/dask_mpi.py b/distributed/cli/dask_mpi.py index 398596508a3..c7669073f79 100644 --- a/distributed/cli/dask_mpi.py +++ b/distributed/cli/dask_mpi.py @@ -7,7 +7,7 @@ from warnings import warn from distributed import Scheduler, Nanny, Worker -from distributed.bokeh.worker import BokehWorker +from distributed.dashboard import BokehWorker from distributed.cli.utils import check_python_3 from distributed.comm.addressing import uri_from_host_port from distributed.utils import get_ip_interface @@ -82,12 +82,12 @@ def main( if rank == 0 and scheduler: try: - from distributed.bokeh.scheduler import BokehScheduler + from distributed.dashboard import BokehScheduler except ImportError: services = {} else: services = { - ("bokeh", bokeh_port): partial(BokehScheduler, prefix=bokeh_prefix) + ("dashboard", bokeh_port): partial(BokehScheduler, prefix=bokeh_prefix) } scheduler = Scheduler( scheduler_file=scheduler_file, loop=loop, services=services @@ -107,7 +107,7 @@ def main( name=rank if scheduler else None, ncores=nthreads, local_dir=local_directory, - services={("bokeh", bokeh_worker_port): BokehWorker}, + services={("dashboard", bokeh_worker_port): BokehWorker}, memory_limit=memory_limit, ) addr = uri_from_host_port(host, None, 0) diff --git a/distributed/cli/dask_scheduler.py b/distributed/cli/dask_scheduler.py index 3668be684d0..1f78426f635 100755 --- a/distributed/cli/dask_scheduler.py +++ b/distributed/cli/dask_scheduler.py @@ -68,27 +68,23 @@ help="Address on which to listen for diagnostics dashboard", ) @click.option( - "--bokeh/--no-bokeh", - "_bokeh", + "--dashboard/--no-dashboard", + "dashboard", default=True, show_default=True, required=False, - help="Launch Bokeh Web UI", + help="Launch the Dashboard", ) @click.option("--show/--no-show", default=False, help="Show web UI") @click.option( - "--bokeh-whitelist", - default=None, - multiple=True, - help="IP addresses to whitelist for bokeh.", + "--dashboard-prefix", type=str, default=None, help="Prefix for the dashboard app" ) -@click.option("--bokeh-prefix", type=str, default=None, help="Prefix for the bokeh app") @click.option( "--use-xheaders", type=bool, default=False, show_default=True, - help="User xheaders in bokeh app for ssl termination in header", + help="User xheaders in dashboard app for ssl termination in header", ) @click.option("--pid-file", type=str, default="", help="File to write the process PID") @click.option( @@ -119,9 +115,8 @@ def main( port, bokeh_port, show, - _bokeh, - bokeh_whitelist, - bokeh_prefix, + dashboard, + dashboard_prefix, use_xheaders, pid_file, scheduler_file, @@ -195,8 +190,8 @@ def del_pid_file(): host=host, port=port, interface=interface, - dashboard_address=dashboard_address if _bokeh else None, - service_kwargs={"bokeh": {"prefix": bokeh_prefix}}, + dashboard_address=dashboard_address if dashboard else None, + service_kwargs={"dashboard": {"prefix": dashboard_prefix}}, ) scheduler.start() if not preload: diff --git a/distributed/cli/dask_worker.py b/distributed/cli/dask_worker.py index 439bdaf4a62..c4f83f61405 100755 --- a/distributed/cli/dask_worker.py +++ b/distributed/cli/dask_worker.py @@ -70,12 +70,12 @@ help="Address on which to listen for diagnostics dashboard", ) @click.option( - "--bokeh/--no-bokeh", - "bokeh", + "--dashboard/--no-dashboard", + "dashboard", default=True, show_default=True, required=False, - help="Launch Bokeh Web UI", + help="Launch the Dashboard", ) @click.option( "--listen-address", @@ -163,7 +163,9 @@ default=None, help="Seconds to wait for a scheduler before closing", ) -@click.option("--bokeh-prefix", type=str, default="", help="Prefix for the bokeh app") +@click.option( + "--dashboard-prefix", type=str, default="", help="Prefix for the dashboard" +) @click.option( "--preload", type=str, @@ -190,7 +192,7 @@ def main( pid_file, reconnect, resources, - bokeh, + dashboard, bokeh_port, local_directory, scheduler_file, @@ -198,7 +200,7 @@ def main( death_timeout, preload, preload_argv, - bokeh_prefix, + dashboard_prefix, tls_ca_file, tls_cert, tls_key, @@ -338,8 +340,8 @@ def del_pid_file(): interface=interface, host=host, port=port, - dashboard_address=dashboard_address if bokeh else None, - service_kwargs={"bokhe": {"prefix": bokeh_prefix}}, + dashboard_address=dashboard_address if dashboard else None, + service_kwargs={"bokhe": {"prefix": dashboard_prefix}}, name=name if nprocs == 1 or not name else name + "-" + str(i), **kwargs ) diff --git a/distributed/cli/tests/test_dask_scheduler.py b/distributed/cli/tests/test_dask_scheduler.py index 26fe607b901..754082f35eb 100644 --- a/distributed/cli/tests/test_dask_scheduler.py +++ b/distributed/cli/tests/test_dask_scheduler.py @@ -26,7 +26,7 @@ def test_defaults(loop): - with popen(["dask-scheduler", "--no-bokeh"]) as proc: + with popen(["dask-scheduler", "--no-dashboard"]) as proc: @gen.coroutine def f(): @@ -43,7 +43,7 @@ def f(): def test_hostport(loop): - with popen(["dask-scheduler", "--no-bokeh", "--host", "127.0.0.1:8978"]): + with popen(["dask-scheduler", "--no-dashboard", "--host", "127.0.0.1:8978"]): @gen.coroutine def f(): @@ -57,18 +57,18 @@ def f(): c.sync(f) -def test_no_bokeh(loop): +def test_no_dashboard(loop): pytest.importorskip("bokeh") - with popen(["dask-scheduler", "--no-bokeh"]) as proc: + with popen(["dask-scheduler", "--no-dashboard"]) as proc: with Client("127.0.0.1:%d" % Scheduler.default_port, loop=loop) as c: for i in range(3): line = proc.stderr.readline() - assert b"bokeh" not in line.lower() + assert b"dashboard" not in line.lower() with pytest.raises(Exception): requests.get("http://127.0.0.1:8787/status/") -def test_bokeh(loop): +def test_dashboard(loop): pytest.importorskip("bokeh") with popen(["dask-scheduler"]) as proc: @@ -97,7 +97,7 @@ def test_bokeh(loop): requests.get("http://127.0.0.1:8787/status/") -def test_bokeh_non_standard_ports(loop): +def test_dashboard_non_standard_ports(loop): pytest.importorskip("bokeh") with popen( @@ -122,20 +122,12 @@ def test_bokeh_non_standard_ports(loop): @pytest.mark.skipif( not sys.platform.startswith("linux"), reason="Need 127.0.0.2 to mean localhost" ) -def test_bokeh_whitelist(loop): +def test_dashboard_whitelist(loop): pytest.importorskip("bokeh") with pytest.raises(Exception): requests.get("http://localhost:8787/status/").ok - with popen( - [ - "dask-scheduler", - "--bokeh-whitelist", - "127.0.0.2:8787", - "--bokeh-whitelist", - "127.0.0.3:8787", - ] - ) as proc: + with popen(["dask-scheduler"]) as proc: with Client("127.0.0.1:%d" % Scheduler.default_port, loop=loop) as c: pass @@ -153,9 +145,9 @@ def test_bokeh_whitelist(loop): def test_multiple_workers(loop): - with popen(["dask-scheduler", "--no-bokeh"]) as s: - with popen(["dask-worker", "localhost:8786", "--no-bokeh"]) as a: - with popen(["dask-worker", "localhost:8786", "--no-bokeh"]) as b: + with popen(["dask-scheduler", "--no-dashboard"]) as s: + with popen(["dask-worker", "localhost:8786", "--no-dashboard"]) as a: + with popen(["dask-worker", "localhost:8786", "--no-dashboard"]) as b: with Client("127.0.0.1:%d" % Scheduler.default_port, loop=loop) as c: start = time() while len(c.ncores()) < 2: @@ -180,9 +172,9 @@ def test_interface(loop): "Available interfaces are: %s." % (if_names,) ) - with popen(["dask-scheduler", "--no-bokeh", "--interface", if_name]) as s: + with popen(["dask-scheduler", "--no-dashboard", "--interface", if_name]) as s: with popen( - ["dask-worker", "127.0.0.1:8786", "--no-bokeh", "--interface", if_name] + ["dask-worker", "127.0.0.1:8786", "--no-dashboard", "--interface", if_name] ) as a: with Client("tcp://127.0.0.1:%d" % Scheduler.default_port, loop=loop) as c: start = time() @@ -217,12 +209,12 @@ def check_pidfile(proc, pidfile): assert proc.pid == pid with tmpfile() as s: - with popen(["dask-scheduler", "--pid-file", s, "--no-bokeh"]) as sched: + with popen(["dask-scheduler", "--pid-file", s, "--no-dashboard"]) as sched: check_pidfile(sched, s) with tmpfile() as w: with popen( - ["dask-worker", "127.0.0.1:8786", "--pid-file", w, "--no-bokeh"] + ["dask-worker", "127.0.0.1:8786", "--pid-file", w, "--no-dashboard"] ) as worker: check_pidfile(worker, w) @@ -230,21 +222,21 @@ def check_pidfile(proc, pidfile): def test_scheduler_port_zero(loop): with tmpfile() as fn: with popen( - ["dask-scheduler", "--no-bokeh", "--scheduler-file", fn, "--port", "0"] + ["dask-scheduler", "--no-dashboard", "--scheduler-file", fn, "--port", "0"] ) as sched: with Client(scheduler_file=fn, loop=loop) as c: assert c.scheduler.port assert c.scheduler.port != 8786 -def test_bokeh_port_zero(loop): +def test_dashboard_port_zero(loop): pytest.importorskip("bokeh") with tmpfile() as fn: with popen(["dask-scheduler", "--dashboard-address", ":0"]) as proc: count = 0 while count < 1: line = proc.stderr.readline() - if b"bokeh" in line.lower() or b"web" in line.lower(): + if b"dashboard" in line.lower(): sleep(0.01) count += 1 assert b":0" not in line diff --git a/distributed/cli/tests/test_dask_worker.py b/distributed/cli/tests/test_dask_worker.py index aac27061b21..fa62594a753 100644 --- a/distributed/cli/tests/test_dask_worker.py +++ b/distributed/cli/tests/test_dask_worker.py @@ -17,7 +17,7 @@ def test_nanny_worker_ports(loop): - with popen(["dask-scheduler", "--port", "9359", "--no-bokeh"]) as sched: + with popen(["dask-scheduler", "--port", "9359", "--no-dashboard"]) as sched: with popen( [ "dask-worker", @@ -28,7 +28,7 @@ def test_nanny_worker_ports(loop): "9684", "--nanny-port", "5273", - "--no-bokeh", + "--no-dashboard", ] ) as worker: with Client("127.0.0.1:9359", loop=loop) as c: @@ -47,9 +47,15 @@ def test_nanny_worker_ports(loop): def test_memory_limit(loop): - with popen(["dask-scheduler", "--no-bokeh"]) as sched: + with popen(["dask-scheduler", "--no-dashboard"]) as sched: with popen( - ["dask-worker", "127.0.0.1:8786", "--memory-limit", "2e3MB", "--no-bokeh"] + [ + "dask-worker", + "127.0.0.1:8786", + "--memory-limit", + "2e3MB", + "--no-dashboard", + ] ) as worker: with Client("127.0.0.1:8786", loop=loop) as c: while not c.ncores(): @@ -61,9 +67,9 @@ def test_memory_limit(loop): def test_no_nanny(loop): - with popen(["dask-scheduler", "--no-bokeh"]) as sched: + with popen(["dask-scheduler", "--no-dashboard"]) as sched: with popen( - ["dask-worker", "127.0.0.1:8786", "--no-nanny", "--no-bokeh"] + ["dask-worker", "127.0.0.1:8786", "--no-nanny", "--no-dashboard"] ) as worker: assert any(b"Registered" in worker.stderr.readline() for i in range(15)) @@ -71,7 +77,7 @@ def test_no_nanny(loop): @pytest.mark.slow @pytest.mark.parametrize("nanny", ["--nanny", "--no-nanny"]) def test_no_reconnect(nanny, loop): - with popen(["dask-scheduler", "--no-bokeh"]) as sched: + with popen(["dask-scheduler", "--no-dashboard"]) as sched: wait_for_port(("127.0.0.1", 8786)) with popen( [ @@ -79,7 +85,7 @@ def test_no_reconnect(nanny, loop): "tcp://127.0.0.1:8786", "--no-reconnect", nanny, - "--no-bokeh", + "--no-dashboard", ] ) as worker: sleep(2) @@ -91,12 +97,12 @@ def test_no_reconnect(nanny, loop): def test_resources(loop): - with popen(["dask-scheduler", "--no-bokeh"]) as sched: + with popen(["dask-scheduler", "--no-dashboard"]) as sched: with popen( [ "dask-worker", "tcp://127.0.0.1:8786", - "--no-bokeh", + "--no-dashboard", "--resources", "A=1 B=2,C=3", ] @@ -112,13 +118,13 @@ def test_resources(loop): @pytest.mark.parametrize("nanny", ["--nanny", "--no-nanny"]) def test_local_directory(loop, nanny): with tmpfile() as fn: - with popen(["dask-scheduler", "--no-bokeh"]) as sched: + with popen(["dask-scheduler", "--no-dashboard"]) as sched: with popen( [ "dask-worker", "127.0.0.1:8786", nanny, - "--no-bokeh", + "--no-dashboard", "--local-directory", fn, ] @@ -136,8 +142,12 @@ def test_local_directory(loop, nanny): @pytest.mark.parametrize("nanny", ["--nanny", "--no-nanny"]) def test_scheduler_file(loop, nanny): with tmpfile() as fn: - with popen(["dask-scheduler", "--no-bokeh", "--scheduler-file", fn]) as sched: - with popen(["dask-worker", "--scheduler-file", fn, nanny, "--no-bokeh"]): + with popen( + ["dask-scheduler", "--no-dashboard", "--scheduler-file", fn] + ) as sched: + with popen( + ["dask-worker", "--scheduler-file", fn, nanny, "--no-dashboard"] + ): with Client(scheduler_file=fn, loop=loop) as c: start = time() while not c.scheduler_info()["workers"]: @@ -147,8 +157,8 @@ def test_scheduler_file(loop, nanny): def test_scheduler_address_env(loop, monkeypatch): monkeypatch.setenv("DASK_SCHEDULER_ADDRESS", "tcp://127.0.0.1:8786") - with popen(["dask-scheduler", "--no-bokeh"]) as sched: - with popen(["dask-worker", "--no-bokeh"]): + with popen(["dask-scheduler", "--no-dashboard"]) as sched: + with popen(["dask-worker", "--no-dashboard"]): with Client(os.environ["DASK_SCHEDULER_ADDRESS"], loop=loop) as c: start = time() while not c.scheduler_info()["workers"]: @@ -157,7 +167,7 @@ def test_scheduler_address_env(loop, monkeypatch): def test_nprocs_requires_nanny(loop): - with popen(["dask-scheduler", "--no-bokeh"]) as sched: + with popen(["dask-scheduler", "--no-dashboard"]) as sched: with popen( ["dask-worker", "127.0.0.1:8786", "--nprocs=2", "--no-nanny"] ) as worker: @@ -168,7 +178,7 @@ def test_nprocs_requires_nanny(loop): def test_nprocs_expands_name(loop): - with popen(["dask-scheduler", "--no-bokeh"]) as sched: + with popen(["dask-scheduler", "--no-dashboard"]) as sched: with popen( ["dask-worker", "127.0.0.1:8786", "--nprocs", "2", "--name", "foo"] ) as worker: @@ -194,13 +204,13 @@ def test_nprocs_expands_name(loop): "listen_address", ["tcp://0.0.0.0:39837", "tcp://127.0.0.2:39837"] ) def test_contact_listen_address(loop, nanny, listen_address): - with popen(["dask-scheduler", "--no-bokeh"]) as sched: + with popen(["dask-scheduler", "--no-dashboard"]) as sched: with popen( [ "dask-worker", "127.0.0.1:8786", nanny, - "--no-bokeh", + "--no-dashboard", "--contact-address", "tcp://127.0.0.2:39837", "--listen-address", @@ -228,9 +238,9 @@ def func(dask_worker): @pytest.mark.parametrize("nanny", ["--nanny", "--no-nanny"]) @pytest.mark.parametrize("host", ["127.0.0.2", "0.0.0.0"]) def test_respect_host_listen_address(loop, nanny, host): - with popen(["dask-scheduler", "--no-bokeh"]) as sched: + with popen(["dask-scheduler", "--no-dashboard"]) as sched: with popen( - ["dask-worker", "127.0.0.1:8786", nanny, "--no-bokeh", "--host", host] + ["dask-worker", "127.0.0.1:8786", nanny, "--no-dashboard", "--host", host] ) as worker: with Client("127.0.0.1:8786") as client: while not client.ncores(): @@ -247,7 +257,7 @@ def func(dask_worker): assert all(host in v for v in listen_addresses.values()) -def test_bokeh_non_standard_ports(loop): +def test_dashboard_non_standard_ports(loop): pytest.importorskip("bokeh") try: import jupyter_server_proxy # noqa: F401 diff --git a/distributed/cli/tests/test_tls_cli.py b/distributed/cli/tests/test_tls_cli.py index d983039c962..4663a9b38ff 100644 --- a/distributed/cli/tests/test_tls_cli.py +++ b/distributed/cli/tests/test_tls_cli.py @@ -33,9 +33,9 @@ def wait_for_cores(c, ncores=1): def test_basic(loop): - with popen(["dask-scheduler", "--no-bokeh"] + tls_args) as s: + with popen(["dask-scheduler", "--no-dashboard"] + tls_args) as s: with popen( - ["dask-worker", "--no-bokeh", "tls://127.0.0.1:8786"] + tls_args + ["dask-worker", "--no-dashboard", "tls://127.0.0.1:8786"] + tls_args ) as w: with Client( "tls://127.0.0.1:8786", loop=loop, security=tls_security() @@ -44,9 +44,10 @@ def test_basic(loop): def test_nanny(loop): - with popen(["dask-scheduler", "--no-bokeh"] + tls_args) as s: + with popen(["dask-scheduler", "--no-dashboard"] + tls_args) as s: with popen( - ["dask-worker", "--no-bokeh", "--nanny", "tls://127.0.0.1:8786"] + tls_args + ["dask-worker", "--no-dashboard", "--nanny", "tls://127.0.0.1:8786"] + + tls_args ) as w: with Client( "tls://127.0.0.1:8786", loop=loop, security=tls_security() @@ -55,9 +56,9 @@ def test_nanny(loop): def test_separate_key_cert(loop): - with popen(["dask-scheduler", "--no-bokeh"] + tls_args_2) as s: + with popen(["dask-scheduler", "--no-dashboard"] + tls_args_2) as s: with popen( - ["dask-worker", "--no-bokeh", "tls://127.0.0.1:8786"] + tls_args_2 + ["dask-worker", "--no-dashboard", "tls://127.0.0.1:8786"] + tls_args_2 ) as w: with Client( "tls://127.0.0.1:8786", loop=loop, security=tls_security() @@ -67,8 +68,8 @@ def test_separate_key_cert(loop): def test_use_config_file(loop): with new_config_file(tls_only_config()): - with popen(["dask-scheduler", "--no-bokeh", "--host", "tls://"]) as s: - with popen(["dask-worker", "--no-bokeh", "tls://127.0.0.1:8786"]) as w: + with popen(["dask-scheduler", "--no-dashboard", "--host", "tls://"]) as s: + with popen(["dask-worker", "--no-dashboard", "tls://127.0.0.1:8786"]) as w: with Client( "tls://127.0.0.1:8786", loop=loop, security=tls_security() ) as c: diff --git a/distributed/client.py b/distributed/client.py index afe6f6ef39f..d924b608c61 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -812,9 +812,9 @@ def _repr_html_(self): text = ( "

        Client

        \n" "
          \n" "
        • Scheduler: not connected\n" ) - if info and "bokeh" in info["services"]: + if info and "dashboard" in info["services"]: protocol, rest = scheduler.address.split("://") - port = info["services"]["bokeh"] + port = info["services"]["dashboard"] if protocol == "inproc": host = "localhost" else: @@ -3852,7 +3852,7 @@ def _get_task_stream( from .diagnostics.task_stream import rectangles rects = rectangles(msgs) - from .bokeh.components import task_stream_figure + from .dashboard.components import task_stream_figure source, figure = task_stream_figure(sizing_mode="stretch_both") source.data.update(rects) diff --git a/distributed/dashboard/__init__.py b/distributed/dashboard/__init__.py new file mode 100644 index 00000000000..675963b1463 --- /dev/null +++ b/distributed/dashboard/__init__.py @@ -0,0 +1,2 @@ +from .scheduler import BokehScheduler +from .worker import BokehWorker diff --git a/distributed/bokeh/components.py b/distributed/dashboard/components.py similarity index 100% rename from distributed/bokeh/components.py rename to distributed/dashboard/components.py diff --git a/distributed/bokeh/core.py b/distributed/dashboard/core.py similarity index 100% rename from distributed/bokeh/core.py rename to distributed/dashboard/core.py diff --git a/distributed/bokeh/export_tool.coffee b/distributed/dashboard/export_tool.coffee similarity index 100% rename from distributed/bokeh/export_tool.coffee rename to distributed/dashboard/export_tool.coffee diff --git a/distributed/bokeh/export_tool.js b/distributed/dashboard/export_tool.js similarity index 100% rename from distributed/bokeh/export_tool.js rename to distributed/dashboard/export_tool.js diff --git a/distributed/bokeh/export_tool.py b/distributed/dashboard/export_tool.py similarity index 100% rename from distributed/bokeh/export_tool.py rename to distributed/dashboard/export_tool.py diff --git a/distributed/bokeh/proxy.py b/distributed/dashboard/proxy.py similarity index 98% rename from distributed/bokeh/proxy.py rename to distributed/dashboard/proxy.py index 9353e383112..89f9f87aae6 100644 --- a/distributed/bokeh/proxy.py +++ b/distributed/dashboard/proxy.py @@ -124,7 +124,7 @@ def check_worker_dashboard_exits(scheduler, worker): addr, port = worker.split(":") workers = list(scheduler.workers.values()) for w in workers: - bokeh_port = w.services.get("bokeh", "") + bokeh_port = w.services.get("dashboard", "") if addr == w.host and port == str(bokeh_port): return True return False diff --git a/distributed/bokeh/scheduler.py b/distributed/dashboard/scheduler.py similarity index 99% rename from distributed/bokeh/scheduler.py rename to distributed/dashboard/scheduler.py index e0f5bfffab6..6476d3aa6e4 100644 --- a/distributed/bokeh/scheduler.py +++ b/distributed/dashboard/scheduler.py @@ -170,7 +170,7 @@ def update(self): workers = list(self.scheduler.workers.values()) dashboard_host = [ws.host for ws in workers] - dashboard_port = [ws.services.get("bokeh", "") for ws in workers] + dashboard_port = [ws.services.get("dashboard", "") for ws in workers] y = list(range(len(workers))) occupancy = [ws.occupancy for ws in workers] @@ -403,7 +403,7 @@ def update(self): workers = list(self.scheduler.workers.values()) dashboard_host = [ws.host for ws in workers] - dashboard_port = [ws.services.get("bokeh", "") for ws in workers] + dashboard_port = [ws.services.get("dashboard", "") for ws in workers] y = list(range(len(workers))) nprocessing = [len(ws.processing) for ws in workers] diff --git a/distributed/bokeh/scheduler_html.py b/distributed/dashboard/scheduler_html.py similarity index 99% rename from distributed/bokeh/scheduler_html.py rename to distributed/dashboard/scheduler_html.py index 1d3635c37c5..5f481f783be 100644 --- a/distributed/bokeh/scheduler_html.py +++ b/distributed/dashboard/scheduler_html.py @@ -177,7 +177,7 @@ def get(self): class IndividualPlots(RequestHandler): def get(self): - bokeh_server = self.server.services["bokeh"] + bokeh_server = self.server.services["dashboard"] result = { uri.strip("/").replace("-", " ").title(): uri for uri in bokeh_server.apps diff --git a/distributed/bokeh/static/css/base.css b/distributed/dashboard/static/css/base.css similarity index 100% rename from distributed/bokeh/static/css/base.css rename to distributed/dashboard/static/css/base.css diff --git a/distributed/bokeh/static/css/status.css b/distributed/dashboard/static/css/status.css similarity index 100% rename from distributed/bokeh/static/css/status.css rename to distributed/dashboard/static/css/status.css diff --git a/distributed/bokeh/static/css/system.css b/distributed/dashboard/static/css/system.css similarity index 100% rename from distributed/bokeh/static/css/system.css rename to distributed/dashboard/static/css/system.css diff --git a/distributed/bokeh/static/images/dask-logo.svg b/distributed/dashboard/static/images/dask-logo.svg similarity index 100% rename from distributed/bokeh/static/images/dask-logo.svg rename to distributed/dashboard/static/images/dask-logo.svg diff --git a/distributed/bokeh/static/images/fa-bars.svg b/distributed/dashboard/static/images/fa-bars.svg similarity index 100% rename from distributed/bokeh/static/images/fa-bars.svg rename to distributed/dashboard/static/images/fa-bars.svg diff --git a/distributed/bokeh/templates/base.html b/distributed/dashboard/templates/base.html similarity index 100% rename from distributed/bokeh/templates/base.html rename to distributed/dashboard/templates/base.html diff --git a/distributed/bokeh/templates/call-stack.html b/distributed/dashboard/templates/call-stack.html similarity index 100% rename from distributed/bokeh/templates/call-stack.html rename to distributed/dashboard/templates/call-stack.html diff --git a/distributed/bokeh/templates/json-index.html b/distributed/dashboard/templates/json-index.html similarity index 100% rename from distributed/bokeh/templates/json-index.html rename to distributed/dashboard/templates/json-index.html diff --git a/distributed/bokeh/templates/logs.html b/distributed/dashboard/templates/logs.html similarity index 100% rename from distributed/bokeh/templates/logs.html rename to distributed/dashboard/templates/logs.html diff --git a/distributed/bokeh/templates/main.html b/distributed/dashboard/templates/main.html similarity index 100% rename from distributed/bokeh/templates/main.html rename to distributed/dashboard/templates/main.html diff --git a/distributed/bokeh/templates/simple.html b/distributed/dashboard/templates/simple.html similarity index 100% rename from distributed/bokeh/templates/simple.html rename to distributed/dashboard/templates/simple.html diff --git a/distributed/bokeh/templates/status.html b/distributed/dashboard/templates/status.html similarity index 100% rename from distributed/bokeh/templates/status.html rename to distributed/dashboard/templates/status.html diff --git a/distributed/bokeh/templates/system.html b/distributed/dashboard/templates/system.html similarity index 100% rename from distributed/bokeh/templates/system.html rename to distributed/dashboard/templates/system.html diff --git a/distributed/bokeh/templates/task.html b/distributed/dashboard/templates/task.html similarity index 100% rename from distributed/bokeh/templates/task.html rename to distributed/dashboard/templates/task.html diff --git a/distributed/bokeh/templates/worker-table.html b/distributed/dashboard/templates/worker-table.html similarity index 100% rename from distributed/bokeh/templates/worker-table.html rename to distributed/dashboard/templates/worker-table.html diff --git a/distributed/bokeh/templates/worker.html b/distributed/dashboard/templates/worker.html similarity index 100% rename from distributed/bokeh/templates/worker.html rename to distributed/dashboard/templates/worker.html diff --git a/distributed/bokeh/templates/workers.html b/distributed/dashboard/templates/workers.html similarity index 100% rename from distributed/bokeh/templates/workers.html rename to distributed/dashboard/templates/workers.html diff --git a/distributed/dashboard/tests/test_bokeh.py b/distributed/dashboard/tests/test_bokeh.py new file mode 100644 index 00000000000..363272be5f6 --- /dev/null +++ b/distributed/dashboard/tests/test_bokeh.py @@ -0,0 +1,5 @@ +def test_old_import(): + try: + from distributed.bokeh import BokehScheduler # noqa: F401 + except ImportError as e: + assert "distributed.dashboard" in str(e) diff --git a/distributed/bokeh/tests/test_components.py b/distributed/dashboard/tests/test_components.py similarity index 92% rename from distributed/bokeh/tests/test_components.py rename to distributed/dashboard/tests/test_components.py index 028f209b41a..d441db57aec 100644 --- a/distributed/bokeh/tests/test_components.py +++ b/distributed/dashboard/tests/test_components.py @@ -7,10 +7,9 @@ from bokeh.models import ColumnDataSource, Model from tornado import gen -from distributed.bokeh import messages from distributed.utils_test import slowinc, gen_cluster -from distributed.bokeh.components import ( +from distributed.dashboard.components import ( TaskStream, MemoryUsage, Processing, @@ -24,7 +23,6 @@ def test_basic(Component): c = Component() assert isinstance(c.source, ColumnDataSource) assert isinstance(c.root, Model) - c.update(messages) @gen_cluster(client=True, check_new_threads=False) diff --git a/distributed/bokeh/tests/test_scheduler_bokeh.py b/distributed/dashboard/tests/test_scheduler_bokeh.py similarity index 95% rename from distributed/bokeh/tests/test_scheduler_bokeh.py rename to distributed/dashboard/tests/test_scheduler_bokeh.py index 057aa679655..f8a813514b4 100644 --- a/distributed/bokeh/tests/test_scheduler_bokeh.py +++ b/distributed/dashboard/tests/test_scheduler_bokeh.py @@ -17,8 +17,8 @@ from distributed.client import wait from distributed.metrics import time from distributed.utils_test import gen_cluster, inc, dec, slowinc, div -from distributed.bokeh.worker import Counters, BokehWorker -from distributed.bokeh.scheduler import ( +from distributed.dashboard.worker import Counters, BokehWorker +from distributed.dashboard.scheduler import ( BokehScheduler, SystemMonitor, Occupancy, @@ -36,7 +36,7 @@ ProfileServer, ) -from distributed.bokeh import scheduler +from distributed.dashboard import scheduler scheduler.PROFILING = False @@ -44,10 +44,12 @@ @pytest.mark.skipif( sys.version_info[0] == 2, reason="https://github.com/bokeh/bokeh/issues/5494" ) -@gen_cluster(client=True, scheduler_kwargs={"services": {("bokeh", 0): BokehScheduler}}) +@gen_cluster( + client=True, scheduler_kwargs={"services": {("dashboard", 0): BokehScheduler}} +) def test_simple(c, s, a, b): - assert isinstance(s.services["bokeh"], BokehScheduler) - port = s.services["bokeh"].port + assert isinstance(s.services["dashboard"], BokehScheduler) + port = s.services["dashboard"].port future = c.submit(sleep, 1) yield gen.sleep(0.1) @@ -80,7 +82,7 @@ def test_simple(c, s, a, b): assert response -@gen_cluster(client=True, worker_kwargs=dict(services={"bokeh": BokehWorker})) +@gen_cluster(client=True, worker_kwargs=dict(services={"dashboard": BokehWorker})) def test_basic(c, s, a, b): for component in [SystemMonitor, Occupancy, StealingTimeSeries]: ss = component(s) @@ -573,11 +575,13 @@ def test_profile_server(c, s, a, b): assert time() < start + 2 -@gen_cluster(client=True, scheduler_kwargs={"services": {("bokeh", 0): BokehScheduler}}) +@gen_cluster( + client=True, scheduler_kwargs={"services": {("dashboard", 0): BokehScheduler}} +) def test_root_redirect(c, s, a, b): http_client = AsyncHTTPClient() response = yield http_client.fetch( - "http://localhost:%d/" % s.services["bokeh"].port + "http://localhost:%d/" % s.services["dashboard"].port ) assert response.code == 200 assert "/status" in response.effective_url @@ -585,8 +589,8 @@ def test_root_redirect(c, s, a, b): @gen_cluster( client=True, - scheduler_kwargs={"services": {("bokeh", 0): BokehScheduler}}, - worker_kwargs={"services": {"bokeh": BokehWorker}}, + scheduler_kwargs={"services": {("dashboard", 0): BokehScheduler}}, + worker_kwargs={"services": {"dashboard": BokehWorker}}, timeout=180, ) def test_proxy_to_workers(c, s, a, b): @@ -597,7 +601,7 @@ def test_proxy_to_workers(c, s, a, b): except ImportError: proxy_exists = False - dashboard_port = s.services["bokeh"].port + dashboard_port = s.services["dashboard"].port http_client = AsyncHTTPClient() response = yield http_client.fetch("http://localhost:%d/" % dashboard_port) assert response.code == 200 @@ -605,7 +609,7 @@ def test_proxy_to_workers(c, s, a, b): for w in [a, b]: host = w.ip - port = w.service_ports["bokeh"] + port = w.service_ports["dashboard"] proxy_url = "http://localhost:%d/proxy/%s/%s/status" % ( dashboard_port, port, diff --git a/distributed/bokeh/tests/test_scheduler_bokeh_html.py b/distributed/dashboard/tests/test_scheduler_bokeh_html.py similarity index 79% rename from distributed/bokeh/tests/test_scheduler_bokeh_html.py rename to distributed/dashboard/tests/test_scheduler_bokeh_html.py index 691121f7514..f872d02dc84 100644 --- a/distributed/bokeh/tests/test_scheduler_bokeh_html.py +++ b/distributed/dashboard/tests/test_scheduler_bokeh_html.py @@ -13,14 +13,13 @@ from dask.sizeof import sizeof from distributed.utils_test import gen_cluster, slowinc, inc -from distributed.bokeh.scheduler import BokehScheduler -from distributed.bokeh.worker import BokehWorker +from distributed.dashboard import BokehScheduler, BokehWorker @gen_cluster( client=True, - scheduler_kwargs={"services": {("bokeh", 0): BokehScheduler}}, - worker_kwargs={"services": {"bokeh": BokehWorker}}, + scheduler_kwargs={"services": {("dashboard", 0): BokehScheduler}}, + worker_kwargs={"services": {"dashboard": BokehWorker}}, ) def test_connect(c, s, a, b): future = c.submit(lambda x: x + 1, 1) @@ -41,7 +40,7 @@ def test_connect(c, s, a, b): "individual-plots.json", ]: response = yield http_client.fetch( - "http://localhost:%d/%s" % (s.services["bokeh"].port, suffix) + "http://localhost:%d/%s" % (s.services["dashboard"].port, suffix) ) assert response.code == 200 body = response.body.decode() @@ -54,13 +53,15 @@ def test_connect(c, s, a, b): @gen_cluster( client=True, - scheduler_kwargs={"services": {("bokeh", 0): (BokehScheduler, {"prefix": "/foo"})}}, + scheduler_kwargs={ + "services": {("dashboard", 0): (BokehScheduler, {"prefix": "/foo"})} + }, ) def test_prefix(c, s, a, b): http_client = AsyncHTTPClient() for suffix in ["foo/info/main/workers.html", "foo/json/index.html", "foo/system"]: response = yield http_client.fetch( - "http://localhost:%d/%s" % (s.services["bokeh"].port, suffix) + "http://localhost:%d/%s" % (s.services["dashboard"].port, suffix) ) assert response.code == 200 body = response.body.decode() @@ -73,7 +74,7 @@ def test_prefix(c, s, a, b): @gen_cluster( client=True, check_new_threads=False, - scheduler_kwargs={"services": {("bokeh", 0): BokehScheduler}}, + scheduler_kwargs={"services": {("dashboard", 0): BokehScheduler}}, ) def test_prometheus(c, s, a, b): pytest.importorskip("prometheus_client") @@ -85,7 +86,7 @@ def test_prometheus(c, s, a, b): # prometheus_client errors for _ in range(2): response = yield http_client.fetch( - "http://localhost:%d/metrics" % s.services["bokeh"].port + "http://localhost:%d/metrics" % s.services["dashboard"].port ) assert response.code == 200 assert response.headers["Content-Type"] == "text/plain; version=0.0.4" @@ -98,13 +99,13 @@ def test_prometheus(c, s, a, b): @gen_cluster( client=True, check_new_threads=False, - scheduler_kwargs={"services": {("bokeh", 0): BokehScheduler}}, + scheduler_kwargs={"services": {("dashboard", 0): BokehScheduler}}, ) def test_health(c, s, a, b): http_client = AsyncHTTPClient() response = yield http_client.fetch( - "http://localhost:%d/health" % s.services["bokeh"].port + "http://localhost:%d/health" % s.services["dashboard"].port ) assert response.code == 200 assert response.headers["Content-Type"] == "text/plain" @@ -113,7 +114,9 @@ def test_health(c, s, a, b): assert txt == "ok" -@gen_cluster(client=True, scheduler_kwargs={"services": {("bokeh", 0): BokehScheduler}}) +@gen_cluster( + client=True, scheduler_kwargs={"services": {("dashboard", 0): BokehScheduler}} +) def test_task_page(c, s, a, b): future = c.submit(lambda x: x + 1, 1, workers=a.address) x = c.submit(inc, 1) @@ -122,7 +125,7 @@ def test_task_page(c, s, a, b): "info/task/" + url_escape(future.key) + ".html", response = yield http_client.fetch( - "http://localhost:%d/info/task/" % s.services["bokeh"].port + "http://localhost:%d/info/task/" % s.services["dashboard"].port + url_escape(future.key) + ".html" ) diff --git a/distributed/bokeh/tests/test_worker_bokeh.py b/distributed/dashboard/tests/test_worker_bokeh.py similarity index 79% rename from distributed/bokeh/tests/test_worker_bokeh.py rename to distributed/dashboard/tests/test_worker_bokeh.py index 03a7ed3861b..11699d9ac83 100644 --- a/distributed/bokeh/tests/test_worker_bokeh.py +++ b/distributed/dashboard/tests/test_worker_bokeh.py @@ -14,7 +14,7 @@ from distributed.client import wait from distributed.metrics import time from distributed.utils_test import gen_cluster, inc, dec -from distributed.bokeh.worker import ( +from distributed.dashboard.worker import ( BokehWorker, StateTable, CrossFilter, @@ -29,10 +29,10 @@ @pytest.mark.skipif( sys.version_info[0] == 2, reason="https://github.com/bokeh/bokeh/issues/5494" ) -@gen_cluster(client=True, worker_kwargs={"services": {("bokeh", 0): BokehWorker}}) +@gen_cluster(client=True, worker_kwargs={"services": {("dashboard", 0): BokehWorker}}) def test_simple(c, s, a, b): - assert s.workers[a.address].services == {"bokeh": a.services["bokeh"].port} - assert s.workers[b.address].services == {"bokeh": b.services["bokeh"].port} + assert s.workers[a.address].services == {"dashboard": a.services["dashboard"].port} + assert s.workers[b.address].services == {"dashboard": b.services["dashboard"].port} future = c.submit(sleep, 1) yield gen.sleep(0.1) @@ -40,15 +40,17 @@ def test_simple(c, s, a, b): http_client = AsyncHTTPClient() for suffix in ["main", "crossfilter", "system"]: response = yield http_client.fetch( - "http://localhost:%d/%s" % (a.services["bokeh"].port, suffix) + "http://localhost:%d/%s" % (a.services["dashboard"].port, suffix) ) assert "bokeh" in response.body.decode().lower() -@gen_cluster(client=True, worker_kwargs={"services": {("bokeh", 0): (BokehWorker, {})}}) +@gen_cluster( + client=True, worker_kwargs={"services": {("dashboard", 0): (BokehWorker, {})}} +) def test_services_kwargs(c, s, a, b): - assert s.workers[a.address].services == {"bokeh": a.services["bokeh"].port} - assert isinstance(a.services["bokeh"], BokehWorker) + assert s.workers[a.address].services == {"dashboard": a.services["dashboard"].port} + assert isinstance(a.services["dashboard"], BokehWorker) @gen_cluster(client=True) @@ -139,15 +141,15 @@ def test_CommunicatingStream(c, s, a, b): @gen_cluster( client=True, check_new_threads=False, - worker_kwargs={"services": {("bokeh", 0): BokehWorker}}, + worker_kwargs={"services": {("dashboard", 0): BokehWorker}}, ) def test_prometheus(c, s, a, b): pytest.importorskip("prometheus_client") - assert s.workers[a.address].services == {"bokeh": a.services["bokeh"].port} + assert s.workers[a.address].services == {"dashboard": a.services["dashboard"].port} http_client = AsyncHTTPClient() for suffix in ["metrics"]: response = yield http_client.fetch( - "http://localhost:%d/%s" % (a.services["bokeh"].port, suffix) + "http://localhost:%d/%s" % (a.services["dashboard"].port, suffix) ) assert response.code == 200 diff --git a/distributed/bokeh/tests/test_worker_bokeh_html.py b/distributed/dashboard/tests/test_worker_bokeh_html.py similarity index 75% rename from distributed/bokeh/tests/test_worker_bokeh_html.py rename to distributed/dashboard/tests/test_worker_bokeh_html.py index d59fec8d2d8..99916b3fdc7 100644 --- a/distributed/bokeh/tests/test_worker_bokeh_html.py +++ b/distributed/dashboard/tests/test_worker_bokeh_html.py @@ -4,10 +4,10 @@ from tornado.httpclient import AsyncHTTPClient from distributed.utils_test import gen_cluster -from distributed.bokeh.worker import BokehWorker +from distributed.dashboard import BokehWorker -@gen_cluster(client=True, worker_kwargs={"services": {("bokeh", 0): BokehWorker}}) +@gen_cluster(client=True, worker_kwargs={"services": {("dashboard", 0): BokehWorker}}) def test_prometheus(c, s, a, b): pytest.importorskip("prometheus_client") from prometheus_client.parser import text_string_to_metric_families @@ -18,7 +18,7 @@ def test_prometheus(c, s, a, b): # prometheus_client errors for _ in range(2): response = yield http_client.fetch( - "http://localhost:%d/metrics" % a.services["bokeh"].port + "http://localhost:%d/metrics" % a.services["dashboard"].port ) assert response.code == 200 assert response.headers["Content-Type"] == "text/plain; version=0.0.4" @@ -28,12 +28,12 @@ def test_prometheus(c, s, a, b): assert len(families) > 0 -@gen_cluster(client=True, worker_kwargs={"services": {("bokeh", 0): BokehWorker}}) +@gen_cluster(client=True, worker_kwargs={"services": {("dashboard", 0): BokehWorker}}) def test_health(c, s, a, b): http_client = AsyncHTTPClient() response = yield http_client.fetch( - "http://localhost:%d/health" % a.services["bokeh"].port + "http://localhost:%d/health" % a.services["dashboard"].port ) assert response.code == 200 assert response.headers["Content-Type"] == "text/plain" diff --git a/distributed/bokeh/theme.yaml b/distributed/dashboard/theme.yaml similarity index 100% rename from distributed/bokeh/theme.yaml rename to distributed/dashboard/theme.yaml diff --git a/distributed/bokeh/utils.py b/distributed/dashboard/utils.py similarity index 100% rename from distributed/bokeh/utils.py rename to distributed/dashboard/utils.py diff --git a/distributed/bokeh/worker.py b/distributed/dashboard/worker.py similarity index 100% rename from distributed/bokeh/worker.py rename to distributed/dashboard/worker.py diff --git a/distributed/bokeh/worker_html.py b/distributed/dashboard/worker_html.py similarity index 100% rename from distributed/bokeh/worker_html.py rename to distributed/dashboard/worker_html.py diff --git a/distributed/deploy/cluster.py b/distributed/deploy/cluster.py index 8425b836a4d..69cc5be9fac 100644 --- a/distributed/deploy/cluster.py +++ b/distributed/deploy/cluster.py @@ -88,7 +88,7 @@ def scheduler_address(self): def dashboard_link(self): template = dask.config.get("distributed.dashboard.link") host = self.scheduler.address.split("://")[1].split(":")[0] - port = self.scheduler.services["bokeh"].port + port = self.scheduler.services["dashboard"].port return template.format(host=host, port=port, **os.environ) def scale(self, n): @@ -165,7 +165,7 @@ def _widget(self): layout = Layout(width="150px") - if "bokeh" in self.scheduler.services: + if "dashboard" in self.scheduler.services: link = self.dashboard_link link = '

          Dashboard: %s

          \n' % ( link, diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index ad0aea25f6c..d5a954effc8 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -112,11 +112,11 @@ def __init__( self._created = weakref.WeakSet() if scheduler is None: try: - from distributed.bokeh.scheduler import BokehScheduler + from distributed.dashboard import BokehScheduler except ImportError: services = {} else: - services = {("bokeh", 8787): BokehScheduler} + services = {("dashboard", 8787): BokehScheduler} scheduler = {"cls": Scheduler, "options": {"services": services}} self.scheduler_spec = scheduler diff --git a/distributed/deploy/tests/test_local.py b/distributed/deploy/tests/test_local.py index 4498611d7e8..6f9a4a03244 100644 --- a/distributed/deploy/tests/test_local.py +++ b/distributed/deploy/tests/test_local.py @@ -198,8 +198,8 @@ def test_duplicate_clients(): with pytest.warns(Exception) as info: c2 = yield Client(processes=False, silence_logs=False, dashboard_address=9876) - assert "bokeh" in c1.cluster.scheduler.services - assert "bokeh" in c2.cluster.scheduler.services + assert "dashboard" in c1.cluster.scheduler.services + assert "dashboard" in c2.cluster.scheduler.services assert any( all( @@ -341,7 +341,7 @@ def test_bokeh(loop, processes): processes=processes, dashboard_address=0, ) as c: - bokeh_port = c.scheduler.services["bokeh"].port + bokeh_port = c.scheduler.services["dashboard"].port url = "http://127.0.0.1:%d/status/" % bokeh_port start = time() while True: @@ -485,10 +485,10 @@ def test_bokeh_kwargs(loop): silence_logs=False, loop=loop, dashboard_address=0, - service_kwargs={"bokeh": {"prefix": "/foo"}}, + service_kwargs={"dashboard": {"prefix": "/foo"}}, ) as c: - bs = c.scheduler.services["bokeh"] + bs = c.scheduler.services["dashboard"] assert bs.prefix == "/foo" diff --git a/distributed/diagnostics/tests/test_eventstream.py b/distributed/diagnostics/tests/test_eventstream.py index 0995d80db26..7ec646d7e91 100644 --- a/distributed/diagnostics/tests/test_eventstream.py +++ b/distributed/diagnostics/tests/test_eventstream.py @@ -1,6 +1,6 @@ from __future__ import print_function, division, absolute_import -from copy import deepcopy +import collections import pytest from tornado import gen @@ -26,10 +26,12 @@ def test_eventstream(c, s, *workers): assert len(es.buffer) == 11 - from distributed.bokeh import messages from distributed.diagnostics.progress_stream import task_stream_append - lists = deepcopy(messages["task-events"]["rectangles"]) + lists = { + name: collections.deque(maxlen=100) + for name in "start duration key name color worker worker_thread y alpha".split() + } workers = dict() for msg in es.buffer: task_stream_append(lists, msg, workers) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 991ff1a2108..9db6477aeb7 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -878,13 +878,13 @@ def __init__( if dashboard_address is not None: try: - from distributed.bokeh.scheduler import BokehScheduler + from distributed.dashboard import BokehScheduler except ImportError: logger.debug("To start diagnostics web server please install Bokeh") else: - self.service_specs[("bokeh", dashboard_address)] = ( + self.service_specs[("dashboard", dashboard_address)] = ( BokehScheduler, - (service_kwargs or {}).get("bokeh", {}), + (service_kwargs or {}).get("dashboard", {}), ) # Communication state diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 28c2f939eb7..c731ae6e5ad 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -3506,7 +3506,7 @@ def test_reconnect(loop): "127.0.0.1", "--port", "9393", - "--no-bokeh", + "--no-dashboard", ] with popen(scheduler_cli) as s: c = Client("127.0.0.1:9393", loop=loop) @@ -5221,12 +5221,11 @@ def test_quiet_scheduler_loss(c, s): @pytest.mark.skipif("USER" not in os.environ, reason="no USER env variable") def test_diagnostics_link_env_variable(loop): pytest.importorskip("bokeh") - from distributed.bokeh.scheduler import BokehScheduler + from distributed.dashboard import BokehScheduler - with cluster(scheduler_kwargs={"services": {("bokeh", 12355): BokehScheduler}}) as ( - s, - [a, b], - ): + with cluster( + scheduler_kwargs={"services": {("dashboard", 12355): BokehScheduler}} + ) as (s, [a, b]): with Client(s["address"], loop=loop) as c: with dask.config.set( {"distributed.dashboard.link": "http://foo-{USER}:{port}/status"} diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index 4b3c0ac0ade..f53340d1004 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -324,7 +324,10 @@ def check_rpc_message_lifetime(*listen_args): obj = CountedObject() assert CountedObject.n_instances == 1 del obj - assert CountedObject.n_instances == 0 + start = time() + while CountedObject.n_instances != 0: + yield gen.sleep(0.01) + assert time() < start + 1 with rpc(server.address) as remote: obj = CountedObject() diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index f5ce276b8d2..6df271ae34e 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -1174,7 +1174,7 @@ def test_correct_bad_time_estimate(c, s, *workers): @gen_test() def test_service_hosts(): pytest.importorskip("bokeh") - from distributed.bokeh.scheduler import BokehScheduler + from distributed.dashboard import BokehScheduler port = 0 for url, expected in [ @@ -1182,12 +1182,12 @@ def test_service_hosts(): ("tcp://127.0.0.1", "127.0.0.1"), ("tcp://127.0.0.1:38275", "127.0.0.1"), ]: - services = {("bokeh", port): BokehScheduler} + services = {("dashboard", port): BokehScheduler} s = Scheduler(services=services) yield s.start(url) - sock = first(s.services["bokeh"].server._http._sockets.values()) + sock = first(s.services["dashboard"].server._http._sockets.values()) if isinstance(expected, tuple): assert sock.getsockname()[0] in expected else: @@ -1196,12 +1196,12 @@ def test_service_hosts(): port = ("127.0.0.1", 0) for url in ["tcp://0.0.0.0", "tcp://127.0.0.1", "tcp://127.0.0.1:38275"]: - services = {("bokeh", port): BokehScheduler} + services = {("dashboard", port): BokehScheduler} s = Scheduler(services=services) yield s.start(url) - sock = first(s.services["bokeh"].server._http._sockets.values()) + sock = first(s.services["dashboard"].server._http._sockets.values()) assert sock.getsockname()[0] == "127.0.0.1" yield s.close() @@ -1566,9 +1566,9 @@ def test_host_address(): def test_dashboard_address(): pytest.importorskip("bokeh") s = yield Scheduler(dashboard_address="127.0.0.1:8901", port=0) - assert s.services["bokeh"].port == 8901 + assert s.services["dashboard"].port == 8901 yield s.close() s = yield Scheduler(dashboard_address="127.0.0.1", port=0) - assert s.services["bokeh"].port + assert s.services["dashboard"].port yield s.close() diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index bf4e483f441..12a6b5ff68f 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -972,25 +972,25 @@ def test_worker_fds(s): @gen_cluster(ncores=[]) def test_service_hosts_match_worker(s): pytest.importorskip("bokeh") - from distributed.bokeh.worker import BokehWorker + from distributed.dashboard import BokehWorker - services = {("bokeh", ":0"): BokehWorker} + services = {("dashboard", ":0"): BokehWorker} - w = Worker(s.address, services={("bokeh", ":0"): BokehWorker}) + w = Worker(s.address, services={("dashboard", ":0"): BokehWorker}) yield w._start("tcp://0.0.0.0") - sock = first(w.services["bokeh"].server._http._sockets.values()) + sock = first(w.services["dashboard"].server._http._sockets.values()) assert sock.getsockname()[0] in ("::", "0.0.0.0") yield w.close() - w = Worker(s.address, services={("bokeh", ":0"): BokehWorker}) + w = Worker(s.address, services={("dashboard", ":0"): BokehWorker}) yield w._start("tcp://127.0.0.1") - sock = first(w.services["bokeh"].server._http._sockets.values()) + sock = first(w.services["dashboard"].server._http._sockets.values()) assert sock.getsockname()[0] in ("::", "0.0.0.0") yield w.close() - w = Worker(s.address, services={("bokeh", 0): BokehWorker}) + w = Worker(s.address, services={("dashboard", 0): BokehWorker}) yield w._start("tcp://127.0.0.1") - sock = first(w.services["bokeh"].server._http._sockets.values()) + sock = first(w.services["dashboard"].server._http._sockets.values()) assert sock.getsockname()[0] == "127.0.0.1" yield w.close() @@ -998,14 +998,14 @@ def test_service_hosts_match_worker(s): @gen_cluster(ncores=[]) def test_start_services(s): pytest.importorskip("bokeh") - from distributed.bokeh.worker import BokehWorker + from distributed.dashboard import BokehWorker - services = {("bokeh", ":1234"): BokehWorker} + services = {("dashboard", ":1234"): BokehWorker} w = Worker(s.address, services=services) yield w._start() - assert w.services["bokeh"].server.port == 1234 + assert w.services["dashboard"].server.port == 1234 yield w.close() diff --git a/distributed/worker.py b/distributed/worker.py index 667bd83490f..711dad31651 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -546,13 +546,13 @@ def __init__( if dashboard_address is not None: try: - from distributed.bokeh.worker import BokehWorker + from distributed.dashboard import BokehWorker except ImportError: logger.debug("To start diagnostics web server please install Bokeh") else: - self.service_specs[("bokeh", dashboard_address)] = ( + self.service_specs[("dashboard", dashboard_address)] = ( BokehWorker, - (service_kwargs or {}).get("bokeh", {}), + (service_kwargs or {}).get("dashboard", {}), ) self.metrics = dict(metrics) if metrics else {} diff --git a/setup.py b/setup.py index 3ef26a047dc..0df22f3f911 100755 --- a/setup.py +++ b/setup.py @@ -2,46 +2,51 @@ import os from setuptools import setup -import sys import versioneer -requires = open('requirements.txt').read().strip().split('\n') +requires = open("requirements.txt").read().strip().split("\n") install_requires = [] extras_require = {} for r in requires: - if ';' in r: + if ";" in r: # requirements.txt conditional dependencies need to be reformatted for wheels # to the form: `'[extra_name]:condition' : ['requirements']` - req, cond = r.split(';', 1) - cond = ':' + cond + req, cond = r.split(";", 1) + cond = ":" + cond cond_reqs = extras_require.setdefault(cond, []) cond_reqs.append(req) else: install_requires.append(r) -setup(name='distributed', - version=versioneer.get_version(), - cmdclass=versioneer.get_cmdclass(), - description='Distributed scheduler for Dask', - url='https://distributed.readthedocs.io/en/latest/', - maintainer='Matthew Rocklin', - maintainer_email='mrocklin@gmail.com', - license='BSD', - package_data={'': ['templates/index.html', 'template.html'], - 'distributed': ['bokeh/templates/*.html']}, - include_package_data=True, - install_requires=install_requires, - extras_require=extras_require, - packages=['distributed', - 'distributed.bokeh', - 'distributed.cli', - 'distributed.comm', - 'distributed.deploy', - 'distributed.diagnostics', - 'distributed.protocol'], - long_description=(open('README.rst').read() if os.path.exists('README.rst') - else ''), - classifiers=[ +setup( + name="distributed", + version=versioneer.get_version(), + cmdclass=versioneer.get_cmdclass(), + description="Distributed scheduler for Dask", + url="https://distributed.readthedocs.io/en/latest/", + maintainer="Matthew Rocklin", + maintainer_email="mrocklin@gmail.com", + license="BSD", + package_data={ + "": ["templates/index.html", "template.html"], + "distributed": ["dashboard/templates/*.html"], + }, + include_package_data=True, + install_requires=install_requires, + extras_require=extras_require, + packages=[ + "distributed", + "distributed.dashboard", + "distributed.cli", + "distributed.comm", + "distributed.deploy", + "distributed.diagnostics", + "distributed.protocol", + ], + long_description=( + open("README.rst").read() if os.path.exists("README.rst") else "" + ), + classifiers=[ "Development Status :: 5 - Production/Stable", "Intended Audience :: Developers", "Intended Audience :: Science/Research", @@ -54,8 +59,8 @@ "Programming Language :: Python :: 3.7", "Topic :: Scientific/Engineering", "Topic :: System :: Distributed Computing", - ], - entry_points=''' + ], + entry_points=""" [console_scripts] dask-ssh=distributed.cli.dask_ssh:go dask-submit=distributed.cli.dask_submit:go @@ -63,5 +68,6 @@ dask-scheduler=distributed.cli.dask_scheduler:go dask-worker=distributed.cli.dask_worker:go dask-mpi=distributed.cli.dask_mpi:go - ''', - zip_safe=False) + """, + zip_safe=False, +) From 4e3ba76be99ae5d572364e3b8a05a5a7ec42cce5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Est=C3=A8ve?= Date: Wed, 29 May 2019 20:50:27 +0200 Subject: [PATCH 0308/1550] Add back LocalCluster.__repr__. (#2732) LocalCluster.__repr__ was removed in #2675. --- distributed/deploy/local.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/distributed/deploy/local.py b/distributed/deploy/local.py index 17150fdf70f..298c47d7a31 100644 --- a/distributed/deploy/local.py +++ b/distributed/deploy/local.py @@ -197,6 +197,13 @@ def __init__( ) self.scale(n_workers) + def __repr__(self): + return "LocalCluster(%r, workers=%d, ncores=%d)" % ( + self.scheduler_address, + len(self.workers), + sum(w.ncores for w in self.workers.values()), + ) + def nprocesses_nthreads(n): """ From 23b1d93ca7028e0b3dad0b55d6d133559b1d7f35 Mon Sep 17 00:00:00 2001 From: Manuel Garrido Date: Wed, 29 May 2019 22:10:54 +0100 Subject: [PATCH 0309/1550] add kwargs to progressbars (#2638) * add kwargs to progressbars * remove assertion * linting and add kwarg test for progress bar --- distributed/diagnostics/progressbar.py | 25 ++++++++++++++----- .../diagnostics/tests/test_progressbar.py | 8 ++++++ 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/distributed/diagnostics/progressbar.py b/distributed/diagnostics/progressbar.py index 08ba8f7da63..8a381562f27 100644 --- a/distributed/diagnostics/progressbar.py +++ b/distributed/diagnostics/progressbar.py @@ -118,6 +118,7 @@ def __init__( loop=None, complete=True, start=True, + **kwargs ): super(TextProgressBar, self).__init__(keys, scheduler, interval, complete) self.width = width @@ -154,7 +155,13 @@ class ProgressWidget(ProgressBar): """ def __init__( - self, keys, scheduler=None, interval="100ms", complete=False, loop=None + self, + keys, + scheduler=None, + interval="100ms", + complete=False, + loop=None, + **kwargs ): super(ProgressWidget, self).__init__(keys, scheduler, interval, complete) @@ -207,7 +214,13 @@ def _draw_bar(self, remaining, all, **kwargs): class MultiProgressBar(object): def __init__( - self, keys, scheduler=None, func=key_split, interval="100ms", complete=False + self, + keys, + scheduler=None, + func=key_split, + interval="100ms", + complete=False, + **kwargs ): self.scheduler = get_scheduler(scheduler) @@ -306,6 +319,7 @@ def __init__( interval=0.1, func=key_split, complete=False, + **kwargs ): super(MultiProgressWidget, self).__init__( keys, scheduler, func, interval, complete @@ -425,7 +439,6 @@ def progress(*futures, **kwargs): notebook = kwargs.pop("notebook", None) multi = kwargs.pop("multi", True) complete = kwargs.pop("complete", True) - assert not kwargs futures = futures_of(futures) if not isinstance(futures, (set, list)): @@ -434,9 +447,9 @@ def progress(*futures, **kwargs): notebook = is_kernel() # often but not always correct assumption if notebook: if multi: - bar = MultiProgressWidget(futures, complete=complete) + bar = MultiProgressWidget(futures, complete=complete, **kwargs) else: - bar = ProgressWidget(futures, complete=complete) + bar = ProgressWidget(futures, complete=complete, **kwargs) return bar else: - TextProgressBar(futures, complete=complete) + TextProgressBar(futures, complete=complete, **kwargs) diff --git a/distributed/diagnostics/tests/test_progressbar.py b/distributed/diagnostics/tests/test_progressbar.py index ac21f1637bc..ba42f2ce6ea 100644 --- a/distributed/diagnostics/tests/test_progressbar.py +++ b/distributed/diagnostics/tests/test_progressbar.py @@ -76,3 +76,11 @@ def test_progress_function(client, capsys): progress(f) check_bar_completed(capsys) + + +def test_progress_function_w_kwargs(client, capsys): + f = client.submit(lambda: 1) + g = client.submit(lambda: 2) + + progress(f, interval="20ms") + check_bar_completed(capsys) From d9626a59fa0ee5953293666591a083d2c249ddc1 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 30 May 2019 15:47:16 -0700 Subject: [PATCH 0310/1550] Close nannies gracefully (#2731) Previously a worker process could be stopped before it told its nanny that it was going away. Now we intentionally tell the nanny ahead of time from the scheduler (and the worker for good measure) before we start the shutdown procedure. --- distributed/deploy/tests/test_local.py | 1 + distributed/nanny.py | 11 ++++++++++- distributed/scheduler.py | 1 + distributed/utils_test.py | 6 ++++++ distributed/worker.py | 7 ++++++- 5 files changed, 24 insertions(+), 2 deletions(-) diff --git a/distributed/deploy/tests/test_local.py b/distributed/deploy/tests/test_local.py index 6f9a4a03244..8aad6675f8c 100644 --- a/distributed/deploy/tests/test_local.py +++ b/distributed/deploy/tests/test_local.py @@ -474,6 +474,7 @@ def test_death_timeout_raises(loop): loop=loop, ) as cluster: pass + LocalCluster._instances.clear() # ignore test hygiene checks @pytest.mark.skipif(sys.version_info < (3, 6), reason="Unknown") diff --git a/distributed/nanny.py b/distributed/nanny.py index a27f713ea6b..8d2a38192d1 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -131,6 +131,7 @@ def __init__( "restart": self.restart, # cannot call it 'close' on the rpc side for naming conflict "terminate": self.close, + "close_gracefully": self.close_gracefully, "run": self.run, } @@ -355,7 +356,7 @@ def _on_exit(self, exitcode): return try: - if self.status not in ("closing", "closed"): + if self.status not in ("closing", "closed", "closing-gracefully"): if self.auto_restart: logger.warning("Restarting worker") yield self.instantiate() @@ -372,6 +373,14 @@ def _close(self, *args, **kwargs): warnings.warn("Worker._close has moved to Worker.close", stacklevel=2) return self.close(*args, **kwargs) + def close_gracefully(self, comm=None): + """ + A signal that we shouldn't try to restart workers if they go away + + This is used as part of the cluster shutdown process. + """ + self.status = "closing-gracefully" + @gen.coroutine def close(self, comm=None, timeout=5, report=None): """ diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 9db6477aeb7..ca3c1241ea7 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1265,6 +1265,7 @@ def close(self, comm=None, fast=False, close_workers=False): setproctitle("dask-scheduler [closing]") if close_workers: + self.broadcast(msg={"op": "close_gracefully"}, nanny=True) for worker in self.workers: self.worker_send(worker, {"op": "close"}) for i in range(20): # wait a second for send signals to clear diff --git a/distributed/utils_test.py b/distributed/utils_test.py index d61046f2a48..10784c6f759 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -44,6 +44,7 @@ from .comm.utils import offload from .config import initialize_logging from .core import connect, rpc, CommClosedError +from .deploy import SpecCluster from .metrics import time from .process import _cleanup_dangling from .proctitle import enable_proctitle_on_children @@ -1477,6 +1478,7 @@ def check_instances(): Client._instances.clear() Worker._instances.clear() Scheduler._instances.clear() + SpecCluster._instances.clear() # assert all(n.status == "closed" for n in Nanny._instances), { # n: n.status for n in Nanny._instances # } @@ -1514,6 +1516,10 @@ def check_instances(): n: n.status for n in Nanny._instances } + # assert not list(SpecCluster._instances) # TODO + assert all(c.status == "closed" for c in SpecCluster._instances) + SpecCluster._instances.clear() + Nanny._instances.clear() DequeHandler.clear_all_instances() diff --git a/distributed/worker.py b/distributed/worker.py index 711dad31651..d0bc735ec67 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1007,6 +1007,11 @@ def close(self, report=True, timeout=10, nanny=True, executor_wait=True): except ValueError: # address not available if already closed logger.info("Stopping worker") self.status = "closing" + + if nanny and self.nanny: + with self.rpc(self.nanny) as r: + yield r.close_gracefully() + setproctitle("dask-worker [closing]") yield [ @@ -1015,7 +1020,6 @@ def close(self, report=True, timeout=10, nanny=True, executor_wait=True): if hasattr(plugin, "teardown") ] - self.stop() for pc in self.periodic_callbacks.values(): pc.stop() with ignoring(EnvironmentError, gen.TimeoutError): @@ -1047,6 +1051,7 @@ def close(self, report=True, timeout=10, nanny=True, executor_wait=True): with self.rpc(self.nanny) as r: yield r.terminate() + self.stop() self.rpc.close() self._closed.set() From a8504d6d4a007ea5d427c2d17434b3dd22350e0a Mon Sep 17 00:00:00 2001 From: Benjamin Zaitlen Date: Fri, 31 May 2019 09:51:47 -0400 Subject: [PATCH 0311/1550] Add Experimental UCX Comm (#2591) --- distributed/cli/dask_scheduler.py | 5 + distributed/cli/dask_worker.py | 5 + distributed/comm/__init__.py | 5 + distributed/comm/addressing.py | 4 +- distributed/comm/tests/__init__.py | 0 distributed/comm/tests/test_comms.py | 14 +- distributed/comm/tests/test_ucx.py | 296 +++++++++++++++++++++++ distributed/comm/ucx.py | 308 ++++++++++++++++++++++++ distributed/core.py | 10 +- distributed/deploy/local.py | 8 +- distributed/preloading.py | 2 + distributed/protocol/__init__.py | 19 ++ distributed/protocol/core.py | 1 + distributed/protocol/cuda.py | 33 +++ distributed/protocol/cudf.py | 74 ++++++ distributed/protocol/cupy.py | 42 ++++ distributed/protocol/numba.py | 61 +++++ distributed/protocol/tests/test_cupy.py | 12 + distributed/protocol/utils.py | 5 +- 19 files changed, 893 insertions(+), 11 deletions(-) create mode 100644 distributed/comm/tests/__init__.py create mode 100644 distributed/comm/tests/test_ucx.py create mode 100644 distributed/comm/ucx.py create mode 100644 distributed/protocol/cuda.py create mode 100644 distributed/protocol/cudf.py create mode 100644 distributed/protocol/cupy.py create mode 100644 distributed/protocol/numba.py create mode 100644 distributed/protocol/tests/test_cupy.py diff --git a/distributed/cli/dask_scheduler.py b/distributed/cli/dask_scheduler.py index 1f78426f635..b27e68eaa9a 100755 --- a/distributed/cli/dask_scheduler.py +++ b/distributed/cli/dask_scheduler.py @@ -39,6 +39,9 @@ default=None, help="Preferred network interface like 'eth0' or 'ib0'", ) +@click.option( + "--protocol", type=str, default=None, help="Protocol like tcp, tls, or ucx" +) @click.option( "--tls-ca-file", type=pem_file_option_type, @@ -121,6 +124,7 @@ def main( pid_file, scheduler_file, interface, + protocol, local_directory, preload, preload_argv, @@ -190,6 +194,7 @@ def del_pid_file(): host=host, port=port, interface=interface, + protocol=protocol, dashboard_address=dashboard_address if dashboard else None, service_kwargs={"dashboard": {"prefix": dashboard_prefix}}, ) diff --git a/distributed/cli/dask_worker.py b/distributed/cli/dask_worker.py index c4f83f61405..2cf570cfc1d 100755 --- a/distributed/cli/dask_worker.py +++ b/distributed/cli/dask_worker.py @@ -104,6 +104,9 @@ @click.option( "--interface", type=str, default=None, help="Network interface like 'eth0' or 'ib0'" ) +@click.option( + "--protocol", type=str, default=None, help="Protocol like tcp, tls, or ucx" +) @click.option("--nthreads", type=int, default=0, help="Number of threads per process.") @click.option( "--nprocs", @@ -197,6 +200,7 @@ def main( local_directory, scheduler_file, interface, + protocol, death_timeout, preload, preload_argv, @@ -338,6 +342,7 @@ def del_pid_file(): security=sec, contact_address=contact_address, interface=interface, + protocol=protocol, host=host, port=port, dashboard_address=dashboard_address if dashboard else None, diff --git a/distributed/comm/__init__.py b/distributed/comm/__init__.py index 0f7c701847d..e0615b38c7a 100644 --- a/distributed/comm/__init__.py +++ b/distributed/comm/__init__.py @@ -18,5 +18,10 @@ def _register_transports(): from . import inproc from . import tcp + try: + from . import ucx + except ImportError: + pass + _register_transports() diff --git a/distributed/comm/addressing.py b/distributed/comm/addressing.py index 3d79befe0f1..d707adb84ac 100644 --- a/distributed/comm/addressing.py +++ b/distributed/comm/addressing.py @@ -72,6 +72,8 @@ def _default(): raise ValueError("missing port number in address %r" % (address,)) return default_port + if "://" in address: + _, address = address.split("://") if address.startswith("["): # IPv6 notation: '[addr]:port' or '[addr]'. # The address may contain multiple colons. @@ -101,7 +103,7 @@ def unparse_host_port(host, port=None): """ if ":" in host and not host.startswith("["): host = "[%s]" % host - if port: + if port is not None: return "%s:%s" % (host, port) else: return host diff --git a/distributed/comm/tests/__init__.py b/distributed/comm/tests/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/distributed/comm/tests/test_comms.py b/distributed/comm/tests/test_comms.py index 0e8782718a0..e761deeab86 100644 --- a/distributed/comm/tests/test_comms.py +++ b/distributed/comm/tests/test_comms.py @@ -26,6 +26,7 @@ from distributed.protocol import to_serialize, Serialized, serialize, deserialize +from distributed.comm.registry import backends from distributed.comm import ( tcp, inproc, @@ -40,7 +41,6 @@ get_local_address_for, ) - EXTERNAL_IP4 = get_ip() if has_ipv6(): with warnings.catch_warnings(record=True): @@ -154,7 +154,6 @@ def test_unparse_host_port(): assert f("[::1]", 123) == "[::1]:123" assert f("127.0.0.1") == "127.0.0.1" - assert f("127.0.0.1", 0) == "127.0.0.1" assert f("127.0.0.1", None) == "127.0.0.1" assert f("127.0.0.1", "*") == "127.0.0.1:*" @@ -488,7 +487,7 @@ def handle_comm(comm): # Check listener properties bound_addr = listener.listen_address bound_scheme, bound_loc = parse_address(bound_addr) - assert bound_scheme in ("inproc", "tcp", "tls") + assert bound_scheme in backends assert bound_scheme == parse_address(addr)[0] if check_listen_addr is not None: @@ -530,6 +529,15 @@ def client_communicate(key, delay=0): listener.stop() +@gen_test() +def test_ucx_client_server(): + pytest.importorskip("distributed.comm.ucx") + import ucp + + addr = ucp.get_address() + yield check_client_server("ucx://" + addr) + + def tcp_eq(expected_host, expected_port=None): def checker(loc): host, port = parse_host_port(loc) diff --git a/distributed/comm/tests/test_ucx.py b/distributed/comm/tests/test_ucx.py new file mode 100644 index 00000000000..55a2f4ec82c --- /dev/null +++ b/distributed/comm/tests/test_ucx.py @@ -0,0 +1,296 @@ +import asyncio + +import pytest + +ucp = pytest.importorskip("ucp") + +from distributed import Client +from distributed.comm import ucx, listen, connect +from distributed.comm.registry import backends, get_backend +from distributed.comm import ucx, parse_address +from distributed.protocol import to_serialize +from distributed.deploy.local import LocalCluster +from distributed.utils_test import gen_test, loop, inc # noqa: 401 + +from .test_comms import check_deserialize + + +HOST = ucp.get_address() + + +def test_registered(): + assert "ucx" in backends + backend = get_backend("ucx") + assert isinstance(backend, ucx.UCXBackend) + + +async def get_comm_pair( + listen_addr="ucx://" + HOST, listen_args=None, connect_args=None, **kwargs +): + q = asyncio.queues.Queue() + + async def handle_comm(comm): + await q.put(comm) + + # Workaround for hanging test in + # pytest distributed/comm/tests/test_ucx.py::test_comm_objs -vs --count=2 + # on the second time through. + ucp._libs.ucp_py.reader_added = 0 + + listener = listen(listen_addr, handle_comm, connection_args=listen_args, **kwargs) + with listener: + comm = await connect( + listener.contact_address, connection_args=connect_args, **kwargs + ) + serv_com = await q.get() + return comm, serv_com + + +@pytest.mark.asyncio +async def test_ping_pong(): + com, serv_com = await get_comm_pair() + msg = {"op": "ping"} + await com.write(msg) + result = await serv_com.read() + assert result == msg + result["op"] = "pong" + + await serv_com.write(result) + + result = await com.read() + assert result == {"op": "pong"} + + await com.close() + await serv_com.close() + + +@pytest.mark.asyncio +async def test_comm_objs(): + comm, serv_comm = await get_comm_pair() + + scheme, loc = parse_address(comm.peer_address) + assert scheme == "ucx" + + scheme, loc = parse_address(serv_comm.peer_address) + assert scheme == "ucx" + + assert comm.peer_address == serv_comm.local_address + + +def test_ucx_specific(): + """ + Test concrete UCX API. + """ + # TODO: + # 1. ensure exceptions in handle_comm fail the test + # 2. Use dict in read / write, put seralization there. + # 3. Test peer_address + # 4. Test cleanup + async def f(): + address = "ucx://{}:{}".format(HOST, 0) + + async def handle_comm(comm): + msg = await comm.read() + msg["op"] = "pong" + await comm.write(msg) + assert comm.closed() is False + await comm.close() + assert comm.closed + + listener = ucx.UCXListener(address, handle_comm) + listener.start() + host, port = listener.get_host_port() + assert host.count(".") == 3 + assert port > 0 + + connector = ucx.UCXConnector() + l = [] + + async def client_communicate(key, delay=0): + addr = "%s:%d" % (host, port) + comm = await connector.connect(addr) + # TODO: peer_address + # assert comm.peer_address == 'ucx://' + addr + assert comm.extra_info == {} + msg = {"op": "ping", "data": key} + await comm.write(msg) + if delay: + await asyncio.sleep(delay) + msg = await comm.read() + assert msg == {"op": "pong", "data": key} + l.append(key) + return comm + assert comm.closed() is False + await comm.close() + assert comm.closed + + comm = await client_communicate(key=1234, delay=0.5) + + # Many clients at once + N = 2 + futures = [client_communicate(key=i, delay=0.05) for i in range(N)] + await asyncio.gather(*futures) + assert set(l) == {1234} | set(range(N)) + + asyncio.run(f()) + + +@pytest.mark.asyncio +async def test_ping_pong_data(): + np = pytest.importorskip("numpy") + + data = np.ones((10, 10)) + + com, serv_com = await get_comm_pair() + msg = {"op": "ping", "data": to_serialize(data)} + await com.write(msg) + result = await serv_com.read() + result["op"] = "pong" + data2 = result.pop("data") + np.testing.assert_array_equal(data2, data) + + await serv_com.write(result) + + result = await com.read() + assert result == {"op": "pong"} + + await com.close() + await serv_com.close() + + +@gen_test() +def test_ucx_deserialize(): + yield check_deserialize("tcp://") + + +@pytest.mark.asyncio +async def test_ping_pong_cudf(): + # if this test appears after cupy an import error arises + # *** ImportError: /usr/lib/x86_64-linux-gnu/libstdc++.so.6: version `CXXABI_1.3.11' + # not found (required by python3.7/site-packages/pyarrow/../../../libarrow.so.12) + cudf = pytest.importorskip("cudf") + + df = cudf.DataFrame({"A": [1, 2, None], "B": [1.0, 2.0, None]}) + + com, serv_com = await get_comm_pair() + msg = {"op": "ping", "data": to_serialize(df)} + + await com.write(msg) + result = await serv_com.read() + data2 = result.pop("data") + assert result["op"] == "ping" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("shape", [(100,), (10, 10), (4947,)]) +async def test_ping_pong_cupy(shape): + cupy = pytest.importorskip("cupy") + com, serv_com = await get_comm_pair() + + arr = cupy.random.random(shape) + msg = {"op": "ping", "data": to_serialize(arr)} + + _, result = await asyncio.gather(com.write(msg), serv_com.read()) + data2 = result.pop("data") + + assert result["op"] == "ping" + cupy.testing.assert_array_equal(arr, data2) + await com.close() + await serv_com.close() + + +@pytest.mark.slow +@pytest.mark.asyncio +@pytest.mark.parametrize( + "n", + [ + int(1e9), + pytest.param( + int(2.5e9), marks=[pytest.mark.xfail(reason="integer type in ucx-py")] + ), + ], +) +async def test_large_cupy(n): + cupy = pytest.importorskip("cupy") + com, serv_com = await get_comm_pair() + + arr = cupy.ones(n, dtype="u1") + msg = {"op": "ping", "data": to_serialize(arr)} + + _, result = await asyncio.gather(com.write(msg), serv_com.read()) + data2 = result.pop("data") + + assert result["op"] == "ping" + assert len(data2) == len(arr) + await com.close() + await serv_com.close() + + +@pytest.mark.asyncio +async def test_ping_pong_numba(): + np = pytest.importorskip("numpy") + numba = pytest.importorskip("numba") + import numba.cuda + + arr = np.arange(10) + arr = numba.cuda.to_device(arr) + + com, serv_com = await get_comm_pair() + msg = {"op": "ping", "data": to_serialize(arr)} + + await com.write(msg) + result = await serv_com.read() + data2 = result.pop("data") + assert result["op"] == "ping" + + +@pytest.mark.skip(reason="hangs") +@pytest.mark.parametrize("processes", [True, False]) +def test_ucx_localcluster(loop, processes): + if processes: + kwargs = {"env": {"UCX_MEMTYPE_CACHE": "n"}} + else: + kwargs = {} + + ucx_addr = ucp.get_address() + with LocalCluster( + protocol="ucx", + interface="ib0", + dashboard_address=None, + n_workers=2, + threads_per_worker=1, + processes=processes, + loop=loop, + **kwargs, + ) as cluster: + with Client(cluster) as client: + x = client.submit(inc, 1) + x.result() + assert x.key in cluster.scheduler.tasks + if not processes: + assert any(w.data == {x.key: 2} for w in cluster.workers.values()) + assert len(cluster.scheduler.workers) == 2 + + +def test_tcp_localcluster(loop): + ucx_addr = "127.0.0.1" + port = 13337 + env = {"UCX_MEMTYPE_CACHE": "n"} + with LocalCluster( + 2, + scheduler_port=port, + ip=ucx_addr, + processes=True, + threads_per_worker=1, + dashboard_address=None, + silence_logs=False, + env=env, + ) as cluster: + pass + # with Client(cluster) as e: + # x = e.submit(inc, 1) + # x.result() + # assert x.key in c.scheduler.tasks + # assert any(w.data == {x.key: 2} for w in c.workers) + # assert e.loop is c.loop + # print(c.scheduler.workers) diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py new file mode 100644 index 00000000000..3f3f0bfe943 --- /dev/null +++ b/distributed/comm/ucx.py @@ -0,0 +1,308 @@ +""" +:ref:`UCX`_ based communications for distributed. + +See :ref:`communications` for more. + +.. _UCX: https://github.com/openucx/ucx +""" +import asyncio +import logging +import struct + +from .addressing import parse_host_port, unparse_host_port +from .core import Comm, Connector, Listener, CommClosedError +from .registry import Backend, backends +from .utils import ensure_concrete_host, to_frames, from_frames +from ..utils import ensure_ip, get_ip, get_ipv6, nbytes + +import ucp + +import os + +os.environ.setdefault("UCX_RNDV_SCHEME", "put_zcopy") +os.environ.setdefault("UCX_MEMTYPE_CACHE", "n") +os.environ.setdefault("UCX_TLS", "rc,cuda_copy") + +logger = logging.getLogger(__name__) +MAX_MSG_LOG = 23 + + +# ---------------------------------------------------------------------------- +# Comm Interface +# ---------------------------------------------------------------------------- + + +class UCX(Comm): + """Comm object using UCP. + + Parameters + ---------- + ep : ucp.Endpoint + The UCP endpoint. + address : str + The address, prefixed with `ucx://` to use. + deserialize : bool, default True + Whether to deserialize data in :meth:`distributed.protocol.loads` + + Notes + ----- + The read-write cycle uses the following pattern: + + Each msg is serialized into a number of "data" frames. We prepend these + real frames with two additional frames + + 1. is_gpus: Boolean indicator for whether the frame should be + received into GPU memory. Packed in '?' format. Unpack with + ``?`` format. + 2. frame_size : Unsigned int describing the size of frame (in bytes) + to receive. Packed in 'Q' format, so a length-0 frame is equivalent + to an unsized frame. Unpacked with ``Q``. + + The expected read cycle is + + 1. Read the frame describing number of frames + 2. Read the frame describing whether each data frame is gpu-bound + 3. Read the frame describing whether each data frame is sized + 4. Read all the data frames. + """ + + def __init__( + self, ep: ucp.Endpoint, local_addr: str, peer_addr: str, deserialize=True + ): + Comm.__init__(self) + self._ep = ep + if local_addr: + assert local_addr.startswith("ucx") + assert peer_addr.startswith("ucx") + self._local_addr = local_addr + self._peer_addr = peer_addr + self.deserialize = deserialize + self.comm_flag = None + logger.debug("UCX.__init__ %s", self) + + @property + def local_address(self) -> str: + return self._local_addr + + @property + def peer_address(self) -> str: + return self._peer_addr + + async def write( + self, + msg: dict, + serializers=("cuda", "dask", "pickle", "error"), + on_error: str = "message", + ): + if serializers is None: + serializers = ("cuda", "dask", "pickle", "error") + # msg can also be a list of dicts when sending batched messages + frames = await to_frames(msg, serializers=serializers, on_error=on_error) + is_gpus = b"".join( + [ + struct.pack("?", hasattr(frame, "__cuda_array_interface__")) + for frame in frames + ] + ) + sizes = b"".join([struct.pack("Q", nbytes(frame)) for frame in frames]) + + nframes = struct.pack("Q", len(frames)) + + meta = b"".join([nframes, is_gpus, sizes]) + + await self.ep.send_obj(meta) + + for frame in frames: + await self.ep.send_obj(frame) + return sum(map(nbytes, frames)) + + async def read(self, deserializers=("cuda", "dask", "pickle", "error")): + if deserializers is None: + deserializers = ("cuda", "dask", "pickle", "error") + resp = await self.ep.recv_future() + obj = ucp.get_obj_from_msg(resp) + nframes, = struct.unpack("Q", obj[:8]) # first eight bytes for number of frames + + gpu_frame_msg = obj[ + 8 : 8 + nframes + ] # next nframes bytes for if they're GPU frames + is_gpus = struct.unpack("{}?".format(nframes), gpu_frame_msg) + + sized_frame_msg = obj[8 + nframes :] # then the rest for frame sizes + sizes = struct.unpack("{}Q".format(nframes), sized_frame_msg) + + frames = [] + + for i, (is_gpu, size) in enumerate(zip(is_gpus, sizes)): + if size > 0: + resp = await self.ep.recv_obj(size, cuda=is_gpu) + else: + resp = await self.ep.recv_future() + frame = ucp.get_obj_from_msg(resp) + frames.append(frame) + + msg = await from_frames( + frames, deserialize=self.deserialize, deserializers=deserializers + ) + + return msg + + def abort(self): + if self._ep: + ucp.destroy_ep(self._ep) + logger.debug("Destroyed UCX endpoint") + self._ep = None + + @property + def ep(self): + if self._ep: + return self._ep + else: + raise CommClosedError("UCX Endpoint is closed") + + async def close(self): + # TODO: Handle in-flight messages? + # sleep is currently used to help flush buffer + self.abort() + + def closed(self): + return self._ep is None + + +class UCXConnector(Connector): + prefix = "ucx://" + comm_class = UCX + encrypted = False + + async def connect(self, address: str, deserialize=True, **connection_args) -> UCX: + logger.debug("UCXConnector.connect: %s", address) + ucp.init() + ip, port = parse_host_port(address) + ep = await ucp.get_endpoint(ip.encode(), port) + return self.comm_class( + ep, + local_addr=None, + peer_addr=self.prefix + address, + deserialize=deserialize, + ) + + +class UCXListener(Listener): + # MAX_LISTENERS 256 in ucx-py + prefix = UCXConnector.prefix + comm_class = UCXConnector.comm_class + encrypted = UCXConnector.encrypted + + def __init__( + self, address: str, comm_handler: None, deserialize=False, **connection_args + ): + if not address.startswith("ucx"): + address = "ucx://" + address + self.ip, self._input_port = parse_host_port(address, default_port=0) + self.comm_handler = comm_handler + self.deserialize = deserialize + self._ep = None # type: ucp.Endpoint + self.listener_instance = None # type: ucp.ListenerFuture + self.ucp_server = None + self._task = None + + self.connection_args = connection_args + self._task = None + + @property + def port(self): + return self.ucp_server.port + + @property + def address(self): + return "ucx://" + self.ip + ":" + str(self.port) + + def start(self): + async def serve_forever(client_ep, listener_instance): + ucx = UCX( + client_ep, + local_addr=self.address, + peer_addr=self.address, # TODO: https://github.com/Akshay-Venkatesh/ucx-py/issues/111 + deserialize=self.deserialize, + ) + self.listener_instance = listener_instance + if self.comm_handler: + await self.comm_handler(ucx) + + ucp.init() + self.ucp_server = ucp.start_listener( + serve_forever, listener_port=self._input_port, is_coroutine=True + ) + + try: + loop = asyncio.get_running_loop() + except (RuntimeError, AttributeError): + loop = asyncio.get_event_loop() + + t = loop.create_task(self.ucp_server.coroutine) + self._task = t + + def stop(self): + # What all should this do? + if self._task: + self._task.cancel() + + if self._ep: + ucp.destroy_ep(self._ep) + # if self.listener_instance: + # ucp.stop_listener(self.listener_instance) + + def get_host_port(self): + # TODO: TCP raises if this hasn't started yet. + return self.ip, self.port + + @property + def listen_address(self): + return self.prefix + unparse_host_port(*self.get_host_port()) + + @property + def contact_address(self): + host, port = self.get_host_port() + host = ensure_concrete_host(host) # TODO: ensure_concrete_host + return self.prefix + unparse_host_port(host, port) + + @property + def bound_address(self): + # TODO: Does this become part of the base API? Kinda hazy, since + # we exclude in for inproc. + return self.get_host_port() + + +class UCXBackend(Backend): + # I / O + + def get_connector(self): + return UCXConnector() + + def get_listener(self, loc, handle_comm, deserialize, **connection_args): + return UCXListener(loc, handle_comm, deserialize, **connection_args) + + # Address handling + # This duplicates BaseTCPBackend + + def get_address_host(self, loc): + return parse_host_port(loc)[0] + + def get_address_host_port(self, loc): + return parse_host_port(loc) + + def resolve_address(self, loc): + host, port = parse_host_port(loc) + return unparse_host_port(ensure_ip(host), port) + + def get_local_address_for(self, loc): + host, port = parse_host_port(loc) + host = ensure_ip(host) + if ":" in host: + local_host = get_ipv6(host) + else: + local_host = get_ip(host) + return unparse_host_port(local_host, None) + + +backends["ucx"] = UCXBackend() diff --git a/distributed/core.py b/distributed/core.py index 17685c9d2d5..79c726eed6d 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -484,7 +484,7 @@ def handle_stream(self, comm, extra=None, every_cycle=[]): pdb.set_trace() raise finally: - comm.close() # TODO: why do we need this now? + yield comm.close() assert comm.closed() @gen.coroutine @@ -498,7 +498,7 @@ def close(self): break else: yield gen.sleep(0.05) - yield [comm.close() for comm in self._comms] + yield [comm.close() for comm in self._comms] # then forcefully close for cb in self._ongoing_coroutines: cb.cancel() for i in range(10): @@ -901,7 +901,7 @@ def collect(self): ) for addr, comms in self.available.items(): for comm in comms: - comm.close() + IOLoop.current().add_callback(comm.close) comms.clear() if self.open < self.limit: self.event.set() @@ -914,11 +914,11 @@ def remove(self, addr): if addr in self.available: comms = self.available.pop(addr) for comm in comms: - comm.close() + IOLoop.current().add_callback(comm.close) if addr in self.occupied: comms = self.occupied.pop(addr) for comm in comms: - comm.close() + IOLoop.current().add_callback(comm.close) if self.open < self.limit: self.event.set() diff --git a/distributed/deploy/local.py b/distributed/deploy/local.py index 298c47d7a31..95f178c7c2e 100644 --- a/distributed/deploy/local.py +++ b/distributed/deploy/local.py @@ -195,7 +195,13 @@ def __init__( asynchronous=asynchronous, silence_logs=silence_logs, ) - self.scale(n_workers) + + def __repr__(self): + return "LocalCluster(%r, workers=%d, ncores=%d)" % ( + self.scheduler_address, + len(self.workers), + sum(w.ncores for w in self.workers.values()), + ) def __repr__(self): return "LocalCluster(%r, workers=%d, ncores=%d)" % ( diff --git a/distributed/preloading.py b/distributed/preloading.py index 0f08f60f71c..a5e67c1611a 100644 --- a/distributed/preloading.py +++ b/distributed/preloading.py @@ -100,6 +100,7 @@ def _import_modules(names, file_dir=None): import_module(name) module = sys.modules[name] + logger.info("Import preload module: %s", name) result_modules[name] = { attrname: getattr(module, attrname, None) for attrname in ("dask_setup", "dask_teardown") @@ -137,6 +138,7 @@ def preload_modules(names, parameter=None, file_dir=None, argv=None): dask_setup.callback(parameter, *context.args, **context.params) else: dask_setup(parameter) + logger.info("Run preload setup function: %s", name) if interface["dask_teardown"]: atexit.register(interface["dask_teardown"], parameter) diff --git a/distributed/protocol/__init__.py b/distributed/protocol/__init__.py index 04691ce605d..3f98436f4b9 100644 --- a/distributed/protocol/__init__.py +++ b/distributed/protocol/__init__.py @@ -4,6 +4,7 @@ from .compression import compressions, default_compression from .core import dumps, loads, maybe_compress, decompress, msgpack +from .cuda import cuda_serialize, cuda_deserialize from .serialize import ( serialize, deserialize, @@ -66,3 +67,21 @@ def _register_arrow(): @dask_deserialize.register_lazy("torchvision") def _register_torch(): from . import torch + + +@cuda_serialize.register_lazy("cupy") +@cuda_deserialize.register_lazy("cupy") +def _register_cupy(): + from . import cupy + + +@cuda_serialize.register_lazy("numba") +@cuda_deserialize.register_lazy("numba") +def _register_numba(): + from . import numba + + +@cuda_serialize.register_lazy("cudf") +@cuda_deserialize.register_lazy("cudf") +def _register_cudf(): + from . import cudf diff --git a/distributed/protocol/core.py b/distributed/protocol/core.py index c1b62b2491e..d54dd2e533e 100644 --- a/distributed/protocol/core.py +++ b/distributed/protocol/core.py @@ -176,6 +176,7 @@ def loads_msgpack(header, payload): See Also: dumps_msgpack """ + header = bytes(header) if header: header = msgpack.loads(header, use_list=False, **msgpack_opts) else: diff --git a/distributed/protocol/cuda.py b/distributed/protocol/cuda.py new file mode 100644 index 00000000000..13be1d75bb8 --- /dev/null +++ b/distributed/protocol/cuda.py @@ -0,0 +1,33 @@ +import dask + +from . import pickle +from .serialize import register_serialization_family +from dask.utils import typename + +cuda_serialize = dask.utils.Dispatch("cuda_serialize") +cuda_deserialize = dask.utils.Dispatch("cuda_deserialize") + + +def cuda_dumps(x): + type_name = typename(type(x)) + try: + dumps = cuda_serialize.dispatch(type(x)) + except TypeError: + raise NotImplementedError(type_name) + + header, frames = dumps(x) + + header["type"] = type_name + header["type-serialized"] = pickle.dumps(type(x)) + header["serializer"] = "cuda" + header["compression"] = (None,) * len(frames) # no compression for gpu data + return header, frames + + +def cuda_loads(header, frames): + typ = pickle.loads(header["type-serialized"]) + loads = cuda_deserialize.dispatch(typ) + return loads(header, frames) + + +register_serialization_family("cuda", cuda_dumps, cuda_loads) diff --git a/distributed/protocol/cudf.py b/distributed/protocol/cudf.py new file mode 100644 index 00000000000..018596b1560 --- /dev/null +++ b/distributed/protocol/cudf.py @@ -0,0 +1,74 @@ +import cudf +from .cuda import cuda_serialize, cuda_deserialize +from .numba import serialize_numba_ndarray, deserialize_numba_ndarray + + +# TODO: +# 1. Just use positions +# a. Fixes duplicate columns +# b. Fixes non-msgpack-serializable names +# 2. cudf.Series +# 3. Serialize the index + + +@cuda_serialize.register(cudf.DataFrame) +def serialize_cudf_dataframe(x): + sub_headers = [] + arrays = [] + null_masks = [] + null_headers = [] + null_counts = {} + + for label, col in x.iteritems(): + header, [frame] = serialize_numba_ndarray(col.data.mem) + header["name"] = label + sub_headers.append(header) + arrays.append(frame) + if col.null_count: + header, [frame] = serialize_numba_ndarray(col.nullmask.mem) + header["name"] = label + null_headers.append(header) + null_masks.append(frame) + null_counts[label] = col.null_count + + arrays.extend(null_masks) + + header = { + "is_cuda": len(arrays), + "subheaders": sub_headers, + # TODO: the header must be msgpack (de)serializable. + # See if we can avoid names, and just use integer positions. + "columns": x.columns.tolist(), + "null_counts": null_counts, + "null_subheaders": null_headers, + } + + return header, arrays + + +@cuda_deserialize.register(cudf.DataFrame) +def serialize_cudf_dataframe(header, frames): + columns = header["columns"] + n_columns = len(header["columns"]) + n_masks = len(header["null_subheaders"]) + + masks = {} + pairs = [] + + for i in range(n_masks): + subheader = header["null_subheaders"][i] + frame = frames[n_columns + i] + mask = deserialize_numba_ndarray(subheader, [frame]) + masks[subheader["name"]] = mask + + for subheader, frame in zip(header["subheaders"], frames[:n_columns]): + name = subheader["name"] + array = deserialize_numba_ndarray(subheader, [frame]) + + if name in masks: + series = cudf.Series.from_masked_array(array, masks[name]) + else: + series = cudf.Series(array) + pairs.append((name, series)) + + return cudf.DataFrame(pairs) diff --git a/distributed/protocol/cupy.py b/distributed/protocol/cupy.py new file mode 100644 index 00000000000..13c0348a821 --- /dev/null +++ b/distributed/protocol/cupy.py @@ -0,0 +1,42 @@ +""" +Efficient serialization GPU arrays. +""" +import cupy +from .cuda import cuda_serialize, cuda_deserialize + + +@cuda_serialize.register(cupy.ndarray) +def serialize_cupy_ndarray(x): + # TODO: handle non-contiguous + # TODO: Handle order='K' ravel + # TODO: 0d + + if x.flags.c_contiguous or x.flags.f_contiguous: + strides = x.strides + data = x.ravel() # order='K' + else: + x = cupy.ascontiguousarray(x) + strides = x.strides + data = x.ravel() + + dtype = (0, x.dtype.str) + + # used in the ucx comms for gpu/cpu message passing + # 'lengths' set by dask + header = x.__cuda_array_interface__.copy() + header["is_cuda"] = 1 + header["dtype"] = dtype + return header, [data] + + +@cuda_deserialize.register(cupy.ndarray) +def deserialize_cupy_array(header, frames): + frame, = frames + # TODO: put this in ucx... as a kind of "fixup" + try: + frame.typestr = header["typestr"] + frame.shape = header["shape"] + except AttributeError: + pass + arr = cupy.asarray(frame) + return arr diff --git a/distributed/protocol/numba.py b/distributed/protocol/numba.py new file mode 100644 index 00000000000..18405ffebe0 --- /dev/null +++ b/distributed/protocol/numba.py @@ -0,0 +1,61 @@ +import numba.cuda +from .cuda import cuda_serialize, cuda_deserialize + + +@cuda_serialize.register(numba.cuda.devicearray.DeviceNDArray) +def serialize_numba_ndarray(x): + # TODO: handle non-contiguous + # TODO: handle 2d + # TODO: 0d + + if x.flags["C_CONTIGUOUS"] or x.flags["F_CONTIGUOUS"]: + strides = x.strides + if x.ndim > 1: + data = x.ravel() # order='K' + else: + data = x + else: + raise ValueError("Array must be contiguous") + x = numba.ascontiguousarray(x) + strides = x.strides + if x.ndim > 1: + data = x.ravel() + else: + data = x + + dtype = (0, x.dtype.str) + nbytes = data.dtype.itemsize * data.size + + # used in the ucx comms for gpu/cpu message passing + # 'lengths' set by dask + header = x.__cuda_array_interface__.copy() + header["is_cuda"] = 1 + header["dtype"] = dtype + return header, [data] + + +@cuda_deserialize.register(numba.cuda.devicearray.DeviceNDArray) +def deserialize_numba_ndarray(header, frames): + frame, = frames + # TODO: put this in ucx... as a kind of "fixup" + if isinstance(frame, bytes): + import numpy as np + + arr2 = np.frombuffer(frame, header["typestr"]) + return numba.cuda.to_device(arr2) + + frame.typestr = header["typestr"] + frame.shape = header["shape"] + + # numba & cupy don't properly roundtrip length-zero arrays. + if frame.shape[0] == 0: + arr = numba.cuda.device_array( + header["shape"], + header["typestr"] + # strides? + # order? + ) + return arr + + arr = numba.cuda.as_cuda_array(frame) + return arr diff --git a/distributed/protocol/tests/test_cupy.py b/distributed/protocol/tests/test_cupy.py new file mode 100644 index 00000000000..26940597f81 --- /dev/null +++ b/distributed/protocol/tests/test_cupy.py @@ -0,0 +1,12 @@ +from distributed.protocol import serialize, deserialize +import pytest + +cupy = pytest.importorskip("cupy") + + +def test_serialize_cupy(): + x = cupy.arange(100) + header, frames = serialize(x, serializers=("cuda", "dask", "pickle")) + y = deserialize(header, frames, deserializers=("cuda", "dask", "pickle", "error")) + + assert (x == y).all() diff --git a/distributed/protocol/utils.py b/distributed/protocol/utils.py index 208caebb926..caf4bb8833b 100644 --- a/distributed/protocol/utils.py +++ b/distributed/protocol/utils.py @@ -90,7 +90,10 @@ def merge_frames(header, frames): L.append(mv[:l]) frames.append(mv[l:]) l = 0 - out.append(b"".join(map(ensure_bytes, L))) + if len(L) == 1: # no work necessary + out.extend(L) + else: + out.append(b"".join(map(ensure_bytes, L))) return out From 7c3b4d1c59b74b39ccfc0f579ce8295713bfb15a Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 31 May 2019 11:04:13 -0700 Subject: [PATCH 0312/1550] Pin pytest >=4 with pip in appveyor and python 3.5 (#2737) --- README.rst | 1 + continuous_integration/setup_conda_environment.cmd | 2 +- continuous_integration/travis/install.sh | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/README.rst b/README.rst index 2d2d285be2b..b6f0edd604f 100644 --- a/README.rst +++ b/README.rst @@ -3,4 +3,5 @@ Distributed A library for distributed computation. See documentation_ for more details. + .. _documentation: https://distributed.readthedocs.io/en/latest diff --git a/continuous_integration/setup_conda_environment.cmd b/continuous_integration/setup_conda_environment.cmd index 5748a8cf20c..3df89fa85fe 100644 --- a/continuous_integration/setup_conda_environment.cmd +++ b/continuous_integration/setup_conda_environment.cmd @@ -50,7 +50,7 @@ call activate %CONDA_ENV% %PIP_INSTALL% git+https://github.com/joblib/joblib.git --upgrade %PIP_INSTALL% git+https://github.com/dask/zict --upgrade -%PIP_INSTALL% pytest-repeat pytest-timeout pytest-faulthandler sortedcollections pytest-asyncio +%PIP_INSTALL% "pytest>=4" pytest-repeat pytest-timeout pytest-faulthandler sortedcollections pytest-asyncio @rem Display final environment (for reproducing) %CONDA% list diff --git a/continuous_integration/travis/install.sh b/continuous_integration/travis/install.sh index f1ff25a9bfa..cb2dbdf5c83 100644 --- a/continuous_integration/travis/install.sh +++ b/continuous_integration/travis/install.sh @@ -59,7 +59,7 @@ conda install -q \ conda install -c defaults -c conda-forge libunwind conda install --no-deps -c defaults -c numba -c conda-forge stacktrace -pip install -q pytest-repeat pytest-faulthandler pytest-asyncio +pip install -q "pytest>=4" pytest-repeat pytest-faulthandler pytest-asyncio pip install -q git+https://github.com/dask/dask.git --upgrade --no-deps pip install -q git+https://github.com/joblib/joblib.git --upgrade --no-deps From a16b8ff071938c08f64e78bb04636c3b4d619325 Mon Sep 17 00:00:00 2001 From: Caleb Date: Fri, 31 May 2019 16:03:39 -0700 Subject: [PATCH 0313/1550] Allow user to configure whether workers are daemon. (#2739) Closes #2718 --- .gitignore | 3 +++ distributed/distributed.yaml | 1 + distributed/nanny.py | 2 +- distributed/process.py | 2 +- distributed/tests/test_nanny.py | 35 +++++++++++++++++++++++++++++++++ 5 files changed, 41 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index a3a40e19289..2d70b7ebd7f 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,6 @@ continuous_integration/hdfs-initialized .pytest_cache/ dask-worker-space/ .vscode/ +*.swp +.ycm_extra_conf.py +tags diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index 3ae9b7ee690..4d78a698e69 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -32,6 +32,7 @@ distributed: incoming: 10 preload: [] preload-argv: [] + daemon: True profile: interval: 10ms # Time between statistical profiling queries diff --git a/distributed/nanny.py b/distributed/nanny.py index 8d2a38192d1..59a8083e832 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -464,7 +464,7 @@ def start(self): env=self.env, ), ) - self.process.daemon = True + self.process.daemon = dask.config.get("distributed.worker.daemon", default=True) self.process.set_exit_callback(self._on_exit) self.running = Event() self.stopped = Event() diff --git a/distributed/process.py b/distributed/process.py index 5dd9368fdc1..556edae290e 100644 --- a/distributed/process.py +++ b/distributed/process.py @@ -330,7 +330,7 @@ def daemon(self, value): @atexit.register def _cleanup_dangling(): for proc in list(_dangling): - if proc.daemon and proc.is_alive(): + if proc.is_alive(): try: logger.warning("reaping stray process %s" % (proc,)) proc.terminate() diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index be0a05afc20..1357a3679e2 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -5,6 +5,7 @@ import os import random import sys +import multiprocessing as mp import numpy as np @@ -344,3 +345,37 @@ def test_data_types(c, s): r = yield c.run(lambda dask_worker: type(dask_worker.data)) assert r[w.worker_address] == dict yield w.close() + + +def _noop(x): + """Define here because closures aren't pickleable.""" + pass + + +@gen_cluster( + ncores=[("127.0.0.1", 1)], + client=True, + Worker=Nanny, + config={"distributed.worker.daemon": False}, +) +def test_mp_process_worker_no_daemon(c, s, a): + def multiprocessing_worker(): + p = mp.Process(target=_noop, args=(None,)) + p.start() + p.join() + + yield c.submit(multiprocessing_worker) + + +@gen_cluster( + ncores=[("127.0.0.1", 1)], + client=True, + Worker=Nanny, + config={"distributed.worker.daemon": False}, +) +def test_mp_pool_worker_no_daemon(c, s, a): + def pool_worker(world_size): + with mp.Pool(processes=world_size) as p: + p.map(_noop, range(world_size)) + + yield c.submit(pool_worker, 4) From 861536ca2cbb6039ae9325c672f1b85f8124bc25 Mon Sep 17 00:00:00 2001 From: Michael Spiegel Date: Mon, 3 Jun 2019 17:37:41 +0200 Subject: [PATCH 0314/1550] Fix the resource key representation before sending graphs (#2716) (#2733) Convert resource key toples to a string representation before they are submitted to the scheduler. The commit is intended to fix #2716. The test case persists the result of a tiny DataFrame operation and checks the resource restrictions. --- distributed/client.py | 1 + distributed/tests/test_resources.py | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/distributed/client.py b/distributed/client.py index d924b608c61..22d89cda4e4 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -2351,6 +2351,7 @@ def _graph_to_futures( resources = self._expand_resources( resources, all_keys=itertools.chain(dsk, keys) ) + resources = {tokey(k): v for k, v in resources.items()} if retries: retries = self._expand_retries( diff --git a/distributed/tests/test_resources.py b/distributed/tests/test_resources.py index d7102ef5301..480532d912e 100644 --- a/distributed/tests/test_resources.py +++ b/distributed/tests/test_resources.py @@ -202,6 +202,24 @@ def test_persist_tuple(c, s, a, b): assert not b.data +@gen_cluster(client=True) +def test_resources_str(c, s, a, b): + pd = pytest.importorskip("pandas") + dd = pytest.importorskip("dask.dataframe") + + yield a.set_resources(MyRes=1) + + x = dd.from_pandas(pd.DataFrame({"A": [1, 2], "B": [3, 4]}), npartitions=1) + y = x.apply(lambda row: row.sum(), axis=1, meta=(None, "int64")) + yy = y.persist(resources={"MyRes": 1}) + yield wait(yy) + + ts_first = s.tasks[tokey(y.__dask_keys__()[0])] + assert ts_first.resource_restrictions == {"MyRes": 1} + ts_last = s.tasks[tokey(y.__dask_keys__()[-1])] + assert ts_last.resource_restrictions == {"MyRes": 1} + + @gen_cluster( client=True, ncores=[ From bcb765de543a0f15ddce5d1a2e86b7f4bdefde4f Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 3 Jun 2019 17:26:06 -0700 Subject: [PATCH 0315/1550] Add async context managers to scheduler/worker classes (#2745) --- distributed/node.py | 7 +++++++ distributed/tests/test_scheduler.py | 10 ++++++++++ 2 files changed, 17 insertions(+) diff --git a/distributed/node.py b/distributed/node.py index ff95a621877..4f0b9813a8e 100644 --- a/distributed/node.py +++ b/distributed/node.py @@ -130,3 +130,10 @@ def stop_services(self): @property def service_ports(self): return {k: v.port for k, v in self.services.items()} + + async def __aenter__(self): + await self + return self + + async def __aexit__(self, typ, value, traceback): + await self.close() diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 6df271ae34e..a0cbfabfa29 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -1572,3 +1572,13 @@ def test_dashboard_address(): s = yield Scheduler(dashboard_address="127.0.0.1", port=0) assert s.services["dashboard"].port yield s.close() + + +@pytest.mark.asyncio +async def test_async_context_manager(): + async with Scheduler(port=0) as s: + assert s.status == "running" + async with Worker(s.address) as w: + assert w.status == "running" + assert s.workers + assert not s.workers From 2a1e089a9dc541a0a19c0fca575c42c719275bd9 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 5 Jun 2019 10:09:19 -0700 Subject: [PATCH 0316/1550] Worker dashboard fixes (#2747) * bokeh -> dashboard in template * Add doc to ProfileTimePlot * Add test for bokeh worker routes * Remove info route from worker To do this we ... 1. Remove the baked in "info" link in the base template 2. Add that to the scheduler's list of links 3. Add a redirect from "info" to the actual page 4. Create a generic redirect route 5. Move that and the RequestHandler to utils to avoid code duplication between scheduler and worker Fixes https://github.com/dask/distributed/issues/2722 * Add worker name Fixes https://github.com/dask/dask/issues/4878 --- distributed/dashboard/scheduler.py | 19 +++++++----- distributed/dashboard/scheduler_html.py | 15 ++-------- distributed/dashboard/templates/base.html | 3 -- .../dashboard/templates/worker-table.html | 6 ++-- distributed/dashboard/templates/worker.html | 6 ++-- .../dashboard/tests/test_scheduler_bokeh.py | 6 ++-- .../dashboard/tests/test_worker_bokeh.py | 29 +++++++++++++++++++ distributed/dashboard/utils.py | 20 +++++++++++++ distributed/dashboard/worker.py | 2 +- distributed/dashboard/worker_html.py | 22 ++------------ 10 files changed, 76 insertions(+), 52 deletions(-) diff --git a/distributed/dashboard/scheduler.py b/distributed/dashboard/scheduler.py index 6476d3aa6e4..86f56e9eda0 100644 --- a/distributed/dashboard/scheduler.py +++ b/distributed/dashboard/scheduler.py @@ -77,7 +77,7 @@ ) template_variables = { - "pages": ["status", "workers", "tasks", "system", "profile", "graph"] + "pages": ["status", "workers", "tasks", "system", "profile", "graph", "info"] } BOKEH_THEME = Theme(os.path.join(os.path.dirname(__file__), "theme.yaml")) @@ -449,7 +449,7 @@ def update(self): "nbytes_text": nbytes_text, "dashboard_host": dashboard_host, "dashboard_port": dashboard_port, - "worker": [ws.address for ws in workers], + "address": [ws.address for ws in workers], "y": y, } @@ -1177,7 +1177,8 @@ class WorkerTable(DashboardComponent): def __init__(self, scheduler, width=800, **kwargs): self.scheduler = scheduler self.names = [ - "worker", + "name", + "address", "ncores", "cpu", "memory", @@ -1195,7 +1196,8 @@ def __init__(self, scheduler, width=800, **kwargs): ) table_names = [ - "worker", + "name", + "address", "ncores", "cpu", "memory", @@ -1242,7 +1244,7 @@ def __init__(self, scheduler, width=800, **kwargs): if name in formatters: table.columns[table_names.index(name)].formatter = formatters[name] - extra_names = ["worker"] + self.extra_names + extra_names = ["name", "address"] + self.extra_names extra_columns = { name: TableColumn(field=name, title=name.replace("_percent", "%")) for name in extra_names @@ -1330,10 +1332,13 @@ def __init__(self, scheduler, width=800, **kwargs): @without_property_validation def update(self): data = {name: [] for name in self.names + self.extra_names} - for addr, ws in sorted(self.scheduler.workers.items()): + for i, (addr, ws) in enumerate( + sorted(self.scheduler.workers.items(), key=lambda kv: kv[1].name) + ): for name in self.names + self.extra_names: data[name].append(ws.metrics.get(name, None)) - data["worker"][-1] = ws.address + data["name"][-1] = ws.name if ws.name is not None else i + data["address"][-1] = ws.address if ws.memory_limit: data["memory_percent"][-1] = ws.metrics["memory"] / ws.memory_limit else: diff --git a/distributed/dashboard/scheduler_html.py b/distributed/dashboard/scheduler_html.py index 5f481f783be..9f2bcd3cbb2 100644 --- a/distributed/dashboard/scheduler_html.py +++ b/distributed/dashboard/scheduler_html.py @@ -1,30 +1,18 @@ from datetime import datetime -import os import toolz from tornado import escape from tornado import gen -from tornado import web from ..utils import log_errors, format_bytes, format_time from .proxy import GlobalProxyHandler - -dirname = os.path.dirname(__file__) +from .utils import RequestHandler, redirect ns = { func.__name__: func for func in [format_bytes, format_time, datetime.fromtimestamp] } -class RequestHandler(web.RequestHandler): - def initialize(self, server=None, extra=None): - self.server = server - self.extra = extra or {} - - def get_template_path(self): - return os.path.join(dirname, "templates") - - class Workers(RequestHandler): def get(self): with log_errors(): @@ -238,6 +226,7 @@ def get(self): routes = [ + (r"info", redirect("info/main/workers.html")), (r"info/main/workers.html", Workers), (r"info/worker/(.*).html", Worker), (r"info/task/(.*).html", Task), diff --git a/distributed/dashboard/templates/base.html b/distributed/dashboard/templates/base.html index da15df28b69..83f5e8527c6 100644 --- a/distributed/dashboard/templates/base.html +++ b/distributed/dashboard/templates/base.html @@ -29,9 +29,6 @@ {{ page|title }}
        • {% endfor %} -
        • - Info -
        Worker Cores {{ len(ws.processing) }} {{ len(ws.has_what) }} bokeh bokeh
        + @@ -13,14 +14,15 @@ {% for ws in worker_list %} + - {% if 'bokeh' in ws.services %} - + {% if 'dashboard' in ws.services %} + {% else %} {% end %} diff --git a/distributed/dashboard/templates/worker.html b/distributed/dashboard/templates/worker.html index 8b26d86e956..9c7608cb8c2 100644 --- a/distributed/dashboard/templates/worker.html +++ b/distributed/dashboard/templates/worker.html @@ -1,8 +1,8 @@ {% extends main.html %} {% block content %} -

        Worker: {{Worker}}

        - {% set ws = workers[Worker] %} - {% set worker_list = [ws] %} +{% set ws = workers[Worker] %} +{% set worker_list = [ws] %} +

        Worker: {{ ws.address }}

        {% include "worker-table.html" %}
        diff --git a/distributed/dashboard/tests/test_scheduler_bokeh.py b/distributed/dashboard/tests/test_scheduler_bokeh.py index f8a813514b4..692a29439c0 100644 --- a/distributed/dashboard/tests/test_scheduler_bokeh.py +++ b/distributed/dashboard/tests/test_scheduler_bokeh.py @@ -354,7 +354,7 @@ def metric_address(worker): assert all(data.values()) assert all(len(v) == 2 for v in data.values()) - my_index = data["worker"].index(a.address), data["worker"].index(b.address) + my_index = data["address"].index(a.address), data["address"].index(b.address) assert [data["metric_port"][i] for i in my_index] == [a.port, b.port] assert [data["metric_address"][i] for i in my_index] == [a.address, b.address] @@ -379,7 +379,7 @@ def metric_port(worker): assert "metric_b" in data assert all(data.values()) assert all(len(v) == 2 for v in data.values()) - my_index = data["worker"].index(a.address), data["worker"].index(b.address) + my_index = data["address"].index(a.address), data["address"].index(b.address) assert [data["metric_a"][i] for i in my_index] == [a.port, None] assert [data["metric_b"][i] for i in my_index] == [None, b.port] @@ -399,7 +399,7 @@ def metric_port(worker): assert "metric_a" in data assert all(data.values()) assert all(len(v) == 2 for v in data.values()) - my_index = data["worker"].index(a.address), data["worker"].index(b.address) + my_index = data["address"].index(a.address), data["address"].index(b.address) assert [data["metric_a"][i] for i in my_index] == [a.port, None] diff --git a/distributed/dashboard/tests/test_worker_bokeh.py b/distributed/dashboard/tests/test_worker_bokeh.py index 11699d9ac83..ef977127d23 100644 --- a/distributed/dashboard/tests/test_worker_bokeh.py +++ b/distributed/dashboard/tests/test_worker_bokeh.py @@ -1,6 +1,7 @@ from __future__ import print_function, division, absolute_import from operator import add, sub +import re from time import sleep import pytest @@ -14,6 +15,7 @@ from distributed.client import wait from distributed.metrics import time from distributed.utils_test import gen_cluster, inc, dec +from distributed.dashboard.scheduler import BokehScheduler from distributed.dashboard.worker import ( BokehWorker, StateTable, @@ -26,6 +28,33 @@ ) +@gen_cluster( + client=True, + worker_kwargs={"services": {("dashboard", 0): BokehWorker}}, + scheduler_kwargs={"services": {("dashboard", 0): BokehScheduler}}, +) +def test_routes(c, s, a, b): + assert isinstance(a.services["dashboard"], BokehWorker) + assert isinstance(b.services["dashboard"], BokehWorker) + port = a.services["dashboard"].port + + future = c.submit(sleep, 1) + yield gen.sleep(0.1) + + http_client = AsyncHTTPClient() + for suffix in ["status", "counters", "system", "profile", "profile-server"]: + response = yield http_client.fetch("http://localhost:%d/%s" % (port, suffix)) + body = response.body.decode() + assert "bokeh" in body.lower() + assert not re.search("href=./", body) # no absolute links + + response = yield http_client.fetch( + "http://localhost:%d/info/main/workers.html" % s.services["dashboard"].port + ) + + assert str(port) in response.body.decode() + + @pytest.mark.skipif( sys.version_info[0] == 2, reason="https://github.com/bokeh/bokeh/issues/5494" ) diff --git a/distributed/dashboard/utils.py b/distributed/dashboard/utils.py index 516ca5bfb88..a9b31345ca9 100644 --- a/distributed/dashboard/utils.py +++ b/distributed/dashboard/utils.py @@ -1,13 +1,16 @@ from __future__ import print_function, division, absolute_import from distutils.version import LooseVersion +import os import bokeh +from tornado import web from toolz import partition from ..compatibility import PY2 BOKEH_VERSION = LooseVersion(bokeh.__version__) +dirname = os.path.dirname(__file__) if BOKEH_VERSION >= "1.0.0" and not PY2: @@ -32,3 +35,20 @@ def parse_args(args): def transpose(lod): keys = list(lod[0].keys()) return {k: [d[k] for d in lod] for k in keys} + + +class RequestHandler(web.RequestHandler): + def initialize(self, server=None, extra=None): + self.server = server + self.extra = extra or {} + + def get_template_path(self): + return os.path.join(dirname, "templates") + + +def redirect(path): + class Redirect(RequestHandler): + def get(self): + self.redirect(path) + + return Redirect diff --git a/distributed/dashboard/worker.py b/distributed/dashboard/worker.py index ed7b68b76b4..aa85afc4197 100644 --- a/distributed/dashboard/worker.py +++ b/distributed/dashboard/worker.py @@ -735,7 +735,7 @@ def counters_doc(server, extra, doc): def profile_doc(server, extra, doc): with log_errors(): doc.title = "Dask Worker Profile" - profile = ProfileTimePlot(server, sizing_mode="scale_width") + profile = ProfileTimePlot(server, sizing_mode="scale_width", doc=doc) profile.trigger_update() doc.add_root(profile.root) diff --git a/distributed/dashboard/worker_html.py b/distributed/dashboard/worker_html.py index c818c8fb1e6..450cce56c8e 100644 --- a/distributed/dashboard/worker_html.py +++ b/distributed/dashboard/worker_html.py @@ -1,17 +1,4 @@ -import os - -from tornado import web - -dirname = os.path.dirname(__file__) - - -class RequestHandler(web.RequestHandler): - def initialize(self, server=None, extra=None): - self.server = server - self.extra = extra or {} - - def get_template_path(self): - return os.path.join(dirname, "templates") +from .utils import RequestHandler, redirect class _PrometheusCollector(object): @@ -67,15 +54,10 @@ def get(self): self.set_header("Content-Type", "text/plain") -class OldRoute(RequestHandler): - def get(self): - self.redirect("/status") - - routes = [ (r"metrics", PrometheusHandler), (r"health", HealthHandler), - (r"main", OldRoute), + (r"main", redirect("/status")), ] From e846991d93054a29e28528224c63db69972ffc9c Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 5 Jun 2019 14:26:50 -0700 Subject: [PATCH 0317/1550] Add SpecCluster.new_worker_spec method (#2751) * Add type name to LocalCluster.__repr__ * Add SpecCluster.new_worker_spec method This is helpful for subclassing --- distributed/deploy/local.py | 10 ++------- distributed/deploy/spec.py | 22 ++++++++++++++++--- distributed/deploy/tests/test_spec_cluster.py | 18 ++++++++++++++- 3 files changed, 38 insertions(+), 12 deletions(-) diff --git a/distributed/deploy/local.py b/distributed/deploy/local.py index 95f178c7c2e..a56cce8c2b2 100644 --- a/distributed/deploy/local.py +++ b/distributed/deploy/local.py @@ -197,14 +197,8 @@ def __init__( ) def __repr__(self): - return "LocalCluster(%r, workers=%d, ncores=%d)" % ( - self.scheduler_address, - len(self.workers), - sum(w.ncores for w in self.workers.values()), - ) - - def __repr__(self): - return "LocalCluster(%r, workers=%d, ncores=%d)" % ( + return "%s(%r, workers=%d, ncores=%d)" % ( + type(self).__name__, self.scheduler_address, len(self.workers), sum(w.ncores for w in self.workers.values()), diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index d5a954effc8..2558a5df26a 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -275,12 +275,28 @@ def scale(self, n): self.worker_spec.popitem() while len(self.worker_spec) < n: - while self._i in self.worker_spec: - self._i += 1 - self.worker_spec[self._i] = self.new_spec + k, spec = self.new_worker_spec() + self.worker_spec[k] = spec self.loop.add_callback(self._correct_state) + def new_worker_spec(self): + """ Return name and spec for the next worker + + Returns + ------- + name: identifier for worker + spec: dict + + See Also + -------- + scale + """ + while self._i in self.worker_spec: + self._i += 1 + + return self._i, self.new_spec + async def scale_down(self, workers): workers = set(workers) diff --git a/distributed/deploy/tests/test_spec_cluster.py b/distributed/deploy/tests/test_spec_cluster.py index ac5706afe1c..eb733f2e68f 100644 --- a/distributed/deploy/tests/test_spec_cluster.py +++ b/distributed/deploy/tests/test_spec_cluster.py @@ -76,7 +76,9 @@ def test_spec_sync(loop): def test_loop_started(): - cluster = SpecCluster(worker_spec) + cluster = SpecCluster( + worker_spec, scheduler={"cls": Scheduler, "options": {"port": 0}} + ) @pytest.mark.asyncio @@ -110,6 +112,7 @@ async def test_broken_worker(): async with SpecCluster( asynchronous=True, workers={"good": {"cls": Worker}, "bad": {"cls": BrokenWorker}}, + scheduler={"cls": Scheduler, "options": {"port": 0}}, ) as cluster: pass @@ -124,3 +127,16 @@ def test_spec_close_clusters(loop): assert cluster in SpecCluster._instances close_clusters() assert cluster.status == "closed" + + +@pytest.mark.asyncio +async def test_new_worker_spec(): + class MyCluster(SpecCluster): + def new_worker_spec(self): + i = len(self.worker_spec) + return i, {"cls": Worker, "options": {"ncores": i + 1}} + + async with MyCluster(asynchronous=True, scheduler=scheduler) as cluster: + cluster.scale(3) + for i in range(3): + assert cluster.worker_spec[i]["options"]["ncores"] == i + 1 From 0696a1f6456b8010b19ac47b14cd2dca0d859246 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 6 Jun 2019 11:15:33 -0700 Subject: [PATCH 0318/1550] Move some of the adaptive logic into the scheduler (#2735) * Move some of the adaptive logic into the scheduler * don't close closed clusters * require pytest >= 4 in CI * use worker_spec if it exists * Don't scale a closed cluster * handle intermittent failures --- continuous_integration/travis/install.sh | 2 +- distributed/deploy/adaptive.py | 117 +++------------------- distributed/deploy/spec.py | 7 +- distributed/deploy/tests/test_adaptive.py | 35 +++---- distributed/scheduler.py | 55 ++++++++++ distributed/tests/test_diskutils.py | 2 + distributed/tests/test_scheduler.py | 25 +++++ distributed/utils.py | 12 +-- 8 files changed, 121 insertions(+), 134 deletions(-) diff --git a/continuous_integration/travis/install.sh b/continuous_integration/travis/install.sh index cb2dbdf5c83..2ab9724db25 100644 --- a/continuous_integration/travis/install.sh +++ b/continuous_integration/travis/install.sh @@ -44,7 +44,7 @@ conda install -q \ paramiko \ prometheus_client \ psutil \ - pytest \ + pytest>=4 \ pytest-timeout \ python=$PYTHON \ requests \ diff --git a/distributed/deploy/adaptive.py b/distributed/deploy/adaptive.py index 793e80d984c..401acc3dc1d 100644 --- a/distributed/deploy/adaptive.py +++ b/distributed/deploy/adaptive.py @@ -4,7 +4,6 @@ import logging import math -import toolz from tornado import gen from ..metrics import time @@ -128,104 +127,6 @@ def stop(self): self._adapt_callback = None del self._adapt_callback - def needs_cpu(self): - """ - Check if the cluster is CPU constrained (too many tasks per core) - - Notes - ----- - Returns ``True`` if the occupancy per core is some factor larger - than ``startup_cost`` and the number of tasks exceeds the number of - cores - """ - total_occupancy = self.scheduler.total_occupancy - total_cores = self.scheduler.total_ncores - - if total_occupancy / (total_cores + 1e-9) > self.startup_cost * 2: - logger.info( - "CPU limit exceeded [%d occupancy / %d cores]", - total_occupancy, - total_cores, - ) - - tasks_processing = 0 - - for w in self.scheduler.workers.values(): - tasks_processing += len(w.processing) - - if tasks_processing > total_cores: - logger.info( - "pending tasks exceed number of cores " "[%d tasks / %d cores]", - tasks_processing, - total_cores, - ) - - return True - - return False - - def needs_memory(self): - """ - Check if the cluster is RAM constrained - - Notes - ----- - Returns ``True`` if the required bytes in distributed memory is some - factor larger than the actual distributed memory available. - """ - limit_bytes = { - addr: ws.memory_limit for addr, ws in self.scheduler.workers.items() - } - worker_bytes = [ws.nbytes for ws in self.scheduler.workers.values()] - - limit = sum(limit_bytes.values()) - total = sum(worker_bytes) - if total > 0.6 * limit: - logger.info("Ram limit exceeded [%d/%d]", limit, total) - return True - else: - return False - - def should_scale_up(self): - """ - Determine whether additional workers should be added to the cluster - - Returns - ------- - scale_up : bool - - Notes - ---- - Additional workers are added whenever - - 1. There are unrunnable tasks and no workers - 2. The cluster is CPU constrained - 3. The cluster is RAM constrained - 4. There are fewer workers than our minimum - - See Also - -------- - needs_cpu - needs_memory - """ - with log_errors(): - if len(self.scheduler.workers) < self.minimum: - return True - - if self.maximum is not None and len(self.scheduler.workers) >= self.maximum: - return False - - if self.scheduler.unrunnable and not self.scheduler.workers: - return True - - needs_cpu = self.needs_cpu() - needs_memory = self.needs_memory() - - if needs_cpu or needs_memory: - return True - - return False - def workers_to_close(self, **kwargs): """ Determine which, if any, workers should potentially be removed from @@ -305,9 +206,17 @@ def get_scale_up_kwargs(self): return {"n": instances} def recommendations(self, comm=None): - should_scale_up = self.should_scale_up() + n = self.scheduler.adaptive_target(target_duration=self.target_duration) + if self.maximum is not None: + n = min(self.maximum, n) + if self.minimum is not None: + n = max(self.minimum, n) workers = set(self.workers_to_close(key=self.worker_key, minimum=self.minimum)) - if should_scale_up and workers: + try: + current = len(self.cluster.worker_spec) + except AttributeError: + current = len(self.cluster.workers) + if n > current and workers: logger.info("Attempting to scale up and scale down simultaneously.") self.close_counts.clear() return { @@ -315,9 +224,9 @@ def recommendations(self, comm=None): "msg": "Trying to scale up and down simultaneously", } - elif should_scale_up: + elif n > current: self.close_counts.clear() - return toolz.merge({"status": "up"}, self.get_scale_up_kwargs()) + return {"status": "up", "n": n} elif workers: d = {} @@ -352,7 +261,7 @@ def _adapt(self): return status = recommendations.pop("status") if status == "up": - f = self.cluster.scale_up(**recommendations) + f = self.cluster.scale(**recommendations) self.log.append((time(), "up", recommendations)) if hasattr(f, "__await__"): yield f diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index 2558a5df26a..85728a057e4 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -274,6 +274,10 @@ def scale(self, n): while len(self.worker_spec) > n: self.worker_spec.popitem() + if self.status in ("closing", "closed"): + self.loop.add_callback(self._correct_state) + return + while len(self.worker_spec) < n: k, spec = self.new_worker_spec() self.worker_spec[k] = spec @@ -321,4 +325,5 @@ def __repr__(self): def close_clusters(): for cluster in list(SpecCluster._instances): with ignoring(gen.TimeoutError): - cluster.close(timeout=10) + if cluster.status != "closed": + cluster.close(timeout=10) diff --git a/distributed/deploy/tests/test_adaptive.py b/distributed/deploy/tests/test_adaptive.py index 8915c721353..cc860636e55 100644 --- a/distributed/deploy/tests/test_adaptive.py +++ b/distributed/deploy/tests/test_adaptive.py @@ -2,13 +2,12 @@ from time import sleep -import pytest from toolz import frequencies, pluck from tornado import gen from tornado.ioloop import IOLoop from distributed import Client, wait, Adaptive, LocalCluster, SpecCluster, Worker -from distributed.utils_test import gen_cluster, gen_test, slowinc, inc, clean +from distributed.utils_test import gen_cluster, gen_test, slowinc, clean from distributed.utils_test import loop, nodebug # noqa: F401 from distributed.metrics import time @@ -116,11 +115,10 @@ def test_adaptive_local_cluster_multi_workers(): yield gen.sleep(0.01) assert time() < start + 15, alc.log - # assert not cluster.workers - assert not cluster.scheduler.workers - yield gen.sleep(0.2) - # assert not cluster.workers - assert not cluster.scheduler.workers + # no workers for a while + for i in range(10): + assert not cluster.scheduler.workers + yield gen.sleep(0.05) futures = c.map(slowinc, range(100), delay=0.01) yield c.gather(futures) @@ -152,6 +150,10 @@ def scale_up(self, n, **kwargs): def scale_down(self, workers): assert False + @property + def workers(self): + return s.workers + assert len(s.workers) == 10 # Assert that adaptive cycle does not reduce cluster below minimum size @@ -163,8 +165,7 @@ def scale_down(self, workers): assert len(s.workers) == 2 -@pytest.mark.xfail(reason="need to rework adaptive") -@gen_test(timeout=30) +@gen_test() def test_min_max(): cluster = yield LocalCluster( 0, @@ -242,7 +243,9 @@ def test_avoid_churn(): yield client.submit(slowinc, i, delay=0.040) yield gen.sleep(0.040) - assert frequencies(pluck(1, adapt.log)) == {"up": 1} + from toolz.curried import pipe, unique, pluck, frequencies + + assert pipe(adapt.log, unique(key=str), pluck(1), frequencies) == {"up": 1} finally: yield client.close() yield cluster.close() @@ -435,15 +438,3 @@ def key(ws): assert names == {"a-1", "a-2"} or names == {"b-1", "b-2"} finally: yield cluster.close() - - -@gen_cluster(client=True, ncores=[]) -def test_without_cluster(c, s): - adapt = Adaptive(scheduler=s) - - future = c.submit(inc, 1) - while not s.tasks: - yield gen.sleep(0.01) - - response = yield c.scheduler.adaptive_recommendations() - assert response["status"] == "up" diff --git a/distributed/scheduler.py b/distributed/scheduler.py index ca3c1241ea7..3cf8de49306 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -6,6 +6,7 @@ import itertools import json import logging +import math from numbers import Number import operator import os @@ -1063,6 +1064,7 @@ def __init__( "get_task_status": self.get_task_status, "get_task_stream": self.get_task_stream, "register_worker_plugin": self.register_worker_plugin, + "adaptive_target": self.adaptive_target, } self._transitions = { @@ -4740,6 +4742,59 @@ def check_idle(self): if close: self.loop.add_callback(self.close) + def adaptive_target(self, target_duration="5s"): + """ Desired number of workers based on the current workload + + This looks at the current running tasks and memory use, and returns a + number of desired workers. This is often used by adaptive scheduling. + + Parameters + ---------- + target_duration: str + A desired duration of time for computations to take. This affects + how rapidly the scheduler will ask to scale. + + See Also + -------- + distributed.deploy.Adaptive + """ + target_duration = parse_timedelta(target_duration) + + # CPU + cpu = math.ceil( + self.total_occupancy / target_duration + ) # TODO: threads per worker + + # Avoid a few long tasks from asking for many cores + tasks_processing = 0 + for ws in self.workers.values(): + tasks_processing += len(ws.processing) + + if tasks_processing > cpu: + break + else: + cpu = min(tasks_processing, cpu) + + if self.unrunnable and not self.workers: + cpu = max(1, cpu) + + # Memory + limit_bytes = {addr: ws.memory_limit for addr, ws in self.workers.items()} + worker_bytes = [ws.nbytes for ws in self.workers.values()] + limit = sum(limit_bytes.values()) + total = sum(worker_bytes) + if total > 0.6 * limit: + memory = 2 * len(self.workers) + else: + memory = 0 + + target = max(memory, cpu) + if target >= len(self.workers): + return target + else: # Scale down? + to_close = self.workers_to_close() + return len(self.workers) - len(to_close) + def decide_worker(ts, all_workers, valid_workers, objective): """ diff --git a/distributed/tests/test_diskutils.py b/distributed/tests/test_diskutils.py index 1bededf84ab..a6dcf3497a3 100644 --- a/distributed/tests/test_diskutils.py +++ b/distributed/tests/test_diskutils.py @@ -276,6 +276,8 @@ def _test_workspace_concurrency(tmpdir, timeout, max_procs): def test_workspace_concurrency(tmpdir): if WINDOWS: raise pytest.xfail.Exception("TODO: unknown failure on windows") + if sys.version_info < (3, 6): + raise pytest.xfail.Exception("TODO: unknown failure on Python 3.5") _test_workspace_concurrency(tmpdir, 2.0, 6) diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index a0cbfabfa29..1c321a02906 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -1574,6 +1574,31 @@ def test_dashboard_address(): yield s.close() +@gen_cluster(client=True) +async def test_adaptive_target(c, s, a, b): + assert s.adaptive_target() == 0 + x = c.submit(inc, 1) + await x + assert s.adaptive_target() == 1 + + # Long task + s.task_duration["slowinc"] = 10 + x = c.submit(slowinc, 1, delay=0.5) + while x.key not in s.tasks: + await gen.sleep(0.01) + assert s.adaptive_target(target_duration=".1s") == 1 # still one + + s.task_duration["slowinc"] = 10 + L = c.map(slowinc, range(100), delay=0.5) + while len(s.tasks) < 100: + await gen.sleep(0.01) + assert 10 < s.adaptive_target(target_duration=".1s") <= 100 + del x, L + while s.tasks: + await gen.sleep(0.01) + assert s.adaptive_target(target_duration=".1s") == 0 + + @pytest.mark.asyncio async def test_async_context_manager(): async with Scheduler(port=0) as s: diff --git a/distributed/utils.py b/distributed/utils.py index 55508a4c574..e8de0bc5108 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -950,13 +950,13 @@ def tmpfile(extension=""): yield filename if os.path.exists(filename): - if os.path.isdir(filename): - shutil.rmtree(filename) - else: - try: + try: + if os.path.isdir(filename): + shutil.rmtree(filename) + else: os.remove(filename) - except OSError: # sometimes we can't remove a generated temp file - pass + except OSError: # sometimes we can't remove a generated temp file + pass def ensure_bytes(s): From 587be8d48536f52453594eebd1a23becf864ccf9 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Thu, 6 Jun 2019 16:42:32 -0500 Subject: [PATCH 0319/1550] Add nanny logs (#2744) --- distributed/client.py | 20 ++++++++++++++------ distributed/nanny.py | 2 ++ distributed/node.py | 24 +++++++++++++++++++++++- distributed/scheduler.py | 28 +++++----------------------- distributed/tests/test_client.py | 16 ++++++++++++++-- distributed/worker.py | 24 ++---------------------- docs/source/api.rst | 2 +- 7 files changed, 61 insertions(+), 55 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 22d89cda4e4..ac098d7987e 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -2270,6 +2270,10 @@ def run(self, function, *args, **kwargs): wait: boolean (optional) If the function is asynchronous whether or not to wait until that function finishes. + nanny : bool, defualt False + Whether to run ``function`` on the nanny. By default, the function + is run on the worker process. If specified, the addresses in + ``workers`` should still be the worker addresses, not the nanny addresses. Examples -------- @@ -3354,7 +3358,7 @@ def get_scheduler_logs(self, n=None): Parameters ---------- - n: int + n : int Number of logs to retrive. Maxes out at 10000 by default, confiruable in config.yaml::log-length @@ -3364,23 +3368,27 @@ def get_scheduler_logs(self, n=None): """ return self.sync(self.scheduler.logs, n=n) - def get_worker_logs(self, n=None, workers=None): + def get_worker_logs(self, n=None, workers=None, nanny=False): """ Get logs from workers Parameters ---------- - n: int + n : int Number of logs to retrive. Maxes out at 10000 by default, confiruable in config.yaml::log-length - workers: iterable - List of worker addresses to retrive. Gets all workers by default. + workers : iterable + List of worker addresses to retrieve. Gets all workers by default. + nanny : bool, default False + Whether to get the logs from the workers (False) or the nannies (True). If + specified, the addresses in `workers` should still be the worker addresses, + not the nanny addresses. Returns ------- Dictionary mapping worker address to logs. Logs are returned in reversed order (newest first) """ - return self.sync(self.scheduler.worker_logs, n=n, workers=workers) + return self.sync(self.scheduler.worker_logs, n=n, workers=workers, nanny=nanny) def retire_workers(self, workers=None, close_workers=True, **kwargs): """ Retire certain workers on the scheduler diff --git a/distributed/nanny.py b/distributed/nanny.py index 59a8083e832..9cf444fc7c4 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -79,6 +79,7 @@ def __init__( protocol=None, **worker_kwargs ): + self._setup_logging(logger) self.loop = loop or IOLoop.current() self.security = security or Security() assert isinstance(self.security, Security) @@ -130,6 +131,7 @@ def __init__( "kill": self.kill, "restart": self.restart, # cannot call it 'close' on the rpc side for naming conflict + "get_logs": self.get_logs, "terminate": self.close, "close_gracefully": self.close_gracefully, "run": self.run, diff --git a/distributed/node.py b/distributed/node.py index 4f0b9813a8e..8bd81ffe5ae 100644 --- a/distributed/node.py +++ b/distributed/node.py @@ -1,12 +1,15 @@ from __future__ import print_function, division, absolute_import import warnings +import logging from tornado.ioloop import IOLoop +import dask -from .compatibility import unicode +from .compatibility import unicode, finalize from .core import Server, ConnectionPool from .versions import get_versions +from .utils import DequeHandler class Node(object): @@ -131,6 +134,25 @@ def stop_services(self): def service_ports(self): return {k: v.port for k, v in self.services.items()} + def _setup_logging(self, logger): + self._deque_handler = DequeHandler( + n=dask.config.get("distributed.admin.log-length") + ) + self._deque_handler.setFormatter( + logging.Formatter(dask.config.get("distributed.admin.log-format")) + ) + logger.addHandler(self._deque_handler) + finalize(self, logger.removeHandler, self._deque_handler) + + def get_logs(self, comm=None, n=None): + deque_handler = self._deque_handler + if n is None: + L = list(deque_handler.deque) + else: + L = deque_handler.deque + L = [L[-i] for i in range(min(n, len(L)))] + return [(msg.levelname, deque_handler.format(msg)) for msg in L] + async def __aenter__(self): await self return self diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 3cf8de49306..2705971e155 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -53,7 +53,6 @@ key_split, validate_key, no_default, - DequeHandler, parse_timedelta, parse_bytes, PeriodicCallback, @@ -844,7 +843,7 @@ def __init__( dashboard_address=None, **kwargs ): - self._setup_logging() + self._setup_logging(logger) # Attributes self.allowed_failures = allowed_failures @@ -1329,16 +1328,6 @@ def close_worker(self, stream=None, worker=None, safe=None): self.worker_send(worker, {"op": "close", "report": False}) self.remove_worker(address=worker, safe=safe) - def _setup_logging(self): - self._deque_handler = DequeHandler( - n=dask.config.get("distributed.admin.log-length") - ) - self._deque_handler.setFormatter( - logging.Formatter(dask.config.get("distributed.admin.log-format")) - ) - logger.addHandler(self._deque_handler) - finalize(self, logger.removeHandler, self._deque_handler) - ########### # Stimuli # ########### @@ -4627,18 +4616,11 @@ def get_profile_metadata( raise gen.Return({"counts": counts, "keys": keys}) - def get_logs(self, comm=None, n=None): - deque_handler = self._deque_handler - if n is None: - L = list(deque_handler.deque) - else: - L = deque_handler.deque - L = [L[-i] for i in range(min(n, len(L)))] - return [(msg.levelname, deque_handler.format(msg)) for msg in L] - @gen.coroutine - def get_worker_logs(self, comm=None, n=None, workers=None): - results = yield self.broadcast(msg={"op": "get_logs", "n": n}, workers=workers) + def get_worker_logs(self, comm=None, n=None, workers=None, nanny=False): + results = yield self.broadcast( + msg={"op": "get_logs", "n": n}, workers=workers, nanny=nanny + ) raise gen.Return(results) ########### diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index c731ae6e5ad..dc45b3025e5 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -5111,7 +5111,7 @@ def test_task_metadata(c, s, a, b): assert result == {"a": {"c": {"d": 1}}, "b": 2} -@gen_cluster(client=True) +@gen_cluster(client=True, Worker=Nanny) def test_logs(c, s, a, b): yield wait(c.map(inc, range(5))) logs = yield c.get_scheduler_logs(n=5) @@ -5121,11 +5121,23 @@ def test_logs(c, s, a, b): assert "distributed.scheduler" in msg w_logs = yield c.get_worker_logs(n=5) - assert set(w_logs.keys()) == {a.address, b.address} + assert set(w_logs.keys()) == {a.worker_address, b.worker_address} for log in w_logs.values(): for _, msg in log: assert "distributed.worker" in msg + n_logs = yield c.get_worker_logs(nanny=True) + assert set(n_logs.keys()) == {a.worker_address, b.worker_address} + for log in n_logs.values(): + for _, msg in log: + assert "distributed.nanny" in msg + + n_logs = yield c.get_worker_logs(nanny=True, workers=[a.worker_address]) + assert set(n_logs.keys()) == {a.worker_address} + for log in n_logs.values(): + for _, msg in log: + assert "distributed.nanny" in msg + @gen_cluster(client=True) def test_avoid_delayed_finalize(c, s, a, b): diff --git a/distributed/worker.py b/distributed/worker.py index d0bc735ec67..37dcbc2eca1 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -33,7 +33,7 @@ from .comm import get_address_host, get_local_address_for, connect from .comm.utils import offload from .comm.addressing import address_from_user_args -from .compatibility import unicode, get_thread_identity, finalize, MutableMapping +from .compatibility import unicode, get_thread_identity, MutableMapping from .core import error_message, CommClosedError, send_recv, pingpong, coerce_to_address from .diskutils import WorkSpace from .metrics import time @@ -60,7 +60,6 @@ json_load_robust, key_split, format_bytes, - DequeHandler, PeriodicCallback, parse_bytes, parse_timedelta, @@ -412,7 +411,7 @@ def __init__( ) profile_cycle_interval = parse_timedelta(profile_cycle_interval, default="ms") - self._setup_logging() + self._setup_logging(logger) if scheduler_file: cfg = json_load_robust(scheduler_file) @@ -666,16 +665,6 @@ def __repr__(self): ) ) - def _setup_logging(self): - self._deque_handler = DequeHandler( - n=dask.config.get("distributed.admin.log-length") - ) - self._deque_handler.setFormatter( - logging.Formatter(dask.config.get("distributed.admin.log-format")) - ) - logger.addHandler(self._deque_handler) - finalize(self, logger.removeHandler, self._deque_handler) - @property def worker_address(self): """ For API compatibility with Nanny """ @@ -888,15 +877,6 @@ def gather(self, comm=None, who_has=None): self.update_data(data=result, report=False) raise Return({"status": "OK"}) - def get_logs(self, comm=None, n=None): - deque_handler = self._deque_handler - if n is None: - L = list(deque_handler.deque) - else: - L = deque_handler.deque - L = [L[-i] for i in range(min(n, len(L)))] - return [(msg.levelname, deque_handler.format(msg)) for msg in L] - ############# # Lifecycle # ############# diff --git a/docs/source/api.rst b/docs/source/api.rst index e91c4ee6ac1..574a70d34b6 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -19,8 +19,8 @@ API Client.get_executor Client.get_metadata Client.get_scheduler_logs - Client.get_task_stream Client.get_worker_logs + Client.get_task_stream Client.has_what Client.list_datasets Client.map From 5042f579b9b77576da319995cf36d0798875b621 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 7 Jun 2019 12:11:49 -0700 Subject: [PATCH 0320/1550] Add stress test for UCX (#2759) This test generated https://github.com/rapidsai/ucx-py/pull/120 --- distributed/comm/tests/test_ucx.py | 31 ++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/distributed/comm/tests/test_ucx.py b/distributed/comm/tests/test_ucx.py index 55a2f4ec82c..8a0e8927cf6 100644 --- a/distributed/comm/tests/test_ucx.py +++ b/distributed/comm/tests/test_ucx.py @@ -294,3 +294,34 @@ def test_tcp_localcluster(loop): # assert any(w.data == {x.key: 2} for w in c.workers) # assert e.loop is c.loop # print(c.scheduler.workers) + + +@pytest.mark.slow +@pytest.mark.asyncio +async def test_stress(): + from distributed.utils import get_ip_interface + + try: # this check should be removed once UCX + TCP works + get_ip_interface("ib0") + except Exception: + pytest.skip("ib0 interface not found") + + import dask.array as da + from distributed import wait + + chunksize = "10 MB" + + async with LocalCluster( + protocol="ucx", interface="ib0", asynchronous=True + ) as cluster: + async with Client(cluster, asynchronous=True) as client: + rs = da.random.RandomState() + x = rs.random((10000, 10000), chunks=(-1, chunksize)) + x = x.persist() + await wait(x) + + for i in range(10): + x = x.rechunk((chunksize, -1)) + x = x.rechunk((-1, chunksize)) + x = x.persist() + await wait(x) From 756bdd8eb891ee09af6340b7fef4bd883d9fcefb Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 7 Jun 2019 13:04:21 -0700 Subject: [PATCH 0321/1550] Remove module state in Prometheus Handlers (#2760) This also fixes an ImportError in prometheus-client=0.7 --- distributed/dashboard/scheduler_html.py | 24 ++++++++++-------------- distributed/dashboard/worker_html.py | 18 +++++++----------- 2 files changed, 17 insertions(+), 25 deletions(-) diff --git a/distributed/dashboard/scheduler_html.py b/distributed/dashboard/scheduler_html.py index 9f2bcd3cbb2..8b1da2035ea 100644 --- a/distributed/dashboard/scheduler_html.py +++ b/distributed/dashboard/scheduler_html.py @@ -175,17 +175,18 @@ def get(self): class _PrometheusCollector(object): - def __init__(self, server, prometheus_client): + def __init__(self, server): self.server = server - self.prometheus_client = prometheus_client def collect(self): - yield self.prometheus_client.core.GaugeMetricFamily( + from prometheus_client.core import GaugeMetricFamily + + yield GaugeMetricFamily( "dask_scheduler_workers", "Number of workers.", value=len(self.server.workers), ) - yield self.prometheus_client.core.GaugeMetricFamily( + yield GaugeMetricFamily( "dask_scheduler_clients", "Number of clients.", value=len(self.server.clients), @@ -196,26 +197,21 @@ class PrometheusHandler(RequestHandler): _initialized = False def __init__(self, *args, **kwargs): - import prometheus_client # keep out of global namespace - - self.prometheus_client = prometheus_client + import prometheus_client super(PrometheusHandler, self).__init__(*args, **kwargs) - self._init() - - def _init(self): if PrometheusHandler._initialized: return - self.prometheus_client.REGISTRY.register( - _PrometheusCollector(self.server, self.prometheus_client) - ) + prometheus_client.REGISTRY.register(_PrometheusCollector(self.server)) PrometheusHandler._initialized = True def get(self): - self.write(self.prometheus_client.generate_latest()) + import prometheus_client + + self.write(prometheus_client.generate_latest()) self.set_header("Content-Type", "text/plain; version=0.0.4") diff --git a/distributed/dashboard/worker_html.py b/distributed/dashboard/worker_html.py index 450cce56c8e..e1ae50f3afc 100644 --- a/distributed/dashboard/worker_html.py +++ b/distributed/dashboard/worker_html.py @@ -4,7 +4,6 @@ class _PrometheusCollector(object): def __init__(self, server, prometheus_client): self.server = server - self.prometheus_client = prometheus_client def collect(self): # add your metrics here: @@ -14,7 +13,7 @@ def collect(self): yield None # # 2. yield your metrics - # yield self.prometheus_client.core.GaugeMetricFamily( + # yield prometheus_client.core.GaugeMetricFamily( # 'dask_worker_connections', # 'Number of connections currently open.', # value=???, @@ -25,26 +24,23 @@ class PrometheusHandler(RequestHandler): _initialized = False def __init__(self, *args, **kwargs): - import prometheus_client # keep out of global namespace - - self.prometheus_client = prometheus_client + import prometheus_client super(PrometheusHandler, self).__init__(*args, **kwargs) - self._init() - - def _init(self): if PrometheusHandler._initialized: return - self.prometheus_client.REGISTRY.register( - _PrometheusCollector(self.server, self.prometheus_client) + prometheus_client.REGISTRY.register( + _PrometheusCollector(self.server, prometheus_client) ) PrometheusHandler._initialized = True def get(self): - self.write(self.prometheus_client.generate_latest()) + import prometheus_client + + self.write(prometheus_client.generate_latest()) self.set_header("Content-Type", "text/plain; version=0.0.4") From 309e435cbb383e437bb8af3c571b52fb163a0ac9 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 7 Jun 2019 13:25:56 -0700 Subject: [PATCH 0322/1550] Change address -> worker in ColumnDataSource for nbytes plot (#2755) Fixes #2754 --- distributed/dashboard/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/dashboard/scheduler.py b/distributed/dashboard/scheduler.py index 86f56e9eda0..2cb916d0b5d 100644 --- a/distributed/dashboard/scheduler.py +++ b/distributed/dashboard/scheduler.py @@ -449,7 +449,7 @@ def update(self): "nbytes_text": nbytes_text, "dashboard_host": dashboard_host, "dashboard_port": dashboard_port, - "address": [ws.address for ws in workers], + "worker": [ws.address for ws in workers], "y": y, } From d378b41a89e33e660257522dd4b86d44e6d15fc5 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Sat, 8 Jun 2019 13:29:29 -0500 Subject: [PATCH 0323/1550] Delay lookup of allowed failures. (#2761) This allows for setting the config after importing distributed xref https://github.com/dask/dask-examples/pull/75#discussion_r291141404 --- distributed/scheduler.py | 6 +++--- distributed/tests/test_scheduler.py | 14 ++++++++++++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 2705971e155..ae449bcfafe 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -73,8 +73,6 @@ logger = logging.getLogger(__name__) -ALLOWED_FAILURES = dask.config.get("distributed.scheduler.allowed-failures") - LOG_PDB = dask.config.get("distributed.admin.pdb-on-err") DEFAULT_DATA_SIZE = dask.config.get("distributed.scheduler.default-data-size") @@ -829,7 +827,7 @@ def __init__( synchronize_worker_interval="60s", services=None, service_kwargs=None, - allowed_failures=ALLOWED_FAILURES, + allowed_failures=None, extensions=None, validate=False, scheduler_file=None, @@ -846,6 +844,8 @@ def __init__( self._setup_logging(logger) # Attributes + if allowed_failures is None: + allowed_failures = dask.config.get("distributed.scheduler.allowed-failures") self.allowed_failures = allowed_failures self.validate = validate self.status = None diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 1c321a02906..66a8088ace5 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -1607,3 +1607,17 @@ async def test_async_context_manager(): assert w.status == "running" assert s.workers assert not s.workers + + +@pytest.mark.asyncio +async def test_allowed_failures_config(): + async with Scheduler(port=0, allowed_failures=10) as s: + assert s.allowed_failures == 10 + + with dask.config.set({"distributed.scheduler.allowed_failures": 100}): + async with Scheduler(port=0) as s: + assert s.allowed_failures == 100 + + with dask.config.set({"distributed.scheduler.allowed_failures": 0}): + async with Scheduler(port=0) as s: + assert s.allowed_failures == 0 From 2ba70b310dbefc5764ee43079ac1bb783a8cec08 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Mon, 10 Jun 2019 16:31:33 -0500 Subject: [PATCH 0324/1550] Add unknown pytest markers (#2764) --- setup.cfg | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/setup.cfg b/setup.cfg index 434b1fd258c..5533437121b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -41,6 +41,11 @@ universal=1 [tool:pytest] addopts = -rsx -v --durations=10 minversion = 3.2 +markers = + slow: marks tests as slow (deselect with '-m "not slow"') + avoid_travis: marks tests as flaky on TravisCI. + ipython: mark a test as exercising IPython + # filterwarnings = # error # ignore::UserWarning From a511f0ea480d7305d86c6439213c3cf3a6d95dc4 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 18 Jun 2019 14:55:26 +0200 Subject: [PATCH 0325/1550] Replace ncores with nthreads throughout codebase (#2758) --- distributed/cli/dask_mpi.py | 2 +- distributed/cli/dask_worker.py | 6 +- distributed/cli/tests/test_dask_scheduler.py | 6 +- distributed/cli/tests/test_dask_worker.py | 6 +- distributed/cli/tests/test_tls_cli.py | 4 +- distributed/client.py | 30 ++--- distributed/dashboard/components.py | 12 +- distributed/dashboard/scheduler.py | 12 +- distributed/dashboard/scheduler_html.py | 6 +- .../dashboard/templates/worker-table.html | 2 +- .../dashboard/tests/test_scheduler_bokeh.py | 4 +- distributed/dashboard/worker.py | 4 +- distributed/deploy/local.py | 23 ++-- distributed/deploy/spec.py | 8 +- distributed/deploy/tests/test_adaptive.py | 16 +-- distributed/deploy/tests/test_local.py | 35 +++--- distributed/deploy/tests/test_spec_cluster.py | 30 ++--- distributed/deploy/utils_test.py | 4 +- .../diagnostics/tests/test_eventstream.py | 2 +- distributed/diagnostics/tests/test_plugin.py | 2 +- .../diagnostics/tests/test_progressbar.py | 2 +- .../diagnostics/tests/test_task_stream.py | 2 +- distributed/nanny.py | 16 ++- distributed/scheduler.py | 60 +++++----- distributed/stealing.py | 4 +- distributed/tests/test_actor.py | 10 +- distributed/tests/test_client.py | 108 +++++++++--------- distributed/tests/test_collections.py | 2 +- distributed/tests/test_failed_workers.py | 34 +++--- distributed/tests/test_ipython.py | 8 +- distributed/tests/test_locks.py | 2 +- distributed/tests/test_nanny.py | 54 ++++----- distributed/tests/test_priorities.py | 4 +- distributed/tests/test_pubsub.py | 2 +- distributed/tests/test_queues.py | 4 +- distributed/tests/test_resources.py | 36 +++--- distributed/tests/test_scheduler.py | 96 ++++++++-------- distributed/tests/test_steal.py | 54 ++++----- distributed/tests/test_stress.py | 16 +-- distributed/tests/test_tls_functional.py | 6 +- distributed/tests/test_utils_test.py | 8 +- distributed/tests/test_variable.py | 2 +- distributed/tests/test_worker.py | 79 ++++++------- distributed/tests/test_worker_client.py | 10 +- distributed/tests/test_worker_plugins.py | 2 +- distributed/utils_comm.py | 8 +- distributed/utils_test.py | 20 ++-- distributed/worker.py | 40 ++++--- docs/source/api.rst | 2 +- docs/source/local-cluster.rst | 2 +- docs/source/protocol.rst | 2 +- docs/source/scheduling-state.rst | 2 +- docs/source/worker.rst | 4 +- 53 files changed, 466 insertions(+), 449 deletions(-) diff --git a/distributed/cli/dask_mpi.py b/distributed/cli/dask_mpi.py index c7669073f79..2a965824662 100644 --- a/distributed/cli/dask_mpi.py +++ b/distributed/cli/dask_mpi.py @@ -105,7 +105,7 @@ def main( scheduler_file=scheduler_file, loop=loop, name=rank if scheduler else None, - ncores=nthreads, + nthreads=nthreads, local_dir=local_directory, services={("dashboard", bokeh_worker_port): BokehWorker}, memory_limit=memory_limit, diff --git a/distributed/cli/dask_worker.py b/distributed/cli/dask_worker.py index 2cf570cfc1d..a53ddf99f6e 100755 --- a/distributed/cli/dask_worker.py +++ b/distributed/cli/dask_worker.py @@ -2,6 +2,7 @@ import atexit import logging +import multiprocessing import gc import os from sys import exit @@ -11,7 +12,6 @@ import dask from distributed import Nanny, Worker from distributed.utils import parse_timedelta -from distributed.worker import _ncores from distributed.security import Security from distributed.cli.utils import check_python_3, install_signal_handlers from distributed.comm import get_address_host_port @@ -280,7 +280,7 @@ def main( port = worker_port if not nthreads: - nthreads = _ncores // nprocs + nthreads = multiprocessing.cpu_count() // nprocs if pid_file: with open(pid_file, "w") as f: @@ -329,7 +329,7 @@ def del_pid_file(): t( scheduler, scheduler_file=scheduler_file, - ncores=nthreads, + nthreads=nthreads, services=services, loop=loop, resources=resources, diff --git a/distributed/cli/tests/test_dask_scheduler.py b/distributed/cli/tests/test_dask_scheduler.py index 754082f35eb..e04fa24bad1 100644 --- a/distributed/cli/tests/test_dask_scheduler.py +++ b/distributed/cli/tests/test_dask_scheduler.py @@ -53,7 +53,7 @@ def f(): ] with Client("127.0.0.1:8978", loop=loop) as c: - assert len(c.ncores()) == 0 + assert len(c.nthreads()) == 0 c.sync(f) @@ -150,7 +150,7 @@ def test_multiple_workers(loop): with popen(["dask-worker", "localhost:8786", "--no-dashboard"]) as b: with Client("127.0.0.1:%d" % Scheduler.default_port, loop=loop) as c: start = time() - while len(c.ncores()) < 2: + while len(c.nthreads()) < 2: sleep(0.1) assert time() < start + 10 @@ -178,7 +178,7 @@ def test_interface(loop): ) as a: with Client("tcp://127.0.0.1:%d" % Scheduler.default_port, loop=loop) as c: start = time() - while not len(c.ncores()): + while not len(c.nthreads()): sleep(0.1) assert time() - start < 5 info = c.scheduler_info() diff --git a/distributed/cli/tests/test_dask_worker.py b/distributed/cli/tests/test_dask_worker.py index fa62594a753..dc7c761fdf1 100644 --- a/distributed/cli/tests/test_dask_worker.py +++ b/distributed/cli/tests/test_dask_worker.py @@ -58,7 +58,7 @@ def test_memory_limit(loop): ] ) as worker: with Client("127.0.0.1:8786", loop=loop) as c: - while not c.ncores(): + while not c.nthreads(): sleep(0.1) info = c.scheduler_info() [d] = info["workers"].values() @@ -218,7 +218,7 @@ def test_contact_listen_address(loop, nanny, listen_address): ] ) as worker: with Client("127.0.0.1:8786") as client: - while not client.ncores(): + while not client.nthreads(): sleep(0.1) info = client.scheduler_info() assert "tcp://127.0.0.2:39837" in info["workers"] @@ -243,7 +243,7 @@ def test_respect_host_listen_address(loop, nanny, host): ["dask-worker", "127.0.0.1:8786", nanny, "--no-dashboard", "--host", host] ) as worker: with Client("127.0.0.1:8786") as client: - while not client.ncores(): + while not client.nthreads(): sleep(0.1) info = client.scheduler_info() diff --git a/distributed/cli/tests/test_tls_cli.py b/distributed/cli/tests/test_tls_cli.py index 4663a9b38ff..37fdc9bb00f 100644 --- a/distributed/cli/tests/test_tls_cli.py +++ b/distributed/cli/tests/test_tls_cli.py @@ -25,9 +25,9 @@ tls_args_2 = ["--tls-ca-file", ca_file, "--tls-cert", cert, "--tls-key", key] -def wait_for_cores(c, ncores=1): +def wait_for_cores(c, nthreads=1): start = time() - while len(c.ncores()) < 1: + while len(c.nthreads()) < 1: sleep(0.1) assert time() < start + 10 diff --git a/distributed/client.py b/distributed/client.py index ac098d7987e..7ad897bf616 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -770,12 +770,12 @@ def __repr__(self): if addr: workers = info.get("workers", {}) nworkers = len(workers) - ncores = sum(w["ncores"] for w in workers.values()) + nthreads = sum(w["nthreads"] for w in workers.values()) return "<%s: scheduler=%r processes=%d cores=%d>" % ( self.__class__.__name__, addr, nworkers, - ncores, + nthreads, ) elif self.scheduler is not None: return "<%s: scheduler=%r>" % ( @@ -830,7 +830,7 @@ def _repr_html_(self): if info: workers = len(info["workers"]) - cores = sum(w["ncores"] for w in info["workers"].values()) + cores = sum(w["nthreads"] for w in info["workers"].values()) memory = sum(w["memory_limit"] for w in info["workers"].values()) memory = format_bytes(memory) text2 = ( @@ -1868,19 +1868,19 @@ def _scatter( else: data2 = valmap(to_serialize, data) if direct: - ncores = None + nthreads = None start = time() - while not ncores: - if ncores is not None: + while not nthreads: + if nthreads is not None: yield gen.sleep(0.1) if time() > start + timeout: raise gen.TimeoutError("No valid workers found") - ncores = yield self.scheduler.ncores(workers=workers) - if not ncores: + nthreads = yield self.scheduler.ncores(workers=workers) + if not nthreads: raise ValueError("No valid workers") _, who_has, nbytes = yield scatter_to_workers( - ncores, data2, report=False, rpc=self.rpc + nthreads, data2, report=False, rpc=self.rpc ) yield self.scheduler.update_data( @@ -3013,7 +3013,7 @@ def replicate(self, futures, n=None, workers=None, branching_factor=2, **kwargs) **kwargs ) - def ncores(self, workers=None, **kwargs): + def nthreads(self, workers=None, **kwargs): """ The number of threads/cores available on each worker node Parameters @@ -3024,7 +3024,7 @@ def ncores(self, workers=None, **kwargs): Examples -------- - >>> c.ncores() # doctest: +SKIP + >>> c.threads() # doctest: +SKIP {'192.168.1.141:46784': 8, '192.167.1.142:47548': 8, '192.167.1.143:47329': 8, @@ -3043,6 +3043,8 @@ def ncores(self, workers=None, **kwargs): workers = [workers] return self.sync(self.scheduler.ncores, workers=workers, **kwargs) + ncores = nthreads + def who_has(self, futures=None, **kwargs): """ The workers storing each future's data @@ -3067,7 +3069,7 @@ def who_has(self, futures=None, **kwargs): See Also -------- Client.has_what - Client.ncores + Client.nthreads """ if futures is not None: futures = self.futures_of(futures) @@ -3099,7 +3101,7 @@ def has_what(self, workers=None, **kwargs): See Also -------- Client.who_has - Client.ncores + Client.nthreads Client.processing """ if isinstance(workers, tuple) and all( @@ -3130,7 +3132,7 @@ def processing(self, workers=None): -------- Client.who_has Client.has_what - Client.ncores + Client.nthreads """ if isinstance(workers, tuple) and all( isinstance(i, (str, tuple)) for i in workers diff --git a/distributed/dashboard/components.py b/distributed/dashboard/components.py index 16efa1d2eb0..e7234e2e6f7 100644 --- a/distributed/dashboard/components.py +++ b/distributed/dashboard/components.py @@ -276,7 +276,7 @@ class Processing(DashboardComponent): """ def __init__(self, **kwargs): - data = self.processing_update({"processing": {}, "ncores": {}}) + data = self.processing_update({"processing": {}, "nthreads": {}}) self.source = ColumnDataSource(data) x_range = Range1d(-1, 1) @@ -321,12 +321,12 @@ def __init__(self, **kwargs): def update(self, messages): with log_errors(): msg = messages["processing"] - if not msg.get("ncores"): + if not msg.get("nthreads"): return data = self.processing_update(msg) x_range = self.root.x_range max_right = max(data["right"]) - cores = max(data["ncores"]) + cores = max(data["nthreads"]) if x_range.end < max_right: x_range.end = max_right + 2 elif x_range.end > 2 * max_right + cores: # way out there, walk back @@ -341,8 +341,8 @@ def processing_update(msg): names = sorted(names) processing = msg["processing"] processing = [processing[name] for name in names] - ncores = msg["ncores"] - ncores = [ncores[name] for name in names] + nthreads = msg["nthreads"] + nthreads = [nthreads[name] for name in names] n = len(names) d = { "name": list(names), @@ -350,7 +350,7 @@ def processing_update(msg): "right": list(processing), "top": list(range(n, 0, -1)), "bottom": list(range(n - 1, -1, -1)), - "ncores": ncores, + "nthreads": nthreads, } d["alpha"] = [0.7] * n diff --git a/distributed/dashboard/scheduler.py b/distributed/dashboard/scheduler.py index 2cb916d0b5d..f6f1fef7590 100644 --- a/distributed/dashboard/scheduler.py +++ b/distributed/dashboard/scheduler.py @@ -189,7 +189,7 @@ def update(self): if total: self.root.title.text = "Occupancy -- total time: %s wall time: %s" % ( format_time(total), - format_time(total / self.scheduler.total_ncores), + format_time(total / self.scheduler.total_nthreads), ) else: self.root.title.text = "Occupancy" @@ -1179,7 +1179,7 @@ def __init__(self, scheduler, width=800, **kwargs): self.names = [ "name", "address", - "ncores", + "nthreads", "cpu", "memory", "memory_limit", @@ -1198,7 +1198,7 @@ def __init__(self, scheduler, width=800, **kwargs): table_names = [ "name", "address", - "ncores", + "nthreads", "cpu", "memory", "memory_limit", @@ -1223,7 +1223,7 @@ def __init__(self, scheduler, width=800, **kwargs): "read_bytes": NumberFormatter(format="0 b"), "write_bytes": NumberFormatter(format="0 b"), "num_fds": NumberFormatter(format="0"), - "ncores": NumberFormatter(format="0"), + "nthreads": NumberFormatter(format="0"), } if BOKEH_VERSION < "0.12.15": @@ -1345,8 +1345,8 @@ def update(self): data["memory_percent"][-1] = "" data["memory_limit"][-1] = ws.memory_limit data["cpu"][-1] = ws.metrics["cpu"] / 100.0 - data["cpu_fraction"][-1] = ws.metrics["cpu"] / 100.0 / ws.ncores - data["ncores"][-1] = ws.ncores + data["cpu_fraction"][-1] = ws.metrics["cpu"] / 100.0 / ws.nthreads + data["nthreads"][-1] = ws.nthreads self.source.data.update(data) diff --git a/distributed/dashboard/scheduler_html.py b/distributed/dashboard/scheduler_html.py index 8b1da2035ea..65a89b33fbb 100644 --- a/distributed/dashboard/scheduler_html.py +++ b/distributed/dashboard/scheduler_html.py @@ -107,7 +107,7 @@ def get(self): scheduler = self.server erred = 0 nbytes = 0 - ncores = 0 + nthreads = 0 memory = 0 processing = 0 released = 0 @@ -124,7 +124,7 @@ def get(self): if ts.waiters: waiting_data += 1 for ws in scheduler.workers.values(): - ncores += ws.ncores + nthreads += ws.nthreads memory += len(ws.has_what) nbytes += ws.nbytes processing += len(ws.processing) @@ -132,7 +132,7 @@ def get(self): response = { "bytes": nbytes, "clients": len(scheduler.clients), - "cores": ncores, + "cores": nthreads, "erred": erred, "hosts": len(scheduler.host_info), "idle": len(scheduler.idle), diff --git a/distributed/dashboard/templates/worker-table.html b/distributed/dashboard/templates/worker-table.html index 4835849daad..a3566f90c3f 100644 --- a/distributed/dashboard/templates/worker-table.html +++ b/distributed/dashboard/templates/worker-table.html @@ -15,7 +15,7 @@
        - + diff --git a/distributed/dashboard/tests/test_scheduler_bokeh.py b/distributed/dashboard/tests/test_scheduler_bokeh.py index 692a29439c0..d9a83caf00b 100644 --- a/distributed/dashboard/tests/test_scheduler_bokeh.py +++ b/distributed/dashboard/tests/test_scheduler_bokeh.py @@ -321,8 +321,8 @@ def test_WorkerTable(c, s, a, b): assert all(wt.source.data.values()) assert all(len(v) == 2 for v in wt.source.data.values()) - ncores = wt.source.data["ncores"] - assert all(ncores) + nthreads = wt.source.data["nthreads"] + assert all(nthreads) @gen_cluster(client=True) diff --git a/distributed/dashboard/worker.py b/distributed/dashboard/worker.py index aa85afc4197..c6633a170aa 100644 --- a/distributed/dashboard/worker.py +++ b/distributed/dashboard/worker.py @@ -76,7 +76,7 @@ def update(self): w = self.worker d = { "Stored": [len(w.data)], - "Executing": ["%d / %d" % (len(w.executing), w.ncores)], + "Executing": ["%d / %d" % (len(w.executing), w.nthreads)], "Ready": [len(w.ready)], "Waiting": [len(w.waiting_for_data)], "Connections": [len(w.in_flight_workers)], @@ -251,7 +251,7 @@ def __init__(self, worker, **kwargs): fig = figure( title="Executing History", x_axis_type="datetime", - y_range=[-0.1, worker.ncores + 0.1], + y_range=[-0.1, worker.nthreads + 0.1], height=150, tools="", x_range=x_range, diff --git a/distributed/deploy/local.py b/distributed/deploy/local.py index a56cce8c2b2..ffb06b0a4bf 100644 --- a/distributed/deploy/local.py +++ b/distributed/deploy/local.py @@ -3,6 +3,7 @@ import atexit import logging import math +import multiprocessing import warnings import weakref @@ -11,7 +12,7 @@ from .spec import SpecCluster from ..nanny import Nanny from ..scheduler import Scheduler -from ..worker import Worker, parse_memory_limit, _ncores +from ..worker import Worker, parse_memory_limit logger = logging.getLogger(__name__) @@ -70,7 +71,7 @@ class LocalCluster(SpecCluster): -------- >>> cluster = LocalCluster() # Create a local cluster with as many workers as cores # doctest: +SKIP >>> cluster # doctest: +SKIP - LocalCluster("127.0.0.1:8786", workers=8, ncores=8) + LocalCluster("127.0.0.1:8786", workers=8, threads=8) >>> c = Client(cluster) # connect to local cluster # doctest: +SKIP @@ -141,21 +142,23 @@ def __init__( worker_services = worker_services or {} if n_workers is None and threads_per_worker is None: if processes: - n_workers, threads_per_worker = nprocesses_nthreads(_ncores) + n_workers, threads_per_worker = nprocesses_nthreads() else: n_workers = 1 - threads_per_worker = _ncores + threads_per_worker = multiprocessing.cpu_count() if n_workers is None and threads_per_worker is not None: - n_workers = max(1, _ncores // threads_per_worker) + n_workers = max(1, multiprocessing.cpu_count() // threads_per_worker) if n_workers and threads_per_worker is None: # Overcommit threads per worker, rather than undercommit - threads_per_worker = max(1, int(math.ceil(_ncores / n_workers))) + threads_per_worker = max( + 1, int(math.ceil(multiprocessing.cpu_count() / n_workers)) + ) if n_workers and "memory_limit" not in worker_kwargs: worker_kwargs["memory_limit"] = parse_memory_limit("auto", 1, n_workers) worker_kwargs.update( { - "ncores": threads_per_worker, + "nthreads": threads_per_worker, "services": worker_services, "dashboard_address": worker_dashboard_address, "interface": interface, @@ -197,15 +200,15 @@ def __init__( ) def __repr__(self): - return "%s(%r, workers=%d, ncores=%d)" % ( + return "%s(%r, workers=%d, nthreads=%d)" % ( type(self).__name__, self.scheduler_address, len(self.workers), - sum(w.ncores for w in self.workers.values()), + sum(w.nthreads for w in self.workers.values()), ) -def nprocesses_nthreads(n): +def nprocesses_nthreads(n=multiprocessing.cpu_count()): """ The default breakdown of processes and threads for a given number of cores diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index 85728a057e4..bb46f81db88 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -44,8 +44,8 @@ class does handle all of the logic around asynchronously cleanly setting up >>> from dask.distributed import Scheduler, Worker, Nanny >>> scheduler = {'cls': Scheduler, 'options': {"dashboard_address": ':8787'}} >>> workers = { - ... 'my-worker': {"cls": Worker, "options": {"ncores": 1}}, - ... 'my-nanny': {"cls": Nanny, "options": {"ncores": 2}}, + ... 'my-worker': {"cls": Worker, "options": {"nthreads": 1}}, + ... 'my-nanny': {"cls": Nanny, "options": {"nthreads": 2}}, ... } >>> cluster = SpecCluster(scheduler=scheduler, workers=workers) @@ -53,8 +53,8 @@ class does handle all of the logic around asynchronously cleanly setting up >>> cluster.worker_spec { - 'my-worker': {"cls": Worker, "options": {"ncores": 1}}, - 'my-nanny': {"cls": Nanny, "options": {"ncores": 2}}, + 'my-worker': {"cls": Worker, "options": {"nthreads": 1}}, + 'my-nanny': {"cls": Nanny, "options": {"nthreads": 2}}, } While the instantiation of this spec is stored in the ``.workers`` diff --git a/distributed/deploy/tests/test_adaptive.py b/distributed/deploy/tests/test_adaptive.py index cc860636e55..146d7b95dbb 100644 --- a/distributed/deploy/tests/test_adaptive.py +++ b/distributed/deploy/tests/test_adaptive.py @@ -23,11 +23,11 @@ def test_get_scale_up_kwargs(loop): with Client(cluster, loop=loop) as c: future = c.submit(lambda x: x + 1, 1) assert future.result() == 2 - assert c.ncores() + assert c.nthreads() assert alc.get_scale_up_kwargs() == {"n": 3} -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 4) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 4) def test_simultaneous_scale_up_and_down(c, s, *workers): class TestAdaptive(Adaptive): def get_scale_up_kwargs(self): @@ -65,22 +65,22 @@ def test_adaptive_local_cluster(loop): ) as cluster: alc = Adaptive(cluster.scheduler, cluster, interval=100) with Client(cluster, loop=loop) as c: - assert not c.ncores() + assert not c.nthreads() future = c.submit(lambda x: x + 1, 1) assert future.result() == 2 - assert c.ncores() + assert c.nthreads() sleep(0.1) - assert c.ncores() # still there after some time + assert c.nthreads() # still there after some time del future start = time() - while cluster.scheduler.ncores: + while cluster.scheduler.nthreads: sleep(0.01) assert time() < start + 5 - assert not c.ncores() + assert not c.nthreads() @nodebug @@ -128,7 +128,7 @@ def test_adaptive_local_cluster_multi_workers(): yield cluster.close() -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 10, active_rpc_timeout=10) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 10, active_rpc_timeout=10) def test_adaptive_scale_down_override(c, s, *workers): class TestAdaptive(Adaptive): def __init__(self, *args, **kwargs): diff --git a/distributed/deploy/tests/test_local.py b/distributed/deploy/tests/test_local.py index 8aad6675f8c..520996f64a0 100644 --- a/distributed/deploy/tests/test_local.py +++ b/distributed/deploy/tests/test_local.py @@ -2,6 +2,7 @@ from functools import partial import gc +import multiprocessing import subprocess import sys from time import sleep @@ -90,7 +91,7 @@ def test_procs(): assert len(c.workers) == 2 assert all(isinstance(w, Worker) for w in c.workers.values()) with Client(c.scheduler.address) as e: - assert all(w.ncores == 3 for w in c.workers.values()) + assert all(w.nthreads == 3 for w in c.workers.values()) assert all(isinstance(w, Worker) for w in c.workers.values()) repr(c) @@ -105,7 +106,7 @@ def test_procs(): assert len(c.workers) == 2 assert all(isinstance(w, Nanny) for w in c.workers.values()) with Client(c.scheduler.address) as e: - assert all(v == 3 for v in e.ncores().values()) + assert all(v == 3 for v in e.nthreads().values()) c.scale(3) assert all(isinstance(w, Nanny) for w in c.workers.values()) @@ -181,7 +182,7 @@ def test_Client_with_local(loop): 1, scheduler_port=0, silence_logs=False, dashboard_address=None, loop=loop ) as c: with Client(c) as e: - assert len(e.ncores()) == len(c.workers) + assert len(e.nthreads()) == len(c.workers) assert c.scheduler_address in repr(c) @@ -227,33 +228,33 @@ def test_Client_twice(loop): @pytest.mark.skipif("sys.version_info[0] == 2", reason="fork issues") def test_defaults(): - from distributed.worker import _ncores + _nthreads = multiprocessing.cpu_count() with LocalCluster( scheduler_port=0, silence_logs=False, dashboard_address=None ) as c: - assert sum(w.ncores for w in c.workers.values()) == _ncores + assert sum(w.nthreads for w in c.workers.values()) == _nthreads assert all(isinstance(w, Nanny) for w in c.workers.values()) with LocalCluster( processes=False, scheduler_port=0, silence_logs=False, dashboard_address=None ) as c: - assert sum(w.ncores for w in c.workers.values()) == _ncores + assert sum(w.nthreads for w in c.workers.values()) == _nthreads assert all(isinstance(w, Worker) for w in c.workers.values()) assert len(c.workers) == 1 with LocalCluster( n_workers=2, scheduler_port=0, silence_logs=False, dashboard_address=None ) as c: - if _ncores % 2 == 0: - expected_total_threads = max(2, _ncores) + if _nthreads % 2 == 0: + expected_total_threads = max(2, _nthreads) else: - # n_workers not a divisor of _ncores => threads are overcommitted - expected_total_threads = max(2, _ncores + 1) - assert sum(w.ncores for w in c.workers.values()) == expected_total_threads + # n_workers not a divisor of _nthreads => threads are overcommitted + expected_total_threads = max(2, _nthreads + 1) + assert sum(w.nthreads for w in c.workers.values()) == expected_total_threads with LocalCluster( - threads_per_worker=_ncores * 2, + threads_per_worker=_nthreads * 2, scheduler_port=0, silence_logs=False, dashboard_address=None, @@ -261,12 +262,12 @@ def test_defaults(): assert len(c.workers) == 1 with LocalCluster( - n_workers=_ncores * 2, + n_workers=_nthreads * 2, scheduler_port=0, silence_logs=False, dashboard_address=None, ) as c: - assert all(w.ncores == 1 for w in c.workers.values()) + assert all(w.nthreads == 1 for w in c.workers.values()) with LocalCluster( threads_per_worker=2, n_workers=3, @@ -275,7 +276,7 @@ def test_defaults(): dashboard_address=None, ) as c: assert len(c.workers) == 3 - assert all(w.ncores == 2 for w in c.workers.values()) + assert all(w.nthreads == 2 for w in c.workers.values()) def test_worker_params(): @@ -361,7 +362,7 @@ def test_bokeh(loop, processes): @pytest.mark.skipif(sys.version_info < (3, 6), reason="Unknown") def test_blocks_until_full(loop): with Client(loop=loop) as c: - assert len(c.ncores()) > 0 + assert len(c.nthreads()) > 0 @gen_test() @@ -383,7 +384,7 @@ def test_scale_up_and_down(): cluster.scale(2) yield cluster assert len(cluster.workers) == 2 - assert len(cluster.scheduler.ncores) == 2 + assert len(cluster.scheduler.nthreads) == 2 cluster.scale(1) yield cluster diff --git a/distributed/deploy/tests/test_spec_cluster.py b/distributed/deploy/tests/test_spec_cluster.py index eb733f2e68f..0c062d3d3e0 100644 --- a/distributed/deploy/tests/test_spec_cluster.py +++ b/distributed/deploy/tests/test_spec_cluster.py @@ -17,9 +17,9 @@ async def _(): worker_spec = { - 0: {"cls": Worker, "options": {"ncores": 1}}, - 1: {"cls": Worker, "options": {"ncores": 2}}, - "my-worker": {"cls": MyWorker, "options": {"ncores": 3}}, + 0: {"cls": Worker, "options": {"nthreads": 1}}, + 1: {"cls": Worker, "options": {"nthreads": 2}}, + "my-worker": {"cls": MyWorker, "options": {"nthreads": 3}}, } scheduler = {"cls": Scheduler, "options": {"port": 0}} @@ -37,9 +37,9 @@ async def test_specification(): assert isinstance(cluster.workers[1], Worker) assert isinstance(cluster.workers["my-worker"], MyWorker) - assert cluster.workers[0].ncores == 1 - assert cluster.workers[1].ncores == 2 - assert cluster.workers["my-worker"].ncores == 3 + assert cluster.workers[0].nthreads == 1 + assert cluster.workers[1].nthreads == 2 + assert cluster.workers["my-worker"].nthreads == 3 async with Client(cluster, asynchronous=True) as client: result = await client.submit(lambda x: x + 1, 10) @@ -51,9 +51,9 @@ async def test_specification(): def test_spec_sync(loop): worker_spec = { - 0: {"cls": Worker, "options": {"ncores": 1}}, - 1: {"cls": Worker, "options": {"ncores": 2}}, - "my-worker": {"cls": MyWorker, "options": {"ncores": 3}}, + 0: {"cls": Worker, "options": {"nthreads": 1}}, + 1: {"cls": Worker, "options": {"nthreads": 2}}, + "my-worker": {"cls": MyWorker, "options": {"nthreads": 3}}, } with SpecCluster(workers=worker_spec, scheduler=scheduler, loop=loop) as cluster: assert cluster.worker_spec is worker_spec @@ -64,9 +64,9 @@ def test_spec_sync(loop): assert isinstance(cluster.workers[1], Worker) assert isinstance(cluster.workers["my-worker"], MyWorker) - assert cluster.workers[0].ncores == 1 - assert cluster.workers[1].ncores == 2 - assert cluster.workers["my-worker"].ncores == 3 + assert cluster.workers[0].nthreads == 1 + assert cluster.workers[1].nthreads == 2 + assert cluster.workers["my-worker"].nthreads == 3 with Client(cluster, loop=loop) as client: assert cluster.loop is cluster.scheduler.loop @@ -83,7 +83,7 @@ def test_loop_started(): @pytest.mark.asyncio async def test_scale(): - worker = {"cls": Worker, "options": {"ncores": 1}} + worker = {"cls": Worker, "options": {"nthreads": 1}} async with SpecCluster( asynchronous=True, scheduler=scheduler, worker=worker ) as cluster: @@ -134,9 +134,9 @@ async def test_new_worker_spec(): class MyCluster(SpecCluster): def new_worker_spec(self): i = len(self.worker_spec) - return i, {"cls": Worker, "options": {"ncores": i + 1}} + return i, {"cls": Worker, "options": {"nthreads": i + 1}} async with MyCluster(asynchronous=True, scheduler=scheduler) as cluster: cluster.scale(3) for i in range(3): - assert cluster.worker_spec[i]["options"]["ncores"] == i + 1 + assert cluster.worker_spec[i]["options"]["nthreads"] == i + 1 diff --git a/distributed/deploy/utils_test.py b/distributed/deploy/utils_test.py index 9da8d64cd50..2bb55c7da08 100644 --- a/distributed/deploy/utils_test.py +++ b/distributed/deploy/utils_test.py @@ -18,7 +18,7 @@ def tearDown(self): @pytest.mark.xfail() def test_cores(self): info = self.client.scheduler_info() - assert len(self.client.ncores()) == 2 + assert len(self.client.nthreads()) == 2 def test_submit(self): future = self.client.submit(lambda x: x + 1, 1) @@ -27,7 +27,7 @@ def test_submit(self): def test_context_manager(self): with self.Cluster(**self.kwargs) as c: with Client(c) as e: - assert e.ncores() + assert e.nthreads() def test_no_workers(self): with self.Cluster(0, scheduler_port=0, **self.kwargs): diff --git a/distributed/diagnostics/tests/test_eventstream.py b/distributed/diagnostics/tests/test_eventstream.py index 7ec646d7e91..9139f75eab3 100644 --- a/distributed/diagnostics/tests/test_eventstream.py +++ b/distributed/diagnostics/tests/test_eventstream.py @@ -11,7 +11,7 @@ from distributed.utils_test import div, gen_cluster -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 3) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3) def test_eventstream(c, s, *workers): pytest.importorskip("bokeh") diff --git a/distributed/diagnostics/tests/test_plugin.py b/distributed/diagnostics/tests/test_plugin.py index fa4449c74b7..1c9ebd7a1a8 100644 --- a/distributed/diagnostics/tests/test_plugin.py +++ b/distributed/diagnostics/tests/test_plugin.py @@ -34,7 +34,7 @@ def transition(self, key, start, finish, *args, **kwargs): assert counter not in s.plugins -@gen_cluster(ncores=[], client=False) +@gen_cluster(nthreads=[], client=False) def test_add_remove_worker(s): events = [] diff --git a/distributed/diagnostics/tests/test_progressbar.py b/distributed/diagnostics/tests/test_progressbar.py index ba42f2ce6ea..3e5f0633d49 100644 --- a/distributed/diagnostics/tests/test_progressbar.py +++ b/distributed/diagnostics/tests/test_progressbar.py @@ -44,7 +44,7 @@ def test_TextProgressBar_empty(capsys): @gen_test() def f(): s = yield Scheduler(port=0) - a, b = yield [Worker(s.address, ncores=1), Worker(s.address, ncores=1)] + a, b = yield [Worker(s.address, nthreads=1), Worker(s.address, nthreads=1)] progress = TextProgressBar([], scheduler=s.address, start=False, interval=0.01) yield progress.listen() diff --git a/distributed/diagnostics/tests/test_task_stream.py b/distributed/diagnostics/tests/test_task_stream.py index 366de8d79d5..ad23ca5ae8c 100644 --- a/distributed/diagnostics/tests/test_task_stream.py +++ b/distributed/diagnostics/tests/test_task_stream.py @@ -14,7 +14,7 @@ from distributed.metrics import time -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 3) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3) def test_TaskStreamPlugin(c, s, *workers): es = TaskStreamPlugin(s) assert not es.buffer diff --git a/distributed/nanny.py b/distributed/nanny.py index 9cf444fc7c4..d907c7171a1 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -3,6 +3,7 @@ from datetime import timedelta import logging from multiprocessing.queues import Empty +import multiprocessing import os import psutil import shutil @@ -32,7 +33,7 @@ PeriodicCallback, parse_timedelta, ) -from .worker import _ncores, run, parse_memory_limit, Worker +from .worker import run, parse_memory_limit, Worker logger = logging.getLogger(__name__) @@ -54,6 +55,7 @@ def __init__( scheduler_port=None, scheduler_file=None, worker_port=0, + nthreads=None, ncores=None, loop=None, local_dir="dask-worker-space", @@ -96,8 +98,12 @@ def __init__( else: self.scheduler_addr = coerce_to_address((scheduler_ip, scheduler_port)) + if ncores is not None: + warnings.warn("the ncores= parameter has moved to nthreads=") + nthreads = ncores + self._given_worker_port = worker_port - self.ncores = ncores or _ncores + self.nthreads = nthreads or multiprocessing.cpu_count() self.reconnect = reconnect self.validate = validate self.resources = resources @@ -120,7 +126,7 @@ def __init__( self.quiet = quiet self.auto_restart = True - self.memory_limit = parse_memory_limit(memory_limit, self.ncores) + self.memory_limit = parse_memory_limit(memory_limit, self.nthreads) if silence_logs: silence_logging(level=silence_logs) @@ -160,7 +166,7 @@ def __init__( self.status = "init" def __repr__(self): - return "" % (self.worker_address, self.ncores) + return "" % (self.worker_address, self.nthreads) @gen.coroutine def _unregister(self, timeout=10): @@ -263,7 +269,7 @@ def instantiate(self, comm=None): if self.process is None: worker_kwargs = dict( scheduler_ip=self.scheduler_addr, - ncores=self.ncores, + nthreads=self.nthreads, local_dir=self.local_dir, services=self.services, nanny=self.address, diff --git a/distributed/scheduler.py b/distributed/scheduler.py index ae449bcfafe..434d118a422 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -162,9 +162,9 @@ class WorkerState(object): The total memory size, in bytes, used by the tasks this worker holds in memory (i.e. the tasks in this worker's :attr:`has_what`). - .. attribute:: ncores: int + .. attribute:: nthreads: int - The number of CPU cores made available on this worker. + The number of CPU threads made available on this worker. .. attribute:: resources: {str: Number} @@ -218,7 +218,7 @@ class WorkerState(object): "name", "nanny", "nbytes", - "ncores", + "nthreads", "occupancy", "pid", "processing", @@ -234,7 +234,7 @@ def __init__( address=None, pid=0, name=None, - ncores=0, + nthreads=0, memory_limit=0, local_directory=None, services=None, @@ -243,7 +243,7 @@ def __init__( self.address = address self.pid = pid self.name = name - self.ncores = ncores + self.nthreads = nthreads self.memory_limit = memory_limit self.local_directory = local_directory self.services = services or {} @@ -272,7 +272,7 @@ def clean(self): address=self.address, pid=self.pid, name=self.name, - ncores=self.ncores, + nthreads=self.nthreads, memory_limit=self.memory_limit, local_directory=self.local_directory, services=self.services, @@ -299,7 +299,7 @@ def identity(self): "resources": self.resources, "local_directory": self.local_directory, "name": self.name, - "ncores": self.ncores, + "nthreads": self.nthreads, "memory_limit": self.memory_limit, "last_seen": self.last_seen, "services": self.services, @@ -963,7 +963,7 @@ def __init__( # Worker state self.workers = sortedcontainers.SortedDict() for old_attr, new_attr, wrap in [ - ("ncores", "ncores", None), + ("nthreads", "nthreads", None), ("worker_bytes", "nbytes", None), ("worker_resources", "resources", None), ("used_resources", "used_resources", None), @@ -980,7 +980,7 @@ def __init__( self.idle = sortedcontainers.SortedSet(key=operator.attrgetter("address")) self.saturated = set() - self.total_ncores = 0 + self.total_nthreads = 0 self.total_occupancy = 0 self.host_info = defaultdict(dict) self.resources = defaultdict(dict) @@ -1128,7 +1128,7 @@ def __repr__(self): return '' % ( self.address, len(self.workers), - self.total_ncores, + self.total_nthreads, ) def identity(self, comm=None): @@ -1394,7 +1394,7 @@ def add_worker( comm=None, address=None, keys=(), - ncores=None, + nthreads=None, name=None, resolve_address=True, nbytes=None, @@ -1422,7 +1422,7 @@ def add_worker( self.workers[address] = ws = WorkerState( address=address, pid=pid, - ncores=ncores, + nthreads=nthreads, memory_limit=memory_limit, name=name, local_directory=local_directory, @@ -1440,12 +1440,12 @@ def add_worker( return if "addresses" not in self.host_info[host]: - self.host_info[host].update({"addresses": set(), "cores": 0}) + self.host_info[host].update({"addresses": set(), "nthreads": 0}) self.host_info[host]["addresses"].add(address) - self.host_info[host]["cores"] += ncores + self.host_info[host]["nthreads"] += nthreads - self.total_ncores += ncores + self.total_nthreads += nthreads self.aliases[name] = address response = self.heartbeat_worker( @@ -1465,7 +1465,7 @@ def add_worker( self.stream_comms[address] = BatchedSend(interval="5ms", loop=self.loop) - if ws.ncores > len(ws.processing): + if ws.nthreads > len(ws.processing): self.idle.add(ws) for plugin in self.plugins[:]: @@ -1906,9 +1906,9 @@ def remove_worker(self, comm=None, address=None, safe=False, close=True): self.remove_resources(address) - self.host_info[host]["cores"] -= ws.ncores + self.host_info[host]["nthreads"] -= ws.nthreads self.host_info[host]["addresses"].remove(address) - self.total_ncores -= ws.ncores + self.total_nthreads -= ws.nthreads if not self.host_info[host]["addresses"]: del self.host_info[host] @@ -2489,22 +2489,22 @@ def scatter( raise gen.TimeoutError("No workers found") if workers is None: - ncores = {w: ws.ncores for w, ws in self.workers.items()} + nthreads = {w: ws.nthreads for w, ws in self.workers.items()} else: workers = [self.coerce_address(w) for w in workers] - ncores = {w: self.workers[w].ncores for w in workers} + nthreads = {w: self.workers[w].nthreads for w in workers} assert isinstance(data, dict) keys, who_has, nbytes = yield scatter_to_workers( - ncores, data, rpc=self.rpc, report=False + nthreads, data, rpc=self.rpc, report=False ) self.update_data(who_has=who_has, nbytes=nbytes, client=client) if broadcast: if broadcast == True: # noqa: E712 - n = len(ncores) + n = len(nthreads) else: n = broadcast yield self.replicate(keys=keys, workers=workers, n=n) @@ -3283,9 +3283,9 @@ def get_has_what(self, comm=None, workers=None): def get_ncores(self, comm=None, workers=None): if workers is not None: workers = map(self.coerce_address, workers) - return {w: self.workers[w].ncores for w in workers if w in self.workers} + return {w: self.workers[w].nthreads for w in workers if w in self.workers} else: - return {w: ws.ncores for w, ws in self.workers.items()} + return {w: ws.nthreads for w, ws in self.workers.items()} @gen.coroutine def get_call_stack(self, comm=None, keys=None): @@ -4363,19 +4363,19 @@ def check_idle_saturated(self, ws, occ=None): - Idle: do not have enough work to stay busy They are considered saturated if they both have enough tasks to occupy - all of their cores, and if the expected runtime of those tasks is large - enough. + all of their threads, and if the expected runtime of those tasks is + large enough. This is useful for load balancing and adaptivity. """ - if self.total_ncores == 0 or ws.status == "closed": + if self.total_nthreads == 0 or ws.status == "closed": return if occ is None: occ = ws.occupancy - nc = ws.ncores + nc = ws.nthreads p = len(ws.processing) - avg = self.total_occupancy / self.total_ncores + avg = self.total_occupancy / self.total_nthreads if p < nc or occ / nc < avg / 2: self.idle.add(ws) @@ -4538,7 +4538,7 @@ def worker_objective(self, ts, ws): comm_bytes = sum( [dts.get_nbytes() for dts in ts.dependencies if ws not in dts.who_has] ) - stack_time = ws.occupancy / ws.ncores + stack_time = ws.occupancy / ws.nthreads start_time = comm_bytes / self.bandwidth + stack_time if ts.actor: diff --git a/distributed/stealing.py b/distributed/stealing.py index dc8c989e39d..afcdf2a1cfa 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -333,7 +333,7 @@ def maybe_move_task(level, ts, sat, idl, duration, cost_multiplier): saturated = [ ws for ws in saturated - if combined_occupancy(ws) > 0.2 and len(ws.processing) > ws.ncores + if combined_occupancy(ws) > 0.2 and len(ws.processing) > ws.nthreads ] elif len(s.saturated) < 20: saturated = sorted(saturated, key=combined_occupancy, reverse=True) @@ -379,7 +379,7 @@ def maybe_move_task(level, ts, sat, idl, duration, cost_multiplier): continue if combined_occupancy(sat) < 0.2: continue - if len(sat.processing) <= sat.ncores: + if len(sat.processing) <= sat.nthreads: continue i += 1 diff --git a/distributed/tests/test_actor.py b/distributed/tests/test_actor.py index ec2636ccd50..fd6bf0335e1 100644 --- a/distributed/tests/test_actor.py +++ b/distributed/tests/test_actor.py @@ -341,13 +341,13 @@ def add(n, counter): done = c.submit(lambda x: None, futures) while not done.done(): - assert len(s.processing) <= a.ncores + b.ncores + assert len(s.processing) <= a.nthreads + b.nthreads yield gen.sleep(0.01) yield done -@gen_cluster(client=True, ncores=[("127.0.0.1", 5)] * 2) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 5)] * 2) def test_thread_safety(c, s, a, b): class Unsafe(object): def __init__(self): @@ -394,7 +394,7 @@ def __init__(self, x): assert s.tasks[x.key].who_has != s.tasks[y.key].who_has # second load balanced -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 5) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 5) def test_load_balance_map(c, s, *workers): class Foo(object): def __init__(self, x, y=None): @@ -409,7 +409,7 @@ def __init__(self, x, y=None): assert all(len(w.actors) == 2 for w in workers) -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 4, Worker=Nanny) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 4, Worker=Nanny) def bench_param_server(c, s, *workers): import dask.array as da import numpy as np @@ -506,7 +506,7 @@ def check(dask_worker): @gen_cluster( client=True, - ncores=[("127.0.0.1", 1)], + nthreads=[("127.0.0.1", 1)], config={"distributed.worker.profile.interval": "1ms"}, ) def test_actors_in_profile(c, s, a): diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index dc45b3025e5..d18216ef0ef 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -558,7 +558,7 @@ def test_gather_strict(c, s, a, b): assert xx == 2 -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)]) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)]) def test_gather_skip(c, s, a): x = c.submit(div, 1, 0, priority=10) y = c.submit(slowinc, 1, delay=0.5) @@ -953,7 +953,7 @@ def test_remove_worker(c, s, a, b): assert result == list(map(inc, range(20))) -@gen_cluster(ncores=[("127.0.0.1", 1)], client=True) +@gen_cluster(nthreads=[("127.0.0.1", 1)], client=True) def test_errors_dont_block(c, s, w): L = [c.submit(inc, 1), c.submit(throws, 1), c.submit(inc, 2), c.submit(throws, 2)] @@ -1359,13 +1359,13 @@ def test_scatter_direct_broadcast(c, s, a, b): assert not s.counters["op"].components[0]["scatter"] -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 4) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 4) def test_scatter_direct_balanced(c, s, *workers): futures = yield c.scatter([1, 2, 3], direct=True) assert sorted([len(w.data) for w in workers]) == [0, 1, 1, 1] -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 4) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 4) def test_scatter_direct_broadcast_target(c, s, *workers): futures = yield c.scatter([123, 456], direct=True, workers=workers[0].address) assert futures[0].key in workers[0].data @@ -1384,13 +1384,13 @@ def test_scatter_direct_broadcast_target(c, s, *workers): ) -@gen_cluster(client=True, ncores=[]) +@gen_cluster(client=True, nthreads=[]) def test_scatter_direct_empty(c, s): with pytest.raises((ValueError, gen.TimeoutError)): yield c.scatter(123, direct=True, timeout=0.1) -@gen_cluster(client=True, timeout=None, ncores=[("127.0.0.1", 1)] * 5) +@gen_cluster(client=True, timeout=None, nthreads=[("127.0.0.1", 1)] * 5) def test_scatter_direct_spread_evenly(c, s, *workers): futures = [] for i in range(10): @@ -1724,7 +1724,7 @@ def test_start_is_idempotent(c): @gen_cluster(client=True) def test_client_with_scheduler(c, s, a, b): - assert s.ncores == {a.address: a.ncores, b.address: b.ncores} + assert s.nthreads == {a.address: a.nthreads, b.address: b.nthreads} x = c.submit(inc, 1) y = c.submit(inc, 2) @@ -2171,7 +2171,7 @@ def test__broadcast(c, s, a, b): assert a.data == b.data == {x.key: 1, y.key: 2} -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 4) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 4) def test__broadcast_integer(c, s, *workers): x, y = yield c.scatter([1, 2], broadcast=2) assert len(s.tasks[x.key].who_has) == 2 @@ -2486,7 +2486,7 @@ def test_futures_of_cancelled_raises(c, s, a, b): @pytest.mark.skip -@gen_cluster(ncores=[("127.0.0.1", 1)], client=True) +@gen_cluster(nthreads=[("127.0.0.1", 1)], client=True) def test_dont_delete_recomputed_results(c, s, w): x = c.submit(inc, 1) # compute first time yield wait([x]) @@ -2504,7 +2504,7 @@ def test_dont_delete_recomputed_results(c, s, w): yield gen.sleep(0.01) -@gen_cluster(ncores=[], client=True) +@gen_cluster(nthreads=[], client=True) def test_fatally_serialized_input(c, s): o = FatallySerializedObject() @@ -2613,14 +2613,14 @@ def test_diagnostic_ui(loop): a_addr = a["address"] b_addr = b["address"] with Client(s["address"], loop=loop) as c: - d = c.ncores() + d = c.nthreads() assert d == {a_addr: 1, b_addr: 1} - d = c.ncores([a_addr]) + d = c.nthreads([a_addr]) assert d == {a_addr: 1} - d = c.ncores(a_addr) + d = c.nthreads(a_addr) assert d == {a_addr: 1} - d = c.ncores(a["address"]) + d = c.nthreads(a["address"]) assert d == {a_addr: 1} x = c.submit(inc, 1) @@ -2813,7 +2813,7 @@ def test_rebalance(c, s, a, b): assert aws not in s.tasks[x.key].who_has or aws not in s.tasks[y.key].who_has -@gen_cluster(ncores=[("127.0.0.1", 1)] * 4, client=True) +@gen_cluster(nthreads=[("127.0.0.1", 1)] * 4, client=True) def test_rebalance_workers(e, s, a, b, c, d): w, x, y, z = yield e.scatter([1, 2, 3, 4], workers=[a.address]) assert len(a.data) == 4 @@ -2903,11 +2903,11 @@ def test_unrunnable_task_runs(c, s, a, b): yield w.close() -@gen_cluster(client=True, ncores=[]) +@gen_cluster(client=True, nthreads=[]) def test_add_worker_after_tasks(c, s): futures = c.map(inc, range(10)) - n = yield Nanny(s.address, ncores=2, loop=s.loop, port=0) + n = yield Nanny(s.address, nthreads=2, loop=s.loop, port=0) result = yield c.gather(futures) @@ -2939,7 +2939,7 @@ def test_submit_on_cancelled_future(c, s, a, b): y = c.submit(inc, x) -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 10) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 10) def test_replicate(c, s, *workers): [a, b] = yield c.scatter([1, 2]) yield s.replicate(keys=[a.key, b.key], n=5) @@ -2964,7 +2964,7 @@ def test_replicate_tuple_keys(c, s, a, b): s.validate_state() -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 10) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 10) def test_replicate_workers(c, s, *workers): [a, b] = yield c.scatter([1, 2], workers=[workers[0].address]) @@ -3015,7 +3015,7 @@ def __getstate__(self): return self.n -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 10) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 10) def test_replicate_tree_branching(c, s, *workers): obj = CountSerialization() [future] = yield c.scatter([obj]) @@ -3025,7 +3025,7 @@ def test_replicate_tree_branching(c, s, *workers): assert max_count > 1 -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 10) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 10) def test_client_replicate(c, s, *workers): x = c.submit(inc, 1) y = c.submit(inc, 2) @@ -3051,7 +3051,7 @@ def test_client_replicate(c, s, *workers): ) @gen_cluster( client=True, - ncores=[("127.0.0.1", 1), ("127.0.0.2", 1), ("127.0.0.2", 1)], + nthreads=[("127.0.0.1", 1), ("127.0.0.2", 1), ("127.0.0.2", 1)], timeout=None, ) def test_client_replicate_host(client, s, a, b, c): @@ -3087,7 +3087,7 @@ def test_client_replicate_sync(c): @pytest.mark.skipif( sys.platform.startswith("win"), reason="Windows timer too coarse-grained" ) -@gen_cluster(client=True, ncores=[("127.0.0.1", 4)] * 1) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 4)] * 1) def test_task_load_adapts_quickly(c, s, a): future = c.submit(slowinc, 1, delay=0.2) # slow yield wait(future) @@ -3099,7 +3099,7 @@ def test_task_load_adapts_quickly(c, s, a): assert 0 < s.task_duration["slowinc"] < 0.1 -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 2) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2) def test_even_load_after_fast_functions(c, s, a, b): x = c.submit(inc, 1, workers=a.address) # very fast y = c.submit(inc, 2, workers=b.address) # very fast @@ -3113,7 +3113,7 @@ def test_even_load_after_fast_functions(c, s, a, b): # assert abs(len(a.data) - len(b.data)) <= 3 -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 2) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2) def test_even_load_on_startup(c, s, a, b): x, y = c.map(inc, [1, 2]) yield wait([x, y]) @@ -3121,7 +3121,7 @@ def test_even_load_on_startup(c, s, a, b): @pytest.mark.skip -@gen_cluster(client=True, ncores=[("127.0.0.1", 2)] * 2) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 2)] * 2) def test_contiguous_load(c, s, a, b): w, x, y, z = c.map(inc, [1, 2, 3, 4]) yield wait([w, x, y, z]) @@ -3131,7 +3131,7 @@ def test_contiguous_load(c, s, a, b): assert {y.key, z.key} in groups -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 4) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 4) def test_balanced_with_submit(c, s, *workers): L = [c.submit(slowinc, i) for i in range(4)] yield wait(L) @@ -3139,7 +3139,7 @@ def test_balanced_with_submit(c, s, *workers): assert len(w.data) == 1 -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 4) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 4) def test_balanced_with_submit_and_resident_data(c, s, *workers): [x] = yield c.scatter([10], broadcast=True) L = [c.submit(slowinc, x, pure=False) for i in range(4)] @@ -3148,7 +3148,7 @@ def test_balanced_with_submit_and_resident_data(c, s, *workers): assert len(w.data) == 2 -@gen_cluster(client=True, ncores=[("127.0.0.1", 20)] * 2) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 20)] * 2) def test_scheduler_saturates_cores(c, s, a, b): for delay in [0, 0.01, 0.1]: futures = c.map(slowinc, range(100), delay=delay) @@ -3163,7 +3163,7 @@ def test_scheduler_saturates_cores(c, s, a, b): yield gen.sleep(0.01) -@gen_cluster(client=True, ncores=[("127.0.0.1", 20)] * 2) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 20)] * 2) def test_scheduler_saturates_cores_random(c, s, a, b): for delay in [0, 0.01, 0.1]: futures = c.map(randominc, range(100), scale=0.1) @@ -3177,7 +3177,7 @@ def test_scheduler_saturates_cores_random(c, s, a, b): yield gen.sleep(0.01) -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 4) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 4) def test_cancel_clears_processing(c, s, *workers): da = pytest.importorskip("dask.array") x = c.submit(slowinc, 1, delay=0.2) @@ -3255,10 +3255,10 @@ def test_get_foo(c, s, a, b): yield wait(futures) x = yield c.scheduler.ncores() - assert x == s.ncores + assert x == s.nthreads x = yield c.scheduler.ncores(workers=[a.address]) - assert x == {a.address: s.ncores[a.address]} + assert x == {a.address: s.nthreads[a.address]} x = yield c.scheduler.has_what() assert valmap(sorted, x) == valmap(sorted, s.has_what) @@ -3287,7 +3287,7 @@ def assert_dict_key_equal(expected, actual): assert list(ev) == list(av) -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 3) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3) def test_get_foo_lost_keys(c, s, u, v, w): x = c.submit(inc, 1, workers=[u.address]) y = yield c.scatter(3, workers=[v.address]) @@ -3489,7 +3489,7 @@ def test_persist_optimize_graph(c, s, a, b): assert not any(tokey(k) in s.tasks for k in b2.__dask_keys__()) -@gen_cluster(client=True, ncores=[]) +@gen_cluster(client=True, nthreads=[]) def test_scatter_raises_if_no_workers(c, s): with pytest.raises(gen.TimeoutError): yield c.scatter(1, timeout=0.5) @@ -3511,7 +3511,7 @@ def test_reconnect(loop): with popen(scheduler_cli) as s: c = Client("127.0.0.1:9393", loop=loop) start = time() - while len(c.ncores()) != 1: + while len(c.nthreads()) != 1: sleep(0.1) assert time() < start + 3 @@ -3524,7 +3524,7 @@ def test_reconnect(loop): sleep(0.01) with pytest.raises(Exception): - c.ncores() + c.nthreads() assert x.status == "cancelled" with pytest.raises(CancelledError): @@ -3536,7 +3536,7 @@ def test_reconnect(loop): sleep(0.1) assert time() < start + 5 start = time() - while len(c.ncores()) != 1: + while len(c.nthreads()) != 1: sleep(0.05) assert time() < start + 15 @@ -3559,7 +3559,7 @@ def test_reconnect(loop): c.close() -@gen_cluster(client=True, ncores=[], client_kwargs={"timeout": 0.5}) +@gen_cluster(client=True, nthreads=[], client_kwargs={"timeout": 0.5}) def test_reconnect_timeout(c, s): with captured_logger(logging.getLogger("distributed.client")) as logger: yield s.close() @@ -3626,7 +3626,7 @@ def start_worker(sleep, duration, repeat=1): break start = time() - while c.ncores(): + while c.nthreads(): sleep(0.2) assert time() < start + 10 @@ -3748,7 +3748,7 @@ def test_lose_scattered_data(c, s, a, b): assert x.key not in s.tasks -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 3) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3) def test_partially_lose_scattered_data(e, s, a, b, c): x = yield e.scatter(1, workers=a.address) yield e.replicate(x, n=2) @@ -3887,7 +3887,7 @@ def test_temp_client(s, a, b): @nodebug # test timing is fragile -@gen_cluster(ncores=[("127.0.0.1", 1)] * 3, client=True) +@gen_cluster(nthreads=[("127.0.0.1", 1)] * 3, client=True) def test_persist_workers(e, s, a, b, c): L1 = [delayed(inc)(i) for i in range(4)] total = delayed(sum)(L1) @@ -3912,7 +3912,7 @@ def test_persist_workers(e, s, a, b, c): assert s.loose_restrictions == {total2.key} | {v.key for v in L2} -@gen_cluster(ncores=[("127.0.0.1", 1)] * 3, client=True) +@gen_cluster(nthreads=[("127.0.0.1", 1)] * 3, client=True) def test_compute_workers(e, s, a, b, c): L1 = [delayed(inc)(i) for i in range(4)] total = delayed(sum)(L1) @@ -3991,7 +3991,7 @@ def test_retire_workers_2(c, s, a, b): assert a.address not in s.workers -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 10) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 10) def test_retire_many_workers(c, s, *workers): futures = yield c.scatter(list(range(100))) @@ -4000,14 +4000,14 @@ def test_retire_many_workers(c, s, *workers): results = yield c.gather(futures) assert results == list(range(100)) - assert len(s.has_what) == len(s.ncores) == 3 + assert len(s.has_what) == len(s.nthreads) == 3 assert all(future.done() for future in futures) assert all(s.tasks[future.key].state == "memory" for future in futures) for w, keys in s.has_what.items(): assert 15 < len(keys) < 50 -@gen_cluster(client=True, ncores=[("127.0.0.1", 3)] * 2) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 3)] * 2) def test_weight_occupancy_against_data_movement(c, s, a, b): s.extensions["stealing"]._pc.callback_time = 1000000 s.task_duration["f"] = 0.01 @@ -4027,8 +4027,8 @@ def f(x, y=0, z=0): assert sum(f.key in b.data for f in futures) >= 1 -@gen_cluster(client=True, ncores=[("127.0.0.1", 1), ("127.0.0.1", 10)]) -def test_distribute_tasks_by_ncores(c, s, a, b): +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1), ("127.0.0.1", 10)]) +def test_distribute_tasks_by_nthreads(c, s, a, b): s.task_duration["f"] = 0.01 s.extensions["stealing"]._pc.callback_time = 1000000 @@ -4664,7 +4664,7 @@ def test_identity(c, s, a, b): assert s.id.lower().startswith("scheduler") -@gen_cluster(client=True, ncores=[("127.0.0.1", 4)] * 2) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 4)] * 2) def test_get_client(c, s, a, b): assert get_client() is c assert c.asynchronous @@ -4713,7 +4713,7 @@ def f(x): assert result == sum(range(10)) -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 1, timeout=100) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 1, timeout=100) def test_secede_simple(c, s, a): def f(): client = get_client() @@ -4725,7 +4725,7 @@ def f(): @pytest.mark.slow -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 2, timeout=60) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2, timeout=60) def test_secede_balances(c, s, a, b): count = threading.active_count() @@ -5055,7 +5055,7 @@ def test_client_async_before_loop_starts(): client=True, Worker=Nanny if PY3 else Worker, timeout=60, - ncores=[("127.0.0.1", 3)] * 2, + nthreads=[("127.0.0.1", 3)] * 2, ) def test_nested_compute(c, s, a, b): def fib(x): @@ -5220,7 +5220,7 @@ def test_client_doesnt_close_given_loop(loop, s, a, b): assert c.submit(inc, 2).result() == 3 -@gen_cluster(client=True, ncores=[]) +@gen_cluster(client=True, nthreads=[]) def test_quiet_scheduler_loss(c, s): c._periodic_callbacks["scheduler-info"].interval = 10 with captured_logger(logging.getLogger("distributed.client")) as logger: @@ -5368,7 +5368,7 @@ def test_client_repr_closed_sync(loop): c._repr_html_() -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)]) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)]) def test_nested_prioritization(c, s, w): x = delayed(inc)(1, dask_key_name=("a", 2)) y = delayed(inc)(2, dask_key_name=("a", 10)) diff --git a/distributed/tests/test_collections.py b/distributed/tests/test_collections.py index dea4296769d..7cb509f6ac7 100644 --- a/distributed/tests/test_collections.py +++ b/distributed/tests/test_collections.py @@ -188,7 +188,7 @@ def test_sparse_arrays(c, s, a, b): yield future -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)]) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)]) def test_delayed_none(c, s, w): x = dask.delayed(None) y = dask.delayed(123) diff --git a/distributed/tests/test_failed_workers.py b/distributed/tests/test_failed_workers.py index dde92c6d24c..b39dd3f3ae7 100644 --- a/distributed/tests/test_failed_workers.py +++ b/distributed/tests/test_failed_workers.py @@ -39,7 +39,7 @@ def test_submit_after_failed_worker_sync(loop): @gen_cluster(client=True, timeout=60, active_rpc_timeout=10) def test_submit_after_failed_worker_async(c, s, a, b): - n = Nanny(s.address, ncores=2, loop=s.loop) + n = Nanny(s.address, nthreads=2, loop=s.loop) n.start(0) while len(s.workers) < 3: yield gen.sleep(0.1) @@ -79,7 +79,7 @@ def test_gather_after_failed_worker(loop): @gen_cluster( client=True, Worker=Nanny, - ncores=[("127.0.0.1", 1)] * 4, + nthreads=[("127.0.0.1", 1)] * 4, config={"distributed.comm.timeouts.connect": "1s"}, ) def test_gather_then_submit_after_failed_workers(c, s, w, x, y, z): @@ -117,7 +117,7 @@ def test_failed_worker_without_warning(c, s, a, b): yield gen.sleep(0.5) start = time() - while len(s.ncores) < 2: + while len(s.nthreads) < 2: yield gen.sleep(0.01) assert time() - start < 10 @@ -126,7 +126,7 @@ def test_failed_worker_without_warning(c, s, a, b): L2 = c.map(inc, range(10, 20)) yield wait(L2) assert all(len(keys) > 0 for keys in s.has_what.values()) - ncores2 = dict(s.ncores) + nthreads2 = dict(s.nthreads) yield c._restart() @@ -134,12 +134,12 @@ def test_failed_worker_without_warning(c, s, a, b): yield wait(L) assert all(len(keys) > 0 for keys in s.has_what.values()) - assert not (set(ncores2) & set(s.ncores)) # no overlap + assert not (set(nthreads2) & set(s.nthreads)) # no overlap @gen_cluster(Worker=Nanny, client=True, timeout=60) def test_restart(c, s, a, b): - assert s.ncores == {a.worker_address: 1, b.worker_address: 2} + assert s.nthreads == {a.worker_address: 1, b.worker_address: 2} x = c.submit(inc, 1) y = c.submit(inc, x) @@ -185,7 +185,7 @@ def test_restart_sync_no_center(loop): assert x.cancelled() y = c.submit(inc, 2) assert y.result() == 3 - assert len(c.ncores()) == 2 + assert len(c.nthreads()) == 2 def test_restart_sync(loop): @@ -198,7 +198,7 @@ def test_restart_sync(loop): c.restart() assert not sync(loop, c.scheduler.who_has) assert x.cancelled() - assert len(c.ncores()) == 2 + assert len(c.nthreads()) == 2 with pytest.raises(CancelledError): x.result() @@ -214,7 +214,7 @@ def test_restart_fast(c, s, a, b): start = time() yield c._restart() assert time() - start < 10 - assert len(s.ncores) == 2 + assert len(s.nthreads) == 2 assert all(x.status == "cancelled" for x in L) @@ -242,7 +242,7 @@ def test_restart_fast_sync(loop): start = time() c.restart() assert time() - start < 10 - assert len(c.ncores()) == 2 + assert len(c.nthreads()) == 2 assert all(x.status == "cancelled" for x in L) @@ -293,7 +293,7 @@ def test_restart_scheduler(s, a, b): gc.collect() addrs = (a.worker_address, b.worker_address) yield s.restart() - assert len(s.ncores) == 2 + assert len(s.nthreads) == 2 addrs2 = (a.worker_address, b.worker_address) assert addrs != addrs2 @@ -315,11 +315,11 @@ def test_forgotten_futures_dont_clean_up_new_futures(c, s, a, b): @gen_cluster(client=True, timeout=60, active_rpc_timeout=10) def test_broken_worker_during_computation(c, s, a, b): s.allowed_failures = 100 - n = Nanny(s.address, ncores=2, loop=s.loop) + n = Nanny(s.address, nthreads=2, loop=s.loop) n.start(0) start = time() - while len(s.ncores) < 3: + while len(s.nthreads) < 3: yield gen.sleep(0.01) assert time() < start + 5 @@ -368,17 +368,17 @@ def test_restart_during_computation(c, s, a, b): yield c._restart() assert not s.rprocessing - assert len(s.ncores) == 2 + assert len(s.nthreads) == 2 assert not s.tasks @gen_cluster(client=True, timeout=60) def test_worker_who_has_clears_after_failed_connection(c, s, a, b): - n = Nanny(s.address, ncores=2, loop=s.loop) + n = Nanny(s.address, nthreads=2, loop=s.loop) n.start(0) start = time() - while len(s.ncores) < 3: + while len(s.nthreads) < 3: yield gen.sleep(0.01) assert time() < start + 5 @@ -406,7 +406,7 @@ def test_worker_who_has_clears_after_failed_connection(c, s, a, b): @pytest.mark.slow -@gen_cluster(client=True, timeout=60, Worker=Nanny, ncores=[("127.0.0.1", 1)]) +@gen_cluster(client=True, timeout=60, Worker=Nanny, nthreads=[("127.0.0.1", 1)]) def test_restart_timeout_on_long_running_task(c, s, a): with captured_logger("distributed.scheduler") as sio: future = c.submit(sleep, 3600) diff --git a/distributed/tests/test_ipython.py b/distributed/tests/test_ipython.py index 8bb64bb4e0b..a6f88ec5241 100644 --- a/distributed/tests/test_ipython.py +++ b/distributed/tests/test_ipython.py @@ -88,7 +88,7 @@ def test_start_ipython_workers_magic(loop, zmq_ctx): with cluster(2) as (s, [a, b]): with Client(s["address"], loop=loop) as e, mock_ipython() as ip: - workers = list(e.ncores())[:2] + workers = list(e.nthreads())[:2] names = ["magic%i" % i for i in range(len(workers))] info_dict = e.start_ipython_workers(workers, magic_names=names) @@ -116,7 +116,7 @@ def test_start_ipython_workers_magic_asterix(loop, zmq_ctx): with cluster(2) as (s, [a, b]): with Client(s["address"], loop=loop) as e, mock_ipython() as ip: - workers = list(e.ncores())[:2] + workers = list(e.nthreads())[:2] info_dict = e.start_ipython_workers(workers, magic_names="magic_*") expected = [ @@ -144,7 +144,7 @@ def test_start_ipython_remote(loop, zmq_ctx): with cluster(1) as (s, [a]): with Client(s["address"], loop=loop) as e, mock_ipython() as ip: - worker = first(e.ncores()) + worker = first(e.nthreads()) ip.user_ns["info"] = e.start_ipython_workers(worker)[worker] remote_magic("info 1") # line magic remote_magic("info", "worker") # cell magic @@ -165,7 +165,7 @@ def test_start_ipython_qtconsole(loop): with mock.patch("distributed._ipython_utils.Popen", Popen), Client( s["address"], loop=loop ) as e: - worker = first(e.ncores()) + worker = first(e.nthreads()) e.start_ipython_workers(worker, qtconsole=True) e.start_ipython_workers(worker, qtconsole=True, qtconsole_args=["--debug"]) assert Popen.call_count == 2 diff --git a/distributed/tests/test_locks.py b/distributed/tests/test_locks.py index 952d43ceb9b..226feec4faf 100644 --- a/distributed/tests/test_locks.py +++ b/distributed/tests/test_locks.py @@ -11,7 +11,7 @@ from distributed.utils_test import client, cluster_fixture, loop # noqa F401 -@gen_cluster(client=True, ncores=[("127.0.0.1", 8)] * 2) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 8)] * 2) def test_lock(c, s, a, b): c.set_metadata("locked", False) diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index 1357a3679e2..6b6d5bf939d 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -22,28 +22,28 @@ from distributed.utils_test import gen_cluster, gen_test, inc, captured_logger -@gen_cluster(ncores=[]) +@gen_cluster(nthreads=[]) def test_nanny(s): - n = yield Nanny(s.address, ncores=2, loop=s.loop) + n = yield Nanny(s.address, nthreads=2, loop=s.loop) with rpc(n.address) as nn: assert n.is_alive() - assert s.ncores[n.worker_address] == 2 + assert s.nthreads[n.worker_address] == 2 assert s.workers[n.worker_address].nanny == n.address yield nn.kill() assert not n.is_alive() - assert n.worker_address not in s.ncores + assert n.worker_address not in s.nthreads assert n.worker_address not in s.workers yield nn.kill() assert not n.is_alive() - assert n.worker_address not in s.ncores + assert n.worker_address not in s.nthreads assert n.worker_address not in s.workers yield nn.instantiate() assert n.is_alive() - assert s.ncores[n.worker_address] == 2 + assert s.nthreads[n.worker_address] == 2 assert s.workers[n.worker_address].nanny == n.address yield nn.terminate() @@ -52,9 +52,9 @@ def test_nanny(s): yield n.close() -@gen_cluster(ncores=[]) +@gen_cluster(nthreads=[]) def test_many_kills(s): - n = yield Nanny(s.address, ncores=2, loop=s.loop) + n = yield Nanny(s.address, nthreads=2, loop=s.loop) assert n.is_alive() yield [n.kill() for i in range(5)] yield [n.kill() for i in range(5)] @@ -65,13 +65,13 @@ def test_many_kills(s): def test_str(s, a, b): assert a.worker_address in str(a) assert a.worker_address in repr(a) - assert str(a.ncores) in str(a) - assert str(a.ncores) in repr(a) + assert str(a.nthreads) in str(a) + assert str(a.nthreads) in repr(a) -@gen_cluster(ncores=[], timeout=20, client=True) +@gen_cluster(nthreads=[], timeout=20, client=True) def test_nanny_process_failure(c, s): - n = yield Nanny(s.address, ncores=2, loop=s.loop) + n = yield Nanny(s.address, nthreads=2, loop=s.loop) first_dir = n.worker_dir assert os.path.exists(first_dir) @@ -97,7 +97,7 @@ def test_nanny_process_failure(c, s): # assert n.worker_address != original_address # most likely start = time() - while n.worker_address not in s.ncores or n.worker_dir is None: + while n.worker_address not in s.nthreads or n.worker_dir is None: yield gen.sleep(0.01) assert time() - start < 5 @@ -111,10 +111,10 @@ def test_nanny_process_failure(c, s): s.stop() -@gen_cluster(ncores=[]) +@gen_cluster(nthreads=[]) def test_run(s): pytest.importorskip("psutil") - n = yield Nanny(s.address, ncores=2, loop=s.loop) + n = yield Nanny(s.address, nthreads=2, loop=s.loop) with rpc(n.address) as nn: response = yield nn.run(function=dumps(lambda: 1)) @@ -126,7 +126,7 @@ def test_run(s): @pytest.mark.slow @gen_cluster( - Worker=Nanny, ncores=[("127.0.0.1", 1)], worker_kwargs={"reconnect": False} + Worker=Nanny, nthreads=[("127.0.0.1", 1)], worker_kwargs={"reconnect": False} ) def test_close_on_disconnect(s, w): yield s.close() @@ -157,7 +157,7 @@ def test_nanny_alt_worker_class(c, s, w1, w2): @pytest.mark.slow -@gen_cluster(client=False, ncores=[]) +@gen_cluster(client=False, nthreads=[]) def test_nanny_death_timeout(s): yield s.close() w = yield Nanny(s.address, death_timeout=1) @@ -184,7 +184,7 @@ def check_func(func): @pytest.mark.skipif( sys.platform.startswith("win"), reason="num_fds not supported on windows" ) -@gen_cluster(client=False, ncores=[]) +@gen_cluster(client=False, nthreads=[]) def test_num_fds(s): psutil = pytest.importorskip("psutil") proc = psutil.Process() @@ -212,7 +212,7 @@ def test_num_fds(s): @pytest.mark.skipif( not sys.platform.startswith("linux"), reason="Need 127.0.0.2 to mean localhost" ) -@gen_cluster(client=True, ncores=[]) +@gen_cluster(client=True, nthreads=[]) def test_worker_uses_same_host_as_nanny(c, s): for host in ["tcp://0.0.0.0", "tcp://127.0.0.2"]: n = Nanny(s.address) @@ -237,7 +237,7 @@ def test_scheduler_file(): s.stop() -@gen_cluster(client=True, Worker=Nanny, ncores=[("127.0.0.1", 2)]) +@gen_cluster(client=True, Worker=Nanny, nthreads=[("127.0.0.1", 2)]) def test_nanny_timeout(c, s, a): x = yield c.scatter(123) with captured_logger( @@ -255,7 +255,7 @@ def test_nanny_timeout(c, s, a): @gen_cluster( - ncores=[("127.0.0.1", 1)], + nthreads=[("127.0.0.1", 1)], client=True, Worker=Nanny, worker_kwargs={"memory_limit": 1e8}, @@ -283,7 +283,7 @@ def leak(): assert "memory" in out.lower() -@gen_cluster(ncores=[], client=True) +@gen_cluster(nthreads=[], client=True) def test_avoid_memory_monitor_if_zero_limit(c, s): nanny = yield Nanny(s.address, loop=s.loop, memory_limit=0) typ = yield c.run(lambda dask_worker: type(dask_worker.data)) @@ -301,7 +301,7 @@ def test_avoid_memory_monitor_if_zero_limit(c, s): yield nanny.close() -@gen_cluster(ncores=[], client=True) +@gen_cluster(nthreads=[], client=True) def test_scheduler_address_config(c, s): with dask.config.set({"scheduler-address": s.address}): nanny = yield Nanny(loop=s.loop) @@ -329,7 +329,7 @@ def test_wait_for_scheduler(): assert "restart" not in log.lower(), log -@gen_cluster(ncores=[], client=True) +@gen_cluster(nthreads=[], client=True) def test_environment_variable(c, s): a = Nanny(s.address, loop=s.loop, memory_limit=0, env={"FOO": "123"}) b = Nanny(s.address, loop=s.loop, memory_limit=0, env={"FOO": "456"}) @@ -339,7 +339,7 @@ def test_environment_variable(c, s): yield [a.close(), b.close()] -@gen_cluster(ncores=[], client=True) +@gen_cluster(nthreads=[], client=True) def test_data_types(c, s): w = yield Nanny(s.address, data=dict) r = yield c.run(lambda dask_worker: type(dask_worker.data)) @@ -353,7 +353,7 @@ def _noop(x): @gen_cluster( - ncores=[("127.0.0.1", 1)], + nthreads=[("127.0.0.1", 1)], client=True, Worker=Nanny, config={"distributed.worker.daemon": False}, @@ -368,7 +368,7 @@ def multiprocessing_worker(): @gen_cluster( - ncores=[("127.0.0.1", 1)], + nthreads=[("127.0.0.1", 1)], client=True, Worker=Nanny, config={"distributed.worker.daemon": False}, diff --git a/distributed/tests/test_priorities.py b/distributed/tests/test_priorities.py index 421bf7e3028..6258c4e16a7 100644 --- a/distributed/tests/test_priorities.py +++ b/distributed/tests/test_priorities.py @@ -83,7 +83,7 @@ def test_expand_persist(c, s, a, b): assert s.tasks[low.key].state == "processing" -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)]) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)]) def test_repeated_persists_same_priority(c, s, w): xs = [delayed(slowinc)(i, delay=0.05, dask_key_name="x-%d" % i) for i in range(10)] ys = [ @@ -107,7 +107,7 @@ def test_repeated_persists_same_priority(c, s, w): assert any(s.tasks[z.key].state == "memory" for z in zs) -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)]) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)]) def test_last_in_first_out(c, s, w): xs = [c.submit(slowinc, i, delay=0.05) for i in range(5)] ys = [c.submit(slowinc, x, delay=0.05) for x in xs] diff --git a/distributed/tests/test_pubsub.py b/distributed/tests/test_pubsub.py index c44637cf9fd..9d2b30dab6f 100644 --- a/distributed/tests/test_pubsub.py +++ b/distributed/tests/test_pubsub.py @@ -49,7 +49,7 @@ def pingpong(a, b, start=False, n=1000, msg=1): # print('duration', stop - start) # I get around 3ms/roundtrip on my laptop -@gen_cluster(client=True, ncores=[]) +@gen_cluster(client=True, nthreads=[]) def test_client(c, s): with pytest.raises(Exception): get_worker() diff --git a/distributed/tests/test_queues.py b/distributed/tests/test_queues.py index e40d3cd492c..2e7702171ad 100644 --- a/distributed/tests/test_queues.py +++ b/distributed/tests/test_queues.py @@ -2,7 +2,6 @@ from datetime import timedelta from time import sleep -import sys import pytest from tornado import gen @@ -113,9 +112,8 @@ def f(x): assert q.get() == 11 -@pytest.mark.skipif(sys.version_info[0] == 2, reason="Multi-client issues") @pytest.mark.slow -@gen_cluster(client=True, ncores=[("127.0.0.1", 2)] * 5, Worker=Nanny, timeout=None) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 2)] * 5, Worker=Nanny, timeout=None) def test_race(c, s, *workers): def f(i): with worker_client() as c: diff --git a/distributed/tests/test_resources.py b/distributed/tests/test_resources.py index 480532d912e..1985d44e2a3 100644 --- a/distributed/tests/test_resources.py +++ b/distributed/tests/test_resources.py @@ -14,7 +14,7 @@ from distributed.utils_test import client, cluster_fixture, loop, s, a, b # noqa: F401 -@gen_cluster(client=True, ncores=[]) +@gen_cluster(client=True, nthreads=[]) def test_resources(c, s): assert not s.worker_resources assert not s.resources @@ -37,7 +37,7 @@ def test_resources(c, s): @gen_cluster( client=True, - ncores=[ + nthreads=[ ("127.0.0.1", 1, {"resources": {"A": 5}}), ("127.0.0.1", 1, {"resources": {"A": 1, "B": 1}}), ], @@ -65,7 +65,7 @@ def test_resource_submit(c, s, a, b): @gen_cluster( client=True, - ncores=[ + nthreads=[ ("127.0.0.1", 1, {"resources": {"A": 1}}), ("127.0.0.1", 1, {"resources": {"B": 1}}), ], @@ -80,7 +80,7 @@ def test_submit_many_non_overlapping(c, s, a, b): @gen_cluster( client=True, - ncores=[ + nthreads=[ ("127.0.0.1", 1, {"resources": {"A": 1}}), ("127.0.0.1", 1, {"resources": {"B": 1}}), ], @@ -96,7 +96,7 @@ def test_move(c, s, a, b): @gen_cluster( client=True, - ncores=[ + nthreads=[ ("127.0.0.1", 1, {"resources": {"A": 1}}), ("127.0.0.1", 1, {"resources": {"B": 1}}), ], @@ -114,7 +114,7 @@ def test_dont_work_steal(c, s, a, b): @gen_cluster( client=True, - ncores=[ + nthreads=[ ("127.0.0.1", 1, {"resources": {"A": 1}}), ("127.0.0.1", 1, {"resources": {"B": 1}}), ], @@ -128,7 +128,7 @@ def test_map(c, s, a, b): @gen_cluster( client=True, - ncores=[ + nthreads=[ ("127.0.0.1", 1, {"resources": {"A": 1}}), ("127.0.0.1", 1, {"resources": {"B": 1}}), ], @@ -147,7 +147,7 @@ def test_persist(c, s, a, b): @gen_cluster( client=True, - ncores=[ + nthreads=[ ("127.0.0.1", 1, {"resources": {"A": 1}}), ("127.0.0.1", 1, {"resources": {"B": 11}}), ], @@ -170,7 +170,7 @@ def test_compute(c, s, a, b): @gen_cluster( client=True, - ncores=[ + nthreads=[ ("127.0.0.1", 1, {"resources": {"A": 1}}), ("127.0.0.1", 1, {"resources": {"B": 1}}), ], @@ -184,7 +184,7 @@ def test_get(c, s, a, b): @gen_cluster( client=True, - ncores=[ + nthreads=[ ("127.0.0.1", 1, {"resources": {"A": 1}}), ("127.0.0.1", 1, {"resources": {"B": 1}}), ], @@ -222,7 +222,7 @@ def test_resources_str(c, s, a, b): @gen_cluster( client=True, - ncores=[ + nthreads=[ ("127.0.0.1", 4, {"resources": {"A": 2}}), ("127.0.0.1", 4, {"resources": {"A": 1}}), ], @@ -240,7 +240,7 @@ def test_submit_many_non_overlapping(c, s, a, b): assert b.total_resources == b.available_resources -@gen_cluster(client=True, ncores=[("127.0.0.1", 4, {"resources": {"A": 2, "B": 1}})]) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 4, {"resources": {"A": 2, "B": 1}})]) def test_minimum_resource(c, s, a): futures = c.map(slowinc, range(30), resources={"A": 1, "B": 1}, delay=0.02) @@ -252,7 +252,7 @@ def test_minimum_resource(c, s, a): assert a.total_resources == a.available_resources -@gen_cluster(client=True, ncores=[("127.0.0.1", 2, {"resources": {"A": 1}})]) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 2, {"resources": {"A": 1}})]) def test_prefer_constrained(c, s, a): futures = c.map(slowinc, range(1000), delay=0.1) constrained = c.map(inc, range(10), resources={"A": 1}) @@ -270,7 +270,7 @@ def test_prefer_constrained(c, s, a): @pytest.mark.skip(reason="") @gen_cluster( client=True, - ncores=[ + nthreads=[ ("127.0.0.1", 2, {"resources": {"A": 1}}), ("127.0.0.1", 2, {"resources": {"A": 1}}), ], @@ -284,7 +284,7 @@ def test_balance_resources(c, s, a, b): assert any(f.key in b.data for f in constrained) -@gen_cluster(client=True, ncores=[("127.0.0.1", 2)]) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 2)]) def test_set_resources(c, s, a): yield a.set_resources(A=2) assert a.total_resources["A"] == 2 @@ -303,7 +303,7 @@ def test_set_resources(c, s, a): @gen_cluster( client=True, - ncores=[ + nthreads=[ ("127.0.0.1", 1, {"resources": {"A": 1}}), ("127.0.0.1", 1, {"resources": {"B": 1}}), ], @@ -325,7 +325,7 @@ def test_persist_collections(c, s, a, b): @pytest.mark.skip(reason="Should protect resource keys from optimization") @gen_cluster( client=True, - ncores=[ + nthreads=[ ("127.0.0.1", 1, {"resources": {"A": 1}}), ("127.0.0.1", 1, {"resources": {"B": 1}}), ], @@ -346,7 +346,7 @@ def test_dont_optimize_out(c, s, a, b): @pytest.mark.xfail(reason="atop fusion seemed to break this") @gen_cluster( client=True, - ncores=[ + nthreads=[ ("127.0.0.1", 1, {"resources": {"A": 1}}), ("127.0.0.1", 1, {"resources": {"B": 1}}), ], diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 66a8088ace5..e8d2a96ee60 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -50,11 +50,11 @@ def test_administration(s, a, b): assert isinstance(s.address, str) assert s.address in str(s) - assert str(sum(s.ncores.values())) in repr(s) - assert str(len(s.ncores)) in repr(s) + assert str(sum(s.nthreads.values())) in repr(s) + assert str(len(s.nthreads)) in repr(s) -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)]) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)]) def test_respect_data_in_memory(c, s, a): x = delayed(inc)(1) y = delayed(inc)(x) @@ -106,14 +106,14 @@ def test_decide_worker_with_many_independent_leaves(c, s, a, b): assert nhits > 80 -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 3) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3) def test_decide_worker_with_restrictions(client, s, a, b, c): x = client.submit(inc, 1, workers=[a.address, b.address]) yield wait(x) assert x.key in a.data or x.key in b.data -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 3) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3) def test_move_data_over_break_restrictions(client, s, a, b, c): [x] = yield client.scatter([1], workers=b.address) y = client.submit(inc, x, workers=[a.address, b.address]) @@ -121,7 +121,7 @@ def test_move_data_over_break_restrictions(client, s, a, b, c): assert y.key in a.data or y.key in b.data -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 3) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3) def test_balance_with_restrictions(client, s, a, b, c): [x], [y] = yield [ client.scatter([[1, 2, 3]], workers=a.address), @@ -133,7 +133,7 @@ def test_balance_with_restrictions(client, s, a, b, c): assert s.tasks[z.key].who_has == {s.workers[c.address]} -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 3) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3) def test_no_valid_workers(client, s, a, b, c): x = client.submit(inc, 1, workers="127.0.0.5:9999") while not s.tasks: @@ -145,7 +145,7 @@ def test_no_valid_workers(client, s, a, b, c): yield gen.with_timeout(timedelta(milliseconds=50), x) -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 3) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3) def test_no_valid_workers_loose_restrictions(client, s, a, b, c): x = client.submit(inc, 1, workers="127.0.0.5:9999", allow_other_workers=True) @@ -153,7 +153,7 @@ def test_no_valid_workers_loose_restrictions(client, s, a, b, c): assert result == 2 -@gen_cluster(client=True, ncores=[]) +@gen_cluster(client=True, nthreads=[]) def test_no_workers(client, s): x = client.submit(inc, 1) while not s.tasks: @@ -165,7 +165,7 @@ def test_no_workers(client, s): yield gen.with_timeout(timedelta(milliseconds=50), x) -@gen_cluster(ncores=[]) +@gen_cluster(nthreads=[]) def test_retire_workers_empty(s): yield s.retire_workers(workers=[]) @@ -207,7 +207,7 @@ def test_remove_worker_from_scheduler(s, a, b): assert a.address in s.stream_comms s.remove_worker(address=a.address) - assert a.address not in s.ncores + assert a.address not in s.nthreads assert len(s.workers[b.address].processing) == len(dsk) # b owns everything s.validate_state() @@ -215,14 +215,14 @@ def test_remove_worker_from_scheduler(s, a, b): @gen_cluster(config={"distributed.scheduler.events-cleanup-delay": "10 ms"}) def test_clear_events_worker_removal(s, a, b): assert a.address in s.events - assert a.address in s.ncores + assert a.address in s.nthreads assert b.address in s.events - assert b.address in s.ncores + assert b.address in s.nthreads s.remove_worker(address=a.address) # Shortly after removal, the events should still be there assert a.address in s.events - assert a.address not in s.ncores + assert a.address not in s.nthreads s.validate_state() start = time() @@ -253,7 +253,7 @@ def test_clear_events_client_removal(c, s, a, b): @gen_cluster() def test_add_worker(s, a, b): - w = Worker(s.address, ncores=3) + w = Worker(s.address, nthreads=3) w.data["x-5"] = 6 w.data["y"] = 1 yield w @@ -267,7 +267,7 @@ def test_add_worker(s, a, b): ) s.add_worker( - address=w.address, keys=list(w.data), ncores=w.ncores, services=s.services + address=w.address, keys=list(w.data), nthreads=w.nthreads, services=s.services ) s.validate_state() @@ -389,7 +389,7 @@ def test_delete_data(c, s, a, b): assert time() < start + 5 -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)]) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)]) def test_delete(c, s, a): x = c.submit(inc, 1) yield x @@ -481,12 +481,12 @@ def test_ready_remove_worker(s, a, b): dependencies={"x-%d" % i: [] for i in range(20)}, ) - assert all(len(w.processing) > w.ncores for w in s.workers.values()) + assert all(len(w.processing) > w.nthreads for w in s.workers.values()) s.remove_worker(address=a.address) assert set(s.workers) == {b.address} - assert all(len(w.processing) > w.ncores for w in s.workers.values()) + assert all(len(w.processing) > w.nthreads for w in s.workers.values()) @gen_cluster(client=True, Worker=Nanny) @@ -591,7 +591,7 @@ def test_coerce_address(): @pytest.mark.skipif( sys.platform.startswith("win"), reason="file descriptors not really a thing" ) -@gen_cluster(ncores=[]) +@gen_cluster(nthreads=[]) def test_file_descriptors_dont_leak(s): psutil = pytest.importorskip("psutil") proc = psutil.Process() @@ -624,12 +624,12 @@ def test_update_graph_culls(s, a, b): assert "z" not in s.dependencies -@gen_cluster(ncores=[]) +@gen_cluster(nthreads=[]) def test_add_worker_is_idempotent(s): - s.add_worker(address=alice, ncores=1, resolve_address=False) - ncores = dict(s.ncores) + s.add_worker(address=alice, nthreads=1, resolve_address=False) + nthreads = dict(s.nthreads) s.add_worker(address=alice, resolve_address=False) - assert s.ncores == s.ncores + assert s.nthreads == s.nthreads def test_io_loop(loop): @@ -654,7 +654,7 @@ def test_story(c, s, a, b): assert len(s.story(x.key, y.key)) > len(story) -@gen_cluster(ncores=[], client=True) +@gen_cluster(nthreads=[], client=True) def test_scatter_no_workers(c, s): with pytest.raises(gen.TimeoutError): yield s.scatter(data={"x": 1}, client="alice", timeout=0.1) @@ -664,16 +664,16 @@ def test_scatter_no_workers(c, s): yield c.scatter(123, timeout=0.1) assert time() < start + 1.5 - w = Worker(s.address, ncores=3) + w = Worker(s.address, nthreads=3) yield [c.scatter(data={"y": 2}, timeout=5), w._start()] assert w.data["y"] == 2 yield w.close() -@gen_cluster(ncores=[]) +@gen_cluster(nthreads=[]) def test_scheduler_sees_memory_limits(s): - w = yield Worker(s.address, ncores=3, memory_limit=12345) + w = yield Worker(s.address, nthreads=3, memory_limit=12345) assert s.workers[w.address].memory_limit == 12345 yield w.close() @@ -688,8 +688,8 @@ def test_retire_workers(c, s, a, b): workers = yield s.retire_workers() assert list(workers) == [a.address] - assert workers[a.address]["ncores"] == a.ncores - assert list(s.ncores) == [b.address] + assert workers[a.address]["nthreads"] == a.nthreads + assert list(s.nthreads) == [b.address] assert s.workers_to_close() == [] @@ -717,7 +717,7 @@ def test_retire_workers_n(c, s, a, b): yield gen.sleep(0.01) -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 4) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 4) def test_workers_to_close(cl, s, *workers): s.task_duration["a"] = 4 s.task_duration["b"] = 4 @@ -732,7 +732,7 @@ def test_workers_to_close(cl, s, *workers): assert len(wtc) == 1 -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 4) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 4) def test_workers_to_close_grouped(c, s, *workers): groups = { workers[0].address: "a", @@ -782,7 +782,7 @@ def test_retire_workers_no_suspicious_tasks(c, s, a, b): sys.platform.startswith("win"), reason="file descriptors not really a thing" ) @pytest.mark.skipif(sys.version_info < (3, 6), reason="intermittent failure") -@gen_cluster(client=True, ncores=[], timeout=240) +@gen_cluster(client=True, nthreads=[], timeout=240) def test_file_descriptors(c, s): yield gen.sleep(0.1) psutil = pytest.importorskip("psutil") @@ -793,7 +793,7 @@ def test_file_descriptors(c, s): N = 20 nannies = yield [Nanny(s.address, loop=s.loop) for i in range(N)] - while len(s.ncores) < N: + while len(s.nthreads) < N: yield gen.sleep(0.1) num_fds_2 = proc.num_fds() @@ -876,7 +876,7 @@ def test_occupancy_cleardown(c, s, a, b): @nodebug -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 30) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 30) def test_balance_many_workers(c, s, *workers): futures = c.map(slowinc, range(20), delay=0.2) yield wait(futures) @@ -884,7 +884,7 @@ def test_balance_many_workers(c, s, *workers): @nodebug -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 30) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 30) def test_balance_many_workers_2(c, s, *workers): s.extensions["stealing"]._pc.callback_time = 100000000 futures = c.map(slowinc, range(90), delay=0.2) @@ -932,7 +932,7 @@ def test_worker_arrives_with_processing_data(c, s, a, b): while not any(w.processing for w in s.workers.values()): yield gen.sleep(0.01) - w = Worker(s.address, ncores=1) + w = Worker(s.address, nthreads=1) w.put_key_in_memory(y.key, 3) yield w @@ -951,7 +951,7 @@ def test_worker_arrives_with_processing_data(c, s, a, b): yield w.close() -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)]) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)]) def test_worker_breaks_and_returns(c, s, a): future = c.submit(slowinc, 1, delay=0.1) for i in range(10): @@ -972,7 +972,7 @@ def test_worker_breaks_and_returns(c, s, a): assert states == {"memory": 1, "released": 10} -@gen_cluster(client=True, ncores=[]) +@gen_cluster(client=True, nthreads=[]) def test_no_workers_to_memory(c, s): x = delayed(slowinc)(1, delay=0.4) y = delayed(slowinc)(x, delay=0.4) @@ -983,7 +983,7 @@ def test_no_workers_to_memory(c, s): while not s.tasks: yield gen.sleep(0.01) - w = Worker(s.address, ncores=1) + w = Worker(s.address, nthreads=1) w.put_key_in_memory(y.key, 3) yield w @@ -1013,7 +1013,7 @@ def test_no_worker_to_memory_restrictions(c, s, a, b): while not s.tasks: yield gen.sleep(0.01) - w = Worker(s.address, ncores=1, name="alice") + w = Worker(s.address, nthreads=1, name="alice") w.put_key_in_memory(y.key, 3) yield w @@ -1122,7 +1122,7 @@ def test_retire_nannies_close(c, s, a, b): assert not s.workers -@gen_cluster(client=True, ncores=[("127.0.0.1", 2)]) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 2)]) def test_fifo_submission(c, s, w): futures = [] for i in range(20): @@ -1147,17 +1147,17 @@ def test_scheduler_file(): @pytest.mark.xfail(reason="") -@gen_cluster(client=True, ncores=[]) +@gen_cluster(client=True, nthreads=[]) def test_non_existent_worker(c, s): with dask.config.set({"distributed.comm.timeouts.connect": "100ms"}): - s.add_worker(address="127.0.0.1:5738", ncores=2, nbytes={}, host_info={}) + s.add_worker(address="127.0.0.1:5738", nthreads=2, nbytes={}, host_info={}) futures = c.map(inc, range(10)) yield gen.sleep(0.300) assert not s.workers assert all(ts.state == "no-worker" for ts in s.tasks.values()) -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 3) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3) def test_correct_bad_time_estimate(c, s, *workers): future = c.submit(slowinc, 1, delay=0) yield wait(future) @@ -1255,7 +1255,7 @@ def test_log_tasks_during_restart(c, s, a, b): assert "exit" in str(s.events) -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 2) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2) def test_reschedule(c, s, a, b): yield c.submit(slowinc, -1, delay=0.1) # learn cost x = c.map(slowinc, range(4), delay=0.1) @@ -1324,7 +1324,7 @@ def test_retries(c, s, a, b): @pytest.mark.xfail(reason="second worker also errant for some reason") -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 3, timeout=5) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3, timeout=5) def test_mising_data_errant_worker(c, s, w1, w2, w3): with dask.config.set({"distributed.comm.timeouts.connect": "1s"}): np = pytest.importorskip("numpy") @@ -1457,7 +1457,7 @@ def test_closing_scheduler_closes_workers(s, a, b): @gen_cluster( - client=True, ncores=[("127.0.0.1", 1)], worker_kwargs={"resources": {"A": 1}} + client=True, nthreads=[("127.0.0.1", 1)], worker_kwargs={"resources": {"A": 1}} ) def test_resources_reset_after_cancelled_task(c, s, w): future = c.submit(sleep, 0.2, resources={"A": 1}) diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 7348d164c72..d233fc28388 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -37,7 +37,7 @@ @pytest.mark.skipif( not sys.platform.startswith("linux"), reason="Need 127.0.0.2 to mean localhost" ) -@gen_cluster(client=True, ncores=[("127.0.0.1", 2), ("127.0.0.2", 2)], timeout=20) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 2), ("127.0.0.2", 2)], timeout=20) def test_work_stealing(c, s, a, b): [x] = yield c._scatter([1], workers=a.address) futures = c.map(slowadd, range(50), [x] * 50) @@ -46,7 +46,7 @@ def test_work_stealing(c, s, a, b): assert len(b.data) > 10 -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 2) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2) def test_dont_steal_expensive_data_fast_computation(c, s, a, b): np = pytest.importorskip("numpy") x = c.submit(np.arange, 1000000, workers=a.address) @@ -64,7 +64,7 @@ def test_dont_steal_expensive_data_fast_computation(c, s, a, b): assert len(a.data) == 12 -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 2) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2) def test_steal_cheap_data_slow_computation(c, s, a, b): x = c.submit(slowinc, 100, delay=0.1) # learn that slowinc is slow yield wait(x) @@ -77,7 +77,7 @@ def test_steal_cheap_data_slow_computation(c, s, a, b): @pytest.mark.avoid_travis -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 2) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2) def test_steal_expensive_data_slow_computation(c, s, a, b): np = pytest.importorskip("numpy") @@ -94,7 +94,7 @@ def test_steal_expensive_data_slow_computation(c, s, a, b): assert b.data # not empty -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 10) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 10) def test_worksteal_many_thieves(c, s, *workers): x = c.submit(slowinc, -1, delay=0.1) yield x @@ -110,7 +110,7 @@ def test_worksteal_many_thieves(c, s, *workers): assert sum(map(len, s.has_what.values())) < 150 -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 2) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2) def test_dont_steal_unknown_functions(c, s, a, b): futures = c.map(inc, [1, 2], workers=a.address, allow_other_workers=True) yield wait(futures) @@ -118,7 +118,7 @@ def test_dont_steal_unknown_functions(c, s, a, b): assert len(b.data) == 0 -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 2) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2) def test_eventually_steal_unknown_functions(c, s, a, b): futures = c.map( slowinc, range(10), delay=0.1, workers=a.address, allow_other_workers=True @@ -129,7 +129,7 @@ def test_eventually_steal_unknown_functions(c, s, a, b): @pytest.mark.skip(reason="") -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 3) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3) def test_steal_related_tasks(e, s, a, b, c): futures = e.map( slowinc, range(20), delay=0.05, workers=a.address, allow_other_workers=True @@ -145,7 +145,7 @@ def test_steal_related_tasks(e, s, a, b, c): assert nearby > 10 -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 10, timeout=1000) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 10, timeout=1000) def test_dont_steal_fast_tasks(c, s, *workers): np = pytest.importorskip("numpy") x = c.submit(np.random.random, 10000000, workers=workers[0].address) @@ -163,7 +163,7 @@ def do_nothing(x, y=None): assert len(s.has_what[workers[0].address]) == 1001 -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)], timeout=20) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)], timeout=20) def test_new_worker_steals(c, s, a): yield wait(c.submit(slowinc, 1, delay=0.01)) @@ -172,7 +172,7 @@ def test_new_worker_steals(c, s, a): while len(a.task_state) < 10: yield gen.sleep(0.01) - b = yield Worker(s.address, loop=s.loop, ncores=1, memory_limit=TOTAL_MEMORY) + b = yield Worker(s.address, loop=s.loop, nthreads=1, memory_limit=TOTAL_MEMORY) result = yield total assert result == sum(map(inc, range(100))) @@ -204,7 +204,7 @@ def test_work_steal_no_kwargs(c, s, a, b): assert result == sum(map(inc, range(100))) -@gen_cluster(client=True, ncores=[("127.0.0.1", 1), ("127.0.0.1", 2)]) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1), ("127.0.0.1", 2)]) def test_dont_steal_worker_restrictions(c, s, a, b): future = c.submit(slowinc, 1, delay=0.10, workers=a.address) yield future @@ -228,7 +228,7 @@ def test_dont_steal_worker_restrictions(c, s, a, b): @pytest.mark.skipif( not sys.platform.startswith("linux"), reason="Need 127.0.0.2 to mean localhost" ) -@gen_cluster(client=True, ncores=[("127.0.0.1", 1), ("127.0.0.2", 1)]) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1), ("127.0.0.2", 1)]) def test_dont_steal_host_restrictions(c, s, a, b): future = c.submit(slowinc, 1, delay=0.10, workers=a.address) yield future @@ -247,7 +247,7 @@ def test_dont_steal_host_restrictions(c, s, a, b): @gen_cluster( - client=True, ncores=[("127.0.0.1", 1, {"resources": {"A": 2}}), ("127.0.0.1", 1)] + client=True, nthreads=[("127.0.0.1", 1, {"resources": {"A": 2}}), ("127.0.0.1", 1)] ) def test_dont_steal_resource_restrictions(c, s, a, b): future = c.submit(slowinc, 1, delay=0.10, workers=a.address) @@ -267,7 +267,9 @@ def test_dont_steal_resource_restrictions(c, s, a, b): @pytest.mark.skip(reason="no stealing of resources") -@gen_cluster(client=True, ncores=[("127.0.0.1", 1, {"resources": {"A": 2}})], timeout=3) +@gen_cluster( + client=True, nthreads=[("127.0.0.1", 1, {"resources": {"A": 2}})], timeout=3 +) def test_steal_resource_restrictions(c, s, a): future = c.submit(slowinc, 1, delay=0.10, workers=a.address) yield future @@ -277,7 +279,7 @@ def test_steal_resource_restrictions(c, s, a): yield gen.sleep(0.01) assert len(a.task_state) == 101 - b = yield Worker(s.address, loop=s.loop, ncores=1, resources={"A": 4}) + b = yield Worker(s.address, loop=s.loop, nthreads=1, resources={"A": 4}) start = time() while not b.task_state or len(a.task_state) == 101: @@ -290,7 +292,7 @@ def test_steal_resource_restrictions(c, s, a): yield b.close() -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 5, timeout=20) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 5, timeout=20) def test_balance_without_dependencies(c, s, *workers): s.extensions["stealing"]._pc.callback_time = 20 @@ -306,7 +308,7 @@ def slow(x): assert max(durations) / min(durations) < 3 -@gen_cluster(client=True, ncores=[("127.0.0.1", 4)] * 2) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 4)] * 2) def test_dont_steal_executing_tasks(c, s, a, b): futures = c.map( slowinc, range(4), delay=0.1, workers=a.address, allow_other_workers=True @@ -317,7 +319,7 @@ def test_dont_steal_executing_tasks(c, s, a, b): assert len(b.data) == 0 -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 10) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 10) def test_dont_steal_few_saturated_tasks_many_workers(c, s, a, *rest): s.extensions["stealing"]._pc.callback_time = 20 x = c.submit(mul, b"0", 100000000, workers=a.address) # 100 MB @@ -334,7 +336,7 @@ def test_dont_steal_few_saturated_tasks_many_workers(c, s, a, *rest): @gen_cluster( client=True, - ncores=[("127.0.0.1", 1)] * 10, + nthreads=[("127.0.0.1", 1)] * 10, worker_kwargs={"memory_limit": TOTAL_MEMORY}, ) def test_steal_when_more_tasks(c, s, a, *rest): @@ -351,7 +353,7 @@ def test_steal_when_more_tasks(c, s, a, *rest): assert time() < start + 1 -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 10) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 10) def test_steal_more_attractive_tasks(c, s, a, *rest): def slow2(x): sleep(1) @@ -473,11 +475,11 @@ def assert_balanced(inp, expected, c, s, *workers): ) def test_balance(inp, expected): test = lambda *args, **kwargs: assert_balanced(inp, expected, *args, **kwargs) - test = gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * len(inp))(test) + test = gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * len(inp))(test) test() -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 2, Worker=Nanny, timeout=20) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2, Worker=Nanny, timeout=20) def test_restart(c, s, a, b): futures = c.map( slowinc, range(100), delay=0.1, workers=a.address, allow_other_workers=True @@ -569,7 +571,7 @@ def test_dont_steal_executing_tasks(c, s, a, b): assert not b.executing -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 2) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2) def test_dont_steal_long_running_tasks(c, s, a, b): def long(delay): with worker_client() as c: @@ -603,7 +605,7 @@ def long(delay): ) <= 1 -@gen_cluster(client=True, ncores=[("127.0.0.1", 5)] * 2) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 5)] * 2) def test_cleanup_repeated_tasks(c, s, a, b): class Foo(object): pass @@ -635,7 +637,7 @@ class Foo(object): assert not list(ws) -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 2) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2) def test_lose_task(c, s, a, b): with captured_logger("distributed.stealing") as log: s.periodic_callbacks["stealing"].interval = 1 diff --git a/distributed/tests/test_stress.py b/distributed/tests/test_stress.py index 8c37f5a82fb..81d7c4360f7 100644 --- a/distributed/tests/test_stress.py +++ b/distributed/tests/test_stress.py @@ -64,7 +64,7 @@ def test_stress_gc(loop, func, n): @pytest.mark.skipif( sys.platform.startswith("win"), reason="test can leave dangling RPC objects" ) -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 8, timeout=None) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 8, timeout=None) def test_cancel_stress(c, s, *workers): da = pytest.importorskip("dask.array") x = da.random.random((50, 50), chunks=(2, 2)) @@ -93,7 +93,7 @@ def test_cancel_stress_sync(loop): c.cancel(f) -@gen_cluster(ncores=[], client=True, timeout=None) +@gen_cluster(nthreads=[], client=True, timeout=None) def test_stress_creation_and_deletion(c, s): # Assertions are handled by the validate mechanism in the scheduler s.allowed_failures = 100000 @@ -108,7 +108,7 @@ def test_stress_creation_and_deletion(c, s): def create_and_destroy_worker(delay): start = time() while time() < start + 5: - n = Nanny(s.address, ncores=2, loop=s.loop) + n = Nanny(s.address, nthreads=2, loop=s.loop) n.start(0) yield gen.sleep(delay) @@ -122,7 +122,7 @@ def create_and_destroy_worker(delay): ) -@gen_cluster(ncores=[("127.0.0.1", 1)] * 10, client=True, timeout=60) +@gen_cluster(nthreads=[("127.0.0.1", 1)] * 10, client=True, timeout=60) def test_stress_scatter_death(c, s, *workers): import random @@ -198,7 +198,7 @@ def vsum(*args): @pytest.mark.avoid_travis @pytest.mark.slow -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 80, timeout=1000) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 80, timeout=1000) def test_stress_communication(c, s, *workers): s.validate = False # very slow otherwise da = pytest.importorskip("dask.array") @@ -218,7 +218,7 @@ def test_stress_communication(c, s, *workers): @pytest.mark.skip -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 10, timeout=60) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 10, timeout=60) def test_stress_steal(c, s, *workers): s.validate = False for w in workers: @@ -244,7 +244,7 @@ def test_stress_steal(c, s, *workers): @pytest.mark.slow -@gen_cluster(ncores=[("127.0.0.1", 1)] * 10, client=True, timeout=120) +@gen_cluster(nthreads=[("127.0.0.1", 1)] * 10, client=True, timeout=120) def test_close_connections(c, s, *workers): da = pytest.importorskip("dask.array") x = da.random.random(size=(1000, 1000), chunks=(1000, 1)) @@ -269,7 +269,7 @@ def test_close_connections(c, s, *workers): reason="IOStream._handle_write blocks on large write_buffer" " https://github.com/tornadoweb/tornado/issues/2110" ) -@gen_cluster(client=True, timeout=20, ncores=[("127.0.0.1", 1)]) +@gen_cluster(client=True, timeout=20, nthreads=[("127.0.0.1", 1)]) def test_no_delay_during_large_transfer(c, s, w): pytest.importorskip("crick") np = pytest.importorskip("numpy") diff --git a/distributed/tests/test_tls_functional.py b/distributed/tests/test_tls_functional.py index 74a9cf3cbd4..7d097e28112 100644 --- a/distributed/tests/test_tls_functional.py +++ b/distributed/tests/test_tls_functional.py @@ -82,7 +82,7 @@ def test_nanny(c, s, a, b): assert isinstance(n, Nanny) assert n.address.startswith("tls://") assert n.worker_address.startswith("tls://") - assert s.ncores == {n.worker_address: n.ncores for n in [a, b]} + assert s.nthreads == {n.worker_address: n.nthreads for n in [a, b]} x = c.submit(inc, 10) result = yield x @@ -101,7 +101,7 @@ def test_rebalance(c, s, a, b): assert len(b.data) == 1 -@gen_tls_cluster(client=True, ncores=[("tls://127.0.0.1", 2)] * 2) +@gen_tls_cluster(client=True, nthreads=[("tls://127.0.0.1", 2)] * 2) def test_work_stealing(c, s, a, b): [x] = yield c._scatter([1], workers=a.address) futures = c.map(slowadd, range(50), [x] * 50, delay=0.1) @@ -127,7 +127,7 @@ def func(x): assert yy == 20 + 1 + (20 + 1) * 2 -@gen_tls_cluster(client=True, ncores=[("tls://127.0.0.1", 1)] * 2) +@gen_tls_cluster(client=True, nthreads=[("tls://127.0.0.1", 1)] * 2) def test_worker_client_gather(c, s, a, b): a_address = a.address b_address = b.address diff --git a/distributed/tests/test_utils_test.py b/distributed/tests/test_utils_test.py index 6f704c23f5b..c0afb9e2c7f 100755 --- a/distributed/tests/test_utils_test.py +++ b/distributed/tests/test_utils_test.py @@ -50,7 +50,7 @@ def test_gen_cluster(c, s, a, b): assert isinstance(s, Scheduler) for w in [a, b]: assert isinstance(w, Worker) - assert s.ncores == {w.address: w.ncores for w in [a, b]} + assert s.nthreads == {w.address: w.nthreads for w in [a, b]} @pytest.mark.skip(reason="This hangs on travis") @@ -74,13 +74,13 @@ def test_gen_cluster_without_client(s, a, b): assert isinstance(s, Scheduler) for w in [a, b]: assert isinstance(w, Worker) - assert s.ncores == {w.address: w.ncores for w in [a, b]} + assert s.nthreads == {w.address: w.nthreads for w in [a, b]} @gen_cluster( client=True, scheduler="tls://127.0.0.1", - ncores=[("tls://127.0.0.1", 1), ("tls://127.0.0.1", 2)], + nthreads=[("tls://127.0.0.1", 1), ("tls://127.0.0.1", 2)], security=tls_only_security(), ) def test_gen_cluster_tls(e, s, a, b): @@ -90,7 +90,7 @@ def test_gen_cluster_tls(e, s, a, b): for w in [a, b]: assert isinstance(w, Worker) assert w.address.startswith("tls://") - assert s.ncores == {w.address: w.ncores for w in [a, b]} + assert s.nthreads == {w.address: w.nthreads for w in [a, b]} @gen_test() diff --git a/distributed/tests/test_variable.py b/distributed/tests/test_variable.py index 4d8851668f9..e734cc3094f 100644 --- a/distributed/tests/test_variable.py +++ b/distributed/tests/test_variable.py @@ -148,7 +148,7 @@ def test_timeout_get(c, s, a, b): @pytest.mark.skipif(sys.version_info[0] == 2, reason="Multi-client issues") @pytest.mark.slow -@gen_cluster(client=True, ncores=[("127.0.0.1", 2)] * 5, Worker=Nanny, timeout=None) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 2)] * 5, Worker=Nanny, timeout=None) def test_race(c, s, *workers): NITERS = 50 diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 12a6b5ff68f..fa0cf857a32 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -3,6 +3,7 @@ from concurrent.futures import ThreadPoolExecutor from datetime import timedelta import logging +import multiprocessing from numbers import Number from operator import add import os @@ -49,12 +50,10 @@ ) -def test_worker_ncores(): - from distributed.worker import _ncores - +def test_worker_nthreads(): w = Worker("127.0.0.1", 8019) try: - assert w.executor._max_workers == _ncores + assert w.executor._max_workers == multiprocessing.cpu_count() finally: shutil.rmtree(w.local_dir) @@ -63,8 +62,8 @@ def test_worker_ncores(): def test_str(s, a, b): assert a.address in str(a) assert a.address in repr(a) - assert str(a.ncores) in str(a) - assert str(a.ncores) in repr(a) + assert str(a.nthreads) in str(a) + assert str(a.nthreads) in repr(a) assert str(len(a.executing)) in repr(a) @@ -73,7 +72,7 @@ def test_identity(): ident = w.identity(None) assert "Worker" in ident["type"] assert ident["scheduler"] == "tcp://127.0.0.1:8019" - assert isinstance(ident["ncores"], int) + assert isinstance(ident["nthreads"], int) assert isinstance(ident["memory_limit"], Number) @@ -198,7 +197,7 @@ def g(): @pytest.mark.skip(reason="don't yet support uploading pyc files") -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)]) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)]) def test_upload_file_pyc(c, s, w): with tmpfile() as dirname: os.mkdir(dirname) @@ -329,7 +328,7 @@ def f(): loop.run_sync(f) -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)]) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)]) def test_worker_task_data(c, s, w): x = delayed(2) xx = c.persist(x) @@ -370,7 +369,7 @@ async def test_io_loop(): await w.close() -@gen_cluster(client=True, ncores=[]) +@gen_cluster(client=True, nthreads=[]) def test_spill_to_disk(c, s): np = pytest.importorskip("numpy") w = yield Worker( @@ -441,7 +440,7 @@ def f(dask_worker=None): assert response == {a.address: a.id, b.address: b.id} -@gen_cluster(client=True, ncores=[]) +@gen_cluster(client=True, nthreads=[]) def test_Executor(c, s): with ThreadPoolExecutor(2) as e: w = Worker(s.address, executor=e) @@ -462,7 +461,7 @@ def test_Executor(c, s): ) @gen_cluster( client=True, - ncores=[("127.0.0.1", 1)], + nthreads=[("127.0.0.1", 1)], timeout=30, worker_kwargs={"memory_limit": 10e6}, ) @@ -475,7 +474,7 @@ def test_spill_by_default(c, s, w): del x, y -@gen_cluster(ncores=[("127.0.0.1", 1)], worker_kwargs={"reconnect": False}) +@gen_cluster(nthreads=[("127.0.0.1", 1)], worker_kwargs={"reconnect": False}) def test_close_on_disconnect(s, w): yield s.close() @@ -486,10 +485,10 @@ def test_close_on_disconnect(s, w): def test_memory_limit_auto(): - a = Worker("127.0.0.1", 8099, ncores=1) - b = Worker("127.0.0.1", 8099, ncores=2) - c = Worker("127.0.0.1", 8099, ncores=100) - d = Worker("127.0.0.1", 8099, ncores=200) + a = Worker("127.0.0.1", 8099, nthreads=1) + b = Worker("127.0.0.1", 8099, nthreads=2) + c = Worker("127.0.0.1", 8099, nthreads=100) + d = Worker("127.0.0.1", 8099, nthreads=200) assert isinstance(a.memory_limit, Number) assert isinstance(b.memory_limit, Number) @@ -585,7 +584,7 @@ def test_system_monitor(s, a, b): @gen_cluster( - client=True, ncores=[("127.0.0.1", 2, {"resources": {"A": 1}}), ("127.0.0.1", 1)] + client=True, nthreads=[("127.0.0.1", 2, {"resources": {"A": 1}}), ("127.0.0.1", 1)] ) def test_restrictions(c, s, a, b): # Resource restrictions @@ -615,7 +614,7 @@ def test_clean_nbytes(c, s, a, b): assert len(a.nbytes) + len(b.nbytes) == 1 -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 20) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 20) def test_gather_many_small(c, s, a, *workers): a.total_out_connections = 2 futures = yield c._scatter(list(range(100))) @@ -636,7 +635,7 @@ def f(*args): assert a.comm_nbytes == 0 -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 3) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3) def test_multiple_transfers(c, s, w1, w2, w3): x = c.submit(inc, 1, workers=w1.address) y = c.submit(inc, 2, workers=w2.address) @@ -649,7 +648,7 @@ def test_multiple_transfers(c, s, w1, w2, w3): assert len(transfers) == 2 -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 3) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3) def test_share_communication(c, s, w1, w2, w3): x = c.submit(mul, b"1", int(w3.target_message_size + 1), workers=w1.address) y = c.submit(mul, b"2", int(w3.target_message_size + 1), workers=w2.address) @@ -733,7 +732,7 @@ def test_hold_onto_dependents(c, s, a, b): @pytest.mark.slow -@gen_cluster(client=False, ncores=[]) +@gen_cluster(client=False, nthreads=[]) def test_worker_death_timeout(s): with dask.config.set({"distributed.comm.timeouts.connect": "1s"}): yield s.close() @@ -756,7 +755,7 @@ def test_stop_doing_unnecessary_work(c, s, a, b): assert time() - start < 0.5 -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)]) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)]) def test_priorities(c, s, w): values = [] for i in range(10): @@ -842,7 +841,7 @@ def __sizeof__(self): @pytest.mark.skip(reason="Our logic here is faulty") @gen_cluster( - ncores=[("127.0.0.1", 2)], client=True, worker_kwargs={"memory_limit": 10e9} + nthreads=[("127.0.0.1", 2)], client=True, worker_kwargs={"memory_limit": 10e9} ) def test_fail_write_many_to_disk(c, s, a): a.validate = False @@ -947,7 +946,7 @@ def test_global_workers(s, a, b): @pytest.mark.skipif(WINDOWS, reason="file descriptors") -@gen_cluster(ncores=[]) +@gen_cluster(nthreads=[]) def test_worker_fds(s): psutil = pytest.importorskip("psutil") yield gen.sleep(0.05) @@ -969,7 +968,7 @@ def test_worker_fds(s): assert time() < start + 0.5 -@gen_cluster(ncores=[]) +@gen_cluster(nthreads=[]) def test_service_hosts_match_worker(s): pytest.importorskip("bokeh") from distributed.dashboard import BokehWorker @@ -995,7 +994,7 @@ def test_service_hosts_match_worker(s): yield w.close() -@gen_cluster(ncores=[]) +@gen_cluster(nthreads=[]) def test_start_services(s): pytest.importorskip("bokeh") from distributed.dashboard import BokehWorker @@ -1051,7 +1050,7 @@ def test_statistical_profiling_2(c, s, a, b): @gen_cluster( - ncores=[("127.0.0.1", 1)], + nthreads=[("127.0.0.1", 1)], client=True, worker_kwargs={"memory_monitor_interval": 10}, ) @@ -1082,7 +1081,7 @@ def f(n): @pytest.mark.slow @gen_cluster( - ncores=[("127.0.0.1", 2)], + nthreads=[("127.0.0.1", 2)], client=True, worker_kwargs={ "memory_monitor_interval": 10, @@ -1151,7 +1150,7 @@ def some_name(): assert result.startswith("some_name") -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 2) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2) def test_reschedule(c, s, a, b): s.extensions["stealing"]._pc.stop() a_address = a.address @@ -1180,7 +1179,7 @@ def test_deque_handler(): assert any(msg.msg == "foo456" for msg in deque_handler.deque) -@gen_cluster(ncores=[], client=True) +@gen_cluster(nthreads=[], client=True) def test_avoid_memory_monitor_if_zero_limit(c, s): worker = yield Worker( s.address, loop=s.loop, memory_limit=0, memory_monitor_interval=10 @@ -1198,7 +1197,7 @@ def test_avoid_memory_monitor_if_zero_limit(c, s): @gen_cluster( - ncores=[("127.0.0.1", 1)], + nthreads=[("127.0.0.1", 1)], config={ "distributed.worker.memory.spill": False, "distributed.worker.memory.target": False, @@ -1223,12 +1222,12 @@ def func(dask_scheduler): assert time() < start + 10 -@gen_cluster(ncores=[("127.0.0.1", 1)], worker_kwargs={"memory_limit": "2e3 MB"}) +@gen_cluster(nthreads=[("127.0.0.1", 1)], worker_kwargs={"memory_limit": "2e3 MB"}) def test_parse_memory_limit(s, w): assert w.memory_limit == 2e9 -@gen_cluster(ncores=[], client=True) +@gen_cluster(nthreads=[], client=True) def test_scheduler_address_config(c, s): with dask.config.set({"scheduler-address": s.address}): worker = yield Worker(loop=s.loop) @@ -1257,7 +1256,9 @@ def test_wait_for_outgoing(c, s, a, b): @pytest.mark.skipif( not sys.platform.startswith("linux"), reason="Need 127.0.0.2 to mean localhost" ) -@gen_cluster(ncores=[("127.0.0.1", 1), ("127.0.0.1", 1), ("127.0.0.2", 1)], client=True) +@gen_cluster( + nthreads=[("127.0.0.1", 1), ("127.0.0.1", 1), ("127.0.0.2", 1)], client=True +) def test_prefer_gather_from_local_address(c, s, w1, w2, w3): x = yield c.scatter(123, workers=[w1.address, w3.address], broadcast=True) @@ -1270,7 +1271,7 @@ def test_prefer_gather_from_local_address(c, s, w1, w2, w3): @gen_cluster( client=True, - ncores=[("127.0.0.1", 1)] * 20, + nthreads=[("127.0.0.1", 1)] * 20, timeout=30, config={"distributed.worker.connections.incoming": 1}, ) @@ -1367,7 +1368,7 @@ def test_register_worker_callbacks_err(c, s, a, b): yield c.register_worker_callbacks(setup=lambda: 1 / 0) -@gen_cluster(ncores=[]) +@gen_cluster(nthreads=[]) def test_data_types(s): w = yield Worker(s.address, data=dict) assert isinstance(w.data, dict) @@ -1389,7 +1390,7 @@ def __init__(self, x, y): yield w.close() -@gen_cluster(ncores=[]) +@gen_cluster(nthreads=[]) def test_local_dir(s): with tmpfile() as fn: with dask.config.set(temporary_directory=fn): @@ -1401,7 +1402,7 @@ def test_local_dir(s): @pytest.mark.skipif( not sys.platform.startswith("linux"), reason="Need 127.0.0.2 to mean localhost" ) -@gen_cluster(ncores=[], client=True) +@gen_cluster(nthreads=[], client=True) def test_host_address(c, s): w = yield Worker(s.address, host="127.0.0.2") assert "127.0.0.2" in w.address diff --git a/distributed/tests/test_worker_client.py b/distributed/tests/test_worker_client.py index 9c4616e9d26..fe1d49def6d 100644 --- a/distributed/tests/test_worker_client.py +++ b/distributed/tests/test_worker_client.py @@ -42,7 +42,7 @@ def func(x): assert len([id for id in s.wants_what if id.lower().startswith("client")]) == 1 -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 2) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2) def test_scatter_from_worker(c, s, a, b): def func(): with worker_client() as c: @@ -78,12 +78,12 @@ def func(): assert result is True start = time() - while not all(v == 1 for v in s.ncores.values()): + while not all(v == 1 for v in s.nthreads.values()): yield gen.sleep(0.1) assert time() < start + 5 -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 2) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2) def test_scatter_singleton(c, s, a, b): np = pytest.importorskip("numpy") @@ -96,7 +96,7 @@ def func(): yield c.submit(func) -@gen_cluster(client=True, ncores=[("127.0.0.1", 1)] * 2) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2) def test_gather_multi_machine(c, s, a, b): a_address = a.address b_address = b.address @@ -162,7 +162,7 @@ def mysum(): assert time() < start + 3 -@gen_cluster(client=True, ncores=[("127.0.0.1", 3)]) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 3)]) def test_separate_thread_false(c, s, a): a.count = 0 diff --git a/distributed/tests/test_worker_plugins.py b/distributed/tests/test_worker_plugins.py index 25388459788..425a267923a 100644 --- a/distributed/tests/test_worker_plugins.py +++ b/distributed/tests/test_worker_plugins.py @@ -19,7 +19,7 @@ def teardown(self, worker): self.worker._my_plugin_status = "teardown" -@gen_cluster(client=True, ncores=[]) +@gen_cluster(client=True, nthreads=[]) def test_create_with_client(c, s): yield c.register_worker_plugin(MyPlugin(123)) diff --git a/distributed/utils_comm.py b/distributed/utils_comm.py index d2bd19908af..e6f5235afe0 100644 --- a/distributed/utils_comm.py +++ b/distributed/utils_comm.py @@ -110,19 +110,19 @@ def __repr__(self): @gen.coroutine -def scatter_to_workers(ncores, data, rpc=rpc, report=True, serializers=None): +def scatter_to_workers(nthreads, data, rpc=rpc, report=True, serializers=None): """ Scatter data directly to workers This distributes data in a round-robin fashion to a set of workers based on - how many cores they have. ncores should be a dictionary mapping worker + how many cores they have. nthreads should be a dictionary mapping worker identities to numbers of cores. See scatter for parameter docstring """ - assert isinstance(ncores, dict) + assert isinstance(nthreads, dict) assert isinstance(data, dict) - workers = list(concat([w] * nc for w, nc in ncores.items())) + workers = list(concat([w] * nc for w, nc in nthreads.items())) names, data = list(zip(*data.items())) worker_iter = drop(_round_robin_counter[0] % len(workers), cycle(workers)) diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 10784c6f759..89e0f3283e1 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -646,7 +646,7 @@ def cluster( q = mp_context.Queue() fn = "_test_worker-%s" % uuid.uuid4() kwargs = merge( - {"ncores": 1, "local_dir": fn, "memory_limit": TOTAL_MEMORY}, + {"nthreads": 1, "local_dir": fn, "memory_limit": TOTAL_MEMORY}, worker_kwargs, ) proc = mp_context.Process( @@ -678,8 +678,8 @@ def cluster( with rpc(saddr, **rpc_kwargs) as s: while True: - ncores = loop.run_sync(s.ncores) - if len(ncores) == nworkers: + nthreads = loop.run_sync(s.ncores) + if len(nthreads) == nworkers: break if time() - start > 5: raise Exception("Timeout on cluster creation") @@ -783,7 +783,7 @@ def test_func(): @gen.coroutine def start_cluster( - ncores, + nthreads, scheduler_addr, loop, security=None, @@ -798,7 +798,7 @@ def start_cluster( workers = [ Worker( s.address, - ncores=ncore[1], + nthreads=ncore[1], name=i, security=security, loop=loop, @@ -806,7 +806,7 @@ def start_cluster( host=ncore[0], **(merge(worker_kwargs, ncore[2]) if len(ncore) > 2 else worker_kwargs) ) - for i, ncore in enumerate(ncores) + for i, ncore in enumerate(nthreads) ] # for w in workers: # w.rpc = workers[0].rpc @@ -814,7 +814,7 @@ def start_cluster( yield workers start = time() - while len(s.workers) < len(ncores) or any( + while len(s.workers) < len(nthreads) or any( comm.comm is None for comm in s.stream_comms.values() ): yield gen.sleep(0.01) @@ -840,7 +840,7 @@ def end_worker(w): def gen_cluster( - ncores=[("127.0.0.1", 1), ("127.0.0.1", 2)], + nthreads=[("127.0.0.1", 1), ("127.0.0.1", 2)], scheduler="127.0.0.1", timeout=10, security=None, @@ -885,7 +885,7 @@ def coro(): for i in range(5): try: s, ws = yield start_cluster( - ncores, + nthreads, scheduler, loop, security=security, @@ -1406,7 +1406,7 @@ def bump_rlimit(limit, desired): def gen_tls_cluster(**kwargs): - kwargs.setdefault("ncores", [("tls://127.0.0.1", 1), ("tls://127.0.0.1", 2)]) + kwargs.setdefault("nthreads", [("tls://127.0.0.1", 1), ("tls://127.0.0.1", 2)]) return gen_cluster( scheduler="tls://127.0.0.1", security=tls_only_security(), **kwargs ) diff --git a/distributed/worker.py b/distributed/worker.py index 37dcbc2eca1..86d078254f0 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -5,6 +5,7 @@ from datetime import timedelta import heapq import logging +import multiprocessing import os from pickle import PicklingError import random @@ -53,7 +54,6 @@ _maybe_complex, log_errors, ignoring, - mp_context, import_file, silence_logging, thread_state, @@ -69,8 +69,6 @@ from .utils_comm import pack_data, gather_from_workers from .utils_perf import ThrottledGC, enable_gc_diagnosis, disable_gc_diagnosis -_ncores = mp_context.cpu_count() - logger = logging.getLogger(__name__) LOG_PDB = dask.config.get("distributed.admin.pdb-on-err") @@ -116,8 +114,8 @@ class Worker(ServerNode): These attributes don't change significantly during execution. - * **ncores:** ``int``: - Number of cores used by this worker process + * **nthreads:** ``int``: + Number of nthreads used by this worker process * **executor:** ``concurrent.futures.ThreadPoolExecutor``: Executor used to perform computation * **local_dir:** ``path``: @@ -233,7 +231,7 @@ class Worker(ServerNode): ip: str, optional data: MutableMapping, type, None The object to use for storage, builds a disk-backed LRU dict by default - ncores: int, optional + nthreads: int, optional loop: tornado.ioloop.IOLoop local_dir: str, optional Directory where we place local resources @@ -241,7 +239,7 @@ class Worker(ServerNode): memory_limit: int, float, string Number of bytes of memory that this worker should use. Set to zero for no limit. Set to 'auto' to calculate - as TOTAL_MEMORY * min(1, ncores / total_cores) + as TOTAL_MEMORY * min(1, nthreads / total_cores) Use strings or numbers like 5GB or 5e9 memory_target_fraction: float Fraction of memory to try to stay beneath @@ -281,6 +279,7 @@ def __init__( scheduler_port=None, scheduler_file=None, ncores=None, + nthreads=None, loop=None, local_dir=None, services=None, @@ -432,7 +431,11 @@ def __init__( security=security, ) - self.ncores = ncores or _ncores + if ncores is not None: + warnings.warn("the ncores= parameter has moved to nthreads=") + nthreads = ncores + + self.nthreads = nthreads or multiprocessing.cpu_count() self.total_resources = resources or {} self.available_resources = (resources or {}).copy() self.death_timeout = parse_timedelta(death_timeout) @@ -471,7 +474,7 @@ def __init__( self.connection_args = self.security.get_connection_args("worker") self.listen_args = self.security.get_listen_args("worker") - self.memory_limit = parse_memory_limit(memory_limit, self.ncores) + self.memory_limit = parse_memory_limit(memory_limit, self.nthreads) self.paused = False @@ -526,7 +529,7 @@ def __init__( self._closed = Event() self.reconnect = reconnect self.executor = executor or ThreadPoolExecutor( - self.ncores, thread_name_prefix="Dask-Worker-Threads'" + self.nthreads, thread_name_prefix="Dask-Worker-Threads'" ) self.actor_executor = ThreadPoolExecutor( 1, thread_name_prefix="Dask-Actor-Threads" @@ -658,7 +661,7 @@ def __repr__(self): self.status, len(self.data), len(self.executing), - self.ncores, + self.nthreads, len(self.ready), len(self.in_flight_tasks), len(self.waiting_for_data), @@ -687,7 +690,8 @@ def identity(self, comm=None): "type": type(self).__name__, "id": self.id, "scheduler": self.scheduler.address, - "ncores": self.ncores, + "nthreads": self.nthreads, + "ncores": self.nthreads, # backwards compatibility "memory_limit": self.memory_limit, } @@ -722,7 +726,7 @@ def _register_with_scheduler(self): reply=False, address=self.contact_address, keys=list(self.data), - ncores=self.ncores, + nthreads=self.nthreads, name=self.name, nbytes=self.nbytes, types=types, @@ -941,7 +945,7 @@ def _start(self, addr_or_port=0): logger.info(" %16s at: %26s" % (k, listen_host + ":" + str(v))) logger.info("Waiting to connect to: %26s", self.scheduler.address) logger.info("-" * 49) - logger.info(" Threads: %26d", self.ncores) + logger.info(" Threads: %26d", self.nthreads) if self.memory_limit: logger.info(" Memory: %26s", format_bytes(self.memory_limit)) logger.info(" Local Directory: %26s", self.local_dir) @@ -2283,7 +2287,7 @@ def ensure_computing(self): if self.paused: return try: - while self.constrained and len(self.executing) < self.ncores: + while self.constrained and len(self.executing) < self.nthreads: key = self.constrained[0] if self.task_state.get(key) != "constrained": self.constrained.popleft() @@ -2293,7 +2297,7 @@ def ensure_computing(self): self.transition(key, "executing") else: break - while self.ready and len(self.executing) < self.ncores: + while self.ready and len(self.executing) < self.nthreads: _, key = heapq.heappop(self.ready) if self.task_state.get(key) in READY: self.transition(key, "executing") @@ -2955,12 +2959,12 @@ class Reschedule(Exception): pass -def parse_memory_limit(memory_limit, ncores, total_cores=_ncores): +def parse_memory_limit(memory_limit, nthreads, total_cores=multiprocessing.cpu_count()): if memory_limit is None: return None if memory_limit == "auto": - memory_limit = int(TOTAL_MEMORY * min(1, ncores / total_cores)) + memory_limit = int(TOTAL_MEMORY * min(1, nthreads / total_cores)) with ignoring(ValueError, TypeError): memory_limit = float(memory_limit) if isinstance(memory_limit, float) and memory_limit <= 1: diff --git a/docs/source/api.rst b/docs/source/api.rst index 574a70d34b6..adefe5b86c4 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -24,7 +24,7 @@ API Client.has_what Client.list_datasets Client.map - Client.ncores + Client.nthreads Client.persist Client.publish_dataset Client.profile diff --git a/docs/source/local-cluster.rst b/docs/source/local-cluster.rst index c415bbfbeba..d596ccaed24 100644 --- a/docs/source/local-cluster.rst +++ b/docs/source/local-cluster.rst @@ -7,7 +7,7 @@ For convenience you can start a local cluster from your Python session. >>> from distributed import Client, LocalCluster >>> cluster = LocalCluster() - LocalCluster("127.0.0.1:8786", workers=8, ncores=8) + LocalCluster("127.0.0.1:8786", workers=8, nthreads=8) >>> client = Client(cluster) diff --git a/docs/source/protocol.rst b/docs/source/protocol.rst index 645ba4aa905..334e2c0e4bd 100644 --- a/docs/source/protocol.rst +++ b/docs/source/protocol.rst @@ -25,7 +25,7 @@ In practice we represent these messages with dictionaries/mappings:: {'op': 'register-worker', 'address': '192.168.1.42', 'name': 'alice', - 'ncores': 4} + 'nthreads': 4} {'x': b'...', 'y': b'...'} diff --git a/docs/source/scheduling-state.rst b/docs/source/scheduling-state.rst index 90db367767f..515bb26cdb0 100644 --- a/docs/source/scheduling-state.rst +++ b/docs/source/scheduling-state.rst @@ -112,7 +112,7 @@ containers to help with scheduling tasks: .. attribute:: Scheduler.saturated: {WorkerState} A set of workers whose computing power (as - measured by :attr:`WorkerState.ncores`) is fully exploited by processing + measured by :attr:`WorkerState.nthreads`) is fully exploited by processing tasks, and whose current :attr:`~WorkerState.occupancy` is a lot greater than the average. diff --git a/docs/source/worker.rst b/docs/source/worker.rst index 4b835d7ba67..530a27b9505 100644 --- a/docs/source/worker.rst +++ b/docs/source/worker.rst @@ -100,7 +100,7 @@ are the available options:: --name TEXT Alias --memory-limit TEXT Maximum bytes of memory that this worker should use. Use 0 for unlimited, or 'auto' for - TOTAL_MEMORY * min(1, ncores / total_cores) + TOTAL_MEMORY * min(1, nthreads / total_nthreads) --no-nanny --help Show this message and exit. @@ -151,7 +151,7 @@ command line ``--memory-limit`` keyword or the ``memory_limit=`` Python keyword argument, which sets the memory limit per worker processes launched by dask-worker :: - $ dask-worker tcp://scheduler:port --memory-limit=auto # TOTAL_MEMORY * min(1, ncores / total_cores) + $ dask-worker tcp://scheduler:port --memory-limit=auto # TOTAL_MEMORY * min(1, nthreads / total_nthreads) $ dask-worker tcp://scheduler:port --memory-limit=4e9 # four gigabytes per worker process. Workers use a few different heuristics to keep memory use beneath this limit: From 1fb26c7ab0e84b7b678b52cb8ffd0e2948fc63ed Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 19 Jun 2019 09:30:38 +0200 Subject: [PATCH 0326/1550] Clean up lingering ncores->nthreads change in widget code (#2785) --- distributed/deploy/cluster.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/deploy/cluster.py b/distributed/deploy/cluster.py index 69cc5be9fac..74d61a995e1 100644 --- a/distributed/deploy/cluster.py +++ b/distributed/deploy/cluster.py @@ -123,7 +123,7 @@ def scale(self, n): def _widget_status(self): workers = len(self.scheduler.workers) - cores = sum(ws.ncores for ws in self.scheduler.workers.values()) + cores = sum(ws.nthreads for ws in self.scheduler.workers.values()) memory = sum(ws.memory_limit for ws in self.scheduler.workers.values()) memory = format_bytes(memory) text = """ From eba954b0ce589b2186e7ea78d697a79ac1faad62 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Wed, 19 Jun 2019 10:37:56 -0500 Subject: [PATCH 0327/1550] Raise when workers initialization times out (#2784) This changes Worker / Nanny startup to raise when they timeout. This bubbles up to the `dask-worker` CLI. Closes #2781 --- distributed/cli/dask_worker.py | 5 ++++- distributed/cli/tests/test_dask_worker.py | 14 ++++++++++++++ distributed/nanny.py | 8 +++++++- distributed/tests/test_nanny.py | 5 +++-- distributed/tests/test_worker.py | 6 ++++-- distributed/worker.py | 8 +++++++- 6 files changed, 39 insertions(+), 7 deletions(-) diff --git a/distributed/cli/dask_worker.py b/distributed/cli/dask_worker.py index a53ddf99f6e..13b699ab0dd 100755 --- a/distributed/cli/dask_worker.py +++ b/distributed/cli/dask_worker.py @@ -373,7 +373,10 @@ def run(): try: loop.run_sync(run) - except (KeyboardInterrupt, TimeoutError): + except TimeoutError: + # We already log the exception in nanny / worker. Don't do it again. + raise TimeoutError("Timed out starting worker.") from None + except KeyboardInterrupt: pass finally: logger.info("End worker") diff --git a/distributed/cli/tests/test_dask_worker.py b/distributed/cli/tests/test_dask_worker.py index dc7c761fdf1..9191d7aba4d 100644 --- a/distributed/cli/tests/test_dask_worker.py +++ b/distributed/cli/tests/test_dask_worker.py @@ -1,6 +1,7 @@ from __future__ import print_function, division, absolute_import import pytest +from click.testing import CliRunner pytest.importorskip("requests") @@ -9,6 +10,7 @@ import os from time import sleep +import distributed.cli.dask_worker from distributed import Client from distributed.metrics import time from distributed.utils import sync, tmpfile @@ -292,3 +294,15 @@ def test_dashboard_non_standard_ports(loop): with pytest.raises(Exception): requests.get("http://localhost:4833/status/") + + +@pytest.mark.slow +@pytest.mark.parametrize("no_nanny", [True, False]) +def test_worker_timeout(no_nanny): + runner = CliRunner() + args = ["192.168.1.100:7777", "--death-timeout=1"] + if no_nanny: + args.append("--no-nanny") + result = runner.invoke(distributed.cli.dask_worker.main, args) + assert result.exit_code != 0 + assert str(result.exception).startswith("Timed out") diff --git a/distributed/nanny.py b/distributed/nanny.py index d907c7171a1..f518b330d7c 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -304,7 +304,13 @@ def instantiate(self, comm=None): ) except gen.TimeoutError: yield self.close(timeout=self.death_timeout) - raise gen.Return("timed out") + logger.exception( + "Timed out connecting Nanny '%s' to scheduler '%s'", + self, + self.scheduler_addr, + ) + raise + else: result = yield self.process.start() raise gen.Return(result) diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index 6b6d5bf939d..40c8d49012d 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -160,9 +160,10 @@ def test_nanny_alt_worker_class(c, s, w1, w2): @gen_cluster(client=False, nthreads=[]) def test_nanny_death_timeout(s): yield s.close() - w = yield Nanny(s.address, death_timeout=1) + w = Nanny(s.address, death_timeout=1) + with pytest.raises(gen.TimeoutError): + yield w - yield gen.sleep(3) assert w.status == "closed" diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index fa0cf857a32..a0e8244e8bd 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -736,9 +736,11 @@ def test_hold_onto_dependents(c, s, a, b): def test_worker_death_timeout(s): with dask.config.set({"distributed.comm.timeouts.connect": "1s"}): yield s.close() - w = yield Worker(s.address, death_timeout=1) + w = Worker(s.address, death_timeout=1) + + with pytest.raises(gen.TimeoutError): + yield w - yield gen.sleep(2) assert w.status == "closed" diff --git a/distributed/worker.py b/distributed/worker.py index 86d078254f0..63dfaed3114 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -708,8 +708,14 @@ def _register_with_scheduler(self): logger.info("-" * 49) while True: if self.death_timeout and time() > start + self.death_timeout: + logger.exception( + "Timed out when connecting to scheduler '%s'", + self.scheduler.address, + ) yield self.close(timeout=1) - return + raise gen.TimeoutError( + "Timed out connecting to scheduler '%s'" % self.scheduler.address + ) if self.status in ("closed", "closing"): raise gen.Return try: From c5f479ff28e91aed47c9e307ca3e0e65ea9c9150 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Thu, 20 Jun 2019 15:08:46 -0500 Subject: [PATCH 0328/1550] Add version option to scheduler and worker CLI (#2782) * Add version option to scheduler and worker CLI * Add version to other cli commands * Add tests --- distributed/cli/dask_mpi.py | 1 + distributed/cli/dask_remote.py | 1 + distributed/cli/dask_scheduler.py | 1 + distributed/cli/dask_ssh.py | 1 + distributed/cli/dask_submit.py | 1 + distributed/cli/dask_worker.py | 1 + distributed/cli/tests/test_dask_mpi.py | 8 ++++++++ distributed/cli/tests/test_dask_remote.py | 6 ++++++ distributed/cli/tests/test_dask_scheduler.py | 9 +++++++++ distributed/cli/tests/test_dask_ssh.py | 8 ++++++++ distributed/cli/tests/test_dask_submit.py | 6 ++++++ distributed/cli/tests/test_dask_worker.py | 6 ++++++ 12 files changed, 49 insertions(+) create mode 100644 distributed/cli/tests/test_dask_ssh.py diff --git a/distributed/cli/dask_mpi.py b/distributed/cli/dask_mpi.py index 2a965824662..7b9aeaca213 100644 --- a/distributed/cli/dask_mpi.py +++ b/distributed/cli/dask_mpi.py @@ -63,6 +63,7 @@ help="Worker's Bokeh port for visual diagnostics", ) @click.option("--bokeh-prefix", type=str, default=None, help="Prefix for the bokeh app") +@click.version_option() def main( scheduler_file, interface, diff --git a/distributed/cli/dask_remote.py b/distributed/cli/dask_remote.py index 933d8d318b0..29cc5c3c784 100644 --- a/distributed/cli/dask_remote.py +++ b/distributed/cli/dask_remote.py @@ -8,6 +8,7 @@ @click.command() @click.option("--host", type=str, default=None, help="IP or hostname of this server") @click.option("--port", type=int, default=8788, help="Remote Client Port") +@click.version_option() def main(host, port): _remote(host, port) diff --git a/distributed/cli/dask_scheduler.py b/distributed/cli/dask_scheduler.py index b27e68eaa9a..c38f405f04e 100755 --- a/distributed/cli/dask_scheduler.py +++ b/distributed/cli/dask_scheduler.py @@ -113,6 +113,7 @@ @click.argument( "preload_argv", nargs=-1, type=click.UNPROCESSED, callback=validate_preload_argv ) +@click.version_option() def main( host, port, diff --git a/distributed/cli/dask_ssh.py b/distributed/cli/dask_ssh.py index 2d98992d969..1d264dc80e5 100755 --- a/distributed/cli/dask_ssh.py +++ b/distributed/cli/dask_ssh.py @@ -101,6 +101,7 @@ help="Worker to run. Defaults to distributed.cli.dask_worker", ) @click.pass_context +@click.version_option() def main( ctx, scheduler, diff --git a/distributed/cli/dask_submit.py b/distributed/cli/dask_submit.py index 1ef759407c6..071dd5bbe32 100644 --- a/distributed/cli/dask_submit.py +++ b/distributed/cli/dask_submit.py @@ -9,6 +9,7 @@ @click.command() @click.argument("remote_client_address", type=str, required=True) @click.argument("filepath", type=str, required=True) +@click.version_option() def main(remote_client_address, filepath): @gen.coroutine def f(): diff --git a/distributed/cli/dask_worker.py b/distributed/cli/dask_worker.py index 13b699ab0dd..1463c29afd1 100755 --- a/distributed/cli/dask_worker.py +++ b/distributed/cli/dask_worker.py @@ -180,6 +180,7 @@ @click.argument( "preload_argv", nargs=-1, type=click.UNPROCESSED, callback=validate_preload_argv ) +@click.version_option() def main( scheduler, host, diff --git a/distributed/cli/tests/test_dask_mpi.py b/distributed/cli/tests/test_dask_mpi.py index 8bc8dddca2e..89f1140bfab 100644 --- a/distributed/cli/tests/test_dask_mpi.py +++ b/distributed/cli/tests/test_dask_mpi.py @@ -8,12 +8,14 @@ pytest.importorskip("mpi4py") import requests +from click.testing import CliRunner from distributed import Client from distributed.utils import tmpfile from distributed.metrics import time from distributed.utils_test import popen from distributed.utils_test import loop # noqa: F401 +from distributed.cli.dask_remote import main @pytest.mark.parametrize("nanny", ["--nanny", "--no-nanny"]) @@ -96,3 +98,9 @@ def test_bokeh(loop): with pytest.raises(Exception): requests.get("http://localhost:59583/status/") + + +def test_version_option(): + runner = CliRunner() + result = runner.invoke(main, ["--version"]) + assert result.exit_code == 0 diff --git a/distributed/cli/tests/test_dask_remote.py b/distributed/cli/tests/test_dask_remote.py index 04d04d62ecf..14da80f949c 100644 --- a/distributed/cli/tests/test_dask_remote.py +++ b/distributed/cli/tests/test_dask_remote.py @@ -7,3 +7,9 @@ def test_dask_remote(): result = runner.invoke(main, ["--help"]) assert "--host TEXT IP or hostname of this server" in result.output assert result.exit_code == 0 + + +def test_version_option(): + runner = CliRunner() + result = runner.invoke(main, ["--version"]) + assert result.exit_code == 0 diff --git a/distributed/cli/tests/test_dask_scheduler.py b/distributed/cli/tests/test_dask_scheduler.py index e04fa24bad1..7de7e881270 100644 --- a/distributed/cli/tests/test_dask_scheduler.py +++ b/distributed/cli/tests/test_dask_scheduler.py @@ -13,7 +13,9 @@ from time import sleep from tornado import gen +from click.testing import CliRunner +import distributed from distributed import Scheduler, Client from distributed.utils import get_ip, get_ip_interface, tmpfile from distributed.utils_test import ( @@ -23,6 +25,7 @@ ) from distributed.utils_test import loop # noqa: F401 from distributed.metrics import time +import distributed.cli.dask_scheduler def test_defaults(loop): @@ -374,3 +377,9 @@ def check_passthrough(): finally: shutil.rmtree(tmpdir) + + +def test_version_option(): + runner = CliRunner() + result = runner.invoke(distributed.cli.dask_scheduler.main, ["--version"]) + assert result.exit_code == 0 diff --git a/distributed/cli/tests/test_dask_ssh.py b/distributed/cli/tests/test_dask_ssh.py new file mode 100644 index 00000000000..9be8cb06f62 --- /dev/null +++ b/distributed/cli/tests/test_dask_ssh.py @@ -0,0 +1,8 @@ +from click.testing import CliRunner +from distributed.cli.dask_ssh import main + + +def test_version_option(): + runner = CliRunner() + result = runner.invoke(main, ["--version"]) + assert result.exit_code == 0 diff --git a/distributed/cli/tests/test_dask_submit.py b/distributed/cli/tests/test_dask_submit.py index 83c7c1067fa..8f5f961ea96 100644 --- a/distributed/cli/tests/test_dask_submit.py +++ b/distributed/cli/tests/test_dask_submit.py @@ -7,3 +7,9 @@ def test_submit_runs_as_a_cli(): result = runner.invoke(main, ["--help"]) assert result.exit_code == 0 assert "Usage: main [OPTIONS] REMOTE_CLIENT_ADDRESS FILEPATH" in result.output + + +def test_version_option(): + runner = CliRunner() + result = runner.invoke(main, ["--version"]) + assert result.exit_code == 0 diff --git a/distributed/cli/tests/test_dask_worker.py b/distributed/cli/tests/test_dask_worker.py index 9191d7aba4d..edba84d2ef4 100644 --- a/distributed/cli/tests/test_dask_worker.py +++ b/distributed/cli/tests/test_dask_worker.py @@ -296,6 +296,12 @@ def test_dashboard_non_standard_ports(loop): requests.get("http://localhost:4833/status/") +def test_version_option(): + runner = CliRunner() + result = runner.invoke(distributed.cli.dask_worker.main, ["--version"]) + assert result.exit_code == 0 + + @pytest.mark.slow @pytest.mark.parametrize("no_nanny", [True, False]) def test_worker_timeout(no_nanny): From 4ba820a1218d468692c36d4551eeba6491440366 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 21 Jun 2019 10:13:08 +0200 Subject: [PATCH 0329/1550] Add warnings around ncores= keywords (#2791) --- distributed/scheduler.py | 6 ++++++ distributed/utils_test.py | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 434d118a422..d370705e9af 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -13,6 +13,7 @@ import pickle import random import six +import warnings import weakref import psutil @@ -307,6 +308,11 @@ def identity(self): "nanny": self.nanny, } + @property + def ncores(self): + warnings.warn("WorkerState.ncores has moved to WorkerState.nthreads") + return self.nthreads + class TaskState(object): """ diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 89e0f3283e1..293cf5c0737 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -22,6 +22,7 @@ import threading from time import sleep import uuid +import warnings import weakref try: @@ -841,6 +842,7 @@ def end_worker(w): def gen_cluster( nthreads=[("127.0.0.1", 1), ("127.0.0.1", 2)], + ncores=None, scheduler="127.0.0.1", timeout=10, security=None, @@ -865,6 +867,10 @@ def test_foo(scheduler, worker1, worker2): start end """ + if ncores is not None: + warnings.warn("ncores= has moved to nthreads=") + nthreads = ncores + worker_kwargs = merge( {"memory_limit": TOTAL_MEMORY, "death_timeout": 5}, worker_kwargs ) From c5e830d4f386ef30664403d1d415f94c751cc8ce Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Fri, 21 Jun 2019 12:43:39 -0500 Subject: [PATCH 0330/1550] Remove "experimental" from TLS docs [skip ci] (#2793) --- docs/source/tls.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/tls.rst b/docs/source/tls.rst index 8aaad99c45a..d367dabbf7b 100644 --- a/docs/source/tls.rst +++ b/docs/source/tls.rst @@ -4,9 +4,9 @@ TLS/SSL ======= -Currently dask distributed has experimental support for TLS/SSL communication, +Dask distributed has support for TLS/SSL communication, providing mutual authentication and encryption of communications between cluster -endpoints (Clients, Schedulers and Workers). +endpoints (Clients, Schedulers, and Workers). TLS is enabled by using a ``tls`` address such as ``tls://`` (the default being ``tcp``, which sends data unauthenticated and unencrypted). In From 9b4c8fc177a26a428b451677ed12fdf6bb577f09 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Sat, 22 Jun 2019 08:35:55 -0500 Subject: [PATCH 0331/1550] Update command line cli options docs (#2794) --- docs/requirements.txt | 1 + docs/source/conf.py | 230 ++++++++++++------------ docs/source/submitting-applications.rst | 27 ++- 3 files changed, 142 insertions(+), 116 deletions(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index 329b7d7d23b..61dd185a5b9 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -5,3 +5,4 @@ dask numpydoc sphinx dask_sphinx_theme +sphinx-click diff --git a/docs/source/conf.py b/docs/source/conf.py index 6c79073e3b8..c8ffc0ae50d 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -12,54 +12,51 @@ # All configuration values have a default; values that are commented out # serve to show the default. -import sys -import os -import shlex - # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. -#sys.path.insert(0, os.path.abspath('.')) +# sys.path.insert(0, os.path.abspath('.')) # -- General configuration ------------------------------------------------ # If your documentation needs a minimal Sphinx version, state it here. -#needs_sphinx = '1.0' +# needs_sphinx = '1.0' # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.todo', - 'sphinx.ext.ifconfig', - 'sphinx.ext.viewcode', - 'sphinx.ext.autosummary', - 'sphinx.ext.extlinks', - 'sphinx.ext.intersphinx', - 'numpydoc', + "sphinx.ext.autodoc", + "sphinx.ext.todo", + "sphinx.ext.ifconfig", + "sphinx.ext.viewcode", + "sphinx.ext.autosummary", + "sphinx.ext.extlinks", + "sphinx.ext.intersphinx", + "numpydoc", + "sphinx_click.ext", ] numpydoc_show_class_members = False # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # source_suffix = ['.rst', '.md'] -source_suffix = '.rst' +source_suffix = ".rst" # The encoding of source files. -#source_encoding = 'utf-8-sig' +# source_encoding = 'utf-8-sig' # The master toctree document. -master_doc = 'index' +master_doc = "index" # General information about the project. -project = u'Dask.distributed' -copyright = u'2016, Anaconda, Inc.' -author = u'Anaconda, Inc.' +project = u"Dask.distributed" +copyright = u"2016, Anaconda, Inc." +author = u"Anaconda, Inc." # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the @@ -67,6 +64,7 @@ # # The short X.Y version. import distributed + version = distributed.__version__ # The full version, including alpha/beta/rc tags. release = distributed.__version__ @@ -80,9 +78,9 @@ # There are two options for replacing |today|: either, you set today to some # non-false value, then it is used: -#today = '' +# today = '' # Else, today_fmt is used as the format for a strftime call. -#today_fmt = '%B %d, %Y' +# today_fmt = '%B %d, %Y' # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. @@ -90,27 +88,27 @@ # The reST default role (used for this markup: `text`) to use for all # documents. -#default_role = None +# default_role = None # If true, '()' will be appended to :func: etc. cross-reference text. -#add_function_parentheses = True +# add_function_parentheses = True # If true, the current module name will be prepended to all description # unit titles (such as .. function::). -#add_module_names = True +# add_module_names = True # If true, sectionauthor and moduleauthor directives will be shown in the # output. They are ignored by default. -#show_authors = False +# show_authors = False # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'default' +pygments_style = "default" # A list of ignored prefixes for module index sorting. -#modindex_common_prefix = [] +# modindex_common_prefix = [] # If true, keep warnings as "system message" paragraphs in the built documents. -#keep_warnings = False +# keep_warnings = False # If true, `todo` and `todoList` produce output, else they produce nothing. todo_include_todos = True @@ -119,144 +117,147 @@ # -- Options for HTML output ---------------------------------------------- import dask_sphinx_theme -html_theme = 'dask_sphinx_theme' + +html_theme = "dask_sphinx_theme" html_theme_path = [dask_sphinx_theme.get_html_theme_path()] # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. -#html_theme_options = {} +# html_theme_options = {} # Add any paths that contain custom themes here, relative to this directory. -#html_theme_path = [] +# html_theme_path = [] # The name for this set of Sphinx documents. If None, it defaults to # " v documentation". -#html_title = None +# html_title = None # A shorter title for the navigation bar. Default is the same as html_title. -#html_short_title = None +# html_short_title = None # The name of an image file (relative to this directory) to place at the top # of the sidebar. -#html_logo = None +# html_logo = None # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 # pixels large. -#html_favicon = None +# html_favicon = None # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] # Add any extra paths that contain custom files (such as robots.txt or # .htaccess) here, relative to this directory. These files are copied # directly to the root of the documentation. -#html_extra_path = [] +# html_extra_path = [] # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, # using the given strftime format. -#html_last_updated_fmt = '%b %d, %Y' +# html_last_updated_fmt = '%b %d, %Y' # If true, SmartyPants will be used to convert quotes and dashes to # typographically correct entities. -#html_use_smartypants = True +# html_use_smartypants = True # Custom sidebar templates, maps document names to template names. -#html_sidebars = {} +# html_sidebars = {} # Additional templates that should be rendered to pages, maps page names to # template names. -#html_additional_pages = {} +# html_additional_pages = {} # If false, no module index is generated. -#html_domain_indices = True +# html_domain_indices = True # If false, no index is generated. -#html_use_index = True +# html_use_index = True # If true, the index is split into individual pages for each letter. -#html_split_index = False +# html_split_index = False # If true, links to the reST sources are added to the pages. -#html_show_sourcelink = True +# html_show_sourcelink = True # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. -#html_show_sphinx = True +# html_show_sphinx = True # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. -#html_show_copyright = True +# html_show_copyright = True # If true, an OpenSearch description file will be output, and all pages will # contain a tag referring to it. The value of this option must be the # base URL from which the finished HTML is served. -#html_use_opensearch = '' +# html_use_opensearch = '' # This is the file name suffix for HTML files (e.g. ".xhtml"). -#html_file_suffix = None +# html_file_suffix = None # Language to be used for generating the HTML full-text search index. # Sphinx supports the following languages: # 'da', 'de', 'en', 'es', 'fi', 'fr', 'hu', 'it', 'ja' # 'nl', 'no', 'pt', 'ro', 'ru', 'sv', 'tr' -#html_search_language = 'en' +# html_search_language = 'en' # A dictionary with options for the search language support, empty by default. # Now only 'ja' uses this config value -#html_search_options = {'type': 'default'} +# html_search_options = {'type': 'default'} # The name of a javascript file (relative to the configuration directory) that # implements a search results scorer. If empty, the default will be used. -#html_search_scorer = 'scorer.js' +# html_search_scorer = 'scorer.js' # Output file base name for HTML help builder. -htmlhelp_basename = 'distributeddoc' +htmlhelp_basename = "distributeddoc" # -- Options for LaTeX output --------------------------------------------- latex_elements = { -# The paper size ('letterpaper' or 'a4paper'). -#'papersize': 'letterpaper', - -# The font size ('10pt', '11pt' or '12pt'). -#'pointsize': '10pt', - -# Additional stuff for the LaTeX preamble. -#'preamble': '', - -# Latex figure (float) alignment -#'figure_align': 'htbp', + # The paper size ('letterpaper' or 'a4paper'). + #'papersize': 'letterpaper', + # The font size ('10pt', '11pt' or '12pt'). + #'pointsize': '10pt', + # Additional stuff for the LaTeX preamble. + #'preamble': '', + # Latex figure (float) alignment + #'figure_align': 'htbp', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, 'distributed.tex', u'Dask.distributed Documentation', - u'Matthew Rocklin', 'manual'), + ( + master_doc, + "distributed.tex", + u"Dask.distributed Documentation", + u"Matthew Rocklin", + "manual", + ) ] # The name of an image file (relative to this directory) to place at the top of # the title page. -#latex_logo = None +# latex_logo = None # For "manual" documents, if this is true, then toplevel headings are parts, # not chapters. -#latex_use_parts = False +# latex_use_parts = False # If true, show page references after internal links. -#latex_show_pagerefs = False +# latex_show_pagerefs = False # If true, show URL addresses after external links. -#latex_show_urls = False +# latex_show_urls = False # Documents to append as an appendix to all manuals. -#latex_appendices = [] +# latex_appendices = [] # If false, no module index is generated. -#latex_domain_indices = True +# latex_domain_indices = True # -- Options for manual page output --------------------------------------- @@ -264,12 +265,11 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [ - (master_doc, 'Dask.distributed', u'Dask.distributed Documentation', - [author], 1) + (master_doc, "Dask.distributed", u"Dask.distributed Documentation", [author], 1) ] # If true, show URL addresses after external links. -#man_show_urls = False +# man_show_urls = False # -- Options for Texinfo output ------------------------------------------- @@ -278,22 +278,28 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'Dask.distributed', u'Dask.distributed Documentation', - author, 'Dask.distributed', 'One line description of project.', - 'Miscellaneous'), + ( + master_doc, + "Dask.distributed", + u"Dask.distributed Documentation", + author, + "Dask.distributed", + "One line description of project.", + "Miscellaneous", + ) ] # Documents to append as an appendix to all manuals. -#texinfo_appendices = [] +# texinfo_appendices = [] # If false, no module index is generated. -#texinfo_domain_indices = True +# texinfo_domain_indices = True # How to display URL addresses: 'footnote', 'no', or 'inline'. -#texinfo_show_urls = 'footnote' +# texinfo_show_urls = 'footnote' # If true, do not generate a @detailmenu in the "Top" node's menu. -#texinfo_no_detailmenu = False +# texinfo_no_detailmenu = False # -- Options for Epub output ---------------------------------------------- @@ -305,85 +311,85 @@ epub_copyright = copyright # The basename for the epub file. It defaults to the project name. -#epub_basename = project +# epub_basename = project # The HTML theme for the epub output. Since the default themes are not optimized # for small screen space, using the same theme for HTML and epub output is # usually not wise. This defaults to 'epub', a theme designed to save visual # space. -#epub_theme = 'epub' +# epub_theme = 'epub' # The language of the text. It defaults to the language option # or 'en' if the language is not set. -#epub_language = '' +# epub_language = '' # The scheme of the identifier. Typical schemes are ISBN or URL. -#epub_scheme = '' +# epub_scheme = '' # The unique identifier of the text. This can be a ISBN number # or the project homepage. -#epub_identifier = '' +# epub_identifier = '' # A unique identification for the text. -#epub_uid = '' +# epub_uid = '' # A tuple containing the cover image and cover page html template filenames. -#epub_cover = () +# epub_cover = () # A sequence of (type, uri, title) tuples for the guide element of content.opf. -#epub_guide = () +# epub_guide = () # HTML files that should be inserted before the pages created by sphinx. # The format is a list of tuples containing the path and title. -#epub_pre_files = [] +# epub_pre_files = [] # HTML files shat should be inserted after the pages created by sphinx. # The format is a list of tuples containing the path and title. -#epub_post_files = [] +# epub_post_files = [] # A list of files that should not be packed into the epub file. -epub_exclude_files = ['search.html'] +epub_exclude_files = ["search.html"] # The depth of the table of contents in toc.ncx. -#epub_tocdepth = 3 +# epub_tocdepth = 3 # Allow duplicate toc entries. -#epub_tocdup = True +# epub_tocdup = True # Choose between 'default' and 'includehidden'. -#epub_tocscope = 'default' +# epub_tocscope = 'default' # Fix unsupported image types using the Pillow. -#epub_fix_images = False +# epub_fix_images = False # Scale large images. -#epub_max_image_width = 0 +# epub_max_image_width = 0 # How to display URL addresses: 'footnote', 'no', or 'inline'. -#epub_show_urls = 'inline' +# epub_show_urls = 'inline' # If false, no index is generated. -#epub_use_index = True +# epub_use_index = True # Link to GitHub issues and pull requests using :pr:`1234` and :issue:`1234` # syntax extlinks = { - 'issue': ('https://github.com/dask/distributed/issues/%s', 'GH#'), - 'pr': ('https://github.com/dask/distributed/pull/%s', 'GH#') + "issue": ("https://github.com/dask/distributed/issues/%s", "GH#"), + "pr": ("https://github.com/dask/distributed/pull/%s", "GH#"), } # Configuration for intersphinx: refer to the Python standard library # and the Numpy documentation. intersphinx_mapping = { - 'python': ('https://docs.python.org/3', None), - 'numpy': ('http://docs.scipy.org/doc/numpy', None), - } + "python": ("https://docs.python.org/3", None), + "numpy": ("http://docs.scipy.org/doc/numpy", None), +} # Redirects # https://tech.signavio.com/2017/managing-sphinx-redirects redirect_files = [ # old html, new html - ('joblib.html', 'https://ml.dask.org/joblib.html'), + ("joblib.html", "https://ml.dask.org/joblib.html") ] @@ -400,13 +406,13 @@ def copy_legacy_redirects(app, docname): - if app.builder.name == 'html': + if app.builder.name == "html": for html_src_path, new in redirect_files: page = redirect_template.format(new=new) - target_path = app.outdir + '/' + html_src_path - with open(target_path, 'w') as f: + target_path = app.outdir + "/" + html_src_path + with open(target_path, "w") as f: f.write(page) def setup(app): - app.connect('build-finished', copy_legacy_redirects) + app.connect("build-finished", copy_legacy_redirects) diff --git a/docs/source/submitting-applications.rst b/docs/source/submitting-applications.rst index 5f81f4fd658..8b5ab1d61c8 100644 --- a/docs/source/submitting-applications.rst +++ b/docs/source/submitting-applications.rst @@ -10,8 +10,8 @@ For example, S3 buckets could not be visible from your local machine and hence a attempt to create a dask graph from local machine may not work. -Submitting dask Applications with `dask-submit` ------------------------------------------------ +Submitting dask Applications with ``dask-submit`` +------------------------------------------------- In order to remotely submit scripts to the cluster from a local machine or a CI/CD environment, we need to run a remote client on the same machine as the scheduler:: @@ -20,7 +20,7 @@ environment, we need to run a remote client on the same machine as the scheduler dask-remote --port 8788 -After making sure the `dask-remote` is running, you can submit a script by:: +After making sure the ``dask-remote`` is running, you can submit a script by:: #local machine dask-submit : @@ -28,7 +28,7 @@ After making sure the `dask-remote` is running, you can submit a script by:: Some of the commonly used arguments are: -- ``REMOTE_CLIENT_ADDRESS``: host name where dask-remote client is running +- ``REMOTE_CLIENT_ADDRESS``: host name where ``dask-remote`` client is running - ``FILEPATH``: Local path to file containing dask application For example, given the following dask application saved in a file called @@ -36,6 +36,7 @@ For example, given the following dask application saved in a file called .. code-block:: python + # script.py from distributed import Client def inc(x): @@ -50,3 +51,21 @@ For example, given the following dask application saved in a file called We can submit this application from a local machine by running:: dask-submit : script.py + + +CLI Options +----------- + +.. note:: + + The command line documentation here may differ depending on your installed + version. We recommend referring to the output of ``dask-remote --help`` + and ``dask-submit --help``. + +.. click:: distributed.cli.dask_remote:main + :prog: dask-remote + :show-nested: + +.. click:: distributed.cli.dask_submit:main + :prog: dask-submit + :show-nested: \ No newline at end of file From 912c8a38919079b502e934b1a76ba9b201ec21ab Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Mon, 24 Jun 2019 08:11:20 -0500 Subject: [PATCH 0332/1550] Typo in bokeh service_kwargs for dask-worker (#2783) --- distributed/cli/dask_worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/cli/dask_worker.py b/distributed/cli/dask_worker.py index 1463c29afd1..f341a1abf78 100755 --- a/distributed/cli/dask_worker.py +++ b/distributed/cli/dask_worker.py @@ -347,7 +347,7 @@ def del_pid_file(): host=host, port=port, dashboard_address=dashboard_address if dashboard else None, - service_kwargs={"bokhe": {"prefix": dashboard_prefix}}, + service_kwargs={"dashboard": {"prefix": dashboard_prefix}}, name=name if nprocs == 1 or not name else name + "-" + str(i), **kwargs ) From 99444c24ec8d1c8248e273f5de5bd15b86f2ade2 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Tue, 25 Jun 2019 01:38:28 -0500 Subject: [PATCH 0333/1550] Deprecate --bokeh/--no-bokeh CLI (#2800) Closes https://github.com/dask/distributed/issues/2799 --- distributed/cli/dask_scheduler.py | 13 +++++++++++++ distributed/cli/dask_worker.py | 13 +++++++++++++ distributed/cli/tests/test_dask_worker.py | 19 +++++++++++++++++++ 3 files changed, 45 insertions(+) diff --git a/distributed/cli/dask_scheduler.py b/distributed/cli/dask_scheduler.py index c38f405f04e..f2799164a36 100755 --- a/distributed/cli/dask_scheduler.py +++ b/distributed/cli/dask_scheduler.py @@ -78,6 +78,13 @@ required=False, help="Launch the Dashboard", ) +@click.option( + "--bokeh/--no-bokeh", + "bokeh", + default=None, + required=False, + help="Deprecated. See --dashboard/--no-dashboard.", +) @click.option("--show/--no-show", default=False, help="Show web UI") @click.option( "--dashboard-prefix", type=str, default=None, help="Prefix for the dashboard app" @@ -120,6 +127,7 @@ def main( bokeh_port, show, dashboard, + bokeh, dashboard_prefix, use_xheaders, pid_file, @@ -146,6 +154,11 @@ def main( "Consider adding ``--dashboard-address :%d`` " % bokeh_port ) dashboard_address = bokeh_port + if bokeh is not None: + warnings.warn( + "The --bokeh/--no-bokeh flag has been renamed to --dashboard/--no-dashboard. " + ) + dashboard = bokeh if port is None and (not host or not re.search(r":\d", host)): port = 8786 diff --git a/distributed/cli/dask_worker.py b/distributed/cli/dask_worker.py index f341a1abf78..e86cfa41618 100755 --- a/distributed/cli/dask_worker.py +++ b/distributed/cli/dask_worker.py @@ -77,6 +77,13 @@ required=False, help="Launch the Dashboard", ) +@click.option( + "--bokeh/--no-bokeh", + "bokeh", + default=None, + help="Deprecated. See --dashboard/--no-dashboard.", + required=False, +) @click.option( "--listen-address", type=str, @@ -197,6 +204,7 @@ def main( reconnect, resources, dashboard, + bokeh, bokeh_port, local_directory, scheduler_file, @@ -223,6 +231,11 @@ def main( "Consider adding ``--dashboard-address :%d`` " % bokeh_port ) dashboard_address = bokeh_port + if bokeh is not None: + warnings.warn( + "The --bokeh/--no-bokeh flag has been renamed to --dashboard/--no-dashboard. " + ) + dashboard = bokeh sec = Security( tls_ca_file=tls_ca_file, tls_worker_cert=tls_cert, tls_worker_key=tls_key diff --git a/distributed/cli/tests/test_dask_worker.py b/distributed/cli/tests/test_dask_worker.py index edba84d2ef4..b6c7d393e3b 100644 --- a/distributed/cli/tests/test_dask_worker.py +++ b/distributed/cli/tests/test_dask_worker.py @@ -312,3 +312,22 @@ def test_worker_timeout(no_nanny): result = runner.invoke(distributed.cli.dask_worker.main, args) assert result.exit_code != 0 assert str(result.exception).startswith("Timed out") + + +def test_bokeh_deprecation(): + pytest.importorskip("bokeh") + + runner = CliRunner() + with pytest.warns(UserWarning, match="dashboard"): + try: + runner.invoke(distributed.cli.dask_worker.main, ["--bokeh"]) + except ValueError: + # didn't pass scheduler + pass + + with pytest.warns(UserWarning, match="dashboard"): + try: + runner.invoke(distributed.cli.dask_worker.main, ["--no-bokeh"]) + except ValueError: + # didn't pass scheduler + pass From 991391cb71492c2ecf10366bddd0dc8d526f212e Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 25 Jun 2019 09:01:02 +0200 Subject: [PATCH 0334/1550] Relax warnings before release (#2796) Let's not be too strict about a couple of our warnings and missing functions. --- distributed/deploy/local.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/distributed/deploy/local.py b/distributed/deploy/local.py index ffb06b0a4bf..554459e43ac 100644 --- a/distributed/deploy/local.py +++ b/distributed/deploy/local.py @@ -110,7 +110,8 @@ def __init__( **worker_kwargs ): if ip is not None: - warnings.warn("The ip keyword has been moved to host") + # In the future we should warn users about this move + # warnings.warn("The ip keyword has been moved to host") host = ip if diagnostics_port is not None: @@ -207,6 +208,12 @@ def __repr__(self): sum(w.nthreads for w in self.workers.values()), ) + def start_worker(self, *args, **kwargs): + raise NotImplementedError( + "The `cluster.start_worker` function has been removed. " + "Please see the `cluster.scale` method instead." + ) + def nprocesses_nthreads(n=multiprocessing.cpu_count()): """ From e13f2984be5c5c388818576f78fcf30412374298 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 24 Jun 2019 17:20:54 +0200 Subject: [PATCH 0335/1550] bump version to 2.0 --- distributed/tests/test_client.py | 20 +++++----- distributed/tests/test_worker.py | 13 ++++--- docs/source/changelog.rst | 63 +++++++++++++++++++++++++++++++- requirements.txt | 5 +-- setup.cfg | 3 -- setup.py | 1 - 6 files changed, 80 insertions(+), 25 deletions(-) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index d18216ef0ef..dca108b57ee 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -51,7 +51,7 @@ futures_of, temp_default_client, ) -from distributed.compatibility import PY3 +from distributed.compatibility import PY3, WINDOWS from distributed.metrics import time from distributed.scheduler import Scheduler, KilledWorker @@ -2732,9 +2732,7 @@ def test_persist_get(c, s, a, b): assert result == ((1 + 1) + (2 + 2)) + 10 -@pytest.mark.skipif( - sys.platform.startswith("win"), reason="num_fds not supported on windows" -) +@pytest.mark.skipif(WINDOWS, reason="num_fds not supported on windows") def test_client_num_fds(loop): psutil = pytest.importorskip("psutil") with cluster() as (s, [a, b]): @@ -3084,9 +3082,7 @@ def test_client_replicate_sync(c): assert y.result() == 3 -@pytest.mark.skipif( - sys.platform.startswith("win"), reason="Windows timer too coarse-grained" -) +@pytest.mark.skipif(WINDOWS, reason="Windows timer too coarse-grained") @gen_cluster(client=True, nthreads=[("127.0.0.1", 4)] * 1) def test_task_load_adapts_quickly(c, s, a): future = c.submit(slowinc, 1, delay=0.2) # slow @@ -3573,9 +3569,7 @@ def test_reconnect_timeout(c, s): @pytest.mark.slow -@pytest.mark.skipif( - sys.platform.startswith("win"), reason="num_fds not supported on windows" -) +@pytest.mark.skipif(WINDOWS, reason="num_fds not supported on windows") @pytest.mark.skipif( sys.version_info[0] == 2, reason="Semaphore.acquire doesn't support timeout option" ) @@ -5522,7 +5516,11 @@ def test_profile_bokeh(c, s, a, b): assert isinstance(figure, Model) with tmpfile("html") as fn: - yield c.profile(filename=fn) + try: + yield c.profile(filename=fn) + except PermissionError: + if WINDOWS: + pytest.xfail() assert os.path.exists(fn) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index a0e8244e8bd..0f462d7de05 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1038,17 +1038,20 @@ def test_statistical_profiling(c, s, a, b): assert profile["count"] +@pytest.mark.slow @nodebug -@gen_cluster(client=True) +@gen_cluster(client=True, timeout=20) def test_statistical_profiling_2(c, s, a, b): da = pytest.importorskip("dask.array") - for i in range(5): + while True: x = da.random.random(1000000, chunks=(10000,)) y = (x + x * 2) - x.sum().persist() yield wait(y) - profile = a.get_profile() - assert profile["count"] - assert "sum" in str(profile) or "random" in str(profile) + + profile = a.get_profile() + text = str(profile) + if profile["count"] and "sum" in text and "random" in text: + break @gen_cluster( diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index 613773a0c1c..37a44591ee5 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -1,12 +1,66 @@ Changelog ========= +2.0.0 - 2019-06-25 +------------------ + +- **Drop support for Python 2** +- Relax warnings before release (:pr:`2796`) `Matthew Rocklin`_ +- Deprecate --bokeh/--no-bokeh CLI (:pr:`2800`) `Tom Augspurger`_ +- Typo in bokeh service_kwargs for dask-worker (:pr:`2783`) `Tom Augspurger`_ +- Update command line cli options docs (:pr:`2794`) `James Bourbeau`_ +- Remove "experimental" from TLS docs (:pr:`2793`) `James Bourbeau`_ +- Add warnings around ncores= keywords (:pr:`2791`) `Matthew Rocklin`_ +- Add --version option to scheduler and worker CLI (:pr:`2782`) `Tom Augspurger`_ +- Raise when workers initialization times out (:pr:`2784`) `Tom Augspurger`_ +- Replace ncores with nthreads throughout codebase (:pr:`2758`) `Matthew Rocklin`_ +- Add unknown pytest markers (:pr:`2764`) `Tom Augspurger`_ +- Delay lookup of allowed failures. (:pr:`2761`) `Tom Augspurger`_ +- Change address -> worker in ColumnDataSource for nbytes plot (:pr:`2755`) `Matthew Rocklin`_ +- Remove module state in Prometheus Handlers (:pr:`2760`) `Matthew Rocklin`_ +- Add stress test for UCX (:pr:`2759`) `Matthew Rocklin`_ +- Add nanny logs (:pr:`2744`) `Tom Augspurger`_ +- Move some of the adaptive logic into the scheduler (:pr:`2735`) `Matthew Rocklin`_ +- Add SpecCluster.new_worker_spec method (:pr:`2751`) `Matthew Rocklin`_ +- Worker dashboard fixes (:pr:`2747`) `Matthew Rocklin`_ +- Add async context managers to scheduler/worker classes (:pr:`2745`) `Matthew Rocklin`_ +- Fix the resource key representation before sending graphs (:pr:`2733`) `Michael Spiegel`_ +- Allow user to configure whether workers are daemon. (:pr:`2739`) `Caleb`_ +- Pin pytest >=4 with pip in appveyor and python 3.5 (:pr:`2737`) `Matthew Rocklin`_ +- Add Experimental UCX Comm (:pr:`2591`) `Ben Zaitlen`_ `Tom Augspurger`_ `Matthew Rocklin`_ +- Close nannies gracefully (:pr:`2731`) `Matthew Rocklin`_ +- add kwargs to progressbars (:pr:`2638`) `Manuel Garrido`_ +- Add back LocalCluster.__repr__. (:pr:`2732`) `Loïc Estève`_ +- Move bokeh module to dashboard (:pr:`2724`) `Matthew Rocklin`_ +- Close clusters at exit (:pr:`2730`) `Matthew Rocklin`_ +- Add SchedulerPlugin TaskState example (:pr:`2622`) `Matt Nicolls`_ +- Add SpecificationCluster (:pr:`2675`) `Matthew Rocklin`_ +- Replace register_worker_callbacks with worker plugins (:pr:`2453`) `Matthew Rocklin`_ +- Proxy worker dashboards from scheduler dashboard (:pr:`2715`) `Ben Zaitlen`_ +- Add docstring to Scheduler.check_idle_saturated (:pr:`2721`) `Matthew Rocklin`_ +- Refer to LocalCluster in Client docstring (:pr:`2719`) `Matthew Rocklin`_ +- Remove special casing of Scikit-Learn BaseEstimator serialization (:pr:`2713`) `Matthew Rocklin`_ +- Fix two typos in Pub class docstring (:pr:`2714`) `Magnus Nord`_ +- Support uploading files with multiple modules (:pr:`2587`) `Sam Grayson`_ +- Change the main workers bokeh page to /status (:pr:`2689`) `Ben Zaitlen`_ +- Cleanly stop periodic callbacks in Client (:pr:`2705`) `Matthew Rocklin`_ +- Disable pan tool for the Progress, Byte Stored and Tasks Processing plot (:pr:`2703`) `Mathieu Dugré`_ +- Except errors in Nanny's memory monitor if process no longer exists (:pr:`2701`) `Matthew Rocklin`_ +- Handle heartbeat when worker has just left (:pr:`2702`) `Matthew Rocklin`_ +- Modify styling of histograms for many-worker dashboard plots (:pr:`2695`) `Mathieu Dugré`_ +- Add method to wait for n workers before continuing (:pr:`2688`) `Daniel Farrell`_ +- Support computation on delayed(None) (:pr:`2697`) `Matthew Rocklin`_ +- Cleanup localcluster (:pr:`2693`) `Matthew Rocklin`_ +- Use 'temporary-directory' from dask.config for Worker's directory (:pr:`2654`) `Matthew Rocklin`_ +- Remove support for Iterators and Queues (:pr:`2671`) `Matthew Rocklin`_ + + 1.28.1 - 2019-05-13 ------------------- This is a small bugfix release due to a config change upstream. -- Use config accessor method for "scheduler-address" (#2676) `James Bourbeau`_ +- Use config accessor method for "scheduler-address" (:pr:`2676`) `James Bourbeau`_ 1.28.0 - 2019-05-08 @@ -1039,3 +1093,10 @@ significantly without many new features. .. _`condoratberlin`: https://github.com/condoratberlin .. _`K.-Michael Aye`: https://github.com/michaelaye .. _`@plbertrand`: https://github.com/plbertrand +.. _`Michael Spiegel`: https://github.com/Spiegel0 +.. _`Caleb`: https://github.com/calebho +.. _`Ben Zaitlen`: https://github.com/quasiben +.. _`Manuel Garrido`: https://github.com/manugarri +.. _`Magnus Nord`: https://github.com/magnunor +.. _`Sam Grayson`: https://github.com/charmoniumQ +.. _`Mathieu Dugré`: https://github.com/mathdugre diff --git a/requirements.txt b/requirements.txt index a6c6b0f62f6..e376b2a50cc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ click >= 6.6 cloudpickle >= 0.2.2 -dask >= 0.18.0 +dask >= 2 msgpack psutil >= 5.0 six @@ -9,7 +9,4 @@ tblib toolz >= 0.7.4 tornado >= 5 zict >= 0.1.3 -# Compatibility packages -futures; python_version < '3.0' -singledispatch; python_version < '3.4' pyyaml diff --git a/setup.cfg b/setup.cfg index 5533437121b..042a8b86f35 100644 --- a/setup.cfg +++ b/setup.cfg @@ -35,9 +35,6 @@ versionfile_build = distributed/_version.py tag_prefix = parentdir_prefix = distributed- -[bdist_wheel] -universal=1 - [tool:pytest] addopts = -rsx -v --durations=10 minversion = 3.2 diff --git a/setup.py b/setup.py index 0df22f3f911..6c4bce91d83 100755 --- a/setup.py +++ b/setup.py @@ -53,7 +53,6 @@ "License :: OSI Approved :: BSD License", "Operating System :: OS Independent", "Programming Language :: Python", - "Programming Language :: Python :: 2.7", "Programming Language :: Python :: 3.5", "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", From ded3f30ac2d161cede56ce58253e50e02b75b188 Mon Sep 17 00:00:00 2001 From: Brett Naul Date: Tue, 25 Jun 2019 14:55:20 -0700 Subject: [PATCH 0336/1550] Fix diagnostics page for memory_limit=None (#2770) * Fix diagnostics page for memory_limit=None * Apply black --- distributed/dashboard/templates/worker-table.html | 2 +- distributed/worker.py | 12 +++++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/distributed/dashboard/templates/worker-table.html b/distributed/dashboard/templates/worker-table.html index a3566f90c3f..c12061fab46 100644 --- a/distributed/dashboard/templates/worker-table.html +++ b/distributed/dashboard/templates/worker-table.html @@ -16,7 +16,7 @@ - + diff --git a/distributed/worker.py b/distributed/worker.py index 63dfaed3114..bd4907372d5 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -2456,7 +2456,9 @@ def memory_monitor(self): "Process memory: %s -- Worker memory limit: %s", int(frac * 100), format_bytes(proc.memory_info().rss), - format_bytes(self.memory_limit), + format_bytes(self.memory_limit) + if self.memory_limit is not None + else "None", ) self.paused = True elif self.paused: @@ -2465,7 +2467,9 @@ def memory_monitor(self): "Process memory: %s -- Worker memory limit: %s", int(frac * 100), format_bytes(proc.memory_info().rss), - format_bytes(self.memory_limit), + format_bytes(self.memory_limit) + if self.memory_limit is not None + else "None", ) self.paused = False self.ensure_computing() @@ -2483,7 +2487,9 @@ def memory_monitor(self): "is leaking memory? Process memory: %s -- " "Worker memory limit: %s", format_bytes(proc.memory_info().rss), - format_bytes(self.memory_limit), + format_bytes(self.memory_limit) + if self.memory_limit is not None + else "None", ) break k, v, weight = self.data.fast.evict() From 437c573627a89a11e0cc2e0fdd99209f586161f2 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 26 Jun 2019 10:21:53 +0200 Subject: [PATCH 0337/1550] Correctly manage tasks beyond deque limit in TaskStream plot (#2797) Fixes #2501 --- distributed/dashboard/scheduler.py | 9 ++++-- .../dashboard/tests/test_scheduler_bokeh.py | 28 +++++++++++++++++++ distributed/diagnostics/task_stream.py | 16 +++++++++-- distributed/distributed.yaml | 5 ++++ 4 files changed, 53 insertions(+), 5 deletions(-) diff --git a/distributed/dashboard/scheduler.py b/distributed/dashboard/scheduler.py index f6f1fef7590..bed13950fc0 100644 --- a/distributed/dashboard/scheduler.py +++ b/distributed/dashboard/scheduler.py @@ -35,6 +35,7 @@ from bokeh.themes import Theme from bokeh.transform import factor_cmap from bokeh.io import curdoc +import dask from toolz import pipe, merge from tornado import escape @@ -1417,7 +1418,9 @@ def tasks_doc(scheduler, extra, doc): with log_errors(): ts = TaskStream( scheduler, - n_rectangles=100000, + n_rectangles=dask.config.get( + "distributed.scheduler.dashboard.tasks.task-stream-length" + ), clear_interval="60s", sizing_mode="stretch_both", ) @@ -1447,7 +1450,9 @@ def status_doc(scheduler, extra, doc): with log_errors(): task_stream = TaskStream( scheduler, - n_rectangles=1000, + n_rectangles=dask.config.get( + "distributed.scheduler.dashboard.status.task-stream-length" + ), clear_interval="10s", sizing_mode="stretch_both", ) diff --git a/distributed/dashboard/tests/test_scheduler_bokeh.py b/distributed/dashboard/tests/test_scheduler_bokeh.py index d9a83caf00b..4d60f304876 100644 --- a/distributed/dashboard/tests/test_scheduler_bokeh.py +++ b/distributed/dashboard/tests/test_scheduler_bokeh.py @@ -627,3 +627,31 @@ def test_proxy_to_workers(c, s, a, b): assert b"pip install jupyter-server-proxy" in response_proxy.body assert response_direct.code == 200 assert b"Crossfilter" in response_direct.body + + +@gen_cluster( + client=True, + scheduler_kwargs={"services": {("dashboard", 0): BokehScheduler}}, + config={ + "distributed.scheduler.dashboard.tasks.task-stream-length": 10, + "distributed.scheduler.dashboard.status.task-stream-length": 10, + }, +) +async def test_lots_of_tasks(c, s, a, b): + import toolz + + ts = TaskStream(s) + ts.update() + futures = c.map(toolz.identity, range(100)) + await wait(futures) + + tsp = [p for p in s.plugins if "taskstream" in type(p).__name__.lower()][0] + assert len(tsp.buffer) == 10 + ts.update() + assert len(ts.source.data["start"]) == 10 + assert "identity" in str(ts.source.data) + + futures = c.map(lambda x: x, range(100), pure=False) + await wait(futures) + ts.update() + assert "lambda" in str(ts.source.data) diff --git a/distributed/diagnostics/task_stream.py b/distributed/diagnostics/task_stream.py index 89cacb67c97..17e62c3045e 100644 --- a/distributed/diagnostics/task_stream.py +++ b/distributed/diagnostics/task_stream.py @@ -3,6 +3,7 @@ from collections import deque import logging +import dask from .progress_stream import color_of from .plugin import SchedulerPlugin from ..utils import key_split, format_time, parse_timedelta @@ -13,7 +14,16 @@ class TaskStreamPlugin(SchedulerPlugin): - def __init__(self, scheduler, maxlen=100000): + def __init__(self, scheduler, maxlen=None): + if maxlen is None: + maxlen = max( + dask.config.get( + "distributed.scheduler.dashboard.status.task-stream-length" + ), + dask.config.get( + "distributed.scheduler.dashboard.tasks.task-stream-length" + ), + ) self.buffer = deque(maxlen=maxlen) self.scheduler = scheduler scheduler.add_plugin(self) @@ -74,8 +84,8 @@ def rectangles(self, istart, istop=None, workers=None, start_boundary=0): msgs = [] diff = self.index - len(self.buffer) if istop is None: - istop = len(self.buffer) - for i in range((istart or 0) - diff, istop - diff if istop else istop): + istop = self.index + for i in range(max(0, (istart or 0) - diff), istop - diff if istop else istop): msg = self.buffer[i] msgs.append(msg) diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index 4d78a698e69..235c735946c 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -22,6 +22,11 @@ distributed: worker-ttl: null # like '60s'. Time to live for workers. They must heartbeat faster than this preload: [] preload-argv: [] + dashboard: + status: + task-stream-length: 1000 + tasks: + task-stream-length: 100000 worker: blocked-handlers: [] From 8990c98e593d90107bccd212a983aa4d3e5707cb Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 26 Jun 2019 15:23:14 +0200 Subject: [PATCH 0338/1550] Add python_requires entry to setup.py (#2807) Alternative to #2806 Fixes #2804 --- setup.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 6c4bce91d83..125ddc9c328 100755 --- a/setup.py +++ b/setup.py @@ -23,9 +23,10 @@ version=versioneer.get_version(), cmdclass=versioneer.get_cmdclass(), description="Distributed scheduler for Dask", - url="https://distributed.readthedocs.io/en/latest/", + url="https://distributed.dask.org", maintainer="Matthew Rocklin", maintainer_email="mrocklin@gmail.com", + python_requires=">=3.5", license="BSD", package_data={ "": ["templates/index.html", "template.html"], From da6a01bdee1c6d90934c61ae056b14610cd56a6c Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 26 Jun 2019 15:25:43 +0200 Subject: [PATCH 0339/1550] bump version to 2.0.1 --- docs/source/changelog.rst | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index 37a44591ee5..6162935b70d 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -1,6 +1,18 @@ Changelog ========= +2.0.1 - 2019-06-26 +------------------ + +We neglected to include ``python_requires=`` in our setup.py file, resulting in +confusion for Python 2 users who erroneously get packages for 2.0.0. +This is fixed in 2.0.1 and we have removed the 2.0.0 files from PyPI. + +- Add python_requires entry to setup.py (:pr:`2807`) `Matthew Rocklin`_ +- Correctly manage tasks beyond deque limit in TaskStream plot (:pr:`2797`) `Matthew Rocklin`_ +- Fix diagnostics page for memory_limit=None (:pr:`2770`) `Brett Naul`_ + + 2.0.0 - 2019-06-25 ------------------ From f8af742fba80451b0db281eebe515951c53de9d4 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Thu, 27 Jun 2019 10:50:11 -0500 Subject: [PATCH 0340/1550] CLN: Use dask.utils.format_bytes (#2810) --- distributed/client.py | 3 +-- distributed/dashboard/scheduler.py | 3 ++- distributed/dashboard/scheduler_html.py | 3 ++- distributed/dashboard/worker.py | 3 ++- distributed/deploy/cluster.py | 10 ++------ distributed/tests/test_utils.py | 5 ++++ distributed/tests/test_worker.py | 3 ++- distributed/utils.py | 32 +++---------------------- distributed/utils_perf.py | 3 ++- distributed/worker.py | 2 +- 10 files changed, 22 insertions(+), 45 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 7ad897bf616..a5e2b7f6103 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -28,7 +28,7 @@ from dask.core import flatten, get_dependencies from dask.optimization import SubgraphCallable from dask.compatibility import apply, unicode -from dask.utils import ensure_dict +from dask.utils import ensure_dict, format_bytes try: from cytoolz import first, groupby, merge, valmap, keymap @@ -82,7 +82,6 @@ log_errors, str_graph, key_split, - format_bytes, asciitable, thread_state, no_default, diff --git a/distributed/dashboard/scheduler.py b/distributed/dashboard/scheduler.py index bed13950fc0..7cce430346a 100644 --- a/distributed/dashboard/scheduler.py +++ b/distributed/dashboard/scheduler.py @@ -36,6 +36,7 @@ from bokeh.transform import factor_cmap from bokeh.io import curdoc import dask +from dask.utils import format_bytes from toolz import pipe, merge from tornado import escape @@ -55,7 +56,7 @@ from .worker import SystemMonitor, counters_doc from .utils import transpose, BOKEH_VERSION, without_property_validation from ..metrics import time -from ..utils import log_errors, format_bytes, format_time +from ..utils import log_errors, format_time from ..diagnostics.progress_stream import color_of, progress_quads, nbytes_bar from ..diagnostics.progress import AllProgress from ..diagnostics.graph_layout import GraphLayout diff --git a/distributed/dashboard/scheduler_html.py b/distributed/dashboard/scheduler_html.py index 65a89b33fbb..08829241d47 100644 --- a/distributed/dashboard/scheduler_html.py +++ b/distributed/dashboard/scheduler_html.py @@ -1,10 +1,11 @@ from datetime import datetime +from dask.utils import format_bytes import toolz from tornado import escape from tornado import gen -from ..utils import log_errors, format_bytes, format_time +from ..utils import log_errors, format_time from .proxy import GlobalProxyHandler from .utils import RequestHandler, redirect diff --git a/distributed/dashboard/worker.py b/distributed/dashboard/worker.py index c6633a170aa..d8f8adc1c7d 100644 --- a/distributed/dashboard/worker.py +++ b/distributed/dashboard/worker.py @@ -22,6 +22,7 @@ from bokeh.plotting import figure from bokeh.palettes import RdBu from bokeh.themes import Theme +from dask.utils import format_bytes from toolz import merge, partition_all from .components import ( @@ -35,7 +36,7 @@ from ..compatibility import WINDOWS from ..diagnostics.progress_stream import color_of from ..metrics import time -from ..utils import log_errors, key_split, format_bytes, format_time +from ..utils import log_errors, key_split, format_time logger = logging.getLogger(__name__) diff --git a/distributed/deploy/cluster.py b/distributed/deploy/cluster.py index 74d61a995e1..866910784e4 100644 --- a/distributed/deploy/cluster.py +++ b/distributed/deploy/cluster.py @@ -4,19 +4,13 @@ from weakref import ref import dask +from dask.utils import format_bytes from tornado import gen from .adaptive import Adaptive from ..compatibility import get_thread_identity -from ..utils import ( - format_bytes, - PeriodicCallback, - log_errors, - ignoring, - sync, - thread_state, -) +from ..utils import PeriodicCallback, log_errors, ignoring, sync, thread_state logger = logging.getLogger(__name__) diff --git a/distributed/tests/test_utils.py b/distributed/tests/test_utils.py index b82dce4e7d9..df98bbe59e1 100644 --- a/distributed/tests/test_utils.py +++ b/distributed/tests/test_utils.py @@ -543,3 +543,8 @@ def test_warn_on_duration(): assert record assert any("foo" in str(rec.message) for rec in record) + + +def test_format_bytes_compat(): + # moved to dask, but exported here for compatibility + from distributed.utils import format_bytes # noqa diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 0f462d7de05..562a0e037b7 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -15,6 +15,7 @@ import dask from dask import delayed +from dask.utils import format_bytes import pytest from toolz import pluck, sliding_window, first import tornado @@ -28,7 +29,7 @@ from distributed.scheduler import Scheduler from distributed.metrics import time from distributed.worker import Worker, error_message, logger, parse_memory_limit -from distributed.utils import tmpfile, format_bytes +from distributed.utils import tmpfile from distributed.utils_test import ( inc, mul, diff --git a/distributed/utils.py b/distributed/utils.py index e8de0bc5108..46982fd6324 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -36,6 +36,9 @@ import dask from dask import istask + +# provide format_bytes here for backwards compatibility +from dask.utils import format_bytes # noqa import toolz import tornado from tornado import gen @@ -1112,35 +1115,6 @@ def __reduce__(self): return (itemgetter, (self.index,)) -def format_bytes(n): - """ Format bytes as text - - >>> format_bytes(1) - '1 B' - >>> format_bytes(1234) - '1.23 kB' - >>> format_bytes(12345678) - '12.35 MB' - >>> format_bytes(1234567890) - '1.23 GB' - >>> format_bytes(1234567890000) - '1.23 TB' - >>> format_bytes(1234567890000000) - '1.23 PB' - """ - if n > 1e15: - return "%0.2f PB" % (n / 1e15) - if n > 1e12: - return "%0.2f TB" % (n / 1e12) - if n > 1e9: - return "%0.2f GB" % (n / 1e9) - if n > 1e6: - return "%0.2f MB" % (n / 1e6) - if n > 1e3: - return "%0.2f kB" % (n / 1000) - return "%d B" % n - - byte_sizes = { "kB": 10 ** 3, "MB": 10 ** 6, diff --git a/distributed/utils_perf.py b/distributed/utils_perf.py index 9f300c5f567..b1f65256c1e 100644 --- a/distributed/utils_perf.py +++ b/distributed/utils_perf.py @@ -5,9 +5,10 @@ import logging import threading +from dask.utils import format_bytes + from .compatibility import PY2, PYPY from .metrics import thread_time -from .utils import format_bytes logger = _logger = logging.getLogger(__name__) diff --git a/distributed/worker.py b/distributed/worker.py index bd4907372d5..582564e3f6d 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -19,6 +19,7 @@ import dask from dask.core import istask from dask.compatibility import apply +from dask.utils import format_bytes try: from cytoolz import pluck, partial, merge, first @@ -59,7 +60,6 @@ thread_state, json_load_robust, key_split, - format_bytes, PeriodicCallback, parse_bytes, parse_timedelta, From bf65f7afccc952fa0d16d2974323e34a438934e1 Mon Sep 17 00:00:00 2001 From: Jim Crist Date: Sat, 29 Jun 2019 10:06:27 -0500 Subject: [PATCH 0341/1550] Add HTTPS support for the dashboard (#2812) Adds optional HTTPS support for the scheduler dashboard. This is only available via the configuration file, by setting the following fields: - `distributed.scheduler.tls.cert`: the certificate file - `distributed.scheduler.tls.key`: the key file, optional if the key file is concatenated with the cert above - `distributed.scheduler.tls.ca-file`: the CA file, optional These certs *may* be the same as those used for the scheduler/worker/client communication, but aren't required to be. The user is responsible for making this decision and providing the proper configuration. Likewise, the user is responsible for providing trusted certificates, or understanding the security implications of telling their browser "I understand the risks, trust this certificate" (this is more likely, given the transient nature of dask clusters). The generated dashboard links now format on an optional `scheme` parameter, which is either `http` or `https`, depending on if the TLS configuration fields above are configured. --- distributed/client.py | 4 +- distributed/dashboard/scheduler.py | 19 +++++++ .../dashboard/tests/test_scheduler_bokeh.py | 52 +++++++++++++++++-- distributed/deploy/cluster.py | 14 +++-- distributed/distributed.yaml | 7 ++- distributed/tests/test_client.py | 9 ++-- distributed/utils.py | 9 ++++ 7 files changed, 99 insertions(+), 15 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index a5e2b7f6103..ec564694c49 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -91,6 +91,7 @@ shutting_down, Any, has_keyword, + format_dashboard_link, ) from .versions import get_versions @@ -818,8 +819,7 @@ def _repr_html_(self): host = "localhost" else: host = rest.split(":")[0] - template = dask.config.get("distributed.dashboard.link") - address = template.format(host=host, port=port, **os.environ) + address = format_dashboard_link(host, port) text += ( "
      • Dashboard: %(web)s\n" % {"web": address} diff --git a/distributed/dashboard/scheduler.py b/distributed/dashboard/scheduler.py index 7cce430346a..013edb39ace 100644 --- a/distributed/dashboard/scheduler.py +++ b/distributed/dashboard/scheduler.py @@ -1597,6 +1597,25 @@ def __init__(self, scheduler, io_loop=None, prefix="", **kwargs): self.server_kwargs = kwargs + # TLS configuration + http_server_kwargs = kwargs.setdefault("http_server_kwargs", {}) + tls_key = dask.config.get("distributed.scheduler.dashboard.tls.key") + tls_cert = dask.config.get("distributed.scheduler.dashboard.tls.cert") + tls_ca_file = dask.config.get("distributed.scheduler.dashboard.tls.ca-file") + if tls_cert and "ssl_options" not in http_server_kwargs: + import ssl + + ctx = ssl.create_default_context( + cafile=tls_ca_file, purpose=ssl.Purpose.SERVER_AUTH + ) + ctx.load_cert_chain(tls_cert, keyfile=tls_key) + # Unlike the client/scheduler/worker TLS handling, we don't care + # about authenticating the user's webclient, TLS here is just for + # encryption. Disable these checks. + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_NONE + http_server_kwargs["ssl_options"] = ctx + self.server_kwargs["prefix"] = prefix or None self.apps = { diff --git a/distributed/dashboard/tests/test_scheduler_bokeh.py b/distributed/dashboard/tests/test_scheduler_bokeh.py index 4d60f304876..3c7f85dc89a 100644 --- a/distributed/dashboard/tests/test_scheduler_bokeh.py +++ b/distributed/dashboard/tests/test_scheduler_bokeh.py @@ -2,6 +2,7 @@ import json import re +import ssl import sys from time import sleep @@ -10,13 +11,13 @@ pytest.importorskip("bokeh") from toolz import first from tornado import gen -from tornado.httpclient import AsyncHTTPClient +from tornado.httpclient import AsyncHTTPClient, HTTPRequest from dask.core import flatten -from distributed.utils import tokey +from distributed.utils import tokey, format_dashboard_link from distributed.client import wait from distributed.metrics import time -from distributed.utils_test import gen_cluster, inc, dec, slowinc, div +from distributed.utils_test import gen_cluster, inc, dec, slowinc, div, get_cert from distributed.dashboard.worker import Counters, BokehWorker from distributed.dashboard.scheduler import ( BokehScheduler, @@ -655,3 +656,48 @@ async def test_lots_of_tasks(c, s, a, b): await wait(futures) ts.update() assert "lambda" in str(ts.source.data) + + +@gen_cluster( + client=True, + scheduler_kwargs={"services": {("dashboard", 0): BokehScheduler}}, + config={ + "distributed.scheduler.dashboard.tls.key": get_cert("tls-key.pem"), + "distributed.scheduler.dashboard.tls.cert": get_cert("tls-cert.pem"), + "distributed.scheduler.dashboard.tls.ca-file": get_cert("tls-ca-cert.pem"), + }, +) +def test_https_support(c, s, a, b): + assert isinstance(s.services["dashboard"], BokehScheduler) + port = s.services["dashboard"].port + + assert ( + format_dashboard_link("localhost", port) == "https://localhost:%d/status" % port + ) + + ctx = ssl.create_default_context() + ctx.load_verify_locations(get_cert("tls-ca-cert.pem")) + + http_client = AsyncHTTPClient() + for suffix in [ + "system", + "counters", + "workers", + "status", + "tasks", + "stealing", + "graph", + "individual-task-stream", + "individual-progress", + "individual-graph", + "individual-nbytes", + "individual-nprocessing", + "individual-profile", + ]: + req = HTTPRequest( + url="https://localhost:%d/%s" % (port, suffix), ssl_options=ctx + ) + response = yield http_client.fetch(req) + body = response.body.decode() + assert "bokeh" in body.lower() + assert not re.search("href=./", body) # no absolute links diff --git a/distributed/deploy/cluster.py b/distributed/deploy/cluster.py index 866910784e4..9819c1ad017 100644 --- a/distributed/deploy/cluster.py +++ b/distributed/deploy/cluster.py @@ -1,16 +1,21 @@ from datetime import timedelta import logging -import os from weakref import ref -import dask from dask.utils import format_bytes from tornado import gen from .adaptive import Adaptive from ..compatibility import get_thread_identity -from ..utils import PeriodicCallback, log_errors, ignoring, sync, thread_state +from ..utils import ( + PeriodicCallback, + log_errors, + ignoring, + sync, + thread_state, + format_dashboard_link, +) logger = logging.getLogger(__name__) @@ -80,10 +85,9 @@ def scheduler_address(self): @property def dashboard_link(self): - template = dask.config.get("distributed.dashboard.link") host = self.scheduler.address.split("://")[1].split(":")[0] port = self.scheduler.services["dashboard"].port - return template.format(host=host, port=port, **os.environ) + return format_dashboard_link(host, port) def scale(self, n): """ Scale cluster to n workers diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index 235c735946c..e5bd3dd3140 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -27,6 +27,11 @@ distributed: task-stream-length: 1000 tasks: task-stream-length: 100000 + tls: + ca-file: null + key: null + cert: null + worker: blocked-handlers: [] @@ -88,7 +93,7 @@ distributed: ################### dashboard: - link: "http://{host}:{port}/status" + link: "{scheme}://{host}:{port}/status" export-tool: False ################## diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index dca108b57ee..37d5550941f 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -5224,20 +5224,21 @@ def test_quiet_scheduler_loss(c, s): assert "BrokenPipeError" not in text -@pytest.mark.skipif("USER" not in os.environ, reason="no USER env variable") -def test_diagnostics_link_env_variable(loop): +def test_dashboard_link(loop, monkeypatch): pytest.importorskip("bokeh") from distributed.dashboard import BokehScheduler + monkeypatch.setenv("USER", "myusername") + with cluster( scheduler_kwargs={"services": {("dashboard", 12355): BokehScheduler}} ) as (s, [a, b]): with Client(s["address"], loop=loop) as c: with dask.config.set( - {"distributed.dashboard.link": "http://foo-{USER}:{port}/status"} + {"distributed.dashboard.link": "{scheme}://foo-{USER}:{port}/status"} ): text = c._repr_html_() - link = "http://foo-" + os.environ["USER"] + ":12355/status" + link = "http://foo-myusername:12355/status" assert link in text diff --git a/distributed/utils.py b/distributed/utils.py index 46982fd6324..afe43040bd9 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -1483,3 +1483,12 @@ def typename(typ): return typ.__module__ + "." + typ.__name__ except AttributeError: return str(typ) + + +def format_dashboard_link(host, port): + template = dask.config.get("distributed.dashboard.link") + if dask.config.get("distributed.scheduler.dashboard.tls.cert"): + scheme = "https" + else: + scheme = "http" + return template.format(scheme=scheme, host=host, port=port, **os.environ) From fdc94d113c7fd2e29b530d9c277e76f2bf06a8d9 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sun, 30 Jun 2019 10:56:22 +0100 Subject: [PATCH 0342/1550] Relax check for worker references in cluster context manager (#2813) --- distributed/client.py | 5 +++-- distributed/utils_test.py | 7 ++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index ec564694c49..3c15e68fe99 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -3278,9 +3278,10 @@ def _profile( if plot == "save" and not filename: filename = "dask-profile.html" - from bokeh.plotting import save + if filename: + from bokeh.plotting import save - save(figure, title="Dask Profile", filename=filename) + save(figure, title="Dask Profile", filename=filename) raise gen.Return((state, figure)) else: diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 293cf5c0737..77568bb7595 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -732,9 +732,10 @@ def cluster( client.close() start = time() - while len(ws): - sleep(0.1) - assert time() < start + 3, ("Workers still around after two seconds", list(ws)) + while any(proc.is_alive() for proc in ws): + text = str(list(ws)) + sleep(0.2) + assert time() < start + 5, ("Workers still around after five seconds", text) @gen.coroutine From 9aa0ea60c6ba0fa8dd65b2a8005d3dbf3d65db75 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sun, 30 Jun 2019 21:00:24 +0100 Subject: [PATCH 0343/1550] Use Keyword-only arguments (#2814) Previously for functions with both `*args` and `**kwargs` inputs we often explicitly popped off values explicitly. Now that we no longer support Python 2 we can use these as keyword arguments directly. * Fix pytest 5.0 issues --- distributed/client.py | 90 +++++++++++++++----------- distributed/deploy/cluster.py | 6 +- distributed/diagnostics/progressbar.py | 6 +- distributed/recreate_exceptions.py | 4 +- distributed/tests/test_client.py | 46 +++++-------- distributed/utils.py | 17 ++--- distributed/worker.py | 10 +-- 7 files changed, 85 insertions(+), 94 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 3c15e68fe99..8509d9e31a9 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -748,20 +748,20 @@ def asynchronous(self): """ return self._asynchronous and self.loop is IOLoop.current() - def sync(self, func, *args, **kwargs): - asynchronous = kwargs.pop("asynchronous", None) + def sync(self, func, *args, asynchronous=None, callback_timeout=None, **kwargs): if ( asynchronous or self.asynchronous or getattr(thread_state, "asynchronous", False) ): - callback_timeout = kwargs.pop("callback_timeout", None) future = func(*args, **kwargs) if callback_timeout is not None: future = gen.with_timeout(timedelta(seconds=callback_timeout), future) return future else: - return sync(self.loop, func, *args, **kwargs) + return sync( + self.loop, func, *args, callback_timeout=callback_timeout, **kwargs + ) def __repr__(self): # Note: avoid doing I/O here... @@ -1351,7 +1351,23 @@ def get_executor(self, **kwargs): """ return ClientExecutor(self, **kwargs) - def submit(self, func, *args, **kwargs): + def submit( + self, + func, + *args, + key=None, + workers=None, + resources=None, + retries=None, + priority=0, + fifo_timeout="100 ms", + allow_other_workers=False, + actor=False, + actors=False, + pure=None, + **kwargs + ): + """ Submit a function application to the scheduler Parameters @@ -1393,15 +1409,9 @@ def submit(self, func, *args, **kwargs): if not callable(func): raise TypeError("First input to submit must be a callable function") - key = kwargs.pop("key", None) - workers = kwargs.pop("workers", None) - resources = kwargs.pop("resources", None) - retries = kwargs.pop("retries", None) - priority = kwargs.pop("priority", 0) - fifo_timeout = kwargs.pop("fifo_timeout", "100ms") - allow_other_workers = kwargs.pop("allow_other_workers", False) - actor = kwargs.pop("actor", kwargs.pop("actors", False)) - pure = kwargs.pop("pure", not actor) + actor = actor or actors + if pure is None: + pure = not actor if allow_other_workers not in (True, False, None): raise TypeError("allow_other_workers= must be True or False") @@ -1452,7 +1462,22 @@ def submit(self, func, *args, **kwargs): return futures[skey] - def map(self, func, *iterables, **kwargs): + def map( + self, + func, + *iterables, + key=None, + workers=None, + retries=None, + resources=None, + priority=0, + allow_other_workers=False, + fifo_timeout="100 ms", + actor=False, + actors=False, + pure=None, + **kwargs + ): """ Map a function on a sequence of arguments Arguments can be normal objects or Futures @@ -1494,6 +1519,11 @@ def map(self, func, *iterables, **kwargs): -------- Client.submit: Submit a single function """ + key = key or funcname(func) + actor = actor or actors + if pure is None: + pure = not actor + if not callable(func): raise TypeError("First input to map must be a callable function") @@ -1505,17 +1535,6 @@ def map(self, func, *iterables, **kwargs): "Consider using a normal for loop and Client.submit" ) - key = kwargs.pop("key", None) - key = key or funcname(func) - workers = kwargs.pop("workers", None) - retries = kwargs.pop("retries", None) - resources = kwargs.pop("resources", None) - user_priority = kwargs.pop("priority", 0) - allow_other_workers = kwargs.pop("allow_other_workers", False) - fifo_timeout = kwargs.pop("fifo_timeout", "100ms") - actor = kwargs.pop("actor", kwargs.pop("actors", False)) - pure = kwargs.pop("pure", not actor) - if allow_other_workers and workers is None: raise ValueError("Only use allow_other_workers= if using workers=") @@ -1581,7 +1600,7 @@ def map(self, func, *iterables, **kwargs): else: loose_restrictions = set() - priority = dict(zip(keys, range(len(keys)))) + internal_priority = dict(zip(keys, range(len(keys)))) if resources: resources = {k: resources for k in keys} @@ -1593,10 +1612,10 @@ def map(self, func, *iterables, **kwargs): keys, restrictions, loose_restrictions, - priority=priority, + priority=internal_priority, resources=resources, retries=retries, - user_priority=user_priority, + user_priority=priority, fifo_timeout=fifo_timeout, actors=actor, ) @@ -2051,7 +2070,7 @@ def retry(self, futures, asynchronous=None): return self.sync(self._retry, futures, asynchronous=asynchronous) @gen.coroutine - def _publish_dataset(self, *args, **kwargs): + def _publish_dataset(self, *args, name=None, **kwargs): with log_errors(): coroutines = [] @@ -2063,7 +2082,6 @@ def add_coro(name, data): ) ) - name = kwargs.pop("name", None) if name: if len(args) == 0: raise ValueError( @@ -2179,8 +2197,7 @@ def get_dataset(self, name, **kwargs): return self.sync(self._get_dataset, name, **kwargs) @gen.coroutine - def _run_on_scheduler(self, function, *args, **kwargs): - wait = kwargs.pop("wait", True) + def _run_on_scheduler(self, function, *args, wait=True, **kwargs): response = yield self.scheduler.run_function( function=dumps(function), args=dumps(args), kwargs=dumps(kwargs), wait=wait ) @@ -2222,10 +2239,7 @@ def run_on_scheduler(self, function, *args, **kwargs): return self.sync(self._run_on_scheduler, function, *args, **kwargs) @gen.coroutine - def _run(self, function, *args, **kwargs): - nanny = kwargs.pop("nanny", False) - workers = kwargs.pop("workers", None) - wait = kwargs.pop("wait", True) + def _run(self, function, *args, nanny=False, workers=None, wait=True, **kwargs): responses = yield self.scheduler.broadcast( msg=dict( op="run", @@ -2582,6 +2596,7 @@ def compute( priority=0, fifo_timeout="60s", actors=None, + traverse=True, **kwargs ): """ Compute dask collections on cluster @@ -2645,7 +2660,6 @@ def compute( collections = [collections] singleton = True - traverse = kwargs.pop("traverse", True) if traverse: collections = tuple( dask.delayed(a) diff --git a/distributed/deploy/cluster.py b/distributed/deploy/cluster.py index 9819c1ad017..d48f27603ff 100644 --- a/distributed/deploy/cluster.py +++ b/distributed/deploy/cluster.py @@ -234,9 +234,9 @@ def asynchronous(self): and self.loop._thread_identity == get_thread_identity() ) - def sync(self, func, *args, **kwargs): - if kwargs.pop("asynchronous", None) or self.asynchronous: - callback_timeout = kwargs.pop("callback_timeout", None) + def sync(self, func, *args, asynchronous=None, callback_timeout=None, **kwargs): + asynchronous = asynchronous or self.asynchronous + if asynchronous: future = func(*args, **kwargs) if callback_timeout is not None: future = gen.with_timeout(timedelta(seconds=callback_timeout), future) diff --git a/distributed/diagnostics/progressbar.py b/distributed/diagnostics/progressbar.py index 8a381562f27..f25bf32a871 100644 --- a/distributed/diagnostics/progressbar.py +++ b/distributed/diagnostics/progressbar.py @@ -405,7 +405,7 @@ def _draw_bar(self, remaining, all, status, **kwargs): ) -def progress(*futures, **kwargs): +def progress(*futures, notebook=None, multi=True, complete=True, **kwargs): """ Track progress of futures This operates differently in the notebook and the console @@ -436,10 +436,6 @@ def progress(*futures, **kwargs): >>> progress(futures) # doctest: +SKIP [########################################] | 100% Completed | 1.7s """ - notebook = kwargs.pop("notebook", None) - multi = kwargs.pop("multi", True) - complete = kwargs.pop("complete", True) - futures = futures_of(futures) if not isinstance(futures, (set, list)): futures = [futures] diff --git a/distributed/recreate_exceptions.py b/distributed/recreate_exceptions.py index d5351bb4d59..78b0f4de9ba 100644 --- a/distributed/recreate_exceptions.py +++ b/distributed/recreate_exceptions.py @@ -23,7 +23,7 @@ def __init__(self, scheduler): self.scheduler.handlers["cause_of_failure"] = self.cause_of_failure self.scheduler.extensions["exceptions"] = self - def cause_of_failure(self, *args, **kwargs): + def cause_of_failure(self, *args, keys=(), **kwargs): """ Return details of first failed task required by set of keys @@ -38,8 +38,6 @@ def cause_of_failure(self, *args, **kwargs): task: the definition of that key deps: keys that the task depends on """ - - keys = kwargs.pop("keys", []) for key in keys: if isinstance(key, list): key = tuple(key) # ensure not a list from msgpack diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 37d5550941f..8fa14269a1f 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -197,18 +197,15 @@ def test_map_retries(c, s, a, b): x, y, z = c.map(*map_varying(args), retries=1, pure=False) assert (yield x) == 2 assert (yield y) == 4 - with pytest.raises(ZeroDivisionError) as exc_info: + with pytest.raises(ZeroDivisionError, match="eight"): yield z - exc_info.match("eight") x, y, z = c.map(*map_varying(args), retries=0, pure=False) - with pytest.raises(ZeroDivisionError) as exc_info: + with pytest.raises(ZeroDivisionError, match="one"): yield x - exc_info.match("one") assert (yield y) == 4 - with pytest.raises(ZeroDivisionError) as exc_info: + with pytest.raises(ZeroDivisionError, match="seven"): yield z - exc_info.match("seven") @gen_cluster(client=True) @@ -217,15 +214,13 @@ def test_compute_retries(c, s, a, b): # Sanity check for varying() use x = c.compute(delayed(varying(args))()) - with pytest.raises(ZeroDivisionError) as exc_info: + with pytest.raises(ZeroDivisionError, match="one"): yield x - exc_info.match("one") # Same retries for all x = c.compute(delayed(varying(args))(), retries=1) - with pytest.raises(ZeroDivisionError) as exc_info: + with pytest.raises(ZeroDivisionError, match="two"): yield x - exc_info.match("two") x = c.compute(delayed(varying(args))(), retries=2) assert (yield x) == 3 @@ -244,16 +239,14 @@ def test_compute_retries(c, s, a, b): gc.collect() assert (yield x) == 30 - with pytest.raises(ZeroDivisionError) as exc_info: + with pytest.raises(ZeroDivisionError, match="five"): yield y - exc_info.match("five") x, y, z = [delayed(varying(args))() for args in (xargs, yargs, zargs)] x, y, z = c.compute([x, y, z], retries={(y, z): 2}) - with pytest.raises(ZeroDivisionError) as exc_info: + with pytest.raises(ZeroDivisionError, match="one"): yield x - exc_info.match("one") assert (yield y) == 70 assert (yield z) == 80 @@ -276,15 +269,13 @@ def test_compute_persisted_retries(c, s, a, b): # Sanity check x = c.persist(delayed(varying(args))()) fut = c.compute(x) - with pytest.raises(ZeroDivisionError) as exc_info: + with pytest.raises(ZeroDivisionError, match="one"): yield fut - exc_info.match("one") x = c.persist(delayed(varying(args))()) fut = c.compute(x, retries=1) - with pytest.raises(ZeroDivisionError) as exc_info: + with pytest.raises(ZeroDivisionError, match="two"): yield fut - exc_info.match("two") x = c.persist(delayed(varying(args))()) fut = c.compute(x, retries=2) @@ -303,9 +294,8 @@ def test_persist_retries(c, s, a, b): x = c.persist(delayed(varying(args))(), retries=1) x = c.compute(x) - with pytest.raises(ZeroDivisionError) as exc_info: + with pytest.raises(ZeroDivisionError, match="two"): yield x - exc_info.match("two") x = c.persist(delayed(varying(args))(), retries=2) x = c.compute(x) @@ -320,9 +310,8 @@ def test_persist_retries(c, s, a, b): x, y, z = c.persist([x, y, z], retries={(y, z): 2}) x, y, z = c.compute([x, y, z]) - with pytest.raises(ZeroDivisionError) as exc_info: + with pytest.raises(ZeroDivisionError, match="one"): yield x - exc_info.match("one") assert (yield y) == 70 assert (yield z) == 80 @@ -2575,9 +2564,8 @@ def test_run_coroutine(c, s, a, b): results = yield c.run(geninc, 1, workers=[]) assert results == {} - with pytest.raises(RuntimeError) as exc_info: + with pytest.raises(RuntimeError, match="hello"): yield c.run(throws, 1) - assert "hello" in str(exc_info) if sys.version_info >= (3, 5): results = yield c.run(asyncinc, 2, delay=0.01) @@ -2603,9 +2591,8 @@ def raise_exception(exc_type, exc_msg): raise exc_type(exc_msg) for exc_type in [ValueError, RuntimeError]: - with pytest.raises(exc_type) as excinfo: + with pytest.raises(exc_type, match="informative message"): c.run(raise_exception, exc_type, "informative message") - assert "informative message" in str(excinfo.value) def test_diagnostic_ui(loop): @@ -4420,16 +4407,15 @@ def test_recreate_error_sync(c): tot = c.submit(sum, x, y) f = c.compute(tot) - with pytest.raises(ZeroDivisionError) as e: + with pytest.raises(ZeroDivisionError): c.recreate_error_locally(f) assert f.status == "error" def test_recreate_error_not_error(c): f = c.submit(dec, 2) - with pytest.raises(ValueError) as e: + with pytest.raises(ValueError, match="No errored futures passed"): c.recreate_error_locally(f) - assert "No errored futures passed" in str(e) @gen_cluster(client=True) @@ -4497,7 +4483,7 @@ def __call__(self, *args): return 1 future = c.submit(Foo(), 1) - with pytest.raises(MyException) as e: + with pytest.raises(MyException): yield future futures = c.map(inc, range(10)) diff --git a/distributed/utils.py b/distributed/utils.py index afe43040bd9..2f4657439cf 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -283,7 +283,7 @@ def quiet(): raise gen.Return(results) -def sync(loop, func, *args, **kwargs): +def sync(loop, func, *args, callback_timeout=None, **kwargs): """ Run coroutine in loop running in separate thread. """ @@ -299,8 +299,6 @@ def sync(loop, func, *args, **kwargs): except AttributeError: pass - timeout = kwargs.pop("callback_timeout", None) - e = threading.Event() main_tid = get_thread_identity() result = [None] @@ -314,8 +312,8 @@ def f(): yield gen.moment thread_state.asynchronous = True future = func(*args, **kwargs) - if timeout is not None: - future = gen.with_timeout(timedelta(seconds=timeout), future) + if callback_timeout is not None: + future = gen.with_timeout(timedelta(seconds=callback_timeout), future) result[0] = yield future except Exception as exc: error[0] = sys.exc_info() @@ -324,9 +322,9 @@ def f(): e.set() loop.add_callback(f) - if timeout is not None: - if not e.wait(timeout): - raise gen.TimeoutError("timed out after %s s." % (timeout,)) + if callback_timeout is not None: + if not e.wait(callback_timeout): + raise gen.TimeoutError("timed out after %s s." % (callback_timeout,)) else: while not e.is_set(): e.wait(10) @@ -1352,8 +1350,7 @@ class DequeHandler(logging.Handler): _instances = weakref.WeakSet() - def __init__(self, *args, **kwargs): - n = kwargs.pop("n", 10000) + def __init__(self, *args, n=10000, **kwargs): self.deque = deque(maxlen=n) super(DequeHandler, self).__init__(*args, **kwargs) self._instances.add(self) diff --git a/distributed/worker.py b/distributed/worker.py index 582564e3f6d..4ce564f1b7b 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -308,6 +308,8 @@ def __init__( nanny=None, plugins=(), low_level_profiler=dask.config.get("distributed.worker.profile.low-level"), + validate=False, + profile_cycle_interval=None, **kwargs ): self.tasks = dict() @@ -369,7 +371,7 @@ def __init__( self.target_message_size = 50e6 # 50 MB self.log = deque(maxlen=100000) - self.validate = kwargs.pop("validate", False) + self.validate = validate self._transitions = { ("waiting", "ready"): self.transition_waiting_ready, @@ -404,10 +406,8 @@ def __init__( self.latency = 0.001 self._client = None - profile_cycle_interval = kwargs.pop( - "profile_cycle_interval", - dask.config.get("distributed.worker.profile.cycle"), - ) + if profile_cycle_interval is None: + profile_cycle_interval = dask.config.get("distributed.worker.profile.cycle") profile_cycle_interval = parse_timedelta(profile_cycle_interval, default="ms") self._setup_logging(logger) From b538d246d9f766f1e96a2635dce48f3f38fa6011 Mon Sep 17 00:00:00 2001 From: tjb900 Date: Tue, 2 Jul 2019 14:57:18 +0800 Subject: [PATCH 0344/1550] Fix case where key, rather than TaskState, could end up in ts.waiting_on (#2819) --- distributed/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index d370705e9af..66d8fdaac90 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -3559,7 +3559,7 @@ def transition_no_worker_waiting(self, key): for dts in ts.dependencies: dep = dts.key if not dts.who_has: - ts.waiting_on.add(dep) + ts.waiting_on.add(dts) if dts.state == "released": recommendations[dep] = "waiting" else: From e42c124787bcaa82ff18b892450e58a34e3af92d Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 2 Jul 2019 10:33:48 +0100 Subject: [PATCH 0345/1550] Fix Client repr with memory_info=None (#2816) --- distributed/client.py | 16 ++++++++++------ distributed/tests/test_client.py | 5 +++++ 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 8509d9e31a9..74e33716cb6 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -801,7 +801,7 @@ def _repr_html_(self): info = sync(self.loop, self.scheduler.identity) scheduler = self.scheduler else: - info = False + info = self._scheduler_identity scheduler = self.scheduler if scheduler is not None: @@ -828,10 +828,14 @@ def _repr_html_(self): text += "\n" if info: - workers = len(info["workers"]) - cores = sum(w["nthreads"] for w in info["workers"].values()) - memory = sum(w["memory_limit"] for w in info["workers"].values()) - memory = format_bytes(memory) + workers = list(info["workers"].values()) + cores = sum(w["nthreads"] for w in workers) + if all(isinstance(w["memory_limit"], Number) for w in workers): + memory = sum(w["memory_limit"] for w in workers) + memory = format_bytes(memory) + else: + memory = "" + text2 = ( "

        Cluster

        \n" "
          \n" @@ -839,7 +843,7 @@ def _repr_html_(self): "
        • Cores: %d
        • \n" "
        • Memory: %s
        • \n" "
        \n" - ) % (workers, cores, memory) + ) % (len(workers), cores, memory) return ( '
      • Worker Name Cores Memory Memory use
        {{ws.address}} {{ ws.name if ws.name is not None else "" }} {{ ws.ncores }} {{ format_bytes(ws.memory_limit) }} {{ format_time(ws.occupancy) }} {{ len(ws.processing) }} {{ len(ws.has_what) }} bokeh dashboard
        {{ws.address}} {{ ws.name if ws.name is not None else "" }} {{ ws.ncores }} {{ ws.nthreads }} {{ format_bytes(ws.memory_limit) }} {{ format_time(ws.occupancy) }} {{ws.address}} {{ ws.name if ws.name is not None else "" }} {{ ws.nthreads }} {{ format_bytes(ws.memory_limit) }} {{ format_bytes(ws.memory_limit) if ws.memory_limit is not None else "" }} {{ format_time(ws.occupancy) }} {{ len(ws.processing) }}
        \n' diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 8fa14269a1f..1f9678583b0 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -1892,6 +1892,11 @@ def test_repr_async(c, s, a, b): c._repr_html_() +@gen_cluster(client=True, worker_kwargs={"memory_limit": None}) +def test_repr_no_memory_limit(c, s, a, b): + c._repr_html_() + + @gen_test() def test_repr_localcluster(): cluster = yield LocalCluster( From bda0ab6f35be915d5d74660c90195912dd7e5355 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Tue, 2 Jul 2019 12:00:11 -0500 Subject: [PATCH 0346/1550] Updates to use update_graph in task journey docs (#2821) --- docs/source/journey.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/journey.rst b/docs/source/journey.rst index 9dd1e15d62b..dd7e60e8d42 100644 --- a/docs/source/journey.rst +++ b/docs/source/journey.rst @@ -41,15 +41,15 @@ Step 2: Arrive in the Scheduler A few milliseconds later, the scheduler receives this message on an open socket. The scheduler updates its state with this little graph that shows how to compute -``z``.:: +``z``:: - scheduler.tasks.update(msg['tasks']) + scheduler.update_graph(tasks=msg['tasks'], keys=msg['keys']) The scheduler also updates *a lot* of other state. Notably, it has to identify that ``x`` and ``y`` are themselves variables, and connect all of those dependencies. This is a long and detail oriented process that involves updating roughly 10 sets and dictionaries. Interested readers should -investigate ``distributed/scheduler.py::update_state()``. While this is fairly +investigate ``distributed/scheduler.py::update_graph()``. While this is fairly complex and tedious to describe rest assured that it all happens in constant time and in about a millisecond. From eca16ed89e40563bab4a01c7cb7e41db30fc10a8 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 5 Jul 2019 10:23:51 +0100 Subject: [PATCH 0347/1550] Remove dask-mpi (#2824) This has moved to the dask-mpi project documented at https://mpi.dask.org Fixes #2823 --- distributed/cli/dask_mpi.py | 147 ------------------------- distributed/cli/tests/test_dask_mpi.py | 106 ------------------ setup.py | 1 - 3 files changed, 254 deletions(-) delete mode 100644 distributed/cli/dask_mpi.py delete mode 100644 distributed/cli/tests/test_dask_mpi.py diff --git a/distributed/cli/dask_mpi.py b/distributed/cli/dask_mpi.py deleted file mode 100644 index 7b9aeaca213..00000000000 --- a/distributed/cli/dask_mpi.py +++ /dev/null @@ -1,147 +0,0 @@ -from functools import partial - -import click -from mpi4py import MPI -from tornado.ioloop import IOLoop -from tornado import gen -from warnings import warn - -from distributed import Scheduler, Nanny, Worker -from distributed.dashboard import BokehWorker -from distributed.cli.utils import check_python_3 -from distributed.comm.addressing import uri_from_host_port -from distributed.utils import get_ip_interface - - -comm = MPI.COMM_WORLD -rank = comm.Get_rank() -loop = IOLoop() - - -@click.command() -@click.option( - "--scheduler-file", - type=str, - default="scheduler.json", - help="Filename to JSON encoded scheduler information. ", -) -@click.option( - "--interface", type=str, default=None, help="Network interface like 'eth0' or 'ib0'" -) -@click.option("--nthreads", type=int, default=0, help="Number of threads per worker.") -@click.option( - "--memory-limit", - default="auto", - help="Number of bytes before spilling data to disk. " - "This can be an integer (nbytes) " - "float (fraction of total memory) " - "or 'auto'", -) -@click.option( - "--local-directory", default="", type=str, help="Directory to place worker files" -) -@click.option( - "--scheduler/--no-scheduler", - default=True, - help=( - "Whether or not to include a scheduler. " - "Use --no-scheduler to increase an existing dask cluster" - ), -) -@click.option( - "--nanny/--no-nanny", - default=True, - help="Start workers in nanny process for management", -) -@click.option( - "--bokeh-port", type=int, default=8787, help="Bokeh port for visual diagnostics" -) -@click.option( - "--bokeh-worker-port", - type=int, - default=8789, - help="Worker's Bokeh port for visual diagnostics", -) -@click.option("--bokeh-prefix", type=str, default=None, help="Prefix for the bokeh app") -@click.version_option() -def main( - scheduler_file, - interface, - nthreads, - local_directory, - memory_limit, - scheduler, - bokeh_port, - bokeh_prefix, - nanny, - bokeh_worker_port, -): - if interface: - host = get_ip_interface(interface) - else: - host = None - - if rank == 0 and scheduler: - try: - from distributed.dashboard import BokehScheduler - except ImportError: - services = {} - else: - services = { - ("dashboard", bokeh_port): partial(BokehScheduler, prefix=bokeh_prefix) - } - scheduler = Scheduler( - scheduler_file=scheduler_file, loop=loop, services=services - ) - addr = uri_from_host_port(host, None, 8786) - scheduler.start(addr) - try: - loop.start() - loop.close() - finally: - scheduler.stop() - else: - W = Nanny if nanny else Worker - worker = W( - scheduler_file=scheduler_file, - loop=loop, - name=rank if scheduler else None, - nthreads=nthreads, - local_dir=local_directory, - services={("dashboard", bokeh_worker_port): BokehWorker}, - memory_limit=memory_limit, - ) - addr = uri_from_host_port(host, None, 0) - - @gen.coroutine - def run(): - yield worker._start(addr) - while worker.status != "closed": - yield gen.sleep(0.2) - - try: - loop.run_sync(run) - loop.close() - finally: - pass - - @gen.coroutine - def close(): - yield worker._close(timeout=2) - - loop.run_sync(close) - - -def go(): - check_python_3() - warn( - "The dask-mpi command line utility in the `distributed` " - "package is deprecated. " - "Please install the `dask-mpi` package instead. " - "More information is available at https://mpi.dask.org" - ) - main() - - -if __name__ == "__main__": - go() diff --git a/distributed/cli/tests/test_dask_mpi.py b/distributed/cli/tests/test_dask_mpi.py deleted file mode 100644 index 89f1140bfab..00000000000 --- a/distributed/cli/tests/test_dask_mpi.py +++ /dev/null @@ -1,106 +0,0 @@ -from __future__ import print_function, division, absolute_import - -import subprocess -from time import sleep - -import pytest - -pytest.importorskip("mpi4py") - -import requests -from click.testing import CliRunner - -from distributed import Client -from distributed.utils import tmpfile -from distributed.metrics import time -from distributed.utils_test import popen -from distributed.utils_test import loop # noqa: F401 -from distributed.cli.dask_remote import main - - -@pytest.mark.parametrize("nanny", ["--nanny", "--no-nanny"]) -def test_basic(loop, nanny): - with tmpfile() as fn: - with popen( - ["mpirun", "--np", "4", "dask-mpi", "--scheduler-file", fn, nanny], - stdin=subprocess.DEVNULL, - ): - with Client(scheduler_file=fn) as c: - - start = time() - while len(c.scheduler_info()["workers"]) != 3: - assert time() < start + 10 - sleep(0.2) - - assert c.submit(lambda x: x + 1, 10, workers=1).result() == 11 - - -def test_no_scheduler(loop): - with tmpfile() as fn: - with popen( - ["mpirun", "--np", "2", "dask-mpi", "--scheduler-file", fn], - stdin=subprocess.DEVNULL, - ): - with Client(scheduler_file=fn) as c: - - start = time() - while len(c.scheduler_info()["workers"]) != 1: - assert time() < start + 10 - sleep(0.2) - - assert c.submit(lambda x: x + 1, 10).result() == 11 - with popen( - [ - "mpirun", - "--np", - "1", - "dask-mpi", - "--scheduler-file", - fn, - "--no-scheduler", - ] - ): - - start = time() - while len(c.scheduler_info()["workers"]) != 2: - assert time() < start + 10 - sleep(0.2) - - -def test_bokeh(loop): - with tmpfile() as fn: - with popen( - [ - "mpirun", - "--np", - "2", - "dask-mpi", - "--scheduler-file", - fn, - "--bokeh-port", - "59583", - "--bokeh-worker-port", - "59584", - ], - stdin=subprocess.DEVNULL, - ): - - for port in [59853, 59584]: - start = time() - while True: - try: - response = requests.get("http://localhost:%d/status/" % port) - assert response.ok - break - except Exception: - sleep(0.1) - assert time() < start + 20 - - with pytest.raises(Exception): - requests.get("http://localhost:59583/status/") - - -def test_version_option(): - runner = CliRunner() - result = runner.invoke(main, ["--version"]) - assert result.exit_code == 0 diff --git a/setup.py b/setup.py index 125ddc9c328..84054d199e0 100755 --- a/setup.py +++ b/setup.py @@ -67,7 +67,6 @@ dask-remote=distributed.cli.dask_remote:go dask-scheduler=distributed.cli.dask_scheduler:go dask-worker=distributed.cli.dask_worker:go - dask-mpi=distributed.cli.dask_mpi:go """, zip_safe=False, ) From 776bb6b745448ca64b8e5ad99d2f7fab7bbd8a8e Mon Sep 17 00:00:00 2001 From: Russ Bubley Date: Fri, 5 Jul 2019 19:36:18 +0100 Subject: [PATCH 0348/1550] Fix typo that prevented error message (#2825) --- distributed/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/core.py b/distributed/core.py index 79c726eed6d..6f08c17ac77 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -369,7 +369,7 @@ def handle_comm(self, comm, shutting_down=shutting_down): op = msg.pop("op") except KeyError: raise ValueError( - "Received unexpected message without 'op' key: " % str(msg) + "Received unexpected message without 'op' key: " + str(msg) ) if self.counters is not None: self.counters["op"].add(op) From 99ac4550d0231ae81076307dda8b821b51bc5792 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Mon, 8 Jul 2019 10:47:33 -0500 Subject: [PATCH 0349/1550] bump version to 2.1.0 --- docs/source/changelog.rst | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index 6162935b70d..cb98ab79fc7 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -1,6 +1,20 @@ Changelog ========= +2.1.0 - 2019-07-08 +------------------ + +- Fix typo that prevented error message (:pr:`2825`) `Russ Bubley`_ +- Remove ``dask-mpi`` (:pr:`2824`) `Matthew Rocklin`_ +- Updates to use ``update_graph`` in task journey docs (:pr:`2821`) `James Bourbeau`_ +- Fix Client repr with ``memory_info=None`` (:pr:`2816`) `Matthew Rocklin`_ +- Fix case where key, rather than ``TaskState``, could end up in ``ts.waiting_on`` (:pr:`2819`) `tjb900`_ +- Use Keyword-only arguments (:pr:`2814`) `Matthew Rocklin`_ +- Relax check for worker references in cluster context manager (:pr:`2813`) `Matthew Rocklin`_ +- Add HTTPS support for the dashboard (:pr:`2812`) `Jim Crist`_ +- Use ``dask.utils.format_bytes`` (:pr:`2810`) `Tom Augspurger`_ + + 2.0.1 - 2019-06-26 ------------------ @@ -1084,7 +1098,6 @@ significantly without many new features. .. _`Diane Trout`: https://github.com/detrout .. _`tjb900`: https://github.com/tjb900 .. _`Stephan Hoyer`: https://github.com/shoyer -.. _`tjb900`: https://github.com/tjb900 .. _`Dirk Petersen`: https://github.com/dirkpetersen .. _`Daniel Farrell`: https://github.com/danpf .. _`George Sakkis`: https://github.com/gsakkis From 21370fa8f3904f548731bb93111e0d730a8e80da Mon Sep 17 00:00:00 2001 From: Russ Bubley Date: Wed, 10 Jul 2019 13:23:22 +0100 Subject: [PATCH 0350/1550] Respect security configuration in LocalCluster (#2822) Fixes #2815 --- distributed/deploy/local.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/distributed/deploy/local.py b/distributed/deploy/local.py index 554459e43ac..cb1a1511e20 100644 --- a/distributed/deploy/local.py +++ b/distributed/deploy/local.py @@ -12,6 +12,7 @@ from .spec import SpecCluster from ..nanny import Nanny from ..scheduler import Scheduler +from ..security import Security from ..worker import Worker, parse_memory_limit logger = logging.getLogger(__name__) @@ -123,11 +124,12 @@ def __init__( self.status = None self.processes = processes + security = security or Security() if protocol is None: if host and "://" in host: protocol = host.split("://")[0] - elif security: + elif security and security.require_encryption: protocol = "tls://" elif not self.processes and not scheduler_port: protocol = "inproc://" From 5b31a87b823792c4d8646dd3bf249fee116fa567 Mon Sep 17 00:00:00 2001 From: Christian Hudon Date: Sun, 14 Jul 2019 09:48:34 -0500 Subject: [PATCH 0351/1550] Add Nanny to worker docs (#2826) Fixes #2771 --- distributed/nanny.py | 10 +++++++++- docs/source/resilience.rst | 14 ++++++++------ docs/source/scheduling-state.rst | 4 ++-- docs/source/worker.rst | 9 +++++++++ 4 files changed, 28 insertions(+), 9 deletions(-) diff --git a/distributed/nanny.py b/distributed/nanny.py index f518b330d7c..b6b43116a73 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -42,7 +42,15 @@ class Nanny(ServerNode): """ A process to manage worker processes The nanny spins up Worker processes, watches then, and kills or restarts - them as necessary. + them as necessary. It is necessary if you want to use the + ``Client.restart`` method, or to restart the worker automatically if + it gets to the terminate fractiom of its memory limit. + + The parameters for the Nanny are mostly the same as those for the Worker. + + See Also + -------- + Worker """ _instances = weakref.WeakSet() diff --git a/docs/source/resilience.rst b/docs/source/resilience.rst index f5300d4fbcd..1936d7ee995 100644 --- a/docs/source/resilience.rst +++ b/docs/source/resilience.rst @@ -48,11 +48,12 @@ This has some fail cases. causes a segmentation fault, then that bad function will repeatedly be called on other workers. This function will be marked as "bad" after it kills a fixed number of workers (defaults to three). -3. Data scattered out to the workers is not kept in the scheduler (it is - often quite large) and so the loss of this data is irreparable. You may - wish to call ``Client.replicate`` on the data with a suitable replication - factor to ensure that it remains long-lived or else back the data off of - some resilient store, like a file system. +3. Data sent out directly to the workers via a call to ``scatter()`` (instead + of being created from a Dask task graph via other Dask functions) is not + kept in the scheduler, as it is often quite large, and so the loss of this + data is irreparable. You may wish to call ``Client.replicate`` on the data + with a suitable replication factor to ensure that it remains long-lived or + else back the data off of some resilient store, like a file system. Hardware Failures @@ -81,4 +82,5 @@ The client provides a mechanism to restart all of the workers in the cluster. This is convenient if, during the course of experimentation, you find your workers in an inconvenient state that makes them unresponsive. The ``Client.restart`` method kills all workers, flushes all scheduler state, and -then brings all workers back online, resulting in a clean cluster. +then brings all workers back online, resulting in a clean cluster. This +requires the nanny process (which is started by default). diff --git a/docs/source/scheduling-state.rst b/docs/source/scheduling-state.rst index 515bb26cdb0..4bffd182439 100644 --- a/docs/source/scheduling-state.rst +++ b/docs/source/scheduling-state.rst @@ -131,8 +131,8 @@ Conversely, "saturated" workers may see their workload lightened through Client State ------------ -Information about each individual client is kept in a :class:`ClientState` -object: +Information about each individual client of the scheduler is kept +in a :class:`ClientState` object: .. autoclass:: ClientState diff --git a/docs/source/worker.rst b/docs/source/worker.rst index 530a27b9505..be288ccf68c 100644 --- a/docs/source/worker.rst +++ b/docs/source/worker.rst @@ -233,6 +233,15 @@ YARN, Mesos, SGE, etc..). After termination the nanny will restart the worker in a fresh state. +Nanny +~~~~~ + +Dask workers are by default launched, monitored, and managed by a small Nanny +process. + +.. autoclass:: distributed.worker.Nanny + + API Documentation ----------------- From d4934986301b5c944e6bd049912086ac2dfb928f Mon Sep 17 00:00:00 2001 From: tjb900 Date: Tue, 16 Jul 2019 00:12:58 +0800 Subject: [PATCH 0352/1550] Don't make False add-keys report to scheduler (#2421) Fixes #2420 --- distributed/worker.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/distributed/worker.py b/distributed/worker.py index 4ce564f1b7b..e124ba6ab1f 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1884,11 +1884,6 @@ def gather_dep(self, worker, dep, deps, total_nbytes, cause=None): self.incoming_count += 1 self.log.append(("receive-dep", worker, list(response["data"]))) - - if response["data"]: - self.batched_stream.send( - {"op": "add-keys", "keys": list(response["data"])} - ) except EnvironmentError as e: logger.exception("Worker stream died during communication: %s", worker) self.log.append(("receive-dep-failed", worker)) From e5ec8daab0d2b30702cfce4acb6259c10aef8e05 Mon Sep 17 00:00:00 2001 From: Jacob Tomlinson Date: Mon, 15 Jul 2019 20:33:37 +0100 Subject: [PATCH 0353/1550] Include type name in SpecCluster repr (#2834) --- distributed/deploy/spec.py | 3 ++- distributed/deploy/tests/test_spec_cluster.py | 13 +++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index bb46f81db88..445949dc200 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -315,7 +315,8 @@ async def scale_down(self, workers): scale_up = scale # backwards compatibility def __repr__(self): - return "SpecCluster(%r, workers=%d)" % ( + return "%s(%r, workers=%d)" % ( + type(self).__name__, self.scheduler_address, len(self.workers), ) diff --git a/distributed/deploy/tests/test_spec_cluster.py b/distributed/deploy/tests/test_spec_cluster.py index 0c062d3d3e0..e51e8f14260 100644 --- a/distributed/deploy/tests/test_spec_cluster.py +++ b/distributed/deploy/tests/test_spec_cluster.py @@ -81,6 +81,19 @@ def test_loop_started(): ) +@pytest.mark.asyncio +async def test_repr(): + worker = {"cls": Worker, "options": {"nthreads": 1}} + + class MyCluster(SpecCluster): + pass + + async with MyCluster( + asynchronous=True, scheduler=scheduler, worker=worker + ) as cluster: + assert "MyCluster" in str(cluster) + + @pytest.mark.asyncio async def test_scale(): worker = {"cls": Worker, "options": {"nthreads": 1}} From af64e07a01e8ce0d76744099a93ca2155d835ba8 Mon Sep 17 00:00:00 2001 From: Gabriel Sailer Date: Tue, 16 Jul 2019 17:24:26 +0200 Subject: [PATCH 0354/1550] Extend prometheus metrics endpoint (#2792) (#2833) * Expose tasks prometheus metric at scheduler * Add basic task metrics to worker Number of tasks in states and number of threads are exposed on the workers /metrics endpoints. * Add worker metrics and reformat tasks * Change log mesage in case of missing crick --- distributed/dashboard/scheduler_html.py | 14 +++- .../dashboard/tests/test_worker_bokeh_html.py | 2 +- distributed/dashboard/worker_html.py | 81 +++++++++++++++---- 3 files changed, 77 insertions(+), 20 deletions(-) diff --git a/distributed/dashboard/scheduler_html.py b/distributed/dashboard/scheduler_html.py index 08829241d47..3f119a929b9 100644 --- a/distributed/dashboard/scheduler_html.py +++ b/distributed/dashboard/scheduler_html.py @@ -184,14 +184,24 @@ def collect(self): yield GaugeMetricFamily( "dask_scheduler_workers", - "Number of workers.", + "Number of workers connected.", value=len(self.server.workers), ) yield GaugeMetricFamily( "dask_scheduler_clients", - "Number of clients.", + "Number of clients connected.", value=len(self.server.clients), ) + yield GaugeMetricFamily( + "dask_scheduler_received_tasks", + "Number of tasks received at scheduler", + value=len(self.server.tasks), + ) + yield GaugeMetricFamily( + "dask_scheduler_unrunnable_tasks", + "Number of unrunnable tasks at scheduler", + value=len(self.server.unrunnable), + ) class PrometheusHandler(RequestHandler): diff --git a/distributed/dashboard/tests/test_worker_bokeh_html.py b/distributed/dashboard/tests/test_worker_bokeh_html.py index 99916b3fdc7..7a4d70a037c 100644 --- a/distributed/dashboard/tests/test_worker_bokeh_html.py +++ b/distributed/dashboard/tests/test_worker_bokeh_html.py @@ -25,7 +25,7 @@ def test_prometheus(c, s, a, b): txt = response.body.decode("utf8") families = {familiy.name for familiy in text_string_to_metric_families(txt)} - assert len(families) > 0 + assert "dask_worker_latency_seconds" in families @gen_cluster(client=True, worker_kwargs={"services": {("dashboard", 0): BokehWorker}}) diff --git a/distributed/dashboard/worker_html.py b/distributed/dashboard/worker_html.py index e1ae50f3afc..27e1f9fe9d2 100644 --- a/distributed/dashboard/worker_html.py +++ b/distributed/dashboard/worker_html.py @@ -1,23 +1,72 @@ +import logging from .utils import RequestHandler, redirect class _PrometheusCollector(object): - def __init__(self, server, prometheus_client): - self.server = server + def __init__(self, server): + self.worker = server + self.logger = logging.getLogger("distributed.dask_worker") + self.crick_available = True + try: + import crick # noqa: F401 + except ImportError: + self.crick_available = False + self.logger.info( + "Not all prometheus metrics available are exported. Digest-based metrics require crick to be installed" + ) def collect(self): - # add your metrics here: - # - # 1. remove the following lines - while False: - yield None - # - # 2. yield your metrics - # yield prometheus_client.core.GaugeMetricFamily( - # 'dask_worker_connections', - # 'Number of connections currently open.', - # value=???, - # ) + from prometheus_client.core import GaugeMetricFamily + + tasks = GaugeMetricFamily( + "dask_worker_tasks", "Number of tasks at worker.", labels=["state"] + ) + tasks.add_metric(["stored"], len(self.worker.data)) + tasks.add_metric(["ready"], len(self.worker.ready)) + tasks.add_metric(["waiting"], len(self.worker.waiting_for_data)) + tasks.add_metric(["serving"], len(self.worker._comms)) + yield tasks + + yield GaugeMetricFamily( + "dask_worker_connections", + "Number of task connections to other workers.", + value=len(self.worker.in_flight_workers), + ) + + yield GaugeMetricFamily( + "dask_worker_threads", + "Number of worker threads.", + value=self.worker.nthreads, + ) + + yield GaugeMetricFamily( + "dask_worker_latency_seconds", + "Latency of worker connection.", + value=self.worker.latency, + ) + + # all metrics using digests require crick to be installed + # the following metrics will export NaN, if the corresponding digests are None + if self.crick_available: + yield GaugeMetricFamily( + "dask_worker_tick_duration_median_seconds", + "Median tick duration at worker.", + value=self.worker.digests["tick-duration"].components[1].quantile(50), + ) + + yield GaugeMetricFamily( + "dask_worker_task_duration_median_seconds", + "Median task runtime at worker.", + value=self.worker.digests["task-duration"].components[1].quantile(50), + ) + + yield GaugeMetricFamily( + "dask_worker_transfer_bandwidth_median_bytes", + "Bandwidth for transfer at worker in Bytes.", + value=self.worker.digests["transfer-bandwidth"] + .components[1] + .quantile(50), + ) class PrometheusHandler(RequestHandler): @@ -31,9 +80,7 @@ def __init__(self, *args, **kwargs): if PrometheusHandler._initialized: return - prometheus_client.REGISTRY.register( - _PrometheusCollector(self.server, prometheus_client) - ) + prometheus_client.REGISTRY.register(_PrometheusCollector(self.server)) PrometheusHandler._initialized = True From df2addc62be91fc017b429c947afa8acd1a64127 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 17 Jul 2019 21:59:23 -0500 Subject: [PATCH 0355/1550] Add alternative SSHCluster implementation (#2827) This is a proof of concept here for two reasons: 1. It opens up a possible alternative for SSH deployment (which was surprisingly popular in the user survey) 2. It is the first non-local application of `SpecCluster` and so serves as a proof of concept for other future deployments that are mostly defined by creating a remote Worker/Scheduler object This forced some changes in `SpecCluster`, notably we now have an `rpc` object that does remote calls rather than accessing the scheduler directly. Also, we're going to have to figure out how to handle all of the keyword arguments. In this case we need to pass them from Python down to the CLI, and presumably we'll also want a `dask-ssh` CLI command which has to translate the other way. --- .travis.yml | 1 + continuous_integration/travis/install.sh | 1 + continuous_integration/travis/setup-ssh.sh | 2 + distributed/deploy/local.py | 1 + distributed/deploy/spec.py | 30 ++-- distributed/deploy/ssh2.py | 171 +++++++++++++++++++++ distributed/deploy/tests/test_ssh2.py | 17 ++ 7 files changed, 214 insertions(+), 9 deletions(-) create mode 100644 continuous_integration/travis/setup-ssh.sh create mode 100644 distributed/deploy/ssh2.py create mode 100644 distributed/deploy/tests/test_ssh2.py diff --git a/.travis.yml b/.travis.yml index bcc09351eff..35f4383748e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -19,6 +19,7 @@ matrix: install: - if [[ $TESTS == true ]]; then source continuous_integration/travis/install.sh ; fi + - if [[ $TESTS == true ]]; then source continuous_integration/travis/setup-ssh.sh ; fi script: - if [[ $TESTS == true ]]; then source continuous_integration/travis/run_tests.sh ; fi diff --git a/continuous_integration/travis/install.sh b/continuous_integration/travis/install.sh index 2ab9724db25..b2fab6afb52 100644 --- a/continuous_integration/travis/install.sh +++ b/continuous_integration/travis/install.sh @@ -67,6 +67,7 @@ pip install -q git+https://github.com/dask/s3fs.git --upgrade --no-deps pip install -q git+https://github.com/dask/zict.git --upgrade --no-deps pip install -q sortedcollections msgpack --no-deps pip install -q keras --upgrade --no-deps +pip install -q asyncssh if [[ $CRICK == true ]]; then conda install -q cython diff --git a/continuous_integration/travis/setup-ssh.sh b/continuous_integration/travis/setup-ssh.sh new file mode 100644 index 00000000000..f102612bc96 --- /dev/null +++ b/continuous_integration/travis/setup-ssh.sh @@ -0,0 +1,2 @@ +ssh-keygen -t rsa -f ~/.ssh/id_rsa -N "" -q +cat ~/.ssh/id_rsa.pub >> ~/.ssh/authorized_keys \ No newline at end of file diff --git a/distributed/deploy/local.py b/distributed/deploy/local.py index cb1a1511e20..5b0aec4e80c 100644 --- a/distributed/deploy/local.py +++ b/distributed/deploy/local.py @@ -200,6 +200,7 @@ def __init__( loop=loop, asynchronous=asynchronous, silence_logs=silence_logs, + security=security, ) def __repr__(self): diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index 445949dc200..6b95e2107fe 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -5,8 +5,10 @@ from tornado import gen from .cluster import Cluster +from ..core import rpc from ..utils import LoopRunner, silence_logging, ignoring from ..scheduler import Scheduler +from ..security import Security class SpecCluster(Cluster): @@ -107,6 +109,7 @@ def __init__( worker=None, asynchronous=False, loop=None, + security=None, silence_logs=False, ): self._created = weakref.WeakSet() @@ -125,6 +128,8 @@ def __init__( self.workers = {} self._i = 0 self._asynchronous = asynchronous + self.security = security or Security() + self.scheduler_comm = None if silence_logs: self._old_logging_level = silence_logging(level=silence_logs) @@ -156,6 +161,10 @@ async def _start(self): self._lock = asyncio.Lock() self.status = "starting" self.scheduler = await self.scheduler + self.scheduler_comm = rpc( + self.scheduler.address, + connection_args=self.security.get_connection_args("client"), + ) self.status = "running" def _correct_state(self): @@ -174,11 +183,13 @@ async def _correct_state_internal(self): pre = list(set(self.workers)) to_close = set(self.workers) - set(self.worker_spec) if to_close: - await self.scheduler.retire_workers(workers=list(to_close)) + if self.scheduler.status == "running": + await self.scheduler_comm.retire_workers(workers=list(to_close)) tasks = [self.workers[w].close() for w in to_close] await asyncio.wait(tasks) for task in tasks: # for tornado gen.coroutine support - await task + with ignoring(RuntimeError): + await task for name in to_close: del self.workers[name] @@ -214,11 +225,10 @@ async def _(): return _().__await__() async def _wait_for_workers(self): - # TODO: this function needs to query scheduler and worker state - # remotely without assuming that they are local - while {d["name"] for d in self.scheduler.identity()["workers"].values()} != set( - self.workers - ): + while { + str(d["name"]) + for d in (await self.scheduler_comm.identity())["workers"].values() + } != set(map(str, self.workers)): if ( any(w.status == "closed" for w in self.workers.values()) and self.scheduler.status == "running" @@ -240,12 +250,14 @@ async def _close(self): return self.status = "closing" - async with self._lock: - await self.scheduler.close(close_workers=True) self.scale(0) await self._correct_state() + async with self._lock: + await self.scheduler_comm.close(close_workers=True) + await self.scheduler.close() for w in self._created: assert w.status == "closed" + self.scheduler_comm.close_rpc() if hasattr(self, "_old_logging_level"): silence_logging(self._old_logging_level) diff --git a/distributed/deploy/ssh2.py b/distributed/deploy/ssh2.py new file mode 100644 index 00000000000..0f8823cdab8 --- /dev/null +++ b/distributed/deploy/ssh2.py @@ -0,0 +1,171 @@ +import asyncio +import logging +import sys +import warnings +import weakref + +import asyncssh + +from .spec import SpecCluster + +logger = logging.getLogger(__name__) + +warnings.warn( + "the distributed.deploy.ssh2 module is experimental " + "and will move/change in the future without notice" +) + + +class Process: + """ A superclass for SSH Workers and Nannies + + See Also + -------- + Worker + Scheduler + """ + + def __init__(self): + self.lock = asyncio.Lock() + self.connection = None + self.proc = None + self.status = "created" + + def __await__(self): + async def _(): + async with self.lock: + if not self.connection: + await self.start() + assert self.connection + weakref.finalize(self, self.proc.terminate) + return self + + return _().__await__() + + async def close(self): + self.proc.terminate() + self.connection.close() + self.status = "closed" + + def __repr__(self): + return "" % (type(self).__name__, self.status) + + +class Worker(Process): + """ A Remote Dask Worker controled by SSH + + Parameters + ---------- + scheduler: str + The address of the scheduler + address: str + The hostname where we should run this worker + connect_kwargs: dict + kwargs to be passed to asyncssh connections + kwargs: + TODO + """ + + def __init__(self, scheduler: str, address: str, connect_kwargs: dict, **kwargs): + self.address = address + self.scheduler = scheduler + self.connect_kwargs = connect_kwargs + self.kwargs = kwargs + + super().__init__() + + async def start(self): + self.connection = await asyncssh.connect(self.address, **self.connect_kwargs) + self.proc = await self.connection.create_process( + " ".join( + [ + sys.executable, + "-m", + "distributed.cli.dask_worker", + self.scheduler, + "--name", # we need to have name for SpecCluster + str(self.kwargs["name"]), + ] + ) + ) + + # We watch stderr in order to get the address, then we return + while True: + line = await self.proc.stderr.readline() + if "worker at" in line: + self.address = line.split("worker at:")[1].strip() + self.status = "running" + break + logger.debug("%s", line) + + +class Scheduler(Process): + """ A Remote Dask Scheduler controled by SSH + + Parameters + ---------- + address: str + The hostname where we should run this worker + connect_kwargs: dict + kwargs to be passed to asyncssh connections + kwargs: + TODO + """ + + def __init__(self, address: str, connect_kwargs: dict, **kwargs): + self.address = address + self.kwargs = kwargs + self.connect_kwargs = connect_kwargs + + super().__init__() + + async def start(self): + logger.debug("Created Scheduler Connection") + + self.connection = await asyncssh.connect(self.address, **self.connect_kwargs) + + self.proc = await self.connection.create_process( + " ".join([sys.executable, "-m", "distributed.cli.dask_scheduler"]) + ) + + # We watch stderr in order to get the address, then we return + while True: + line = await self.proc.stderr.readline() + if "Scheduler at" in line: + self.address = line.split("Scheduler at:")[1].strip() + break + logger.debug("%s", line) + + +def SSHCluster(hosts, connect_kwargs, **kwargs): + """ Deploy a Dask cluster using SSH + + Parameters + ---------- + hosts: List[str] + List of hostnames or addresses on which to launch our cluster + The first will be used for the scheduler and the rest for workers + connect_kwargs: + known_hosts: List[str] or None + The list of keys which will be used to validate the server host + key presented during the SSH handshake. If this is not specified, + the keys will be looked up in the file .ssh/known_hosts. If this + is explicitly set to None, server host key validation will be disabled. + TODO + kwargs: + TODO + ---- + This doesn't handle any keyword arguments yet. It is a proof of concept + """ + scheduler = { + "cls": Scheduler, + "options": {"address": hosts[0], "connect_kwargs": connect_kwargs}, + } + workers = { + i: { + "cls": Worker, + "options": {"address": host, "connect_kwargs": connect_kwargs}, + } + for i, host in enumerate(hosts[1:]) + } + return SpecCluster(workers, scheduler, **kwargs) diff --git a/distributed/deploy/tests/test_ssh2.py b/distributed/deploy/tests/test_ssh2.py new file mode 100644 index 00000000000..beb1c6ef91e --- /dev/null +++ b/distributed/deploy/tests/test_ssh2.py @@ -0,0 +1,17 @@ +import pytest + +pytest.importorskip("asyncssh") + +from dask.distributed import Client +from distributed.deploy.ssh2 import SSHCluster + + +@pytest.mark.asyncio +async def test_basic(): + async with SSHCluster( + ["127.0.0.1"] * 3, connect_kwargs=dict(known_hosts=None), asynchronous=True + ) as cluster: + assert len(cluster.workers) == 2 + async with Client(cluster, asynchronous=True) as client: + result = await client.submit(lambda x: x + 1, 10) + assert result == 11 From 27e8e6548e7a8401c9199a55c7dea2fa7331cb04 Mon Sep 17 00:00:00 2001 From: Pierre Glaser Date: Thu, 18 Jul 2019 08:27:09 +0200 Subject: [PATCH 0356/1550] Dont reuse closed worker in get_worker (#2841) --- distributed/deploy/tests/test_local.py | 25 ++++++++++++++++++++++++- distributed/worker.py | 2 +- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/distributed/deploy/tests/test_local.py b/distributed/deploy/tests/test_local.py index 520996f64a0..ebe2d7ec99e 100644 --- a/distributed/deploy/tests/test_local.py +++ b/distributed/deploy/tests/test_local.py @@ -14,7 +14,7 @@ from tornado import gen import pytest -from distributed import Client, Worker, Nanny +from distributed import Client, Worker, Nanny, get_client from distributed.deploy.local import LocalCluster, nprocesses_nthreads from distributed.metrics import time from distributed.utils_test import ( @@ -831,3 +831,26 @@ def test_starts_up_sync(loop): assert len(cluster.scheduler.workers) == 2 finally: cluster.close() + + +def test_dont_select_closed_worker(): + # Make sure distributed does not try to reuse a client from a + # closed cluster (https://github.com/dask/distributed/issues/2840). + with clean(threads=False): + cluster = LocalCluster(n_workers=0) + c = Client(cluster) + cluster.scale(2) + assert c == get_client() + + c.close() + cluster.close() + + cluster2 = LocalCluster(n_workers=0) + c2 = Client(cluster2) + cluster2.scale(2) + + current_client = get_client() + assert c2 == current_client + + cluster2.close() + c2.close() diff --git a/distributed/worker.py b/distributed/worker.py index e124ba6ab1f..60884ebaba3 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -2852,7 +2852,7 @@ def get_worker(): return thread_state.execution_state["worker"] except AttributeError: try: - return first(Worker._instances) + return first(w for w in Worker._instances if w.status == "running") except StopIteration: raise ValueError("No workers found") From 0ada76c91bbf35f1e783159b63d3effe01c803dd Mon Sep 17 00:00:00 2001 From: Jacob Tomlinson Date: Thu, 18 Jul 2019 15:21:25 +0100 Subject: [PATCH 0357/1550] SpecCluster: move init logic into start (#2850) Move the scheduler creation from `__init__` to `_start`. This allows clusters to make async calls within subclassed `_start` methods before the scheduler object is created. This also ignores exceptions from closing the scheduler if the scheduler has already timed out. --- distributed/deploy/spec.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index 6b95e2107fe..0d2d2d37021 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -5,7 +5,7 @@ from tornado import gen from .cluster import Cluster -from ..core import rpc +from ..core import rpc, CommClosedError from ..utils import LoopRunner, silence_logging, ignoring from ..scheduler import Scheduler from ..security import Security @@ -113,14 +113,6 @@ def __init__( silence_logs=False, ): self._created = weakref.WeakSet() - if scheduler is None: - try: - from distributed.dashboard import BokehScheduler - except ImportError: - services = {} - else: - services = {("dashboard", 8787): BokehScheduler} - scheduler = {"cls": Scheduler, "options": {"services": services}} self.scheduler_spec = scheduler self.worker_spec = workers or {} @@ -137,9 +129,6 @@ def __init__( self._loop_runner = LoopRunner(loop=loop, asynchronous=asynchronous) self.loop = self._loop_runner.loop - self.scheduler = self.scheduler_spec["cls"]( - loop=self.loop, **self.scheduler_spec["options"] - ) self.status = "created" self._instances.add(self) self._correct_state_waiting = None @@ -158,6 +147,18 @@ async def _start(self): if self.status == "closed": raise ValueError("Cluster is closed") + if self.scheduler_spec is None: + try: + from distributed.dashboard import BokehScheduler + except ImportError: + services = {} + else: + services = {("dashboard", 8787): BokehScheduler} + self.scheduler_spec = {"cls": Scheduler, "options": {"services": services}} + self.scheduler = self.scheduler_spec["cls"]( + loop=self.loop, **self.scheduler_spec["options"] + ) + self._lock = asyncio.Lock() self.status = "starting" self.scheduler = await self.scheduler @@ -253,7 +254,8 @@ async def _close(self): self.scale(0) await self._correct_state() async with self._lock: - await self.scheduler_comm.close(close_workers=True) + with ignoring(CommClosedError): + await self.scheduler_comm.close(close_workers=True) await self.scheduler.close() for w in self._created: assert w.status == "closed" From 36b7585c7c84470c7ada34f1d0272f1696a6ecbf Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Fri, 19 Jul 2019 16:20:54 -0400 Subject: [PATCH 0358/1550] Document distributed.Reschedule in API docs (#2860) --- docs/source/api.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/api.rst b/docs/source/api.rst index adefe5b86c4..8d739334b07 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -167,6 +167,7 @@ Other .. autofunction:: distributed.get_client .. autofunction:: distributed.secede .. autofunction:: distributed.rejoin +.. autoclass:: distributed.Reschedule .. autoclass:: get_task_stream .. autoclass:: Lock From 4c105056af643898d06ecd98ab2802cde495ac64 Mon Sep 17 00:00:00 2001 From: Martin Durant Date: Fri, 19 Jul 2019 16:51:13 -0400 Subject: [PATCH 0359/1550] Add fsspec to installation of test builds (#2859) --- continuous_integration/setup_conda_environment.cmd | 1 + continuous_integration/travis/install.sh | 1 + 2 files changed, 2 insertions(+) diff --git a/continuous_integration/setup_conda_environment.cmd b/continuous_integration/setup_conda_environment.cmd index 3df89fa85fe..6fff1a5ca6a 100644 --- a/continuous_integration/setup_conda_environment.cmd +++ b/continuous_integration/setup_conda_environment.cmd @@ -40,6 +40,7 @@ call deactivate tblib ^ tornado=5 ^ zict ^ + fsspec ^ -c conda-forge call activate %CONDA_ENV% diff --git a/continuous_integration/travis/install.sh b/continuous_integration/travis/install.sh index b2fab6afb52..82993032e0b 100644 --- a/continuous_integration/travis/install.sh +++ b/continuous_integration/travis/install.sh @@ -63,6 +63,7 @@ pip install -q "pytest>=4" pytest-repeat pytest-faulthandler pytest-asyncio pip install -q git+https://github.com/dask/dask.git --upgrade --no-deps pip install -q git+https://github.com/joblib/joblib.git --upgrade --no-deps +pip install -q git+https://github.com/intake/filesystem_spec.git --upgrade --no-deps pip install -q git+https://github.com/dask/s3fs.git --upgrade --no-deps pip install -q git+https://github.com/dask/zict.git --upgrade --no-deps pip install -q sortedcollections msgpack --no-deps From 96ff5d3409c0146759cdc68a43d86c96e159daa4 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sat, 20 Jul 2019 13:19:10 -0500 Subject: [PATCH 0360/1550] Make await/start more consistent across Scheduler/Worker/Nanny (#2831) Now every ServerNode has a start async method that returns itself. And the __await__ method is handled in the superclass --- distributed/cli/tests/test_dask_worker.py | 35 ++++----- distributed/comm/addressing.py | 10 ++- distributed/deploy/spec.py | 2 +- distributed/deploy/tests/test_spec_cluster.py | 13 +++- distributed/nanny.py | 63 ++++++---------- distributed/node.py | 10 +++ distributed/scheduler.py | 36 +++------- distributed/tests/test_core.py | 10 +-- distributed/tests/test_failed_workers.py | 23 +++--- distributed/tests/test_nanny.py | 6 +- distributed/tests/test_scheduler.py | 36 ++++------ distributed/tests/test_stress.py | 5 +- distributed/tests/test_utils_test.py | 5 +- distributed/tests/test_worker.py | 71 ++++++++++++++++--- distributed/tests/test_worker_plugins.py | 3 +- distributed/utils_test.py | 13 ++-- distributed/worker.py | 64 ++++------------- 17 files changed, 201 insertions(+), 204 deletions(-) diff --git a/distributed/cli/tests/test_dask_worker.py b/distributed/cli/tests/test_dask_worker.py index b6c7d393e3b..e268229767d 100644 --- a/distributed/cli/tests/test_dask_worker.py +++ b/distributed/cli/tests/test_dask_worker.py @@ -268,29 +268,30 @@ def test_dashboard_non_standard_ports(loop): except ImportError: proxy_exists = False - with popen(["dask-scheduler", "--port", "3449"]): + with popen(["dask-scheduler", "--port", "3449"]) as s: with popen( - ["dask-worker", "tcp://127.0.0.1:3449", "--dashboard-address", ":4833"] + [ + "dask-worker", + "tcp://127.0.0.1:3449", + "--dashboard-address", + ":4833", + "--host", + "127.0.0.1", + ] ) as proc: with Client("127.0.0.1:3449", loop=loop) as c: + c.wait_for_workers(1) pass - start = time() - while True: - try: - response = requests.get("http://127.0.0.1:4833/status") + response = requests.get("http://127.0.0.1:4833/status") + assert response.ok + redirect_resp = requests.get("http://127.0.0.1:4833/main") + redirect_resp.ok + # TEST PROXYING WORKS + if proxy_exists: + url = "http://127.0.0.1:8787/proxy/4833/127.0.0.1/status" + response = requests.get(url) assert response.ok - redirect_resp = requests.get("http://127.0.0.1:4833/main") - redirect_resp.ok - # TEST PROXYING WORKS - if proxy_exists: - url = "http://127.0.0.1:8787/proxy/4833/127.0.0.1/status" - response = requests.get(url) - assert response.ok - break - except Exception: - sleep(0.5) - assert time() < start + 20 with pytest.raises(Exception): requests.get("http://localhost:4833/status/") diff --git a/distributed/comm/addressing.py b/distributed/comm/addressing.py index d707adb84ac..54e37b77f6b 100644 --- a/distributed/comm/addressing.py +++ b/distributed/comm/addressing.py @@ -211,7 +211,13 @@ def uri_from_host_port(host_arg, port_arg, default_port): def address_from_user_args( - host=None, port=None, interface=None, protocol=None, peer=None, security=None + host=None, + port=None, + interface=None, + protocol=None, + peer=None, + security=None, + default_port=0, ): """ Get an address to listen on from common user provided arguments """ if security and security.require_encryption and not protocol: @@ -235,7 +241,7 @@ def address_from_user_args( host = protocol.rstrip("://") + "://" + host if host or port: - addr = uri_from_host_port(host, port, 0) + addr = uri_from_host_port(host, port, default_port) else: addr = "" diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index 0d2d2d37021..441ef10a595 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -156,7 +156,7 @@ async def _start(self): services = {("dashboard", 8787): BokehScheduler} self.scheduler_spec = {"cls": Scheduler, "options": {"services": services}} self.scheduler = self.scheduler_spec["cls"]( - loop=self.loop, **self.scheduler_spec["options"] + loop=self.loop, **self.scheduler_spec.get("options", {}) ) self._lock = asyncio.Lock() diff --git a/distributed/deploy/tests/test_spec_cluster.py b/distributed/deploy/tests/test_spec_cluster.py index e51e8f14260..723c62a80c1 100644 --- a/distributed/deploy/tests/test_spec_cluster.py +++ b/distributed/deploy/tests/test_spec_cluster.py @@ -1,4 +1,4 @@ -from dask.distributed import SpecCluster, Worker, Client, Scheduler +from dask.distributed import SpecCluster, Worker, Client, Scheduler, Nanny from distributed.deploy.spec import close_clusters from distributed.utils_test import loop # noqa: F401 import pytest @@ -153,3 +153,14 @@ def new_worker_spec(self): cluster.scale(3) for i in range(3): assert cluster.worker_spec[i]["options"]["nthreads"] == i + 1 + + +@pytest.mark.asyncio +async def test_nanny_port(): + scheduler = {"cls": Scheduler} + workers = {0: {"cls": Nanny, "options": {"port": 9200}}} + + async with SpecCluster( + scheduler=scheduler, workers=workers, asynchronous=True + ) as cluster: + pass diff --git a/distributed/nanny.py b/distributed/nanny.py index b6b43116a73..00370b81adb 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -17,7 +17,7 @@ from tornado.ioloop import IOLoop, TimeoutError from tornado.locks import Event -from .comm import get_address_host, get_local_address_for, unparse_host_port +from .comm import get_address_host, unparse_host_port from .comm.addressing import address_from_user_args from .core import RPCClosed, CommClosedError, coerce_to_address from .metrics import time @@ -120,6 +120,14 @@ def __init__( self.preload_argv = preload_argv self.Worker = Worker if worker_class is None else worker_class self.env = env or {} + worker_kwargs.update( + { + "port": worker_port, + "interface": interface, + "protocol": protocol, + "host": host, + } + ) self.worker_kwargs = worker_kwargs self.contact_address = contact_address @@ -161,6 +169,13 @@ def __init__( pc = PeriodicCallback(self.memory_monitor, 100, io_loop=self.loop) self.periodic_callbacks["memory"] = pc + if ( + not host + and not interface + and not self.scheduler_addr.startswith("inproc://") + ): + host = get_ip(get_address_host(self.scheduler.address)) + self._start_address = address_from_user_args( host=host, port=port, @@ -208,25 +223,10 @@ def worker_dir(self): return None if self.process is None else self.process.worker_dir @gen.coroutine - def _start(self, addr_or_port=0): + def start(self): """ Start nanny, start local process, start watching """ - addr_or_port = addr_or_port or self._start_address - - # XXX Factor this out - if not addr_or_port: - # Default address is the required one to reach the scheduler - self.listen( - get_local_address_for(self.scheduler.address), - listen_args=self.listen_args, - ) - self.ip = get_address_host(self.address) - elif isinstance(addr_or_port, int): - # addr_or_port is an integer => assume TCP - self.ip = get_ip(get_address_host(self.scheduler.address)) - self.listen((self.ip, addr_or_port), listen_args=self.listen_args) - else: - self.listen(addr_or_port, listen_args=self.listen_args) - self.ip = get_address_host(self.address) + self.listen(self._start_address, listen_args=self.listen_args) + self.ip = get_address_host(self.address) logger.info(" Start Nanny at: %r", self.address) response = yield self.instantiate() @@ -238,13 +238,7 @@ def _start(self, addr_or_port=0): self.start_periodic_callbacks() - raise gen.Return(self) - - def __await__(self): - return self._start().__await__() - - def start(self, addr_or_port=0): - self.loop.add_callback(self._start, addr_or_port) + return self @gen.coroutine def kill(self, comm=None, timeout=2): @@ -295,7 +289,6 @@ def instantiate(self, comm=None): ) worker_kwargs.update(self.worker_kwargs) self.process = WorkerProcess( - worker_args=tuple(), worker_kwargs=worker_kwargs, worker_start_args=(start_arg,), silence_logs=self.silence_logs, @@ -432,18 +425,10 @@ def close(self, comm=None, timeout=5, report=None): class WorkerProcess(object): def __init__( - self, - worker_args, - worker_kwargs, - worker_start_args, - silence_logs, - on_exit, - worker, - env, + self, worker_kwargs, worker_start_args, silence_logs, on_exit, worker, env ): self.status = "init" self.silence_logs = silence_logs - self.worker_args = worker_args self.worker_kwargs = worker_kwargs self.worker_start_args = worker_start_args self.on_exit = on_exit @@ -475,7 +460,6 @@ def start(self): target=self._run, name="Dask Worker process (from Nanny)", kwargs=dict( - worker_args=self.worker_args, worker_kwargs=self.worker_kwargs, worker_start_args=self.worker_start_args, silence_logs=self.silence_logs, @@ -615,7 +599,6 @@ def _wait_until_connected(self, uid): @classmethod def _run( cls, - worker_args, worker_kwargs, worker_start_args, silence_logs, @@ -639,7 +622,7 @@ def _run( IOLoop.clear_instance() loop = IOLoop() loop.make_current() - worker = Worker(*worker_args, **worker_kwargs) + worker = Worker(**worker_kwargs) @gen.coroutine def do_stop(timeout=5, executor_wait=True): @@ -679,7 +662,7 @@ def run(): Try to start worker and inform parent of outcome. """ try: - yield worker._start(*worker_start_args) + yield worker except Exception as e: logger.exception("Failed to start worker") init_result_q.put({"uid": uid, "exception": e}) diff --git a/distributed/node.py b/distributed/node.py index 8bd81ffe5ae..323e2c3e49d 100644 --- a/distributed/node.py +++ b/distributed/node.py @@ -4,6 +4,7 @@ import logging from tornado.ioloop import IOLoop +from tornado import gen import dask from .compatibility import unicode, finalize @@ -159,3 +160,12 @@ async def __aenter__(self): async def __aexit__(self, typ, value, traceback): await self.close() + + def __await__(self): + if self.status == "running": + return gen.sleep(0).__await__() + else: + return self.start().__await__() + + async def start(self): # subclasses should implement this + return self diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 66d8fdaac90..63ae1ef33b2 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -48,7 +48,6 @@ from .utils import ( All, ignoring, - get_ip, get_fileno_limit, log_errors, key_split, @@ -842,7 +841,7 @@ def __init__( idle_timeout=None, interface=None, host=None, - port=8786, + port=0, protocol=None, dashboard_address=None, **kwargs @@ -1098,6 +1097,7 @@ def __init__( interface=interface, protocol=protocol, security=security, + default_port=self.default_port, ) super(Scheduler, self).__init__( @@ -1177,12 +1177,11 @@ def get_worker_service_addr(self, worker, service_name, protocol=False): else: return ws.host, port - def start(self, addr_or_port=None, start_queues=True): + @gen.coroutine + def start(self): """ Clear out old state and restart all running coroutines """ enable_gc_diagnosis() - addr_or_port = addr_or_port or self._start_address - self.clear_task_state() with ignoring(AttributeError): @@ -1196,21 +1195,14 @@ def start(self, addr_or_port=None, start_queues=True): raise exc if self.status != "running": - if isinstance(addr_or_port, int): - # Listen on all interfaces. `get_ip()` is not suitable - # as it would prevent connecting via 127.0.0.1. - self.listen(("", addr_or_port), listen_args=self.listen_args) - self.ip = get_ip() - listen_ip = "" - else: - self.listen(addr_or_port, listen_args=self.listen_args) - self.ip = get_address_host(self.listen_address) - listen_ip = self.ip + self.listen(self._start_address, listen_args=self.listen_args) + self.ip = get_address_host(self.listen_address) + listen_ip = self.ip if listen_ip == "0.0.0.0": listen_ip = "" - if isinstance(addr_or_port, str) and addr_or_port.startswith("inproc://"): + if self._start_address.startswith("inproc://"): listen_ip = "localhost" # Services listen on all addresses @@ -1239,16 +1231,8 @@ def del_scheduler_file(): setproctitle("dask-scheduler [%s]" % (self.address,)) - return self.finished() - - def __await__(self): - self.start() - - @gen.coroutine - def _(): - return self - - return _().__await__() + yield self.finished() + return self @gen.coroutine def finished(self): diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index f53340d1004..38a43a1a5c8 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -373,13 +373,13 @@ def g(): server = Server({"ping": pingpong}) server.listen(listen_arg) - remote = rpc(server.address) - yield [g() for i in range(10)] + with rpc(server.address) as remote: + yield [g() for i in range(10)] - server.stop() + server.stop() - remote.close_comms() - assert all(comm.closed() for comm in remote.comms) + remote.close_comms() + assert all(comm.closed() for comm in remote.comms) @gen_test() diff --git a/distributed/tests/test_failed_workers.py b/distributed/tests/test_failed_workers.py index b39dd3f3ae7..5465a7dd5f0 100644 --- a/distributed/tests/test_failed_workers.py +++ b/distributed/tests/test_failed_workers.py @@ -39,8 +39,7 @@ def test_submit_after_failed_worker_sync(loop): @gen_cluster(client=True, timeout=60, active_rpc_timeout=10) def test_submit_after_failed_worker_async(c, s, a, b): - n = Nanny(s.address, nthreads=2, loop=s.loop) - n.start(0) + n = yield Nanny(s.address, nthreads=2, loop=s.loop) while len(s.workers) < 3: yield gen.sleep(0.1) @@ -128,7 +127,7 @@ def test_failed_worker_without_warning(c, s, a, b): assert all(len(keys) > 0 for keys in s.has_what.values()) nthreads2 = dict(s.nthreads) - yield c._restart() + yield c.restart() L = c.map(inc, range(10)) yield wait(L) @@ -148,7 +147,7 @@ def test_restart(c, s, a, b): assert set(s.who_has) == {x.key, y.key} - f = yield c._restart() + f = yield c.restart() assert f is c assert len(s.workers) == 2 @@ -171,7 +170,7 @@ def test_restart_cleared(c, s, a, b): f = c.compute(x) yield wait([f]) - yield c._restart() + yield c.restart() for coll in [s.tasks, s.unrunnable]: assert not coll @@ -212,7 +211,7 @@ def test_restart_fast(c, s, a, b): L = c.map(sleep, range(10)) start = time() - yield c._restart() + yield c.restart() assert time() - start < 10 assert len(s.nthreads) == 2 @@ -255,7 +254,7 @@ def test_fast_kill(c, s, a, b): L = c.map(sleep, range(10)) start = time() - yield c._restart() + yield c.restart() assert time() - start < 10 assert all(x.status == "cancelled" for x in L) @@ -302,7 +301,7 @@ def test_restart_scheduler(s, a, b): @gen_cluster(Worker=Nanny, client=True, timeout=60) def test_forgotten_futures_dont_clean_up_new_futures(c, s, a, b): x = c.submit(inc, 1) - yield c._restart() + yield c.restart() y = c.submit(inc, 1) del x import gc @@ -315,8 +314,7 @@ def test_forgotten_futures_dont_clean_up_new_futures(c, s, a, b): @gen_cluster(client=True, timeout=60, active_rpc_timeout=10) def test_broken_worker_during_computation(c, s, a, b): s.allowed_failures = 100 - n = Nanny(s.address, nthreads=2, loop=s.loop) - n.start(0) + n = yield Nanny(s.address, nthreads=2, loop=s.loop) start = time() while len(s.nthreads) < 3: @@ -365,7 +363,7 @@ def test_restart_during_computation(c, s, a, b): yield gen.sleep(0.5) assert s.rprocessing - yield c._restart() + yield c.restart() assert not s.rprocessing assert len(s.nthreads) == 2 @@ -374,8 +372,7 @@ def test_restart_during_computation(c, s, a, b): @gen_cluster(client=True, timeout=60) def test_worker_who_has_clears_after_failed_connection(c, s, a, b): - n = Nanny(s.address, nthreads=2, loop=s.loop) - n.start(0) + n = yield Nanny(s.address, nthreads=2, loop=s.loop) start = time() while len(s.nthreads) < 3: diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index 40c8d49012d..7722476a2c5 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -216,8 +216,7 @@ def test_num_fds(s): @gen_cluster(client=True, nthreads=[]) def test_worker_uses_same_host_as_nanny(c, s): for host in ["tcp://0.0.0.0", "tcp://127.0.0.2"]: - n = Nanny(s.address) - yield n._start(host) + n = yield Nanny(s.address, host=host) def func(dask_worker): return dask_worker.listener.listen_address @@ -230,8 +229,7 @@ def func(dask_worker): @gen_test() def test_scheduler_file(): with tmpfile() as fn: - s = Scheduler(scheduler_file=fn) - s.start(8008) + s = yield Scheduler(scheduler_file=fn, port=8008) w = yield Nanny(scheduler_file=fn) assert set(s.workers) == {w.worker_address} yield w.close() diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index e8d2a96ee60..3d29bc79a1e 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -256,7 +256,6 @@ def test_add_worker(s, a, b): w = Worker(s.address, nthreads=3) w.data["x-5"] = 6 w.data["y"] = 1 - yield w dsk = {("x-%d" % i): (inc, i) for i in range(10)} s.update_graph( @@ -265,11 +264,8 @@ def test_add_worker(s, a, b): client="client", dependencies={k: set() for k in dsk}, ) - - s.add_worker( - address=w.address, keys=list(w.data), nthreads=w.nthreads, services=s.services - ) - + s.validate_state() + yield w s.validate_state() assert w.ip in s.host_info @@ -665,7 +661,7 @@ def test_scatter_no_workers(c, s): assert time() < start + 1.5 w = Worker(s.address, nthreads=3) - yield [c.scatter(data={"y": 2}, timeout=5), w._start()] + yield [c.scatter(data={"y": 2}, timeout=5), w] assert w.data["y"] == 2 yield w.close() @@ -1172,7 +1168,7 @@ def test_correct_bad_time_estimate(c, s, *workers): @gen_test() -def test_service_hosts(): +async def test_service_hosts(): pytest.importorskip("bokeh") from distributed.dashboard import BokehScheduler @@ -1184,26 +1180,20 @@ def test_service_hosts(): ]: services = {("dashboard", port): BokehScheduler} - s = Scheduler(services=services) - yield s.start(url) - - sock = first(s.services["dashboard"].server._http._sockets.values()) - if isinstance(expected, tuple): - assert sock.getsockname()[0] in expected - else: - assert sock.getsockname()[0] == expected - yield s.close() + async with Scheduler(host=url, services=services) as s: + sock = first(s.services["dashboard"].server._http._sockets.values()) + if isinstance(expected, tuple): + assert sock.getsockname()[0] in expected + else: + assert sock.getsockname()[0] == expected port = ("127.0.0.1", 0) for url in ["tcp://0.0.0.0", "tcp://127.0.0.1", "tcp://127.0.0.1:38275"]: services = {("dashboard", port): BokehScheduler} - s = Scheduler(services=services) - yield s.start(url) - - sock = first(s.services["dashboard"].server._http._sockets.values()) - assert sock.getsockname()[0] == "127.0.0.1" - yield s.close() + async with Scheduler(services=services, host=url) as s: + sock = first(s.services["dashboard"].server._http._sockets.values()) + assert sock.getsockname()[0] == "127.0.0.1" @gen_cluster(client=True, worker_kwargs={"profile_cycle_interval": 100}) diff --git a/distributed/tests/test_stress.py b/distributed/tests/test_stress.py index 81d7c4360f7..6a5dbe72736 100644 --- a/distributed/tests/test_stress.py +++ b/distributed/tests/test_stress.py @@ -108,11 +108,8 @@ def test_stress_creation_and_deletion(c, s): def create_and_destroy_worker(delay): start = time() while time() < start + 5: - n = Nanny(s.address, nthreads=2, loop=s.loop) - n.start(0) - + n = yield Nanny(s.address, nthreads=2, loop=s.loop) yield gen.sleep(delay) - yield n.close() print("Killed nanny") diff --git a/distributed/tests/test_utils_test.py b/distributed/tests/test_utils_test.py index c0afb9e2c7f..eac2ec71529 100755 --- a/distributed/tests/test_utils_test.py +++ b/distributed/tests/test_utils_test.py @@ -176,10 +176,9 @@ def test_tls_cluster(tls_client): def test_tls_scheduler(security, loop): - s = Scheduler(security=security, loop=loop) - s.start("localhost") + s = yield Scheduler(security=security, loop=loop, host="localhost") assert s.address.startswith("tls") - s.close() + yield s.close() if sys.version_info >= (3, 5): diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 562a0e037b7..88186885db1 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -22,10 +22,18 @@ from tornado import gen from tornado.ioloop import TimeoutError -from distributed import Nanny, get_client, wait, default_client, get_worker, Reschedule +from distributed import ( + Client, + Nanny, + get_client, + wait, + default_client, + get_worker, + Reschedule, + wait, +) from distributed.compatibility import WINDOWS, cache_from_source from distributed.core import rpc -from distributed.client import wait from distributed.scheduler import Scheduler from distributed.metrics import time from distributed.worker import Worker, error_message, logger, parse_memory_limit @@ -978,20 +986,23 @@ def test_service_hosts_match_worker(s): services = {("dashboard", ":0"): BokehWorker} - w = Worker(s.address, services={("dashboard", ":0"): BokehWorker}) - yield w._start("tcp://0.0.0.0") + w = yield Worker( + s.address, services={("dashboard", ":0"): BokehWorker}, host="tcp://0.0.0.0" + ) sock = first(w.services["dashboard"].server._http._sockets.values()) assert sock.getsockname()[0] in ("::", "0.0.0.0") yield w.close() - w = Worker(s.address, services={("dashboard", ":0"): BokehWorker}) - yield w._start("tcp://127.0.0.1") + w = yield Worker( + s.address, services={("dashboard", ":0"): BokehWorker}, host="tcp://127.0.0.1" + ) sock = first(w.services["dashboard"].server._http._sockets.values()) assert sock.getsockname()[0] in ("::", "0.0.0.0") yield w.close() - w = Worker(s.address, services={("dashboard", 0): BokehWorker}) - yield w._start("tcp://127.0.0.1") + w = yield Worker( + s.address, services={("dashboard", 0): BokehWorker}, host="tcp://127.0.0.1" + ) sock = first(w.services["dashboard"].server._http._sockets.values()) assert sock.getsockname()[0] == "127.0.0.1" yield w.close() @@ -1004,8 +1015,7 @@ def test_start_services(s): services = {("dashboard", ":1234"): BokehWorker} - w = Worker(s.address, services=services) - yield w._start() + w = yield Worker(s.address, services=services) assert w.services["dashboard"].server.port == 1234 yield w.close() @@ -1440,3 +1450,44 @@ def test_resource_limit(): assert parse_memory_limit(hard_limit, 1, total_cores=1) == new_limit except OSError: pytest.skip("resource could not set the RSS limit") + + +@pytest.mark.asyncio +@pytest.mark.parametrize("Worker", [Worker, Nanny]) +async def test_interface_async(loop, Worker): + from distributed.utils import get_ip_interface + + psutil = pytest.importorskip("psutil") + if_names = sorted(psutil.net_if_addrs()) + for if_name in if_names: + try: + ipv4_addr = get_ip_interface(if_name) + except ValueError: + pass + else: + if ipv4_addr == "127.0.0.1": + break + else: + pytest.skip( + "Could not find loopback interface. " + "Available interfaces are: %s." % (if_names,) + ) + + async with Scheduler(interface=if_name) as s: + assert s.address.startswith("tcp://127.0.0.1") + async with Worker(s.address, interface=if_name) as w: + assert w.address.startswith("tcp://127.0.0.1") + assert w.ip == "127.0.0.1" + async with Client(s.address, asynchronous=True) as c: + info = c.scheduler_info() + assert "tcp://127.0.0.1" in info["address"] + assert all("127.0.0.1" == d["host"] for d in info["workers"].values()) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("Worker", [Worker, Nanny]) +async def test_worker_listens_on_same_interface_by_default(Worker): + async with Scheduler(host="localhost") as s: + assert s.ip in {"127.0.0.1", "localhost"} + async with Worker(s.address) as w: + assert s.ip == w.ip diff --git a/distributed/tests/test_worker_plugins.py b/distributed/tests/test_worker_plugins.py index 425a267923a..bbba39943fb 100644 --- a/distributed/tests/test_worker_plugins.py +++ b/distributed/tests/test_worker_plugins.py @@ -23,8 +23,7 @@ def teardown(self, worker): def test_create_with_client(c, s): yield c.register_worker_plugin(MyPlugin(123)) - worker = Worker(s.address, loop=s.loop) - yield worker._start() + worker = yield Worker(s.address, loop=s.loop) assert worker._my_plugin_status == "setup" assert worker._my_plugin_data == 123 diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 77568bb7595..af28fe66168 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -486,7 +486,7 @@ def run_worker(q, scheduler_q, **kwargs): with pristine_loop() as loop: scheduler_addr = scheduler_q.get() worker = Worker(scheduler_addr, validate=True, **kwargs) - loop.run_sync(lambda: worker._start()) + loop.run_sync(worker.start) q.put(worker.address) try: @@ -504,7 +504,7 @@ def run_nanny(q, scheduler_q, **kwargs): with pristine_loop() as loop: scheduler_addr = scheduler_q.get() worker = Nanny(scheduler_addr, validate=True, **kwargs) - loop.run_sync(lambda: worker._start()) + loop.run_sync(worker.start) q.put(worker.address) try: loop.start() @@ -794,9 +794,14 @@ def start_cluster( worker_kwargs={}, ): s = Scheduler( - loop=loop, validate=True, security=security, port=0, **scheduler_kwargs + loop=loop, + validate=True, + security=security, + port=0, + host=scheduler_addr, + **scheduler_kwargs ) - done = s.start(scheduler_addr) + done = s.start() workers = [ Worker( s.address, diff --git a/distributed/worker.py b/distributed/worker.py index 60884ebaba3..791fc0ba101 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -32,7 +32,7 @@ from . import profile, comm from .batched import BatchedSend -from .comm import get_address_host, get_local_address_for, connect +from .comm import get_address_host, connect from .comm.utils import offload from .comm.addressing import address_from_user_args from .compatibility import unicode, get_thread_identity, MutableMapping @@ -423,6 +423,11 @@ def __init__( scheduler_addr = coerce_to_address((scheduler_ip, scheduler_port)) self.contact_address = contact_address + # Target interface on which we contact the scheduler by default + # TODO: it is unfortunate that we special-case inproc here + if not host and not interface and not scheduler_addr.startswith("inproc://"): + host = get_ip(get_address_host(scheduler_addr)) + self._start_address = address_from_user_args( host=host, port=port, @@ -892,39 +897,18 @@ def gather(self, comm=None, who_has=None): ############# @gen.coroutine - def _start(self, addr_or_port=0): + def start(self): assert self.status is None - addr_or_port = addr_or_port or self._start_address enable_gc_diagnosis() thread_state.on_event_loop_thread = True - # XXX Factor this out - if not addr_or_port: - # Default address is the required one to reach the scheduler - listen_host = get_address_host(self.scheduler.address) - self.listen( - get_local_address_for(self.scheduler.address), - listen_args=self.listen_args, - ) - self.ip = get_address_host(self.address) - elif isinstance(addr_or_port, int): - # addr_or_port is an integer => assume TCP - listen_host = self.ip = get_ip(get_address_host(self.scheduler.address)) - self.listen((listen_host, addr_or_port), listen_args=self.listen_args) - else: - self.listen(addr_or_port, listen_args=self.listen_args) - self.ip = get_address_host(self.address) - try: - listen_host = get_address_host(addr_or_port) - except ValueError: - listen_host = addr_or_port - - if "://" in listen_host: - protocol, listen_host = listen_host.split("://") + self.listen(self._start_address, listen_args=self.listen_args) + self.ip = get_address_host(self.address) if self.name is None: self.name = self.address + preload_modules( self.preload, parameter=self, @@ -934,21 +918,17 @@ def _start(self, addr_or_port=0): # Services listen on all addresses # Note Nanny is not a "real" service, just some metadata # passed in service_ports... - self.start_services(listen_host) + self.start_services(self.ip) try: - listening_address = "%s%s:%d" % ( - self.listener.prefix, - listen_host, - self.port, - ) + listening_address = "%s%s:%d" % (self.listener.prefix, self.ip, self.port) except Exception: - listening_address = "%s%s" % (self.listener.prefix, listen_host) + listening_address = "%s%s" % (self.listener.prefix, self.ip) logger.info(" Start worker at: %26s", self.address) logger.info(" Listening to: %26s", listening_address) for k, v in self.service_ports.items(): - logger.info(" %16s at: %26s" % (k, listen_host + ":" + str(v))) + logger.info(" %16s at: %26s" % (k, self.ip + ":" + str(v))) logger.info("Waiting to connect to: %26s", self.scheduler.address) logger.info("-" * 49) logger.info(" Threads: %26d", self.nthreads) @@ -964,21 +944,7 @@ def _start(self, addr_or_port=0): yield self._register_with_scheduler() self.start_periodic_callbacks() - raise gen.Return(self) - - def __await__(self): - if self.status is not None: - - @gen.coroutine # idempotent - def _(): - raise gen.Return(self) - - return _().__await__() - else: - return self._start().__await__() - - def start(self, port=0): - self.loop.add_callback(self._start, port) + return self def _close(self, *args, **kwargs): warnings.warn("Worker._close has moved to Worker.close", stacklevel=2) From 2a145ac5ccbe34af44ec4d2814897a87b3fef3e7 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sat, 20 Jul 2019 13:19:40 -0500 Subject: [PATCH 0361/1550] Add cleanup fixture for asyncio tests (#2866) --- distributed/core.py | 4 ++-- distributed/deploy/tests/test_local.py | 5 +++-- distributed/deploy/tests/test_spec_cluster.py | 10 +++++----- distributed/tests/test_scheduler.py | 7 ++++--- distributed/tests/test_worker.py | 13 ++++++------- distributed/utils_test.py | 13 +++++++++++++ 6 files changed, 33 insertions(+), 19 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index 6f08c17ac77..4d18547151a 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -154,11 +154,11 @@ def __init__( if not hasattr(self.io_loop, "profile"): ref = weakref.ref(self.io_loop) - if hasattr(self.io_loop, "closing"): + if hasattr(self.io_loop, "asyncio_loop"): def stop(): loop = ref() - return loop is None or loop.closing + return loop is None or loop.asyncio_loop.is_closed() else: diff --git a/distributed/deploy/tests/test_local.py b/distributed/deploy/tests/test_local.py index ebe2d7ec99e..d489b84df0f 100644 --- a/distributed/deploy/tests/test_local.py +++ b/distributed/deploy/tests/test_local.py @@ -17,8 +17,9 @@ from distributed import Client, Worker, Nanny, get_client from distributed.deploy.local import LocalCluster, nprocesses_nthreads from distributed.metrics import time -from distributed.utils_test import ( +from distributed.utils_test import ( # noqa: F401 clean, + cleanup, inc, gen_test, slowinc, @@ -801,7 +802,7 @@ class MyNanny(Nanny): @pytest.mark.asyncio -async def test_worker_class_nanny_async(): +async def test_worker_class_nanny_async(cleanup): class MyNanny(Nanny): pass diff --git a/distributed/deploy/tests/test_spec_cluster.py b/distributed/deploy/tests/test_spec_cluster.py index 723c62a80c1..bb992f8b7c7 100644 --- a/distributed/deploy/tests/test_spec_cluster.py +++ b/distributed/deploy/tests/test_spec_cluster.py @@ -1,6 +1,6 @@ from dask.distributed import SpecCluster, Worker, Client, Scheduler, Nanny from distributed.deploy.spec import close_clusters -from distributed.utils_test import loop # noqa: F401 +from distributed.utils_test import loop, cleanup # noqa: F401 import pytest @@ -25,7 +25,7 @@ async def _(): @pytest.mark.asyncio -async def test_specification(): +async def test_specification(cleanup): async with SpecCluster( workers=worker_spec, scheduler=scheduler, asynchronous=True ) as cluster: @@ -82,7 +82,7 @@ def test_loop_started(): @pytest.mark.asyncio -async def test_repr(): +async def test_repr(cleanup): worker = {"cls": Worker, "options": {"nthreads": 1}} class MyCluster(SpecCluster): @@ -95,7 +95,7 @@ class MyCluster(SpecCluster): @pytest.mark.asyncio -async def test_scale(): +async def test_scale(cleanup): worker = {"cls": Worker, "options": {"nthreads": 1}} async with SpecCluster( asynchronous=True, scheduler=scheduler, worker=worker @@ -143,7 +143,7 @@ def test_spec_close_clusters(loop): @pytest.mark.asyncio -async def test_new_worker_spec(): +async def test_new_worker_spec(cleanup): class MyCluster(SpecCluster): def new_worker_spec(self): i = len(self.worker_spec) diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 3d29bc79a1e..9b512ddad64 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -24,7 +24,8 @@ from distributed.protocol.pickle import dumps from distributed.worker import dumps_function, dumps_task from distributed.utils import tmpfile -from distributed.utils_test import ( +from distributed.utils_test import ( # noqa: F401 + cleanup, inc, dec, gen_cluster, @@ -1590,7 +1591,7 @@ async def test_adaptive_target(c, s, a, b): @pytest.mark.asyncio -async def test_async_context_manager(): +async def test_async_context_manager(cleanup): async with Scheduler(port=0) as s: assert s.status == "running" async with Worker(s.address) as w: @@ -1600,7 +1601,7 @@ async def test_async_context_manager(): @pytest.mark.asyncio -async def test_allowed_failures_config(): +async def test_allowed_failures_config(cleanup): async with Scheduler(port=0, allowed_failures=10) as s: assert s.allowed_failures == 10 diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 88186885db1..2e6af6e0bdc 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -38,7 +38,8 @@ from distributed.metrics import time from distributed.worker import Worker, error_message, logger, parse_memory_limit from distributed.utils import tmpfile -from distributed.utils_test import ( +from distributed.utils_test import ( # noqa: F401 + cleanup, inc, mul, gen_cluster, @@ -370,12 +371,10 @@ def test_gather(s, a, b): @pytest.mark.asyncio -async def test_io_loop(): - s = await Scheduler(port=0) - w = await Worker(s.address, loop=s.loop) - assert w.io_loop is s.loop - await s.close() - await w.close() +async def test_io_loop(cleanup): + async with Scheduler(port=0) as s: + async with Worker(s.address, loop=s.loop) as w: + assert w.io_loop is s.loop @gen_cluster(client=True, nthreads=[]) diff --git a/distributed/utils_test.py b/distributed/utils_test.py index af28fe66168..f0e75f8ed2c 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -1559,3 +1559,16 @@ def null(): with ignoring(AttributeError): del thread_state.on_event_loop_thread + + +@pytest.fixture +def cleanup(): + with check_thread_leak(): + with check_process_leak(): + with check_instances(): + reset_config() + dask.config.set({"distributed.comm.timeouts.connect": "5s"}) + for name, level in logging_levels.items(): + logging.getLogger(name).setLevel(level) + + yield From 967d97128ed6fa169231a83e6a7b1d22a92c9111 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sat, 20 Jul 2019 13:20:22 -0500 Subject: [PATCH 0362/1550] Use only remote connection to scheduler in Adaptive (#2865) This modifies the Adaptive class to only touch the scheduler through communication, rather than direct access. This should enable adaptive scheduling when the scheduler is deployed in a remote process. Fixes #2858 * Pickle worker_key function in adaptive --- distributed/deploy/adaptive.py | 98 +++++++++------------- distributed/deploy/cluster.py | 4 +- distributed/deploy/tests/test_adaptive.py | 99 ++++++++--------------- distributed/distributed.yaml | 1 + distributed/scheduler.py | 19 ++++- 5 files changed, 94 insertions(+), 127 deletions(-) diff --git a/distributed/deploy/adaptive.py b/distributed/deploy/adaptive.py index 401acc3dc1d..761a7d300ee 100644 --- a/distributed/deploy/adaptive.py +++ b/distributed/deploy/adaptive.py @@ -2,12 +2,12 @@ from collections import deque import logging -import math from tornado import gen from ..metrics import time from ..utils import log_errors, PeriodicCallback, parse_timedelta +from ..protocol import pickle logger = logging.getLogger(__name__) @@ -80,14 +80,10 @@ class Adaptive(object): resized. The default implementation checks if there are too many tasks per worker or too little memory available (see :meth:`Adaptive.needs_cpu` and :meth:`Adaptive.needs_memory`). - - :meth:`Adaptive.get_scale_up_kwargs` method controls the arguments passed to - the cluster's ``scale_up`` method. ''' def __init__( self, - scheduler, cluster=None, interval="1s", startup_cost="1s", @@ -96,20 +92,19 @@ def __init__( maximum=None, wait_count=3, target_duration="5s", - worker_key=lambda x: x, + worker_key=None, **kwargs ): interval = parse_timedelta(interval, default="ms") self.worker_key = worker_key - self.scheduler = scheduler self.cluster = cluster self.startup_cost = parse_timedelta(startup_cost, default="s") self.scale_factor = scale_factor if self.cluster: self._adapt_callback = PeriodicCallback( - self._adapt, interval * 1000, io_loop=scheduler.loop + self._adapt, interval * 1000, io_loop=self.loop ) - self.scheduler.loop.add_callback(self._adapt_callback.start) + self.loop.add_callback(self._adapt_callback.start) self._adapting = False self._workers_to_close_kwargs = kwargs self.minimum = minimum @@ -119,7 +114,9 @@ def __init__( self.wait_count = wait_count self.target_duration = parse_timedelta(target_duration) - self.scheduler.handlers["adaptive_recommendations"] = self.recommendations + @property + def scheduler(self): + return self.cluster.scheduler_comm def stop(self): if self.cluster: @@ -127,7 +124,7 @@ def stop(self): self._adapt_callback = None del self._adapt_callback - def workers_to_close(self, **kwargs): + async def workers_to_close(self, **kwargs): """ Determine which, if any, workers should potentially be removed from the cluster. @@ -145,73 +142,53 @@ def workers_to_close(self, **kwargs): -------- Scheduler.workers_to_close """ - if len(self.scheduler.workers) <= self.minimum: + if len(self.cluster.workers) <= self.minimum: return [] kw = dict(self._workers_to_close_kwargs) kw.update(kwargs) - if self.maximum is not None and len(self.scheduler.workers) > self.maximum: - kw["n"] = len(self.scheduler.workers) - self.maximum + if self.maximum is not None and len(self.cluster.workers) > self.maximum: + kw["n"] = len(self.cluster.workers) - self.maximum - L = self.scheduler.workers_to_close(**kw) - if len(self.scheduler.workers) - len(L) < self.minimum: - L = L[: len(self.scheduler.workers) - self.minimum] + L = await self.scheduler.workers_to_close(**kw) + if len(self.cluster.workers) - len(L) < self.minimum: + L = L[: len(self.cluster.workers) - self.minimum] return L - @gen.coroutine - def _retire_workers(self, workers=None): + async def _retire_workers(self, workers=None): if workers is None: - workers = self.workers_to_close(key=self.worker_key, minimum=self.minimum) + workers = await self.workers_to_close( + key=pickle.dumps(self.worker_key) if self.worker_key else None, + minimum=self.minimum, + ) if not workers: raise gen.Return(workers) with log_errors(): - yield self.scheduler.retire_workers( + await self.scheduler.retire_workers( workers=workers, remove=True, close_workers=True ) logger.info("Retiring workers %s", workers) f = self.cluster.scale_down(workers) if hasattr(f, "__await__"): - yield f - - raise gen.Return(workers) - - def get_scale_up_kwargs(self): - """ - Get the arguments to be passed to ``self.cluster.scale_up``. - - Notes - ----- - By default the desired number of total workers is returned (``n``). - Subclasses should ensure that the return dictionary includes a key- - value pair for ``n``, either by implementing it or by calling the - parent's ``get_scale_up_kwargs``. + await f - See Also - -------- - LocalCluster.scale_up - """ - target = math.ceil(self.scheduler.total_occupancy / self.target_duration) - instances = max( - 1, len(self.scheduler.workers) * self.scale_factor, target, self.minimum - ) - - if self.maximum: - instances = min(self.maximum, instances) + return workers - instances = int(instances) - logger.info("Scaling up to %d workers", instances) - return {"n": instances} - - def recommendations(self, comm=None): - n = self.scheduler.adaptive_target(target_duration=self.target_duration) + async def recommendations(self, comm=None): + n = await self.scheduler.adaptive_target(target_duration=self.target_duration) if self.maximum is not None: n = min(self.maximum, n) if self.minimum is not None: n = max(self.minimum, n) - workers = set(self.workers_to_close(key=self.worker_key, minimum=self.minimum)) + workers = set( + await self.workers_to_close( + key=pickle.dumps(self.worker_key) if self.worker_key else None, + minimum=self.minimum, + ) + ) try: current = len(self.cluster.worker_spec) except AttributeError: @@ -249,14 +226,13 @@ def recommendations(self, comm=None): self.close_counts.clear() return None - @gen.coroutine - def _adapt(self): + async def _adapt(self): if self._adapting: # Semaphore to avoid overlapping adapt calls return self._adapting = True try: - recommendations = self.recommendations() + recommendations = await self.recommendations() if not recommendations: return status = recommendations.pop("status") @@ -264,13 +240,17 @@ def _adapt(self): f = self.cluster.scale(**recommendations) self.log.append((time(), "up", recommendations)) if hasattr(f, "__await__"): - yield f + await f elif status == "down": self.log.append((time(), "down", recommendations["workers"])) - workers = yield self._retire_workers(workers=recommendations["workers"]) + workers = await self._retire_workers(workers=recommendations["workers"]) finally: self._adapting = False def adapt(self): - self.scheduler.loop.add_callback(self._adapt) + self.loop.add_callback(self._adapt) + + @property + def loop(self): + return self.cluster.loop diff --git a/distributed/deploy/cluster.py b/distributed/deploy/cluster.py index d48f27603ff..f5d991cd737 100644 --- a/distributed/deploy/cluster.py +++ b/distributed/deploy/cluster.py @@ -62,7 +62,7 @@ def scale_down(self, workers: List[str]): LocalCluster: a simple implementation with local workers """ - def adapt(self, **kwargs): + def adapt(self, Adaptive=Adaptive, **kwargs): """ Turn on adaptivity For keyword arguments see dask.distributed.Adaptive @@ -76,7 +76,7 @@ def adapt(self, **kwargs): if not hasattr(self, "_adaptive_options"): self._adaptive_options = {} self._adaptive_options.update(kwargs) - self._adaptive = Adaptive(self.scheduler, self, **self._adaptive_options) + self._adaptive = Adaptive(self, **self._adaptive_options) return self._adaptive @property diff --git a/distributed/deploy/tests/test_adaptive.py b/distributed/deploy/tests/test_adaptive.py index 146d7b95dbb..2258a83cfe5 100644 --- a/distributed/deploy/tests/test_adaptive.py +++ b/distributed/deploy/tests/test_adaptive.py @@ -2,33 +2,19 @@ from time import sleep +import pytest from toolz import frequencies, pluck from tornado import gen from tornado.ioloop import IOLoop from distributed import Client, wait, Adaptive, LocalCluster, SpecCluster, Worker -from distributed.utils_test import gen_cluster, gen_test, slowinc, clean +from distributed.utils_test import gen_test, slowinc, clean from distributed.utils_test import loop, nodebug # noqa: F401 from distributed.metrics import time -def test_get_scale_up_kwargs(loop): - with LocalCluster( - 0, scheduler_port=0, silence_logs=False, dashboard_address=None, loop=loop - ) as cluster: - - alc = Adaptive(cluster.scheduler, cluster, interval=100, scale_factor=3) - assert alc.get_scale_up_kwargs() == {"n": 1} - - with Client(cluster, loop=loop) as c: - future = c.submit(lambda x: x + 1, 1) - assert future.result() == 2 - assert c.nthreads() - assert alc.get_scale_up_kwargs() == {"n": 3} - - -@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 4) -def test_simultaneous_scale_up_and_down(c, s, *workers): +@pytest.mark.asyncio +async def test_simultaneous_scale_up_and_down(): class TestAdaptive(Adaptive): def get_scale_up_kwargs(self): assert False @@ -36,34 +22,35 @@ def get_scale_up_kwargs(self): def _retire_workers(self): assert False - class TestCluster(object): + class TestCluster(LocalCluster): def scale_up(self, n, **kwargs): assert False def scale_down(self, workers): assert False - cluster = TestCluster() + async with TestCluster(n_workers=4, processes=False, asynchronous=True) as cluster: + async with Client(cluster, asynchronous=True) as c: + s = cluster.scheduler + s.task_duration["a"] = 4 + s.task_duration["b"] = 4 + s.task_duration["c"] = 1 - s.task_duration["a"] = 4 - s.task_duration["b"] = 4 - s.task_duration["c"] = 1 + future = c.map(slowinc, [1, 1, 1], key=["a-4", "b-4", "c-1"]) - future = c.map(slowinc, [1, 1, 1], key=["a-4", "b-4", "c-1"]) + while len(s.rprocessing) < 3: + await gen.sleep(0.001) - while len(s.rprocessing) < 3: - yield gen.sleep(0.001) + ta = cluster.adapt(interval="100 ms", scale_factor=2, Adaptive=TestAdaptive) - ta = TestAdaptive(s, cluster, interval=100, scale_factor=2) - - yield gen.sleep(0.3) + await gen.sleep(0.3) def test_adaptive_local_cluster(loop): with LocalCluster( 0, scheduler_port=0, silence_logs=False, dashboard_address=None, loop=loop ) as cluster: - alc = Adaptive(cluster.scheduler, cluster, interval=100) + alc = cluster.adapt(interval="100 ms") with Client(cluster, loop=loop) as c: assert not c.nthreads() future = c.submit(lambda x: x + 1, 1) @@ -128,41 +115,34 @@ def test_adaptive_local_cluster_multi_workers(): yield cluster.close() -@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 10, active_rpc_timeout=10) -def test_adaptive_scale_down_override(c, s, *workers): +@pytest.mark.asyncio +async def test_adaptive_scale_down_override(): class TestAdaptive(Adaptive): def __init__(self, *args, **kwargs): self.min_size = kwargs.pop("min_size", 0) Adaptive.__init__(self, *args, **kwargs) - def workers_to_close(self, **kwargs): - num_workers = len(self.scheduler.workers) - to_close = self.scheduler.workers_to_close(**kwargs) + async def workers_to_close(self, **kwargs): + num_workers = len(self.cluster.workers) + to_close = await self.scheduler.workers_to_close(**kwargs) if num_workers - len(to_close) < self.min_size: to_close = to_close[: num_workers - self.min_size] return to_close - class TestCluster(object): + class TestCluster(LocalCluster): def scale_up(self, n, **kwargs): assert False - def scale_down(self, workers): - assert False - - @property - def workers(self): - return s.workers - - assert len(s.workers) == 10 - - # Assert that adaptive cycle does not reduce cluster below minimum size - # as determined via override. - cluster = TestCluster() - ta = TestAdaptive(s, cluster, min_size=2, interval=0.1, scale_factor=2) - yield gen.sleep(0.3) + async with TestCluster(n_workers=10, processes=False, asynchronous=True) as cluster: + ta = cluster.adapt( + min_size=2, interval=0.1, scale_factor=2, Adaptive=TestAdaptive + ) + await gen.sleep(0.3) - assert len(s.workers) == 2 + # Assert that adaptive cycle does not reduce cluster below minimum size + # as determined via override. + assert len(cluster.scheduler.workers) == 2 @gen_test() @@ -176,14 +156,7 @@ def test_min_max(): asynchronous=True, ) try: - adapt = Adaptive( - cluster.scheduler, - cluster, - minimum=1, - maximum=2, - interval="20 ms", - wait_count=10, - ) + adapt = cluster.adapt(minimum=1, maximum=2, interval="20 ms", wait_count=10) c = yield Client(cluster, asynchronous=True) start = time() @@ -237,7 +210,7 @@ def test_avoid_churn(): ) client = yield Client(cluster, asynchronous=True) try: - adapt = Adaptive(cluster.scheduler, cluster, interval="20 ms", wait_count=5) + adapt = cluster.adapt(interval="20 ms", wait_count=5) for i in range(10): yield client.submit(slowinc, i, delay=0.040) @@ -267,7 +240,7 @@ def test_adapt_quickly(): dashboard_address=None, ) client = yield Client(cluster, asynchronous=True) - adapt = Adaptive(cluster.scheduler, cluster, interval=20, wait_count=5, maximum=10) + adapt = cluster.adapt(interval=20, wait_count=5, maximum=10) try: future = client.submit(slowinc, 1, delay=0.100) yield wait(future) @@ -346,9 +319,7 @@ def test_no_more_workers_than_tasks(): ) yield cluster._start() try: - adapt = Adaptive( - cluster.scheduler, cluster, minimum=0, maximum=4, interval="10 ms" - ) + adapt = cluster.adapt(minimum=0, maximum=4, interval="10 ms") client = yield Client(cluster, asynchronous=True, loop=loop) cluster.scheduler.task_duration["slowinc"] = 1000 diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index e5bd3dd3140..c6c3e3d1ba2 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -20,6 +20,7 @@ distributed: transition-log-length: 100000 work-stealing: True # workers should steal tasks from each other worker-ttl: null # like '60s'. Time to live for workers. They must heartbeat faster than this + pickle: True # Is the scheduler allowed to deserialize arbitrary bytestrings preload: [] preload-argv: [] dashboard: diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 63ae1ef33b2..6e58fb36ffc 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1069,6 +1069,7 @@ def __init__( "get_task_stream": self.get_task_stream, "register_worker_plugin": self.register_worker_plugin, "adaptive_target": self.adaptive_target, + "workers_to_close": self.workers_to_close, } self._transitions = { @@ -2919,7 +2920,9 @@ def replicate( }, ) - def workers_to_close(self, memory_ratio=None, n=None, key=None, minimum=None): + def workers_to_close( + self, comm=None, memory_ratio=None, n=None, key=None, minimum=None + ): """ Find workers that we can close with low cost @@ -2981,6 +2984,10 @@ def workers_to_close(self, memory_ratio=None, n=None, key=None, minimum=None): if key is None: key = lambda ws: ws.address + if isinstance(key, bytes) and dask.config.get( + "distributed.scheduler.pickle" + ): + key = pickle.loads(key) groups = groupby(key, self.workers.values()) @@ -3209,6 +3216,14 @@ def feed( Caution: this runs arbitrary Python code on the scheduler. This should eventually be phased out. It is mostly used by diagnostics. """ + if not dask.config.get("distributed.scheduler.pickle"): + logger.warn( + "Tried to call 'feed' route with custom fucntions, but " + "pickle is disallowed. Set the 'distributed.scheduler.pickle'" + "config value to True to use the 'feed' route (this is mostly " + "commonly used with progress bars)" + ) + return import pickle interval = parse_timedelta(interval) @@ -4714,7 +4729,7 @@ def check_idle(self): if close: self.loop.add_callback(self.close) - def adaptive_target(self, target_duration="5s"): + def adaptive_target(self, comm=None, target_duration="5s"): """ Desired number of workers based on the current workload This looks at the current running tasks and memory use, and returns a From d28d885437fdea1182277d974f61f4843e6612b7 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sat, 20 Jul 2019 17:49:14 -0500 Subject: [PATCH 0363/1550] Add Server.finished async function (#2864) --- distributed/cli/dask_worker.py | 3 +-- distributed/core.py | 9 ++++++++- distributed/deploy/tests/test_adaptive.py | 6 +++--- distributed/nanny.py | 9 +++++---- distributed/scheduler.py | 20 +------------------- distributed/tests/test_scheduler.py | 11 +++++++++++ distributed/utils_test.py | 7 ++++--- distributed/worker.py | 1 + 8 files changed, 34 insertions(+), 32 deletions(-) diff --git a/distributed/cli/dask_worker.py b/distributed/cli/dask_worker.py index e86cfa41618..8752cd52448 100755 --- a/distributed/cli/dask_worker.py +++ b/distributed/cli/dask_worker.py @@ -380,8 +380,7 @@ def on_signal(signum): @gen.coroutine def run(): yield nannies - while all(n.status != "closed" for n in nannies): - yield gen.sleep(0.2) + yield [n.finished() for n in nannies] install_signal_handlers(loop, cleanup=on_signal) diff --git a/distributed/core.py b/distributed/core.py index 4d18547151a..8aac2edfc33 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -73,7 +73,7 @@ def _raise(*args, **kwargs): class Server(object): - """ Distributed TCP Server + """ Dask Distributed Server Superclass for endpoints in a distributed cluster, such as Worker and Scheduler objects. @@ -146,6 +146,7 @@ def __init__( self.events = None self.event_counts = None self._ongoing_coroutines = weakref.WeakSet() + self._event_finished = Event() self.listener = None self.io_loop = io_loop or IOLoop.current() @@ -211,6 +212,10 @@ def set_thread_ident(): self.__stopped = False + async def finished(self): + """ Wait until the server has finished """ + await self._event_finished.wait() + def start_periodic_callbacks(self): """ Start Periodic Callbacks consistently @@ -507,6 +512,8 @@ def close(self): else: yield gen.sleep(0.01) + self._event_finished.set() + def pingpong(comm): return b"pong" diff --git a/distributed/deploy/tests/test_adaptive.py b/distributed/deploy/tests/test_adaptive.py index 2258a83cfe5..861c5107348 100644 --- a/distributed/deploy/tests/test_adaptive.py +++ b/distributed/deploy/tests/test_adaptive.py @@ -9,12 +9,12 @@ from distributed import Client, wait, Adaptive, LocalCluster, SpecCluster, Worker from distributed.utils_test import gen_test, slowinc, clean -from distributed.utils_test import loop, nodebug # noqa: F401 +from distributed.utils_test import loop, nodebug, cleanup # noqa: F401 from distributed.metrics import time @pytest.mark.asyncio -async def test_simultaneous_scale_up_and_down(): +async def test_simultaneous_scale_up_and_down(cleanup): class TestAdaptive(Adaptive): def get_scale_up_kwargs(self): assert False @@ -116,7 +116,7 @@ def test_adaptive_local_cluster_multi_workers(): @pytest.mark.asyncio -async def test_adaptive_scale_down_override(): +async def test_adaptive_scale_down_override(cleanup): class TestAdaptive(Adaptive): def __init__(self, *args, **kwargs): self.min_size = kwargs.pop("min_size", 0) diff --git a/distributed/nanny.py b/distributed/nanny.py index 00370b81adb..f3bebb1dcac 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -355,7 +355,7 @@ def memory_monitor(self): process.terminate() def is_alive(self): - return self.process is not None and self.process.status == "running" + return self.process is not None and self.process.is_alive() def run(self, *args, **kwargs): return run(self, *args, **kwargs) @@ -401,11 +401,12 @@ def close(self, comm=None, timeout=5, report=None): """ Close the worker process, stop all comms. """ - while self.status == "closing": - yield gen.sleep(0.01) + if self.status == "closing": + yield self.finished() + assert self.status == "closed" if self.status == "closed": - raise gen.Return("OK") + return "OK" self.status = "closing" logger.info("Closing Nanny at %r", self.address) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 6e58fb36ffc..4441b815642 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -818,8 +818,6 @@ class Scheduler(ServerNode): report results * **task_duration:** ``{key-prefix: time}`` Time we expect certain functions to take, e.g. ``{'sum': 0.25}`` - * **coroutines:** ``[Futures]``: - A list of active futures that control operation """ default_port = 8786 @@ -896,7 +894,6 @@ def __init__( self.loop = loop or IOLoop.current() self.client_comms = dict() self.stream_comms = dict() - self.coroutines = [] self._worker_coroutines = [] self._ipython_kernel = None @@ -1189,12 +1186,6 @@ def start(self): for c in self._worker_coroutines: c.cancel() - for cor in self.coroutines: - if cor.done(): - exc = cor.exception() - if exc: - raise exc - if self.status != "running": self.listen(self._start_address, listen_args=self.listen_args) self.ip = get_address_host(self.listen_address) @@ -1232,15 +1223,8 @@ def del_scheduler_file(): setproctitle("dask-scheduler [%s]" % (self.address,)) - yield self.finished() return self - @gen.coroutine - def finished(self): - """ Wait until all coroutines have ceased """ - while any(not c.done() for c in self.coroutines): - yield All(self.coroutines) - @gen.coroutine def close(self, comm=None, fast=False, close_workers=False): """ Send cleanup signal to all coroutines then wait until finished @@ -1250,6 +1234,7 @@ def close(self, comm=None, fast=False, close_workers=False): Scheduler.cleanup """ if self.status.startswith("clos"): + yield self.finished() return self.status = "closing" @@ -1287,9 +1272,6 @@ def close(self, comm=None, fast=False, close_workers=False): for future in futures: yield future - if not fast: - yield self.finished() - for comm in self.client_comms.values(): comm.abort() diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 9b512ddad64..20de5e7b7fd 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -1082,6 +1082,7 @@ def test_close_nanny(c, s, a, b): yield gen.sleep(0.1) assert time() < start + 5 + assert not a.is_alive() assert a.pid is None for i in range(10): @@ -1612,3 +1613,13 @@ async def test_allowed_failures_config(cleanup): with dask.config.set({"distributed.scheduler.allowed_failures": 0}): async with Scheduler(port=0) as s: assert s.allowed_failures == 0 + + +@pytest.mark.asyncio +async def test_finished(): + async with Scheduler(port=0) as s: + async with Worker(s.address) as w: + pass + + await s.finished() + await w.finished() diff --git a/distributed/utils_test.py b/distributed/utils_test.py index f0e75f8ed2c..5f3dff548cf 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -1508,9 +1508,10 @@ def check_instances(): _global_clients.clear() for w in Worker._instances: - w.close(report=False, executor_wait=False) - if w.status == "running": - w.close() + with ignoring(RuntimeError): # closed IOLoop + w.close(report=False, executor_wait=False) + if w.status == "running": + w.close() Worker._instances.clear() for i in range(5): diff --git a/distributed/worker.py b/distributed/worker.py index 791fc0ba101..b052dd05799 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -954,6 +954,7 @@ def _close(self, *args, **kwargs): def close(self, report=True, timeout=10, nanny=True, executor_wait=True): with log_errors(): if self.status in ("closed", "closing"): + yield self.finished() return disable_gc_diagnosis() From 5f8a4f9a0f549e01354f96e373fbe3abda23e25f Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sun, 21 Jul 2019 09:46:03 -0500 Subject: [PATCH 0364/1550] Align text and remove bullets in Client HTML repr (#2867) --- distributed/client.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 74e33716cb6..1dc5af16b81 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -804,14 +804,15 @@ def _repr_html_(self): info = self._scheduler_identity scheduler = self.scheduler + text = ( + '

        Client

        \n' + '
          \n' + ) if scheduler is not None: - text = ( - "

          Client

          \n" "
            \n" "
          • Scheduler: %s\n" - ) % scheduler.address + text += "
          • Scheduler: %s\n" % scheduler.address else: - text = ( - "

            Client

            \n" "
              \n" "
            • Scheduler: not connected\n" - ) + text += "
            • Scheduler: not connected\n" + if info and "dashboard" in info["services"]: protocol, rest = scheduler.address.split("://") port = info["services"]["dashboard"] @@ -837,8 +838,8 @@ def _repr_html_(self): memory = "" text2 = ( - "

              Cluster

              \n" - "
                \n" + '

                Cluster

                \n' + '
                  \n' "
                • Workers: %d
                • \n" "
                • Cores: %d
                • \n" "
                • Memory: %s
                • \n" From 586ead997dc72c1b7170dbf30af253257f53b1fc Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sun, 21 Jul 2019 13:06:59 -0500 Subject: [PATCH 0365/1550] Test dask-scheduler --idle-timeout flag (#2862) Fixes #2846 --- distributed/cli/dask_scheduler.py | 15 +++++++++++++-- distributed/cli/tests/test_dask_scheduler.py | 11 +++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/distributed/cli/dask_scheduler.py b/distributed/cli/dask_scheduler.py index f2799164a36..7142331e861 100755 --- a/distributed/cli/dask_scheduler.py +++ b/distributed/cli/dask_scheduler.py @@ -120,6 +120,12 @@ @click.argument( "preload_argv", nargs=-1, type=click.UNPROCESSED, callback=validate_preload_argv ) +@click.option( + "--idle-timeout", + default=None, + type=str, + help="Time of inactivity after which to kill the scheduler", +) @click.version_option() def main( host, @@ -141,6 +147,7 @@ def main( tls_cert, tls_key, dashboard_address, + idle_timeout, ): g0, g1, g2 = gc.get_threshold() # https://github.com/dask/distributed/issues/1653 gc.set_threshold(g0 * 3, g1 * 3, g2 * 3) @@ -211,6 +218,7 @@ def del_pid_file(): protocol=protocol, dashboard_address=dashboard_address if dashboard else None, service_kwargs={"dashboard": {"prefix": dashboard_prefix}}, + idle_timeout=idle_timeout, ) scheduler.start() if not preload: @@ -226,9 +234,12 @@ def del_pid_file(): install_signal_handlers(loop) + async def run(): + await scheduler + await scheduler.finished() + try: - loop.start() - loop.close() + loop.run_sync(run) finally: scheduler.stop() if local_directory_created: diff --git a/distributed/cli/tests/test_dask_scheduler.py b/distributed/cli/tests/test_dask_scheduler.py index 7de7e881270..24737474165 100644 --- a/distributed/cli/tests/test_dask_scheduler.py +++ b/distributed/cli/tests/test_dask_scheduler.py @@ -383,3 +383,14 @@ def test_version_option(): runner = CliRunner() result = runner.invoke(distributed.cli.dask_scheduler.main, ["--version"]) assert result.exit_code == 0 + + +@pytest.mark.slow +def test_idle_timeout(loop): + start = time() + runner = CliRunner() + result = runner.invoke( + distributed.cli.dask_scheduler.main, ["--idle-timeout", "1s"] + ) + stop = time() + assert 1 < stop - start < 10 From b13403727fc2dc1b1173aa0c347af2881418e36b Mon Sep 17 00:00:00 2001 From: Jim Crist Date: Tue, 23 Jul 2019 18:06:42 -0500 Subject: [PATCH 0366/1550] Remove `Client.upload_environment` (#2877) This method was undocumented and broken. Modifying the user environment is better handled by the deployment solution, the current method is not resilient to worker additions/removals, and uploading and replacing the current environment is an inefficient method to add additional packages. If users really want to upload packages to each worker at runtime, they can use `upload_file` with a zipfile/egg. --- distributed/client.py | 33 --------------------------------- 1 file changed, 33 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 1dc5af16b81..d8fb5dfa13d 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -9,7 +9,6 @@ from datetime import timedelta import errno from functools import partial -from glob import glob import itertools import json import logging @@ -2823,38 +2822,6 @@ def persist( else: return result - @gen.coroutine - def _upload_environment(self, zipfile): - name = os.path.split(zipfile)[1] - yield self._upload_large_file(zipfile, name) - - def unzip(dask_worker=None): - from distributed.utils import log_errors - import zipfile - import shutil - - with log_errors(): - a = os.path.join(dask_worker.worker_dir, name) - b = os.path.join(dask_worker.local_dir, name) - c = os.path.dirname(b) - shutil.move(a, b) - - with zipfile.ZipFile(b) as f: - f.extractall(path=c) - - for fn in glob(os.path.join(c, name[:-4], "bin", "*")): - st = os.stat(fn) - os.chmod(fn, st.st_mode | 64) # chmod u+x fn - - assert os.path.exists(os.path.join(c, name[:-4])) - return c - - yield self._run(unzip, nanny=True) - raise gen.Return(name[:-4]) - - def upload_environment(self, name, zipfile): - return self.sync(self._upload_environment, name, zipfile) - @gen.coroutine def _restart(self, timeout=no_default): if timeout == no_default: From f16ee17bc9481f2d9e19479f390d4248fc7c83a3 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 24 Jul 2019 09:02:42 -0700 Subject: [PATCH 0367/1550] Replace gen.coroutine with async/await in core (#2871) --- distributed/actor.py | 30 +- distributed/cli/dask_scheduler.py | 14 +- distributed/client.py | 350 ++++++++---------- distributed/comm/core.py | 9 +- distributed/comm/inproc.py | 31 +- distributed/comm/tcp.py | 53 ++- distributed/comm/tests/test_comms.py | 12 +- distributed/core.py | 93 ++--- distributed/dashboard/components.py | 6 +- distributed/dashboard/scheduler_html.py | 16 +- .../dashboard/tests/test_components.py | 4 +- .../tests/test_scheduler_bokeh_html.py | 4 +- .../dashboard/tests/test_worker_bokeh.py | 2 +- distributed/deploy/spec.py | 2 +- distributed/deploy/tests/test_adaptive.py | 12 +- distributed/deploy/tests/test_local.py | 5 +- distributed/diagnostics/eventstream.py | 15 +- distributed/diagnostics/progress.py | 10 +- distributed/diagnostics/progress_stream.py | 14 +- distributed/diagnostics/progressbar.py | 37 +- .../diagnostics/tests/test_progressbar.py | 31 +- distributed/lock.py | 19 +- distributed/nanny.py | 126 +++---- distributed/process.py | 2 +- distributed/pubsub.py | 12 +- distributed/queues.py | 70 ++-- distributed/recreate_exceptions.py | 21 +- distributed/scheduler.py | 167 ++++----- distributed/tests/py3_test_pubsub.py | 3 +- distributed/tests/test_as_completed.py | 2 +- distributed/tests/test_batched.py | 6 +- distributed/tests/test_client.py | 26 +- distributed/tests/test_locks.py | 2 +- distributed/tests/test_nanny.py | 28 +- distributed/tests/test_priorities.py | 52 +-- distributed/tests/test_queues.py | 9 +- distributed/tests/test_resources.py | 4 +- distributed/tests/test_scheduler.py | 32 +- distributed/tests/test_steal.py | 4 +- distributed/tests/test_stress.py | 29 +- distributed/tests/test_tls_functional.py | 4 +- distributed/tests/test_utils_test.py | 9 +- distributed/tests/test_worker.py | 12 +- distributed/tests/test_worker_plugins.py | 2 +- distributed/utils.py | 25 +- distributed/utils_comm.py | 34 +- distributed/utils_test.py | 188 +++++----- distributed/variable.py | 34 +- distributed/worker.py | 203 +++++----- docs/source/adaptive.rst | 6 +- docs/source/client.rst | 16 +- docs/source/foundations.rst | 59 ++- 52 files changed, 921 insertions(+), 1035 deletions(-) diff --git a/distributed/actor.py b/distributed/actor.py index 4bbe6faf78d..e45f089effd 100644 --- a/distributed/actor.py +++ b/distributed/actor.py @@ -1,3 +1,4 @@ +import asyncio from tornado import gen import functools @@ -128,10 +129,9 @@ def __getattr__(self, key): @functools.wraps(attr) def func(*args, **kwargs): - @gen.coroutine - def run_actor_function_on_worker(): + async def run_actor_function_on_worker(): try: - result = yield self._worker_rpc.actor_execute( + result = await self._worker_rpc.actor_execute( function=key, actor=self.key, args=[to_serialize(arg) for arg in args], @@ -139,21 +139,20 @@ def run_actor_function_on_worker(): ) except OSError: if self._future: - yield self._future + await self._future else: raise OSError("Unable to contact Actor's worker") - raise gen.Return(result["result"]) + return result["result"] if self._asynchronous: - return run_actor_function_on_worker() + return asyncio.ensure_future(run_actor_function_on_worker()) else: # TODO: this mechanism is error prone # we should endeavor to make dask's standard code work here q = Queue() - @gen.coroutine - def wait_then_add_to_queue(): - x = yield run_actor_function_on_worker() + async def wait_then_add_to_queue(): + x = await run_actor_function_on_worker() q.put(x) self._io_loop.add_callback(wait_then_add_to_queue) @@ -164,11 +163,11 @@ def wait_then_add_to_queue(): else: - @gen.coroutine - def get_actor_attribute_from_worker(): - x = yield self._worker_rpc.actor_attribute( + async def get_actor_attribute_from_worker(): + x = await self._worker_rpc.actor_attribute( attribute=key, actor=self.key ) + return x["result"] raise gen.Return(x["result"]) return self._sync(get_actor_attribute_from_worker) @@ -188,11 +187,10 @@ def __init__(self, rpc, address): self._address = address def __getattr__(self, key): - @gen.coroutine - def func(**msg): + async def func(**msg): msg["op"] = key - result = yield self.rpc.proxy(worker=self._address, msg=msg) - raise gen.Return(result) + result = await self.rpc.proxy(worker=self._address, msg=msg) + return result return func diff --git a/distributed/cli/dask_scheduler.py b/distributed/cli/dask_scheduler.py index 7142331e861..d41b98eb310 100755 --- a/distributed/cli/dask_scheduler.py +++ b/distributed/cli/dask_scheduler.py @@ -1,7 +1,6 @@ from __future__ import print_function, division, absolute_import import atexit -import dask import logging import gc import os @@ -16,9 +15,9 @@ from tornado.ioloop import IOLoop from distributed import Scheduler +from distributed.preloading import validate_preload_argv from distributed.security import Security from distributed.cli.utils import check_python_3, install_signal_handlers -from distributed.preloading import preload_modules, validate_preload_argv from distributed.proctitle import ( enable_proctitle_on_children, enable_proctitle_on_current, @@ -219,16 +218,9 @@ def del_pid_file(): dashboard_address=dashboard_address if dashboard else None, service_kwargs={"dashboard": {"prefix": dashboard_prefix}}, idle_timeout=idle_timeout, + preload=preload, + preload_argv=preload_argv, ) - scheduler.start() - if not preload: - preload = dask.config.get("distributed.scheduler.preload") - if not preload_argv: - preload_argv = dask.config.get("distributed.scheduler.preload-argv") - preload_modules( - preload, parameter=scheduler, file_dir=local_directory, argv=preload_argv - ) - logger.info("Local Directory: %26s", local_directory) logger.info("-" * 47) diff --git a/distributed/client.py b/distributed/client.py index d8fb5dfa13d..c4b2f51426a 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -43,6 +43,7 @@ from tornado.ioloop import IOLoop from tornado.queues import Queue +import asyncio from asyncio import iscoroutine from .batched import BatchedSend @@ -230,32 +231,30 @@ def result(self, timeout=None): else: return result - @gen.coroutine - def _result(self, raiseit=True): - yield self._state.wait() + async def _result(self, raiseit=True): + await self._state.wait() if self.status == "error": exc = clean_exception(self._state.exception, self._state.traceback) if raiseit: six.reraise(*exc) else: - raise gen.Return(exc) + return exc elif self.status == "cancelled": exception = CancelledError(self.key) if raiseit: raise exception else: - raise gen.Return(exception) + return exception else: - result = yield self.client._gather([self]) - raise gen.Return(result[0]) + result = await self.client._gather([self]) + return result[0] - @gen.coroutine - def _exception(self): - yield self._state.wait() + async def _exception(self): + await self._state.wait() if self.status == "error": - raise gen.Return(self._state.exception) + return self._state.exception else: - raise gen.Return(None) + return None def exception(self, timeout=None, **kwargs): """ Return the exception of a failed task @@ -320,13 +319,12 @@ def cancelled(self): """ Returns True if the future has been cancelled """ return self._state.status == "cancelled" - @gen.coroutine - def _traceback(self): - yield self._state.wait() + async def _traceback(self): + await self._state.wait() if self.status == "error": - raise gen.Return(self._state.traceback) + return self._state.traceback else: - raise gen.Return(None) + return None def traceback(self, timeout=None, **kwargs): """ Return the traceback of a failed task @@ -447,6 +445,7 @@ def _get_event(self): def cancel(self): self.status = "cancelled" + self.exception = CancelledError() self._get_event().set() def finish(self, type=None): @@ -479,19 +478,17 @@ def reset(self): if self._event is not None: self._event.clear() - @gen.coroutine - def wait(self, timeout=None): - yield self._get_event().wait(timeout) + async def wait(self, timeout=None): + await self._get_event().wait(timeout) def __repr__(self): return "<%s: %s>" % (self.__class__.__name__, self.status) -@gen.coroutine -def done_callback(future, callback): +async def done_callback(future, callback): """ Coroutine that waits on future, then calls callback """ while future.status == "pending": - yield future._state.wait() + await future._state.wait() callback(future) @@ -867,7 +864,7 @@ def start(self, **kwargs): self.status = "connecting" if self.asynchronous: - self._started = self._start(**kwargs) + self._started = asyncio.ensure_future(self._start(**kwargs)) else: sync(self.loop, self._start, **kwargs) @@ -876,9 +873,8 @@ def __await__(self): return self._started.__await__() else: - @gen.coroutine - def _(): - raise gen.Return(self) + async def _(): + return self return _().__await__() @@ -901,8 +897,7 @@ def _send_to_scheduler(self, msg): "Message: %s" % (self.status, msg) ) - @gen.coroutine - def _start(self, timeout=no_default, **kwargs): + async def _start(self, timeout=no_default, **kwargs): if timeout == no_default: timeout = self._timeout if timeout is not None: @@ -912,7 +907,7 @@ def _start(self, timeout=no_default, **kwargs): if self.cluster is not None: # Ensure the cluster is started (no-op if already running) try: - yield self.cluster._start() + await self.cluster._start() except AttributeError: # Some clusters don't have this method pass except Exception: @@ -923,7 +918,7 @@ def _start(self, timeout=no_default, **kwargs): address = self.cluster.scheduler_address elif self.scheduler_file is not None: while not os.path.exists(self.scheduler_file): - yield gen.sleep(0.01) + await gen.sleep(0.01) for i in range(10): try: with open(self.scheduler_file) as f: @@ -931,33 +926,31 @@ def _start(self, timeout=no_default, **kwargs): address = cfg["address"] break except (ValueError, KeyError): # JSON file not yet flushed - yield gen.sleep(0.01) + await gen.sleep(0.01) elif self._start_arg is None: from .deploy import LocalCluster try: - self.cluster = LocalCluster( + self.cluster = await LocalCluster( loop=self.loop, asynchronous=True, **self._startup_kwargs ) - yield self.cluster except (OSError, socket.error) as e: if e.errno != errno.EADDRINUSE: raise # The default port was taken, use a random one - self.cluster = LocalCluster( + self.cluster = await LocalCluster( scheduler_port=0, loop=self.loop, asynchronous=True, **self._startup_kwargs ) - yield self.cluster # Wait for all workers to be ready # XXX should be a LocalCluster method instead while not self.cluster.workers or len(self.cluster.scheduler.workers) < len( self.cluster.workers ): - yield gen.sleep(0.01) + await gen.sleep(0.01) address = self.cluster.scheduler_address @@ -965,18 +958,17 @@ def _start(self, timeout=no_default, **kwargs): self.scheduler = self.rpc(address) self.scheduler_comm = None - yield self._ensure_connected(timeout=timeout) + await self._ensure_connected(timeout=timeout) for pc in self._periodic_callbacks.values(): pc.start() - self._handle_scheduler_coroutine = self._handle_report() + self._handle_scheduler_coroutine = asyncio.ensure_future(self._handle_report()) self.coroutines.append(self._handle_scheduler_coroutine) - raise gen.Return(self) + return self - @gen.coroutine - def _reconnect(self): + async def _reconnect(self): with log_errors(): assert self.scheduler_comm.comm.closed() @@ -991,11 +983,11 @@ def _reconnect(self): deadline = self.loop.time() + timeout while timeout > 0 and self.status == "connecting": try: - yield self._ensure_connected(timeout=timeout) + await self._ensure_connected(timeout=timeout) break except EnvironmentError: # Wait a bit before retrying - yield gen.sleep(0.1) + await gen.sleep(0.1) timeout = deadline - self.loop.time() else: logger.error( @@ -1003,10 +995,9 @@ def _reconnect(self): "seconds, closing client", self._timeout, ) - yield self._close() + await self._close() - @gen.coroutine - def _ensure_connected(self, timeout=None): + async def _ensure_connected(self, timeout=None): if ( self.scheduler_comm and not self.scheduler_comm.closed() @@ -1018,27 +1009,27 @@ def _ensure_connected(self, timeout=None): self._connecting_to_scheduler = True try: - comm = yield connect( + comm = await connect( self.scheduler.address, timeout=timeout, connection_args=self.connection_args, ) comm.name = "Client->Scheduler" if timeout is not None: - yield gen.with_timeout( + await gen.with_timeout( timedelta(seconds=timeout), self._update_scheduler_info() ) else: - yield self._update_scheduler_info() - yield comm.write( + await self._update_scheduler_info() + await comm.write( {"op": "register-client", "client": self.id, "reply": False} ) finally: self._connecting_to_scheduler = False if timeout is not None: - msg = yield gen.with_timeout(timedelta(seconds=timeout), comm.read()) + msg = await gen.with_timeout(timedelta(seconds=timeout), comm.read()) else: - msg = yield comm.read() + msg = await comm.read() assert len(msg) == 1 assert msg[0]["op"] == "stream-start" @@ -1055,21 +1046,19 @@ def _ensure_connected(self, timeout=None): logger.debug("Started scheduling coroutines. Synchronized") - @gen.coroutine - def _update_scheduler_info(self): + async def _update_scheduler_info(self): if self.status not in ("running", "connecting"): return try: - self._scheduler_identity = yield self.scheduler.identity() + self._scheduler_identity = await self.scheduler.identity() except EnvironmentError: logger.debug("Not able to query scheduler for identity") - @gen.coroutine - def _wait_for_workers(self, n_workers=0): - info = yield self.scheduler.identity() + async def _wait_for_workers(self, n_workers=0): + info = await self.scheduler.identity() while n_workers and len(info["workers"]) < n_workers: - yield gen.sleep(0.1) - info = yield self.scheduler.identity() + await gen.sleep(0.1) + info = await self.scheduler.identity() def wait_for_workers(self, n_workers=0): """Blocking call to wait for n workers before continuing""" @@ -1084,14 +1073,12 @@ def __enter__(self): self.start() return self - @gen.coroutine - def __aenter__(self): - yield self._started - raise gen.Return(self) + async def __aenter__(self): + await self._started + return self - @gen.coroutine - def __aexit__(self, typ, value, traceback): - yield self._close() + async def __aexit__(self, typ, value, traceback): + await self._close() def __exit__(self, type, value, traceback): self.close() @@ -1121,8 +1108,7 @@ def _release_key(self, key): {"op": "client-releases-keys", "keys": [key], "client": self.id} ) - @gen.coroutine - def _handle_report(self): + async def _handle_report(self): """ Listen to scheduler """ with log_errors(): try: @@ -1130,13 +1116,13 @@ def _handle_report(self): if self.scheduler_comm is None: break try: - msgs = yield self.scheduler_comm.comm.read() + msgs = await self.scheduler_comm.comm.read() except CommClosedError: if self.status == "running": logger.info("Client report stream closed to scheduler") logger.info("Reconnecting...") self.status = "connecting" - yield self._reconnect() + await self._reconnect() continue else: break @@ -1212,8 +1198,7 @@ def _handle_error(self, exception=None): logger.warning("Scheduler exception:") logger.exception(exception) - @gen.coroutine - def _close(self, fast=False): + async def _close(self, fast=False): """ Send close signal and wait until scheduler completes """ self.status = "closing" @@ -1243,7 +1228,7 @@ def _close(self, fast=False): # Give the scheduler 'stream-closed' message 100ms to come through # This makes the shutdown slightly smoother and quieter with ignoring(AttributeError, gen.TimeoutError): - yield gen.with_timeout( + await gen.with_timeout( timedelta(milliseconds=100), self._handle_scheduler_coroutine, quiet_exceptions=(CancelledError,), @@ -1254,12 +1239,12 @@ def _close(self, fast=False): and self.scheduler_comm.comm and not self.scheduler_comm.comm.closed() ): - yield self.scheduler_comm.close() + await self.scheduler_comm.close() for key in list(self.futures): self._release_key(key=key) if self._start_arg is None: with ignoring(AttributeError): - yield self.cluster._close() + await self.cluster._close() self.rpc.close() self.status = "closed" if _get_global_client() is self: @@ -1275,7 +1260,7 @@ def _close(self, fast=False): del self.coroutines[:] if not fast: with ignoring(TimeoutError): - yield gen.with_timeout(timedelta(seconds=2), list(coroutines)) + await gen.with_timeout(timedelta(seconds=2), list(coroutines)) with ignoring(AttributeError): self.scheduler.close_rpc() self.scheduler = None @@ -1627,8 +1612,7 @@ def map( return [futures[tokey(k)] for k in keys] - @gen.coroutine - def _gather(self, futures, errors="raise", direct=None, local_worker=None): + async def _gather(self, futures, errors="raise", direct=None, local_worker=None): unpacked, future_set = unpack_remotedata(futures, byte_keys=True) keys = [tokey(future.key) for future in future_set] bad_data = dict() @@ -1645,11 +1629,10 @@ def _gather(self, futures, errors="raise", direct=None, local_worker=None): if w.scheduler.address == self.scheduler.address: direct = True - @gen.coroutine - def wait(k): + async def wait(k): """ Want to stop the All(...) early if we find an error """ st = self.futures[k] - yield st.wait() + await st.wait() if st.status != "finished" and errors == "raise": raise AllExit() @@ -1657,7 +1640,7 @@ def wait(k): logger.debug("Waiting on futures to clear before gather") with ignoring(AllExit): - yield All( + await All( [wait(key) for key in keys if key in self.futures], quiet_exceptions=AllExit, ) @@ -1696,15 +1679,17 @@ def wait(k): # We now do an actual remote communication with workers or scheduler if self._gather_future: # attach onto another pending gather request self._gather_keys |= set(keys) - response = yield self._gather_future + response = await self._gather_future else: # no one waiting, go ahead self._gather_keys = set(keys) - future = self._gather_remote(direct, local_worker) + future = asyncio.ensure_future( + self._gather_remote(direct, local_worker) + ) if self._gather_keys is None: self._gather_future = None else: self._gather_future = future - response = yield future + response = await future if response["status"] == "error": log = logger.warning if errors == "raise" else logger.debug @@ -1728,40 +1713,39 @@ def wait(k): data.update(response["data"]) result = pack_data(unpacked, merge(data, bad_data)) - raise gen.Return(result) + return result - @gen.coroutine - def _gather_remote(self, direct, local_worker): + async def _gather_remote(self, direct, local_worker): """ Perform gather with workers or scheduler This method exists to limit and batch many concurrent gathers into a few. In controls access using a Tornado semaphore, and picks up keys from other requests made recently. """ - yield self._gather_semaphore.acquire() + await self._gather_semaphore.acquire() keys = list(self._gather_keys) self._gather_keys = None # clear state, these keys are being sent off self._gather_future = None try: if direct or local_worker: # gather directly from workers - who_has = yield self.scheduler.who_has(keys=keys) - data2, missing_keys, missing_workers = yield gather_from_workers( + who_has = await self.scheduler.who_has(keys=keys) + data2, missing_keys, missing_workers = await gather_from_workers( who_has, rpc=self.rpc, close=False ) response = {"status": "OK", "data": data2} if missing_keys: keys2 = [key for key in keys if key not in data2] - response = yield self.scheduler.gather(keys=keys2) + response = await self.scheduler.gather(keys=keys2) if response["status"] == "OK": response["data"].update(data2) else: # ask scheduler to gather data for us - response = yield self.scheduler.gather(keys=keys) + response = await self.scheduler.gather(keys=keys) finally: self._gather_semaphore.release() - raise gen.Return(response) + return response def gather(self, futures, errors="raise", direct=None, asynchronous=None): """ Gather futures from distributed memory @@ -1823,8 +1807,7 @@ def gather(self, futures, errors="raise", direct=None, asynchronous=None): asynchronous=asynchronous, ) - @gen.coroutine - def _scatter( + async def _scatter( self, data, workers=None, @@ -1841,7 +1824,7 @@ def _scatter( if isinstance(data, dict) and not all( isinstance(k, (bytes, unicode)) for k in data ): - d = yield self._scatter(keymap(tokey, data), workers, broadcast) + d = await self._scatter(keymap(tokey, data), workers, broadcast) raise gen.Return({k: d[tokey(k)] for k in data}) if isinstance(data, type(range(0))): @@ -1881,7 +1864,7 @@ def _scatter( if local_worker: # running within task local_worker.update_data(data=data, report=False) - yield self.scheduler.update_data( + await self.scheduler.update_data( who_has={key: [local_worker.address] for key in data}, nbytes=valmap(sizeof, data), client=self.id, @@ -1894,22 +1877,22 @@ def _scatter( start = time() while not nthreads: if nthreads is not None: - yield gen.sleep(0.1) + await gen.sleep(0.1) if time() > start + timeout: raise gen.TimeoutError("No valid workers found") - nthreads = yield self.scheduler.ncores(workers=workers) + nthreads = await self.scheduler.ncores(workers=workers) if not nthreads: raise ValueError("No valid workers") - _, who_has, nbytes = yield scatter_to_workers( + _, who_has, nbytes = await scatter_to_workers( nthreads, data2, report=False, rpc=self.rpc ) - yield self.scheduler.update_data( + await self.scheduler.update_data( who_has=who_has, nbytes=nbytes, client=self.id ) else: - yield self.scheduler.scatter( + await self.scheduler.scatter( data=data2, workers=workers, client=self.id, @@ -1923,7 +1906,7 @@ def _scatter( if direct and broadcast: n = None if broadcast is True else broadcast - yield self._replicate(list(out.values()), workers=workers, n=n) + await self._replicate(list(out.values()), workers=workers, n=n) if issubclass(input_type, (list, tuple, set, frozenset)): out = input_type(out[k] for k in names) @@ -1931,7 +1914,7 @@ def _scatter( if unpack: assert len(out) == 1 out = list(out.values())[0] - raise gen.Return(out) + return out def scatter( self, @@ -2030,10 +2013,9 @@ def scatter( hash=hash, ) - @gen.coroutine - def _cancel(self, futures, force=False): + async def _cancel(self, futures, force=False): keys = list({tokey(f.key) for f in futures_of(futures)}) - yield self.scheduler.cancel(keys=keys, client=self.id, force=force) + await self.scheduler.cancel(keys=keys, client=self.id, force=force) for k in keys: st = self.futures.pop(k, None) if st is not None: @@ -2055,10 +2037,9 @@ def cancel(self, futures, asynchronous=None, force=False): """ return self.sync(self._cancel, futures, asynchronous=asynchronous, force=force) - @gen.coroutine - def _retry(self, futures): + async def _retry(self, futures): keys = list({tokey(f.key) for f in futures_of(futures)}) - response = yield self.scheduler.retry(keys=keys, client=self.id) + response = await self.scheduler.retry(keys=keys, client=self.id) for key in response: st = self.futures[key] st.retry() @@ -2179,15 +2160,14 @@ def list_datasets(self, **kwargs): """ return self.sync(self.scheduler.publish_list, **kwargs) - @gen.coroutine - def _get_dataset(self, name): - out = yield self.scheduler.publish_get(name=name, client=self.id) + async def _get_dataset(self, name): + out = await self.scheduler.publish_get(name=name, client=self.id) if out is None: raise KeyError("Dataset '%s' not found" % name) with temp_default_client(self): data = out["data"] - raise gen.Return(data) + return data def get_dataset(self, name, **kwargs): """ @@ -2200,15 +2180,14 @@ def get_dataset(self, name, **kwargs): """ return self.sync(self._get_dataset, name, **kwargs) - @gen.coroutine - def _run_on_scheduler(self, function, *args, wait=True, **kwargs): - response = yield self.scheduler.run_function( + async def _run_on_scheduler(self, function, *args, wait=True, **kwargs): + response = await self.scheduler.run_function( function=dumps(function), args=dumps(args), kwargs=dumps(kwargs), wait=wait ) if response["status"] == "error": six.reraise(*clean_exception(**response)) else: - raise gen.Return(response["result"]) + return response["result"] def run_on_scheduler(self, function, *args, **kwargs): """ Run a function on the scheduler process @@ -2242,9 +2221,10 @@ def run_on_scheduler(self, function, *args, **kwargs): """ return self.sync(self._run_on_scheduler, function, *args, **kwargs) - @gen.coroutine - def _run(self, function, *args, nanny=False, workers=None, wait=True, **kwargs): - responses = yield self.scheduler.broadcast( + async def _run( + self, function, *args, nanny=False, workers=None, wait=True, **kwargs + ): + responses = await self.scheduler.broadcast( msg=dict( op="run", function=dumps(function), @@ -2262,7 +2242,7 @@ def _run(self, function, *args, nanny=False, workers=None, wait=True, **kwargs): elif resp["status"] == "error": six.reraise(*clean_exception(**resp)) if wait: - raise gen.Return(results) + return results def run(self, function, *args, **kwargs): """ @@ -2822,14 +2802,13 @@ def persist( else: return result - @gen.coroutine - def _restart(self, timeout=no_default): + async def _restart(self, timeout=no_default): if timeout == no_default: timeout = self._timeout * 2 self._send_to_scheduler({"op": "restart", "timeout": timeout}) self._restart_event = Event() try: - yield self._restart_event.wait(self.loop.time() + timeout) + await self._restart_event.wait(self.loop.time() + timeout) except gen.TimeoutError: logger.error("Restart timed out after %f seconds", timeout) pass @@ -2837,7 +2816,7 @@ def _restart(self, timeout=no_default): with self._refcount_lock: self.refcount.clear() - raise gen.Return(self) + return self def restart(self, **kwargs): """ Restart the distributed network @@ -2847,12 +2826,11 @@ def restart(self, **kwargs): """ return self.sync(self._restart, **kwargs) - @gen.coroutine - def _upload_file(self, filename, raise_on_error=True): + async def _upload_file(self, filename, raise_on_error=True): with open(filename, "rb") as f: data = f.read() _, fn = os.path.split(filename) - d = yield self.scheduler.broadcast( + d = await self.scheduler.broadcast( msg={"op": "upload_file", "filename": fn, "data": to_serialize(data)} ) @@ -2861,21 +2839,20 @@ def _upload_file(self, filename, raise_on_error=True): if raise_on_error: raise exceptions[0] else: - raise gen.Return(exceptions[0]) + return exceptions[0] assert all(len(data) == v["nbytes"] for v in d.values()) - @gen.coroutine - def _upload_large_file(self, local_filename, remote_filename=None): + async def _upload_large_file(self, local_filename, remote_filename=None): if remote_filename is None: remote_filename = os.path.split(local_filename)[1] with open(local_filename, "rb") as f: data = f.read() - [future] = yield self._scatter([data]) + [future] = await self._scatter([data]) key = future.key - yield self._replicate(future) + await self._replicate(future) def dump_to_file(dask_worker=None): if not os.path.isabs(remote_filename): @@ -2887,7 +2864,7 @@ def dump_to_file(dask_worker=None): return len(dask_worker.data[key]) - response = yield self._run(dump_to_file) + response = await self._run(dump_to_file) assert all(len(data) == v for v in response.values()) @@ -2917,11 +2894,10 @@ def upload_file(self, filename, **kwargs): else: return result - @gen.coroutine - def _rebalance(self, futures=None, workers=None): - yield _wait(futures) + async def _rebalance(self, futures=None, workers=None): + await _wait(futures) keys = list({tokey(f.key) for f in self.futures_of(futures)}) - result = yield self.scheduler.rebalance(keys=keys, workers=workers) + result = await self.scheduler.rebalance(keys=keys, workers=workers) assert result["status"] == "OK" def rebalance(self, futures=None, workers=None, **kwargs): @@ -2944,12 +2920,11 @@ def rebalance(self, futures=None, workers=None, **kwargs): """ return self.sync(self._rebalance, futures, workers, **kwargs) - @gen.coroutine - def _replicate(self, futures, n=None, workers=None, branching_factor=2): + async def _replicate(self, futures, n=None, workers=None, branching_factor=2): futures = self.futures_of(futures) - yield _wait(futures) + await _wait(futures) keys = {tokey(f.key) for f in futures} - yield self.scheduler.replicate( + await self.scheduler.replicate( keys=list(keys), n=n, workers=workers, branching_factor=branching_factor ) @@ -3230,8 +3205,7 @@ def profile( filename=filename, ) - @gen.coroutine - def _profile( + async def _profile( self, key=None, start=None, @@ -3244,7 +3218,7 @@ def _profile( if isinstance(workers, six.string_types + (Number,)): workers = [workers] - state = yield self.scheduler.profile( + state = await self.scheduler.profile( key=key, workers=workers, merge_workers=merge_workers, @@ -3268,10 +3242,10 @@ def _profile( from bokeh.plotting import save save(figure, title="Dask Profile", filename=filename) - raise gen.Return((state, figure)) + return (state, figure) else: - raise gen.Return(state) + return state def scheduler_info(self, **kwargs): """ Basic information about the workers in the cluster @@ -3517,15 +3491,14 @@ def futures_of(self, futures): def start_ipython(self, *args, **kwargs): raise Exception("Method moved to start_ipython_workers") - @gen.coroutine - def _start_ipython_workers(self, workers): + async def _start_ipython_workers(self, workers): if workers is None: - workers = yield self.scheduler.ncores() + workers = await self.scheduler.ncores() - responses = yield self.scheduler.broadcast( + responses = await self.scheduler.broadcast( msg=dict(op="start_ipython"), workers=workers ) - raise gen.Return((workers, responses)) + return workers, responses def start_ipython_workers( self, workers=None, magic_names=False, qtconsole=False, qtconsole_args=None @@ -3840,11 +3813,10 @@ def get_task_stream( filename=filename, ) - @gen.coroutine - def _get_task_stream( + async def _get_task_stream( self, start=None, stop=None, count=None, plot=False, filename="task-stream.html" ): - msgs = yield self.scheduler.get_task_stream(start=start, stop=stop, count=count) + msgs = await self.scheduler.get_task_stream(start=start, stop=stop, count=count) if plot: from .diagnostics.task_stream import rectangles @@ -3857,9 +3829,9 @@ def _get_task_stream( from bokeh.plotting import save save(figure, title="Dask Task Stream", filename=filename) - raise gen.Return((msgs, figure)) + return (msgs, figure) else: - raise gen.Return(msgs) + return msgs def register_worker_callbacks(self, setup=None): """ @@ -3881,9 +3853,8 @@ def register_worker_callbacks(self, setup=None): """ return self.register_worker_plugin(_WorkerSetupPlugin(setup)) - @gen.coroutine - def _register_worker_plugin(self, plugin=None, name=None): - responses = yield self.scheduler.register_worker_plugin( + async def _register_worker_plugin(self, plugin=None, name=None): + responses = await self.scheduler.register_worker_plugin( plugin=dumps(plugin), name=name ) for response in responses.values(): @@ -3892,7 +3863,7 @@ def _register_worker_plugin(self, plugin=None, name=None): typ = type(exc) tb = response["traceback"] six.reraise(typ, exc, tb) - raise gen.Return(responses) + return responses def register_worker_plugin(self, plugin=None, name=None): """ @@ -3976,8 +3947,7 @@ def CompatibleExecutor(*args, **kwargs): FIRST_COMPLETED = "FIRST_COMPLETED" -@gen.coroutine -def _wait(fs, timeout=None, return_when=ALL_COMPLETED): +async def _wait(fs, timeout=None, return_when=ALL_COMPLETED): if timeout is not None and not isinstance(timeout, Number): raise TypeError( "timeout= keyword received a non-numeric value.\n" @@ -3998,7 +3968,7 @@ def _wait(fs, timeout=None, return_when=ALL_COMPLETED): future = wait_for({f._state.wait() for f in fs}) if timeout is not None: future = gen.with_timeout(timedelta(seconds=timeout), future) - yield future + await future done, not_done = ( {fu for fu in fs if fu.status != "pending"}, @@ -4008,7 +3978,7 @@ def _wait(fs, timeout=None, return_when=ALL_COMPLETED): if cancelled: raise CancelledError(cancelled) - raise gen.Return(DoneAndNotDoneFutures(done, not_done)) + return DoneAndNotDoneFutures(done, not_done) def wait(fs, timeout=None, return_when=ALL_COMPLETED): @@ -4027,32 +3997,32 @@ def wait(fs, timeout=None, return_when=ALL_COMPLETED): return result -@gen.coroutine -def _as_completed(fs, queue): +async def _as_completed(fs, queue): fs = futures_of(fs) groups = groupby(lambda f: f.key, fs) firsts = [v[0] for v in groups.values()] - wait_iterator = gen.WaitIterator(*[f._state.wait() for f in firsts]) + wait_iterator = gen.WaitIterator( + *map(asyncio.ensure_future, [f._state.wait() for f in firsts]) + ) while not wait_iterator.done(): - yield wait_iterator.next() + await wait_iterator.next() # TODO: handle case of restarted futures future = firsts[wait_iterator.current_index] for f in groups[future.key]: queue.put_nowait(f) -@gen.coroutine -def _first_completed(futures): +async def _first_completed(futures): """ Return a single completed future See Also: _as_completed """ q = Queue() - yield _as_completed(futures, q) - result = yield q.get() - raise gen.Return(result) + await _as_completed(futures, q) + result = await q.get() + return result class as_completed(object): @@ -4133,15 +4103,14 @@ def _notify(self): with self.thread_condition: self.thread_condition.notify() - @gen.coroutine - def _track_future(self, future): + async def _track_future(self, future): try: - yield _wait(future) + await _wait(future) except CancelledError: pass if self.with_results: try: - result = yield future._result(raiseit=False) + result = await future._result(raiseit=False) except CancelledError as exc: result = exc with self.lock: @@ -4212,16 +4181,15 @@ def __next__(self): self.thread_condition.wait(timeout=0.100) return self._get_and_raise() - @gen.coroutine - def __anext__(self): + async def __anext__(self): if not self.futures and self.queue.empty(): raise StopAsyncIteration while self.queue.empty(): if not self.futures: raise StopAsyncIteration - yield self.condition.wait() + await self.condition.wait() - raise gen.Return(self._get_and_raise()) + return self._get_and_raise() next = __next__ @@ -4443,13 +4411,11 @@ def __exit__(self, typ, value, traceback): L, self.figure = L self.data.extend(L) - @gen.coroutine - def __aenter__(self): - raise gen.Return(self) + async def __aenter__(self): + return self - @gen.coroutine - def __aexit__(self, typ, value, traceback): - L = yield self.client.get_task_stream( + async def __aexit__(self, typ, value, traceback): + L = await self.client.get_task_stream( start=self.start, plot=self._plot, filename=self._filename ) if self._plot: diff --git a/distributed/comm/core.py b/distributed/comm/core.py index e0b236e7b96..869cb9b377f 100644 --- a/distributed/comm/core.py +++ b/distributed/comm/core.py @@ -178,8 +178,7 @@ def connect(self, address, deserialize=True): """ -@gen.coroutine -def connect(addr, timeout=None, deserialize=True, connection_args=None): +async def connect(addr, timeout=None, deserialize=True, connection_args=None): """ Connect to the given address (a URI such as ``tcp://127.0.0.1:1234``) and yield a ``Comm`` object. If the connection attempt fails, it is @@ -212,7 +211,7 @@ def _raise(error): future = connector.connect( loc, deserialize=deserialize, **(connection_args or {}) ) - comm = yield gen.with_timeout( + comm = await gen.with_timeout( timedelta(seconds=deadline - time()), future, quiet_exceptions=EnvironmentError, @@ -222,7 +221,7 @@ def _raise(error): except EnvironmentError as e: error = str(e) if time() < deadline: - yield gen.sleep(0.01) + await gen.sleep(0.01) logger.debug("sleeping on connect") else: _raise(error) @@ -231,7 +230,7 @@ def _raise(error): else: break - raise gen.Return(comm) + return comm def listen(addr, handle_comm, deserialize=True, connection_args=None): diff --git a/distributed/comm/inproc.py b/distributed/comm/inproc.py index 7f267978d51..c9a6dc90281 100644 --- a/distributed/comm/inproc.py +++ b/distributed/comm/inproc.py @@ -7,7 +7,7 @@ import threading import weakref -from tornado import gen, locks +from tornado import locks from tornado.concurrent import Future from tornado.ioloop import IOLoop @@ -180,12 +180,11 @@ def local_address(self): def peer_address(self): return self._peer_addr - @gen.coroutine - def read(self, deserializers="ignored"): + async def read(self, deserializers="ignored"): if self._closed: raise CommClosedError - msg = yield self._read_q.get() + msg = await self._read_q.get() if msg is _EOF: self._closed = True self._finalizer.detach() @@ -193,20 +192,18 @@ def read(self, deserializers="ignored"): if self.deserialize: msg = nested_deserialize(msg) - raise gen.Return(msg) + return msg - @gen.coroutine - def write(self, msg, serializers=None, on_error=None): + async def write(self, msg, serializers=None, on_error=None): if self.closed(): raise CommClosedError # Ensure we feed the queue in the same thread it is read from. self._write_loop.add_callback(self._write_q.put_nowait, msg) - raise gen.Return(1) + return 1 - @gen.coroutine - def close(self): + async def close(self): self.abort() def abort(self): @@ -246,10 +243,9 @@ def __init__(self, address, comm_handler, deserialize=True): self.deserialize = deserialize self.listen_q = Queue() - @gen.coroutine - def _listen(self): + async def _listen(self): while True: - conn_req = yield self.listen_q.get() + conn_req = await self.listen_q.get() if conn_req is None: break comm = InProc( @@ -262,7 +258,7 @@ def _listen(self): ) # Notify connector conn_req.c_loop.add_callback(conn_req.conn_event.set) - self.comm_handler(comm) + IOLoop.current().add_callback(self.comm_handler, comm) def connect_threadsafe(self, conn_req): self.loop.add_callback(self.listen_q.put_nowait, conn_req) @@ -289,8 +285,7 @@ class InProcConnector(Connector): def __init__(self, manager): self.manager = manager - @gen.coroutine - def connect(self, address, deserialize=True, **connection_args): + async def connect(self, address, deserialize=True, **connection_args): listener = self.manager.get_listener_for(address) if listener is None: raise IOError("no endpoint for inproc address %r" % (address,)) @@ -306,7 +301,7 @@ def connect(self, address, deserialize=True, **connection_args): # Wait for connection acknowledgement # (do not pretend we're connected if the other comm never gets # created, for example if the listener was stopped in the meantime) - yield conn_req.conn_event.wait() + await conn_req.conn_event.wait() comm = InProc( local_addr="inproc://" + conn_req.c_addr, @@ -316,7 +311,7 @@ def connect(self, address, deserialize=True, **connection_args): write_loop=listener.loop, deserialize=deserialize, ) - raise gen.Return(comm) + return comm class InProcBackend(Backend): diff --git a/distributed/comm/tcp.py b/distributed/comm/tcp.py index d5351c7d565..602c9a36253 100644 --- a/distributed/comm/tcp.py +++ b/distributed/comm/tcp.py @@ -5,6 +5,7 @@ import socket import struct import sys +from tornado import gen try: import ssl @@ -13,7 +14,7 @@ import dask import tornado -from tornado import gen, netutil +from tornado import netutil from tornado.iostream import StreamClosedError, IOStream from tornado.tcpclient import TCPClient from tornado.tcpserver import TCPServer @@ -184,16 +185,15 @@ def local_address(self): def peer_address(self): return self._peer_addr - @gen.coroutine - def read(self, deserializers=None): + async def read(self, deserializers=None): stream = self.stream if stream is None: raise CommClosedError try: - n_frames = yield stream.read_bytes(8) + n_frames = await stream.read_bytes(8) n_frames = struct.unpack("Q", n_frames)[0] - lengths = yield stream.read_bytes(8 * n_frames) + lengths = await stream.read_bytes(8 * n_frames) lengths = struct.unpack("Q" * n_frames, lengths) frames = [] @@ -201,10 +201,10 @@ def read(self, deserializers=None): if length: if PY3 and self._iostream_has_read_into: frame = bytearray(length) - n = yield stream.read_into(frame) + n = await stream.read_into(frame) assert n == length, (n, length) else: - frame = yield stream.read_bytes(length) + frame = await stream.read_bytes(length) else: frame = b"" frames.append(frame) @@ -214,14 +214,14 @@ def read(self, deserializers=None): convert_stream_closed_error(self, e) else: try: - msg = yield from_frames( + msg = await from_frames( frames, deserialize=self.deserialize, deserializers=deserializers ) except EOFError: # Frames possibly garbled or truncated by communication error self.abort() raise CommClosedError("aborted stream on truncated data") - raise gen.Return(msg) + return msg @gen.coroutine def write(self, msg, serializers=None, on_error="message"): @@ -268,16 +268,15 @@ def write(self, msg, serializers=None, on_error="message"): else: raise - raise gen.Return(sum(map(nbytes, frames))) + return sum(map(nbytes, frames)) - @gen.coroutine - def close(self): + async def close(self): stream, self.stream = self.stream, None if stream is not None and not stream.closed(): try: # Flush the stream's write buffer by waiting for a last write. if stream.writing(): - yield stream.write(b"") + await stream.write(b"") stream.socket.shutdown(socket.SHUT_RDWR) except EnvironmentError: pass @@ -348,14 +347,13 @@ class BaseTCPConnector(Connector, RequireEncryptionMixin): _resolver = None client = TCPClient(resolver=_resolver) - @gen.coroutine - def connect(self, address, deserialize=True, **connection_args): + async def connect(self, address, deserialize=True, **connection_args): self._check_encryption(address, connection_args) ip, port = parse_host_port(address) kwargs = self._get_connect_args(**connection_args) try: - stream = yield BaseTCPConnector.client.connect( + stream = await BaseTCPConnector.client.connect( ip, port, max_buffer_size=MAX_BUFFER_SIZE, **kwargs ) @@ -371,8 +369,8 @@ def connect(self, address, deserialize=True, **connection_args): convert_stream_closed_error(self, e) local_address = self.prefix + get_stream_address(stream) - raise gen.Return( - self.comm_class(stream, local_address, self.prefix + address, deserialize) + return self.comm_class( + stream, local_address, self.prefix + address, deserialize ) @@ -442,17 +440,16 @@ def _check_started(self): if self.tcp_server is None: raise ValueError("invalid operation on non-started TCPListener") - @gen.coroutine - def _handle_stream(self, stream, address): + async def _handle_stream(self, stream, address): address = self.prefix + unparse_host_port(*address[:2]) - stream = yield self._prepare_stream(stream, address) + stream = await self._prepare_stream(stream, address) if stream is None: # Preparation failed return logger.debug("Incoming connection from %r to %r", address, self.contact_address) local_address = self.prefix + get_stream_address(stream) comm = self.comm_class(stream, local_address, address, self.deserialize) - yield self.comm_handler(comm) + await self.comm_handler(comm) def get_host_port(self): """ @@ -490,9 +487,8 @@ class TCPListener(BaseTCPListener): def _get_server_args(self, **connection_args): return {} - @gen.coroutine - def _prepare_stream(self, stream, address): - raise gen.Return(stream) + async def _prepare_stream(self, stream, address): + return stream class TLSListener(BaseTCPListener): @@ -504,10 +500,9 @@ def _get_server_args(self, **connection_args): ctx = _expect_tls_context(connection_args) return {"ssl_options": ctx} - @gen.coroutine - def _prepare_stream(self, stream, address): + async def _prepare_stream(self, stream, address): try: - yield stream.wait_for_handshake() + await stream.wait_for_handshake() except EnvironmentError as e: # The handshake went wrong, log and ignore logger.warning( @@ -517,7 +512,7 @@ def _prepare_stream(self, stream, address): getattr(e, "real_error", None) or e, ) else: - raise gen.Return(stream) + return stream class BaseTCPBackend(Backend): diff --git a/distributed/comm/tests/test_comms.py b/distributed/comm/tests/test_comms.py index e761deeab86..5d52b04a137 100644 --- a/distributed/comm/tests/test_comms.py +++ b/distributed/comm/tests/test_comms.py @@ -800,14 +800,14 @@ def handle_comm(comm): assert comm.closed() listener_errors.append(True) else: - comm.close() + yield comm.close() listener = listen("inproc://", handle_comm) listener.start() contact_addr = listener.contact_address comm = yield connect(contact_addr) - comm.close() + yield comm.close() assert comm.closed() start = time() while len(listener_errors) < 1: @@ -821,7 +821,7 @@ def handle_comm(comm): yield comm.write("foo") comm = yield connect(contact_addr) - comm.write("foo") + yield comm.write("foo") with pytest.raises(CommClosedError): yield comm.read() with pytest.raises(CommClosedError): @@ -829,15 +829,15 @@ def handle_comm(comm): assert comm.closed() comm = yield connect(contact_addr) - comm.write("foo") + yield comm.write("foo") start = time() while not comm.closed(): yield gen.sleep(0.01) assert time() < start + 2 - comm.close() - comm.close() + yield comm.close() + yield comm.close() # diff --git a/distributed/core.py b/distributed/core.py index 8aac2edfc33..d8a34859359 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -1,5 +1,6 @@ from __future__ import print_function, division, absolute_import +import asyncio from collections import defaultdict, deque from concurrent.futures import CancelledError from functools import partial @@ -29,6 +30,7 @@ from . import profile from .system_monitor import SystemMonitor from .utils import ( + is_coroutine_function, get_traceback, truncate_exception, ignoring, @@ -204,7 +206,6 @@ def stop(): self.thread_id = 0 - @gen.coroutine def set_thread_ident(): self.thread_id = get_thread_identity() @@ -326,8 +327,7 @@ def listen(self, port_or_addr=None, listen_args=None): ) self.listener.start() - @gen.coroutine - def handle_comm(self, comm, shutting_down=shutting_down): + async def handle_comm(self, comm, shutting_down=shutting_down): """ Dispatch new communications to coroutine-handlers Handlers is a dictionary mapping operation names to functions or @@ -349,7 +349,7 @@ def handle_comm(self, comm, shutting_down=shutting_down): try: while True: try: - msg = yield comm.read() + msg = await comm.read() logger.debug("Message from %r: %s", address, msg) except EnvironmentError as e: if not shutting_down(): @@ -363,7 +363,7 @@ def handle_comm(self, comm, shutting_down=shutting_down): break except Exception as e: logger.exception(e) - yield comm.write(error_message(e, status="uncaught-error")) + await comm.write(error_message(e, status="uncaught-error")) continue if not isinstance(msg, dict): raise TypeError( @@ -384,7 +384,7 @@ def handle_comm(self, comm, shutting_down=shutting_down): reply = msg.pop("reply", True) if op == "close": if reply: - yield comm.write("OK") + await comm.write("OK") break result = None @@ -412,9 +412,10 @@ def handle_comm(self, comm, shutting_down=shutting_down): logger.debug("Calling into handler %s", handler.__name__) try: result = handler(comm, **msg) - if type(result) is gen.Future: + if hasattr(result, "__await__"): + result = asyncio.ensure_future(result) self._ongoing_coroutines.add(result) - result = yield result + result = await result except (CommClosedError, CancelledError) as e: if self.status == "running": logger.info("Lost connection to %r: %s", address, e) @@ -425,7 +426,7 @@ def handle_comm(self, comm, shutting_down=shutting_down): if reply and result != "dont-reply": try: - yield comm.write(result, serializers=serializers) + await comm.write(result, serializers=serializers) except (EnvironmentError, TypeError) as e: logger.debug( "Lost connection to %r while sending result for op %r: %s", @@ -436,7 +437,7 @@ def handle_comm(self, comm, shutting_down=shutting_down): break msg = result = None if close_desired: - yield comm.close() + await comm.close() if comm.closed(): break @@ -450,8 +451,7 @@ def handle_comm(self, comm, shutting_down=shutting_down): "Failed while closing connection to %r: %s", address, e ) - @gen.coroutine - def handle_stream(self, comm, extra=None, every_cycle=[]): + async def handle_stream(self, comm, extra=None, every_cycle=[]): extra = extra or {} logger.info("Starting established connection") @@ -459,7 +459,7 @@ def handle_stream(self, comm, extra=None, every_cycle=[]): closed = False try: while not closed: - msgs = yield comm.read() + msgs = await comm.read() if not isinstance(msgs, (tuple, list)): msgs = (msgs,) @@ -473,9 +473,14 @@ def handle_stream(self, comm, extra=None, every_cycle=[]): closed = True break handler = self.stream_handlers[op] - handler(**merge(extra, msg)) + if is_coroutine_function(handler): + self.loop.add_callback(handler, **merge(extra, msg)) + else: + handler(**merge(extra, msg)) else: logger.error("odd message %s", msg) + await gen.sleep(0) + for func in every_cycle: func() @@ -489,7 +494,7 @@ def handle_stream(self, comm, extra=None, every_cycle=[]): pdb.set_trace() raise finally: - yield comm.close() + await comm.close() assert comm.closed() @gen.coroutine @@ -519,8 +524,7 @@ def pingpong(comm): return b"pong" -@gen.coroutine -def send_recv(comm, reply=True, serializers=None, deserializers=None, **kwargs): +async def send_recv(comm, reply=True, serializers=None, deserializers=None, **kwargs): """ Send and recv with a Comm. Keyword arguments turn into the message @@ -537,9 +541,9 @@ def send_recv(comm, reply=True, serializers=None, deserializers=None, **kwargs): msg["serializers"] = deserializers try: - yield comm.write(msg, serializers=serializers, on_error="raise") + await comm.write(msg, serializers=serializers, on_error="raise") if reply: - response = yield comm.read(deserializers=deserializers) + response = await comm.read(deserializers=deserializers) else: response = None except EnvironmentError: @@ -548,7 +552,7 @@ def send_recv(comm, reply=True, serializers=None, deserializers=None, **kwargs): raise finally: if please_close: - yield comm.close() + await comm.close() elif force_close: comm.abort() @@ -557,7 +561,7 @@ def send_recv(comm, reply=True, serializers=None, deserializers=None, **kwargs): six.reraise(*clean_exception(**response)) else: raise Exception(response["text"]) - raise gen.Return(response) + return response def addr_from_args(addr=None, ip=None, port=None): @@ -610,8 +614,7 @@ def __init__( self._created = weakref.WeakSet() rpc.active.add(self) - @gen.coroutine - def live_comm(self): + async def live_comm(self): """ Get an open communication Some comms to the ip/port target may be in current use by other @@ -641,7 +644,7 @@ def live_comm(self): for s in to_clear: del self.comms[s] if not open or comm.closed(): - comm = yield connect( + comm = await connect( self.address, self.timeout, deserialize=self.deserialize, @@ -649,44 +652,46 @@ def live_comm(self): ) comm.name = "rpc" self.comms[comm] = False # mark as taken - raise gen.Return(comm) + return comm def close_comms(self): @gen.coroutine def _close_comm(comm): # Make sure we tell the peer to close try: - yield comm.write({"op": "close", "reply": False}) - yield comm.close() + if not comm.closed(): + yield comm.write({"op": "close", "reply": False}) + yield comm.close() except EnvironmentError: comm.abort() for comm in list(self.comms): if comm and not comm.closed(): - _close_comm(comm) + # IOLoop.current().add_callback(_close_comm, comm) + task = asyncio.ensure_future(_close_comm(comm)) for comm in list(self._created): if comm and not comm.closed(): - _close_comm(comm) + # IOLoop.current().add_callback(_close_comm, comm) + task = asyncio.ensure_future(_close_comm(comm)) self.comms.clear() def __getattr__(self, key): - @gen.coroutine - def send_recv_from_rpc(**kwargs): + async def send_recv_from_rpc(**kwargs): if self.serializers is not None and kwargs.get("serializers") is None: kwargs["serializers"] = self.serializers if self.deserializers is not None and kwargs.get("deserializers") is None: kwargs["deserializers"] = self.deserializers try: - comm = yield self.live_comm() + comm = await self.live_comm() comm.name = "rpc." + key - result = yield send_recv(comm=comm, op=key, **kwargs) + result = await send_recv(comm=comm, op=key, **kwargs) except (RPCClosed, CommClosedError) as e: raise e.__class__( "%s: while trying to call remote method %r" % (e, key) ) self.comms[comm] = True # mark as open - raise gen.Return(result) + return result return send_recv_from_rpc @@ -736,21 +741,20 @@ def address(self): return self.addr def __getattr__(self, key): - @gen.coroutine - def send_recv_from_rpc(**kwargs): + async def send_recv_from_rpc(**kwargs): if self.serializers is not None and kwargs.get("serializers") is None: kwargs["serializers"] = self.serializers if self.deserializers is not None and kwargs.get("deserializers") is None: kwargs["deserializers"] = self.deserializers - comm = yield self.pool.connect(self.addr) + comm = await self.pool.connect(self.addr) name, comm.name = comm.name, "ConnectionPool." + key try: - result = yield send_recv(comm=comm, op=key, **kwargs) + result = await send_recv(comm=comm, op=key, **kwargs) finally: self.pool.reuse(self.addr, comm) comm.name = name - raise gen.Return(result) + return result return send_recv_from_rpc @@ -847,8 +851,7 @@ def __call__(self, addr=None, ip=None, port=None): addr, self, serializers=self.serializers, deserializers=self.deserializers ) - @gen.coroutine - def connect(self, addr, timeout=None): + async def connect(self, addr, timeout=None): """ Get a Comm to the given address. For internal use. """ @@ -858,15 +861,15 @@ def connect(self, addr, timeout=None): comm = available.pop() if not comm.closed(): occupied.add(comm) - raise gen.Return(comm) + return comm while self.open >= self.limit: self.event.clear() self.collect() - yield self.event.wait() + await self.event.wait() try: - comm = yield connect( + comm = await connect( addr, timeout=timeout or self.timeout, deserialize=self.deserialize, @@ -882,7 +885,7 @@ def connect(self, addr, timeout=None): if self.open >= self.limit: self.event.clear() - raise gen.Return(comm) + return comm def reuse(self, addr, comm): """ diff --git a/distributed/dashboard/components.py b/distributed/dashboard/components.py index e7234e2e6f7..242a617706e 100644 --- a/distributed/dashboard/components.py +++ b/distributed/dashboard/components.py @@ -1,5 +1,6 @@ from __future__ import print_function, division, absolute_import +import asyncio from bisect import bisect from operator import add from time import time @@ -549,8 +550,7 @@ def update(self, state, metadata=None): @without_property_validation def trigger_update(self, update_metadata=True): - @gen.coroutine - def cb(): + async def cb(): with log_errors(): prof = self.server.get_profile( key=self.key, start=self.start, stop=self.stop @@ -560,7 +560,7 @@ def cb(): else: metadata = None if isinstance(prof, gen.Future): - prof, metadata = yield [prof, metadata] + prof, metadata = await asyncio.gather(prof, metadata) self.doc().add_next_tick_callback(lambda: self.update(prof, metadata)) self.server.loop.add_callback(cb) diff --git a/distributed/dashboard/scheduler_html.py b/distributed/dashboard/scheduler_html.py index 3f119a929b9..3087f323b5f 100644 --- a/distributed/dashboard/scheduler_html.py +++ b/distributed/dashboard/scheduler_html.py @@ -3,7 +3,6 @@ from dask.utils import format_bytes import toolz from tornado import escape -from tornado import gen from ..utils import log_errors, format_time from .proxy import GlobalProxyHandler @@ -59,22 +58,20 @@ def get(self): class WorkerLogs(RequestHandler): - @gen.coroutine - def get(self, worker): + async def get(self, worker): with log_errors(): worker = escape.url_unescape(worker) - logs = yield self.server.get_worker_logs(workers=[worker]) + logs = await self.server.get_worker_logs(workers=[worker]) logs = logs[worker] self.render("logs.html", title="Logs: " + worker, logs=logs, **self.extra) class WorkerCallStacks(RequestHandler): - @gen.coroutine - def get(self, worker): + async def get(self, worker): with log_errors(): worker = escape.url_unescape(worker) keys = self.server.processing[worker] - call_stack = yield self.server.get_call_stack(keys=keys) + call_stack = await self.server.get_call_stack(keys=keys) self.render( "call-stack.html", title="Call Stacks: " + worker, @@ -84,11 +81,10 @@ def get(self, worker): class TaskCallStack(RequestHandler): - @gen.coroutine - def get(self, key): + async def get(self, key): with log_errors(): key = escape.url_unescape(key) - call_stack = yield self.server.get_call_stack(keys=[key]) + call_stack = await self.server.get_call_stack(keys=[key]) if not call_stack: self.write( "

                  Task not actively running. " diff --git a/distributed/dashboard/tests/test_components.py b/distributed/dashboard/tests/test_components.py index d441db57aec..b12780f199b 100644 --- a/distributed/dashboard/tests/test_components.py +++ b/distributed/dashboard/tests/test_components.py @@ -25,7 +25,7 @@ def test_basic(Component): assert isinstance(c.root, Model) -@gen_cluster(client=True, check_new_threads=False) +@gen_cluster(client=True, clean_kwargs={"threads": False}) def test_profile_plot(c, s, a, b): p = ProfilePlot() assert not p.source.data["left"] @@ -34,7 +34,7 @@ def test_profile_plot(c, s, a, b): assert len(p.source.data["left"]) >= 1 -@gen_cluster(client=True, check_new_threads=False) +@gen_cluster(client=True, clean_kwargs={"threads": False}) def test_profile_time_plot(c, s, a, b): from bokeh.io import curdoc diff --git a/distributed/dashboard/tests/test_scheduler_bokeh_html.py b/distributed/dashboard/tests/test_scheduler_bokeh_html.py index f872d02dc84..b66aff02ddc 100644 --- a/distributed/dashboard/tests/test_scheduler_bokeh_html.py +++ b/distributed/dashboard/tests/test_scheduler_bokeh_html.py @@ -73,7 +73,7 @@ def test_prefix(c, s, a, b): @gen_cluster( client=True, - check_new_threads=False, + clean_kwargs={"threads": False}, scheduler_kwargs={"services": {("dashboard", 0): BokehScheduler}}, ) def test_prometheus(c, s, a, b): @@ -98,7 +98,7 @@ def test_prometheus(c, s, a, b): @gen_cluster( client=True, - check_new_threads=False, + clean_kwargs={"threads": False}, scheduler_kwargs={"services": {("dashboard", 0): BokehScheduler}}, ) def test_health(c, s, a, b): diff --git a/distributed/dashboard/tests/test_worker_bokeh.py b/distributed/dashboard/tests/test_worker_bokeh.py index ef977127d23..d320ea24ee8 100644 --- a/distributed/dashboard/tests/test_worker_bokeh.py +++ b/distributed/dashboard/tests/test_worker_bokeh.py @@ -169,7 +169,7 @@ def test_CommunicatingStream(c, s, a, b): @gen_cluster( client=True, - check_new_threads=False, + clean_kwargs={"threads": False}, worker_kwargs={"services": {("dashboard", 0): BokehWorker}}, ) def test_prometheus(c, s, a, b): diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index 441ef10a595..1ba8e7fb213 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -272,7 +272,7 @@ def close(self, timeout=None): def __del__(self): if self.status != "closed": - self.close() + self.loop.add_callback(self.close) def __enter__(self): self.sync(self._correct_state) diff --git a/distributed/deploy/tests/test_adaptive.py b/distributed/deploy/tests/test_adaptive.py index 861c5107348..e0478a9cbdb 100644 --- a/distributed/deploy/tests/test_adaptive.py +++ b/distributed/deploy/tests/test_adaptive.py @@ -255,21 +255,27 @@ def test_adapt_quickly(): d = [x for x in adapt.log[-1] if isinstance(x, dict)][0] assert 2 < d["n"] <= adapt.maximum - while len(cluster.scheduler.workers) < adapt.maximum: + while len(cluster.workers) < adapt.maximum: yield gen.sleep(0.01) del futures - while len(cluster.scheduler.workers) > 1: + while len(cluster.scheduler.tasks) > 1: + yield gen.sleep(0.01) + + yield cluster + + while len(cluster.scheduler.workers) > 1 or len(cluster.worker_spec) > 1: yield gen.sleep(0.01) # Don't scale up for large sequential computations x = yield client.scatter(1) + log = list(cluster._adaptive.log) for i in range(100): x = client.submit(slowinc, x) yield gen.sleep(0.1) - assert len(cluster.scheduler.workers) == 1 + assert len(cluster.workers) == 1 finally: yield client.close() yield cluster.close() diff --git a/distributed/deploy/tests/test_local.py b/distributed/deploy/tests/test_local.py index d489b84df0f..1c098c2b4c5 100644 --- a/distributed/deploy/tests/test_local.py +++ b/distributed/deploy/tests/test_local.py @@ -401,8 +401,9 @@ def test_silent_startup(): from time import sleep from distributed import LocalCluster - with LocalCluster(1, dashboard_address=None, scheduler_port=0): - sleep(1.5) + if __name__ == "__main__": + with LocalCluster(1, dashboard_address=None, scheduler_port=0): + sleep(1.5) """ out = subprocess.check_output( diff --git a/distributed/diagnostics/eventstream.py b/distributed/diagnostics/eventstream.py index a4eb0830534..b9213144d4e 100644 --- a/distributed/diagnostics/eventstream.py +++ b/distributed/diagnostics/eventstream.py @@ -2,8 +2,6 @@ import logging -from tornado import gen - from .plugin import SchedulerPlugin from ..core import connect, coerce_to_address @@ -37,8 +35,7 @@ def teardown(scheduler, es): scheduler.remove_plugin(es) -@gen.coroutine -def eventstream(address, interval): +async def eventstream(address, interval): """ Open a TCP connection to scheduler, receive batched task messages The messages coming back are lists of dicts. Each dict is of the following @@ -59,14 +56,14 @@ def eventstream(address, interval): Examples -------- - >>> stream = yield eventstream('127.0.0.1:8786', 0.100) # doctest: +SKIP - >>> print(yield read(stream)) # doctest: +SKIP + >>> stream = await eventstream('127.0.0.1:8786', 0.100) # doctest: +SKIP + >>> print(await read(stream)) # doctest: +SKIP [{'key': 'x', 'status': 'OK', 'worker': '192.168.0.1:54684', ...}, {'key': 'y', 'status': 'error', 'worker': '192.168.0.1:54684', ...}] """ address = coerce_to_address(address) - comm = yield connect(address) - yield comm.write( + comm = await connect(address) + await comm.write( { "op": "feed", "setup": dumps_function(EventStream), @@ -75,4 +72,4 @@ def eventstream(address, interval): "teardown": dumps_function(teardown), } ) - raise gen.Return(comm) + return comm diff --git a/distributed/diagnostics/progress.py b/distributed/diagnostics/progress.py index 38638a248dd..50c4cd9fad1 100644 --- a/distributed/diagnostics/progress.py +++ b/distributed/diagnostics/progress.py @@ -74,12 +74,11 @@ def __init__(self, keys, scheduler, minimum=0, dt=0.1, complete=False): self.status = None self.extra = {} - @gen.coroutine - def setup(self): + async def setup(self): keys = self.keys while not keys.issubset(self.scheduler.tasks): - yield gen.sleep(0.05) + await gen.sleep(0.05) tasks = [self.scheduler.tasks[k] for k in keys] @@ -163,12 +162,11 @@ def __init__( self, keys, scheduler, minimum=minimum, dt=dt, complete=complete ) - @gen.coroutine - def setup(self): + async def setup(self): keys = self.keys while not keys.issubset(self.scheduler.tasks): - yield gen.sleep(0.05) + await gen.sleep(0.05) tasks = [self.scheduler.tasks[k] for k in keys] diff --git a/distributed/diagnostics/progress_stream.py b/distributed/diagnostics/progress_stream.py index 1630251658a..b1e3787bd5a 100644 --- a/distributed/diagnostics/progress_stream.py +++ b/distributed/diagnostics/progress_stream.py @@ -3,7 +3,6 @@ import logging from toolz import valmap, merge -from tornado import gen from .progress import AllProgress @@ -26,8 +25,7 @@ def counts(scheduler, allprogress): ) -@gen.coroutine -def progress_stream(address, interval): +async def progress_stream(address, interval): """ Open a TCP connection to scheduler, receive progress messages The messages coming back are dicts containing counts of key groups:: @@ -42,12 +40,12 @@ def progress_stream(address, interval): Examples -------- - >>> stream = yield eventstream('127.0.0.1:8786', 0.100) # doctest: +SKIP - >>> print(yield read(stream)) # doctest: +SKIP + >>> stream = await eventstream('127.0.0.1:8786', 0.100) # doctest: +SKIP + >>> print(await read(stream)) # doctest: +SKIP """ address = coerce_to_address(address) - comm = yield connect(address) - yield comm.write( + comm = await connect(address) + await comm.write( { "op": "feed", "setup": dumps_function(AllProgress), @@ -56,7 +54,7 @@ def progress_stream(address, interval): "teardown": dumps_function(Scheduler.remove_plugin), } ) - raise gen.Return(comm) + return comm def nbytes_bar(nbytes): diff --git a/distributed/diagnostics/progressbar.py b/distributed/diagnostics/progressbar.py index f25bf32a871..8d57da779c6 100644 --- a/distributed/diagnostics/progressbar.py +++ b/distributed/diagnostics/progressbar.py @@ -6,7 +6,6 @@ import weakref from toolz import valmap -from tornado import gen from tornado.ioloop import IOLoop from .progress import format_time, Progress, MultiProgress @@ -46,16 +45,14 @@ def __init__(self, keys, scheduler=None, interval="100ms", complete=True): def elapsed(self): return default_timer() - self._start_time - @gen.coroutine - def listen(self): + async def listen(self): complete = self.complete keys = self.keys - @gen.coroutine - def setup(scheduler): + async def setup(scheduler): p = Progress(keys, scheduler, complete=complete) - yield p.setup() - raise gen.Return(p) + await p.setup() + return p def function(scheduler, p): result = { @@ -67,13 +64,13 @@ def function(scheduler, p): result.update(p.extra) return result - self.comm = yield connect( + self.comm = await connect( self.scheduler, connection_args=self.client().connection_args if self.client else None, ) logger.debug("Progressbar Connected to scheduler") - yield self.comm.write( + await self.comm.write( { "op": "feed", "setup": dumps(setup), @@ -85,7 +82,7 @@ def function(scheduler, p): while True: try: - response = yield self.comm.read( + response = await self.comm.read( deserializers=self.client()._deserializers if self.client else None ) except CommClosedError: @@ -94,7 +91,7 @@ def function(scheduler, p): self.status = response["status"] self._draw_bar(**response) if response["status"] in ("error", "finished"): - yield self.comm.close() + await self.comm.close() self._draw_stop(**response) break @@ -240,17 +237,15 @@ def __init__( def elapsed(self): return default_timer() - self._start_time - @gen.coroutine - def listen(self): + async def listen(self): complete = self.complete keys = self.keys func = self.func - @gen.coroutine - def setup(scheduler): + async def setup(scheduler): p = MultiProgress(keys, scheduler, complete=complete, func=func) - yield p.setup() - raise gen.Return(p) + await p.setup() + return p def function(scheduler, p): result = { @@ -262,13 +257,13 @@ def function(scheduler, p): result.update(p.extra) return result - self.comm = yield connect( + self.comm = await connect( self.scheduler, connection_args=self.client().connection_args if self.client else None, ) logger.debug("Progressbar Connected to scheduler") - yield self.comm.write( + await self.comm.write( { "op": "feed", "setup": dumps(setup), @@ -278,14 +273,14 @@ def function(scheduler, p): ) while True: - response = yield self.comm.read( + response = await self.comm.read( deserializers=self.client()._deserializers if self.client else None ) self._last_response = response self.status = response["status"] self._draw_bar(**response) if response["status"] in ("error", "finished"): - yield self.comm.close() + await self.comm.close() self._draw_stop(**response) break logger.debug("Progressbar disconnected from scheduler") diff --git a/distributed/diagnostics/tests/test_progressbar.py b/distributed/diagnostics/tests/test_progressbar.py index 3e5f0633d49..4e6ffe8c7e9 100644 --- a/distributed/diagnostics/tests/test_progressbar.py +++ b/distributed/diagnostics/tests/test_progressbar.py @@ -2,10 +2,12 @@ from time import sleep +import pytest + from distributed import Scheduler, Worker from distributed.diagnostics.progressbar import TextProgressBar, progress from distributed.metrics import time -from distributed.utils_test import inc, div, gen_cluster, gen_test +from distributed.utils_test import inc, div, gen_cluster from distributed.utils_test import client, loop, cluster_fixture # noqa: F401 @@ -40,23 +42,18 @@ def test_TextProgressBar_error(c, s, a, b): assert progress.comm.closed() -def test_TextProgressBar_empty(capsys): - @gen_test() - def f(): - s = yield Scheduler(port=0) - a, b = yield [Worker(s.address, nthreads=1), Worker(s.address, nthreads=1)] - - progress = TextProgressBar([], scheduler=s.address, start=False, interval=0.01) - yield progress.listen() - - assert progress.status == "finished" - check_bar_completed(capsys) - - yield [a.close(), b.close()] - s.close() - yield s.finished() +@pytest.mark.asyncio +async def test_TextProgressBar_empty(capsys): + async with Scheduler(port=0) as s: + async with Worker(s.address, nthreads=1) as a: + async with Worker(s.address, nthreads=1) as b: + progress = TextProgressBar( + [], scheduler=s.address, start=False, interval=0.01 + ) + await progress.listen() - f() + assert progress.status == "finished" + check_bar_completed(capsys) def check_bar_completed(capsys, width=40): diff --git a/distributed/lock.py b/distributed/lock.py index d12b1c41e15..6ad6ab607d3 100644 --- a/distributed/lock.py +++ b/distributed/lock.py @@ -35,8 +35,7 @@ def __init__(self, scheduler): self.scheduler.extensions["locks"] = self - @gen.coroutine - def acquire(self, stream=None, name=None, id=None, timeout=None): + async def acquire(self, stream=None, name=None, id=None, timeout=None): with log_errors(): if isinstance(name, list): name = tuple(name) @@ -50,7 +49,7 @@ def acquire(self, stream=None, name=None, id=None, timeout=None): if timeout is not None: future = gen.with_timeout(timedelta(seconds=timeout), future) try: - yield future + await future except gen.TimeoutError: result = False break @@ -62,7 +61,7 @@ def acquire(self, stream=None, name=None, id=None, timeout=None): if result: assert name not in self.ids self.ids[name] = id - raise gen.Return(result) + return result def release(self, stream=None, name=None, id=None): with log_errors(): @@ -155,14 +154,12 @@ def __enter__(self): def __exit__(self, *args, **kwargs): self.release() - @gen.coroutine - def __aenter__(self): - yield self.acquire() - raise gen.Return(self) + async def __aenter__(self): + await self.acquire() + return self - @gen.coroutine - def __aexit__(self, *args, **kwargs): - yield self.release() + async def __aexit__(self, *args, **kwargs): + await self.release() def __reduce__(self): return (Lock, (self.name,)) diff --git a/distributed/nanny.py b/distributed/nanny.py index f3bebb1dcac..6c859115242 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -191,8 +191,7 @@ def __init__( def __repr__(self): return "" % (self.worker_address, self.nthreads) - @gen.coroutine - def _unregister(self, timeout=10): + async def _unregister(self, timeout=10): if self.process is None: return worker_address = self.process.worker_address @@ -206,7 +205,7 @@ def _unregister(self, timeout=10): RPCClosed, ) try: - yield gen.with_timeout( + await gen.with_timeout( timedelta(seconds=timeout), self.scheduler.unregister(address=self.worker_address), quiet_exceptions=allowed_errors, @@ -222,26 +221,24 @@ def worker_address(self): def worker_dir(self): return None if self.process is None else self.process.worker_dir - @gen.coroutine - def start(self): + async def start(self): """ Start nanny, start local process, start watching """ self.listen(self._start_address, listen_args=self.listen_args) self.ip = get_address_host(self.address) logger.info(" Start Nanny at: %r", self.address) - response = yield self.instantiate() + response = await self.instantiate() if response == "running": assert self.worker_address self.status = "running" else: - yield self.close() + await self.close() self.start_periodic_callbacks() return self - @gen.coroutine - def kill(self, comm=None, timeout=2): + async def kill(self, comm=None, timeout=2): """ Kill the local worker process Blocks until both the process is down and the scheduler is properly @@ -249,13 +246,12 @@ def kill(self, comm=None, timeout=2): """ self.auto_restart = False if self.process is None: - raise gen.Return("OK") + return "OK" deadline = self.loop.time() + timeout - yield self.process.kill(timeout=0.8 * (deadline - self.loop.time())) + await self.process.kill(timeout=0.8 * (deadline - self.loop.time())) - @gen.coroutine - def instantiate(self, comm=None): + async def instantiate(self, comm=None): """ Start a local worker process Blocks until the process is up and the scheduler is properly informed @@ -292,7 +288,7 @@ def instantiate(self, comm=None): worker_kwargs=worker_kwargs, worker_start_args=(start_arg,), silence_logs=self.silence_logs, - on_exit=self._on_exit, + on_exit=self._on_exit_sync, worker=self.Worker, env=self.env, ) @@ -300,11 +296,11 @@ def instantiate(self, comm=None): self.auto_restart = True if self.death_timeout: try: - result = yield gen.with_timeout( + result = await gen.with_timeout( timedelta(seconds=self.death_timeout), self.process.start() ) except gen.TimeoutError: - yield self.close(timeout=self.death_timeout) + await self.close(timeout=self.death_timeout) logger.exception( "Timed out connecting Nanny '%s' to scheduler '%s'", self, @@ -313,26 +309,24 @@ def instantiate(self, comm=None): raise else: - result = yield self.process.start() - raise gen.Return(result) + result = await self.process.start() + return result - @gen.coroutine - def restart(self, comm=None, timeout=2, executor_wait=True): + async def restart(self, comm=None, timeout=2, executor_wait=True): start = time() - @gen.coroutine - def _(): + async def _(): if self.process is not None: - yield self.kill() - yield self.instantiate() + await self.kill() + await self.instantiate() try: - yield gen.with_timeout(timedelta(seconds=timeout), _()) + await gen.with_timeout(timedelta(seconds=timeout), _()) except gen.TimeoutError: logger.error("Restart timed out, returning before finished") - raise gen.Return("timed out") + return "timed out" else: - raise gen.Return("OK") + return "OK" def memory_monitor(self): """ Track worker's memory. Restart if it goes above terminate fraction """ @@ -360,21 +354,23 @@ def is_alive(self): def run(self, *args, **kwargs): return run(self, *args, **kwargs) - @gen.coroutine - def _on_exit(self, exitcode): + def _on_exit_sync(self, exitcode): + self.loop.add_callback(self._on_exit, exitcode) + + async def _on_exit(self, exitcode): if self.status not in ("closing", "closed"): try: - yield self.scheduler.unregister(address=self.worker_address) + await self.scheduler.unregister(address=self.worker_address) except (EnvironmentError, CommClosedError): if not self.reconnect: - yield self.close() + await self.close() return try: if self.status not in ("closing", "closed", "closing-gracefully"): if self.auto_restart: logger.warning("Restarting worker") - yield self.instantiate() + await self.instantiate() except Exception: logger.error( "Failed to restart worker after its process exited", exc_info=True @@ -396,13 +392,12 @@ def close_gracefully(self, comm=None): """ self.status = "closing-gracefully" - @gen.coroutine - def close(self, comm=None, timeout=5, report=None): + async def close(self, comm=None, timeout=5, report=None): """ Close the worker process, stop all comms. """ if self.status == "closing": - yield self.finished() + await self.finished() assert self.status == "closed" if self.status == "closed": @@ -413,15 +408,15 @@ def close(self, comm=None, timeout=5, report=None): self.stop() try: if self.process is not None: - yield self.kill(timeout=timeout) + await self.kill(timeout=timeout) except Exception: pass self.process = None self.rpc.close() self.status = "closed" if comm: - yield comm.write("OK") - yield ServerNode.close(self) + await comm.write("OK") + await ServerNode.close(self) class WorkerProcess(object): @@ -441,17 +436,16 @@ def __init__( self.worker_dir = None self.worker_address = None - @gen.coroutine - def start(self): + async def start(self): """ Ensure the worker process is started. """ enable_proctitle_on_children() if self.status == "running": - raise gen.Return(self.status) + return self.status if self.status == "starting": - yield self.running.wait() - raise gen.Return(self.status) + await self.running.wait() + return self.status self.init_result_q = init_q = mp_context.Queue() self.child_stop_q = mp_context.Queue() @@ -476,10 +470,16 @@ def start(self): self.running = Event() self.stopped = Event() self.status = "starting" - yield self.process.start() - msg = yield self._wait_until_connected(uid) + try: + await self.process.start() + except OSError: + logger.exception("Nanny failed to start process", exc_info=True) + self.process.terminate() + return + + msg = await self._wait_until_connected(uid) if not msg: - raise gen.Return(self.status) + return self.status self.worker_address = msg["address"] self.worker_dir = msg["dir"] assert self.worker_address @@ -488,7 +488,7 @@ def start(self): init_q.close() - raise gen.Return(self.status) + return self.status def _on_exit(self, proc): if proc is not self.process: @@ -518,7 +518,7 @@ def mark_stopped(self): assert r is not None if r != 0: msg = self._death_message(self.process.pid, r) - logger.warning(msg) + logger.info(msg) self.status = "stopped" self.stopped.set() # Release resources @@ -534,8 +534,7 @@ def mark_stopped(self): if self.on_exit is not None: self.on_exit(r) - @gen.coroutine - def kill(self, timeout=2, executor_wait=True): + async def kill(self, timeout=2, executor_wait=True): """ Ensure the worker process is stopped, waiting at most *timeout* seconds before terminating it abruptly. @@ -546,7 +545,7 @@ def kill(self, timeout=2, executor_wait=True): if self.status == "stopped": return if self.status == "stopping": - yield self.stopped.wait() + await self.stopped.wait() return assert self.status in ("starting", "running") self.status = "stopping" @@ -562,19 +561,18 @@ def kill(self, timeout=2, executor_wait=True): self.child_stop_q.close() while process.is_alive() and loop.time() < deadline: - yield gen.sleep(0.05) + await gen.sleep(0.05) if process.is_alive(): logger.warning( "Worker process still alive after %d seconds, killing", timeout ) try: - yield process.terminate() + await process.terminate() except Exception as e: logger.error("Failed to kill worker process: %s", e) - @gen.coroutine - def _wait_until_connected(self, uid): + async def _wait_until_connected(self, uid): delay = 0.05 while True: if self.status != "starting": @@ -582,7 +580,7 @@ def _wait_until_connected(self, uid): try: msg = self.init_result_q.get_nowait() except Empty: - yield gen.sleep(delay) + await gen.sleep(delay) continue if msg["uid"] != uid: # ensure that we didn't cross queues @@ -592,10 +590,10 @@ def _wait_until_connected(self, uid): logger.error( "Failed while trying to start worker process: %s", msg["exception"] ) - yield self.process.join() + await self.process.join() raise msg else: - raise gen.Return(msg) + return msg @classmethod def _run( @@ -625,10 +623,9 @@ def _run( loop.make_current() worker = Worker(**worker_kwargs) - @gen.coroutine - def do_stop(timeout=5, executor_wait=True): + async def do_stop(timeout=5, executor_wait=True): try: - yield worker.close( + await worker.close( report=False, nanny=False, executor_wait=executor_wait, @@ -657,13 +654,12 @@ def watch_stop_q(): t.daemon = True t.start() - @gen.coroutine - def run(): + async def run(): """ Try to start worker and inform parent of outcome. """ try: - yield worker + await worker except Exception as e: logger.exception("Failed to start worker") init_result_q.put({"uid": uid, "exception": e}) @@ -674,7 +670,7 @@ def run(): {"address": worker.address, "dir": worker.local_dir, "uid": uid} ) init_result_q.close() - yield worker.wait_until_closed() + await worker.wait_until_closed() logger.info("Worker closed") try: diff --git a/distributed/process.py b/distributed/process.py index 556edae290e..e716d754db1 100644 --- a/distributed/process.py +++ b/distributed/process.py @@ -332,7 +332,7 @@ def _cleanup_dangling(): for proc in list(_dangling): if proc.is_alive(): try: - logger.warning("reaping stray process %s" % (proc,)) + logger.info("reaping stray process %s" % (proc,)) proc.terminate() except OSError: pass diff --git a/distributed/pubsub.py b/distributed/pubsub.py index f9cf1f6f7c3..f40c0b15b31 100644 --- a/distributed/pubsub.py +++ b/distributed/pubsub.py @@ -308,10 +308,9 @@ def __init__(self, name, worker=None, client=None): self.loop.add_callback(pubsub.publishers[name].add, self) finalize(self, pubsub.trigger_cleanup) - @gen.coroutine - def _start(self): + async def _start(self): if self.worker: - result = yield self.scheduler.pubsub_add_publisher( + result = await self.scheduler.pubsub_add_publisher( name=self.name, worker=self.worker.address ) pubsub = self.worker.extensions["pubsub"] @@ -388,8 +387,7 @@ def __init__(self, name, worker=None, client=None): finalize(self, pubsub.trigger_cleanup) - @gen.coroutine - def _get(self, timeout=None): + async def _get(self, timeout=None): if timeout is not None: timeout = datetime.timedelta(seconds=timeout) start = datetime.datetime.now() @@ -400,9 +398,9 @@ def _get(self, timeout=None): raise gen.TimeoutError() else: timeout2 = None - yield self.condition.wait(timeout=timeout2) + await self.condition.wait(timeout=timeout2) - raise gen.Return(self.buffer.popleft()) + return self.buffer.popleft() __anext__ = _get diff --git a/distributed/queues.py b/distributed/queues.py index 72f0f9fe52c..12bd15b6318 100644 --- a/distributed/queues.py +++ b/distributed/queues.py @@ -5,8 +5,8 @@ import logging import uuid -from tornado import gen import tornado.queues +from tornado.locks import Event from .client import Future, _get_global_client, Client from .utils import tokey, sync, thread_state @@ -49,6 +49,7 @@ def __init__(self, scheduler): self.scheduler.extensions["queues"] = self def create(self, stream=None, name=None, client=None, maxsize=0): + print("name", name) if name not in self.queues: self.queues[name] = tornado.queues.Queue(maxsize=maxsize) self.client_refcount[name] = 1 @@ -64,13 +65,11 @@ def release(self, stream=None, name=None, client=None): del self.client_refcount[name] futures = self.queues[name]._queue del self.queues[name] - self.scheduler.client_releases_keys( - keys=[d["value"] for d in futures if d["type"] == "Future"], - client="queue-%s" % name, - ) + keys = [d["value"] for d in futures if d["type"] == "Future"] + if keys: + self.scheduler.client_releases_keys(keys=keys, client="queue-%s" % name) - @gen.coroutine - def put( + async def put( self, stream=None, name=None, key=None, data=None, client=None, timeout=None ): if key is not None: @@ -81,7 +80,7 @@ def put( record = {"type": "msgpack", "value": data} if timeout is not None: timeout = datetime.timedelta(seconds=(timeout)) - yield self.queues[name].put(record, timeout=timeout) + await self.queues[name].put(record, timeout=timeout) def future_release(self, name=None, key=None, client=None): self.future_refcount[name, key] -= 1 @@ -89,8 +88,7 @@ def future_release(self, name=None, key=None, client=None): self.scheduler.client_releases_keys(keys=[key], client="queue-%s" % name) del self.future_refcount[name, key] - @gen.coroutine - def get(self, stream=None, name=None, client=None, timeout=None, batch=False): + async def get(self, stream=None, name=None, client=None, timeout=None, batch=False): def process(record): """ Add task status if known """ if record["type"] == "Future": @@ -111,7 +109,7 @@ def process(record): out = [] if batch is True: while not q.empty(): - record = yield q.get() + record = await q.get() out.append(record) else: if timeout is not None: @@ -121,16 +119,16 @@ def process(record): ) raise NotImplementedError(msg) for i in range(batch): - record = yield q.get() + record = await q.get() out.append(record) out = [process(o) for o in out] - raise gen.Return(out) + return out else: if timeout is not None: timeout = datetime.timedelta(seconds=timeout) - record = yield self.queues[name].get(timeout=timeout) + record = await self.queues[name].get(timeout=timeout) record = process(record) - raise gen.Return(record) + return record def qsize(self, stream=None, name=None, client=None): return self.queues[name].qsize() @@ -168,12 +166,18 @@ class Queue(object): def __init__(self, name=None, client=None, maxsize=0): self.client = client or _get_global_client() self.name = name or "queue-" + uuid.uuid4().hex + self._event_started = Event() if self.client.asynchronous or getattr( thread_state, "on_event_loop_thread", False ): - self._started = self.client.scheduler.queue_create( - name=self.name, maxsize=maxsize - ) + + async def _create_queue(): + await self.client.scheduler.queue_create( + name=self.name, maxsize=maxsize + ) + self._event_started.set() + + self.client.loop.add_callback(_create_queue) else: sync( self.client.loop, @@ -181,24 +185,22 @@ def __init__(self, name=None, client=None, maxsize=0): name=self.name, maxsize=maxsize, ) - self._started = gen.moment + self._event_started.set() def __await__(self): - @gen.coroutine - def _(): - yield self._started - raise gen.Return(self) + async def _(): + await self._event_started.wait() + return self return _().__await__() - @gen.coroutine - def _put(self, value, timeout=None): + async def _put(self, value, timeout=None): if isinstance(value, Future): - yield self.client.scheduler.queue_put( + await self.client.scheduler.queue_put( key=tokey(value.key), timeout=timeout, name=self.name ) else: - yield self.client.scheduler.queue_put( + await self.client.scheduler.queue_put( data=value, timeout=timeout, name=self.name ) @@ -224,9 +226,8 @@ def qsize(self, **kwargs): """ Current number of elements in the queue """ return self.client.sync(self._qsize, **kwargs) - @gen.coroutine - def _get(self, timeout=None, batch=False): - resp = yield self.client.scheduler.queue_get( + async def _get(self, timeout=None, batch=False): + resp = await self.client.scheduler.queue_get( timeout=timeout, name=self.name, batch=batch ) @@ -248,12 +249,11 @@ def process(d): else: result = list(map(process, resp)) - raise gen.Return(result) + return result - @gen.coroutine - def _qsize(self): - result = yield self.client.scheduler.queue_qsize(name=self.name) - raise gen.Return(result) + async def _qsize(self): + result = await self.client.scheduler.queue_qsize(name=self.name) + return result def close(self): if self.client.status == "running": # TODO: can leave zombie futures diff --git a/distributed/recreate_exceptions.py b/distributed/recreate_exceptions.py index 78b0f4de9ba..d02dc4d94f4 100644 --- a/distributed/recreate_exceptions.py +++ b/distributed/recreate_exceptions.py @@ -1,7 +1,6 @@ from __future__ import print_function, division, absolute_import import logging -from tornado import gen from .client import futures_of, wait from .utils import sync, tokey from .utils_comm import pack_data @@ -77,20 +76,19 @@ def __init__(self, client): def scheduler(self): return self.client.scheduler - @gen.coroutine - def _get_futures_error(self, future): + async def _get_futures_error(self, future): # only get errors for futures that errored. futures = [f for f in futures_of(future) if f.status == "error"] if not futures: raise ValueError("No errored futures passed") - out = yield self.scheduler.cause_of_failure(keys=[f.key for f in futures]) + out = await self.scheduler.cause_of_failure(keys=[f.key for f in futures]) deps, task = out["deps"], out["task"] if isinstance(task, dict): function, args, kwargs = _deserialize(**task) - raise gen.Return((function, args, kwargs, deps)) + return (function, args, kwargs, deps) else: function, args, kwargs = _deserialize(task=task) - raise gen.Return((function, args, kwargs, deps)) + return (function, args, kwargs, deps) def get_futures_error(self, future): """ @@ -122,16 +120,15 @@ def get_futures_error(self, future): """ return self.client.sync(self._get_futures_error, future) - @gen.coroutine - def _recreate_error_locally(self, future): - yield wait(future) - out = yield self._get_futures_error(future) + async def _recreate_error_locally(self, future): + await wait(future) + out = await self._get_futures_error(future) function, args, kwargs, deps = out futures = self.client._graph_to_futures({}, deps) - data = yield self.client._gather(futures) + data = await self.client._gather(futures) args = pack_data(args, data) kwargs = pack_data(kwargs, data) - raise gen.Return((function, args, kwargs)) + return (function, args, kwargs) def recreate_error_locally(self, future): """ diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 4441b815642..ace4d2483d5 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1,5 +1,6 @@ from __future__ import print_function, division, absolute_import +import asyncio from collections import defaultdict, deque, OrderedDict from datetime import timedelta from functools import partial @@ -43,6 +44,7 @@ from . import profile from .metrics import time from .node import ServerNode +from .preloading import preload_modules from .proctitle import setproctitle from .security import Security from .utils import ( @@ -842,6 +844,8 @@ def __init__( port=0, protocol=None, dashboard_address=None, + preload=None, + preload_argv=(), **kwargs ): self._setup_logging(logger) @@ -874,6 +878,13 @@ def __init__( self.time_started = time() self.bandwidth = parse_bytes(dask.config.get("distributed.scheduler.bandwidth")) + if not preload: + preload = dask.config.get("distributed.scheduler.preload") + if not preload_argv: + preload_argv = dask.config.get("distributed.scheduler.preload-argv") + self.preload = preload + self.preload_argv = preload_argv + self.security = security or Security() assert isinstance(self.security, Security) self.connection_args = self.security.get_connection_args("scheduler") @@ -1175,8 +1186,7 @@ def get_worker_service_addr(self, worker, service_name, protocol=False): else: return ws.host, port - @gen.coroutine - def start(self): + async def start(self): """ Clear out old state and restart all running coroutines """ enable_gc_diagnosis() @@ -1219,14 +1229,14 @@ def del_scheduler_file(): finalize(self, del_scheduler_file) + preload_modules(self.preload, parameter=self, argv=self.preload_argv) + self.start_periodic_callbacks() setproctitle("dask-scheduler [%s]" % (self.address,)) - return self - @gen.coroutine - def close(self, comm=None, fast=False, close_workers=False): + async def close(self, comm=None, fast=False, close_workers=False): """ Send cleanup signal to all coroutines then wait until finished See Also @@ -1234,7 +1244,7 @@ def close(self, comm=None, fast=False, close_workers=False): Scheduler.cleanup """ if self.status.startswith("clos"): - yield self.finished() + await self.finished() return self.status = "closing" @@ -1242,12 +1252,12 @@ def close(self, comm=None, fast=False, close_workers=False): setproctitle("dask-scheduler [closing]") if close_workers: - self.broadcast(msg={"op": "close_gracefully"}, nanny=True) + await self.broadcast(msg={"op": "close_gracefully"}, nanny=True) for worker in self.workers: self.worker_send(worker, {"op": "close"}) for i in range(20): # wait a second for send signals to clear if self.workers: - yield gen.sleep(0.05) + await gen.sleep(0.05) else: break @@ -1269,8 +1279,8 @@ def close(self, comm=None, fast=False, close_workers=False): with ignoring(AttributeError): futures.append(comm.close()) - for future in futures: - yield future + for future in futures: # TODO: do all at once + await future for comm in self.client_comms.values(): comm.abort() @@ -1279,13 +1289,12 @@ def close(self, comm=None, fast=False, close_workers=False): self.status = "closed" self.stop() - yield super(Scheduler, self).close() + await super(Scheduler, self).close() setproctitle("dask-scheduler [closed]") disable_gc_diagnosis() - @gen.coroutine - def close_worker(self, stream=None, worker=None, safe=None): + async def close_worker(self, stream=None, worker=None, safe=None): """ Remove a worker from the cluster This both removes the worker from our local state and also sends a @@ -1305,7 +1314,6 @@ def close_worker(self, stream=None, worker=None, safe=None): # Stimuli # ########### - @gen.coroutine def heartbeat_worker( self, comm=None, @@ -1361,8 +1369,7 @@ def heartbeat_worker( "heartbeat-interval": heartbeat_interval(len(self.workers)), } - @gen.coroutine - def add_worker( + async def add_worker( self, comm=None, address=None, @@ -1409,7 +1416,8 @@ def add_worker( "message": "name taken, %s" % name, "time": time(), } - yield comm.write(msg) + if comm: + await comm.write(msg) return if "addresses" not in self.host_info[host]: @@ -1473,15 +1481,16 @@ def add_worker( self.log_event("all", {"action": "add-worker", "worker": address}) logger.info("Register %s", str(address)) - yield comm.write( - { - "status": "OK", - "time": time(), - "heartbeat-interval": heartbeat_interval(len(self.workers)), - "worker-plugins": self.worker_plugins, - } - ) - yield self.handle_worker(comm=comm, worker=address) + if comm: + await comm.write( + { + "status": "OK", + "time": time(), + "heartbeat-interval": heartbeat_interval(len(self.workers)), + "worker-plugins": self.worker_plugins, + } + ) + await self.handle_worker(comm=comm, worker=address) def update_graph( self, @@ -1930,7 +1939,6 @@ def remove_worker(self, comm=None, address=None, safe=False, close=True): if not self.workers: logger.info("Lost all workers") - @gen.coroutine def remove_worker_from_events(): # If the worker isn't registered anymore after the delay, remove from events if address not in self.workers and address in self.events: @@ -2190,8 +2198,7 @@ def report(self, msg, ts=None, client=None): if self.status == "running": logger.critical("Tried writing to closed comm: %s", msg) - @gen.coroutine - def add_client(self, comm, client=None): + async def add_client(self, comm, client=None): """ Add client to network We listen to all future messages from this Comm. @@ -2208,7 +2215,7 @@ def add_client(self, comm, client=None): bcomm.send({"op": "stream-start"}) try: - yield self.handle_stream(comm=comm, extra={"client": client}) + await self.handle_stream(comm=comm, extra={"client": client}) finally: self.remove_client(client=client) logger.debug("Finished handling client %s", client) @@ -2217,7 +2224,7 @@ def add_client(self, comm, client=None): self.client_comms[client].send({"op": "stream-closed"}) try: if not shutting_down(): - yield self.client_comms[client].close() + await self.client_comms[client].close() del self.client_comms[client] if self.status == "running": logger.info("Close client connection: %s", client) @@ -2240,7 +2247,6 @@ def remove_client(self, client=None): ) del self.clients[client] - @gen.coroutine def remove_client_from_events(): # If the client isn't registered anymore after the delay, remove from events if client not in self.clients and client in self.events: @@ -2384,8 +2390,7 @@ def handle_long_running(self, key=None, worker=None, compute_duration=None): ws.processing[ts] = 0 self.check_idle_saturated(ws) - @gen.coroutine - def handle_worker(self, comm=None, worker=None): + async def handle_worker(self, comm=None, worker=None): """ Listen to responses from a single worker @@ -2400,7 +2405,7 @@ def handle_worker(self, comm=None, worker=None): worker_comm.start(comm) logger.info("Starting worker compute stream, %s", worker) try: - yield self.handle_stream(comm=comm, extra={"worker": worker}) + await self.handle_stream(comm=comm, extra={"worker": worker}) finally: if worker in self.stream_comms: worker_comm.abort() @@ -2439,8 +2444,7 @@ def worker_send(self, worker, msg): # Less common interactions # ############################ - @gen.coroutine - def scatter( + async def scatter( self, comm=None, data=None, @@ -2457,7 +2461,7 @@ def scatter( """ start = time() while not self.workers: - yield gen.sleep(0.2) + await gen.sleep(0.2) if time() > start + timeout: raise gen.TimeoutError("No workers found") @@ -2469,7 +2473,7 @@ def scatter( assert isinstance(data, dict) - keys, who_has, nbytes = yield scatter_to_workers( + keys, who_has, nbytes = await scatter_to_workers( nthreads, data, rpc=self.rpc, report=False ) @@ -2480,15 +2484,14 @@ def scatter( n = len(nthreads) else: n = broadcast - yield self.replicate(keys=keys, workers=workers, n=n) + await self.replicate(keys=keys, workers=workers, n=n) self.log_event( [client, "all"], {"action": "scatter", "client": client, "count": len(data)} ) - raise gen.Return(keys) + return keys - @gen.coroutine - def gather(self, comm=None, keys=None, serializers=None): + async def gather(self, comm=None, keys=None, serializers=None): """ Collect data in from workers """ keys = list(keys) who_has = {} @@ -2499,7 +2502,7 @@ def gather(self, comm=None, keys=None, serializers=None): else: who_has[key] = [] - data, missing_keys, missing_workers = yield gather_from_workers( + data, missing_keys, missing_workers = await gather_from_workers( who_has, rpc=self.rpc, close=False, serializers=serializers ) if not missing_keys: @@ -2537,7 +2540,7 @@ def gather(self, comm=None, keys=None, serializers=None): self.transitions({key: "released"}) self.log_event("all", {"action": "gather", "count": len(keys)}) - raise gen.Return(result) + return result def clear_task_state(self): # XXX what about nested state such as ClientState.wants_what @@ -2546,8 +2549,7 @@ def clear_task_state(self): for collection in self._task_state_collections: collection.clear() - @gen.coroutine - def restart(self, client=None, timeout=3): + async def restart(self, client=None, timeout=3): """ Restart all workers. Reset local state. """ with log_errors(): @@ -2596,7 +2598,7 @@ def restart(self, client=None, timeout=3): for nanny in nannies ] ) - resps = yield gen.with_timeout(timedelta(seconds=timeout), resps) + resps = await gen.with_timeout(timedelta(seconds=timeout), resps) if not all(resp == "OK" for resp in resps): logger.error( "Not all workers responded positively: %s", resps, exc_info=True @@ -2610,17 +2612,16 @@ def restart(self, client=None, timeout=3): for nanny in nannies: nanny.close_rpc() - self.start() + await self.start() self.log_event([client, "all"], {"action": "restart", "client": client}) start = time() while time() < start + 10 and len(self.workers) < n_workers: - yield gen.sleep(0.01) + await gen.sleep(0.01) self.report({"op": "restart"}) - @gen.coroutine - def broadcast( + async def broadcast( self, comm=None, msg=None, @@ -2646,28 +2647,26 @@ def broadcast( else: addresses = workers - @gen.coroutine - def send_message(addr): - comm = yield connect( + async def send_message(addr): + comm = await connect( addr, deserialize=self.deserialize, connection_args=self.connection_args ) comm.name = "Scheduler Broadcast" - resp = yield send_recv(comm, close=True, serializers=serializers, **msg) - raise gen.Return(resp) + resp = await send_recv(comm, close=True, serializers=serializers, **msg) + return resp - results = yield All( + results = await All( [send_message(address) for address in addresses if address is not None] ) - raise Return(dict(zip(workers, results))) + return dict(zip(workers, results)) - @gen.coroutine - def proxy(self, comm=None, msg=None, worker=None, serializers=None): + async def proxy(self, comm=None, msg=None, worker=None, serializers=None): """ Proxy a communication through the scheduler to some other worker """ - d = yield self.broadcast( + d = await self.broadcast( comm=comm, msg=msg, workers=[worker], serializers=serializers ) - raise gen.Return(d[worker]) + return d[worker] @gen.coroutine def rebalance(self, comm=None, keys=None, workers=None): @@ -2686,7 +2685,7 @@ def rebalance(self, comm=None, keys=None, workers=None): tasks = {self.tasks[k] for k in keys} missing_data = [ts.key for ts in tasks if not ts.who_has] if missing_data: - raise Return({"status": "missing-data", "keys": missing_data}) + return {"status": "missing-data", "keys": missing_data} else: tasks = set(self.tasks.values()) @@ -3016,8 +3015,7 @@ def key(group): return result - @gen.coroutine - def retire_workers( + async def retire_workers( self, comm=None, workers=None, remove=True, close_workers=False, **kwargs ): """ Gracefully retire workers from cluster @@ -3053,7 +3051,7 @@ def retire_workers( try: workers = self.workers_to_close(**kwargs) if workers: - workers = yield self.retire_workers( + workers = await self.retire_workers( workers=workers, remove=remove, close_workers=close_workers, @@ -3073,18 +3071,20 @@ def retire_workers( other_workers = set(self.workers.values()) - workers if keys: if other_workers: - yield self.replicate( + await self.replicate( keys=keys, workers=[ws.address for ws in other_workers], n=1, delete=False, ) else: - raise gen.Return([]) + return [] worker_keys = {ws.address: ws.identity() for ws in workers} if close_workers and worker_keys: - yield [self.close_worker(worker=w, safe=True) for w in worker_keys] + await asyncio.gather( + *[self.close_worker(worker=w, safe=True) for w in worker_keys] + ) if remove: for w in worker_keys: self.remove_worker(address=w, safe=True) @@ -3099,7 +3099,7 @@ def retire_workers( ) self.log_event(list(worker_keys), {"action": "retired"}) - raise gen.Return(worker_keys) + return worker_keys def add_keys(self, comm=None, worker=None, keys=()): """ @@ -3188,8 +3188,7 @@ def report_on_key(self, key=None, ts=None, client=None): client=client, ) - @gen.coroutine - def feed( + async def feed( self, comm, function=None, setup=None, teardown=None, interval="1s", **kwargs ): """ @@ -3217,16 +3216,16 @@ def feed( if teardown: teardown = pickle.loads(teardown) state = setup(self) if setup else None - if isinstance(state, gen.Future): - state = yield state + if hasattr(state, "__await__"): + state = await state try: while self.status == "running": if state is None: response = function(self) else: response = function(self, state) - yield comm.write(response) - yield gen.sleep(interval) + await comm.write(response) + await gen.sleep(interval) except (EnvironmentError, CommClosedError): pass finally: @@ -3391,15 +3390,14 @@ def get_task_stream(self, comm=None, start=None, stop=None, count=None): ts = [p for p in self.plugins if isinstance(p, TaskStreamPlugin)][0] return ts.collect(start=start, stop=stop, count=count) - @gen.coroutine - def register_worker_plugin(self, comm, plugin, name=None): + async def register_worker_plugin(self, comm, plugin, name=None): """ Registers a setup function, and call it on every worker """ self.worker_plugins.append(plugin) - responses = yield self.broadcast( + responses = await self.broadcast( msg=dict(op="plugin-add", plugin=plugin, name=name) ) - raise gen.Return(responses) + return responses ##################### # State Transitions # @@ -4603,12 +4601,11 @@ def get_profile_metadata( raise gen.Return({"counts": counts, "keys": keys}) - @gen.coroutine - def get_worker_logs(self, comm=None, n=None, workers=None, nanny=False): - results = yield self.broadcast( + async def get_worker_logs(self, comm=None, n=None, workers=None, nanny=False): + results = await self.broadcast( msg={"op": "get_logs", "n": n}, workers=workers, nanny=nanny ) - raise gen.Return(results) + return results ########### # Cleanup # diff --git a/distributed/tests/py3_test_pubsub.py b/distributed/tests/py3_test_pubsub.py index 0cedbb3bd31..294ecfb90c8 100644 --- a/distributed/tests/py3_test_pubsub.py +++ b/distributed/tests/py3_test_pubsub.py @@ -1,6 +1,7 @@ from distributed import Pub, Sub from distributed.utils_test import gen_cluster +import asyncio import toolz from tornado import gen import pytest @@ -22,7 +23,7 @@ def f(_): sub = Sub("a") return list(toolz.take(5, sub)) - c.run(publish, workers=[a.address]) + asyncio.ensure_future(c.run(publish, workers=[a.address])) tasks = [c.submit(f, i) for i in range(4)] results = yield c.gather(tasks) diff --git a/distributed/tests/test_as_completed.py b/distributed/tests/test_as_completed.py index a584025ad03..aa53b9b993a 100644 --- a/distributed/tests/test_as_completed.py +++ b/distributed/tests/test_as_completed.py @@ -238,7 +238,7 @@ def test_as_completed_with_results_no_raise_async(c, s, a, b): z = c.submit(inc, 1) ac = as_completed([x, y, z], with_results=True, raise_errors=False) - y.cancel() + c.loop.add_callback(y.cancel) first = yield ac.__anext__() second = yield ac.__anext__() third = yield ac.__anext__() diff --git a/distributed/tests/test_batched.py b/distributed/tests/test_batched.py index 23d8e677774..af281aff8c3 100644 --- a/distributed/tests/test_batched.py +++ b/distributed/tests/test_batched.py @@ -134,7 +134,7 @@ def test_close_closed(): b.start(comm) b.send(123) - comm.close() # external closing + yield comm.close() # external closing yield b.close() assert "closed" in repr(b) @@ -185,7 +185,7 @@ def recv(): yield All([send(), recv()]) assert L == list(range(0, 10000, 1)) - comm.close() + yield comm.close() @gen.coroutine @@ -222,7 +222,7 @@ def run_traffic_jam(nsends, nbytes): assert results == list(range(nsends)) - comm.close() # external closing + yield comm.close() # external closing yield b.close() diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 1f9678583b0..50415971e20 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -1,11 +1,11 @@ from __future__ import print_function, division, absolute_import -from operator import add - +import asyncio from collections import deque from concurrent.futures import CancelledError import gc import logging +from operator import add import os import pickle import random @@ -3311,11 +3311,11 @@ def test_get_foo_lost_keys(c, s, u, v, w): @gen_cluster( client=True, Worker=Nanny, - check_new_threads=False, worker_kwargs={"death_timeout": "500ms"}, + clean_kwargs={"processes": False, "threads": False}, ) def test_bad_tasks_fail(c, s, a, b): - f = c.submit(sys.exit, 1) + f = c.submit(sys.exit, 0) with pytest.raises(KilledWorker) as info: yield f @@ -3486,7 +3486,7 @@ def test_scatter_raises_if_no_workers(c, s): @pytest.mark.slow def test_reconnect(loop): w = Worker("127.0.0.1", 9393, loop=loop) - w.start() + loop.add_callback(w.start) scheduler_cli = [ "dask-scheduler", @@ -4031,7 +4031,7 @@ def f(x, y=0): assert len(b.data) > 2 * len(a.data) -@gen_cluster(client=True, check_new_threads=False) +@gen_cluster(client=True, clean_kwargs={"threads": False}) def test_add_done_callback(c, s, a, b): S = set() @@ -4616,9 +4616,9 @@ def f(_): from concurrent.futures import ThreadPoolExecutor - e = ThreadPoolExecutor(30) - results = list(e.map(f, range(30))) - assert results and all(results) + with ThreadPoolExecutor(30) as e: + results = list(e.map(f, range(30))) + assert results and all(results) @pytest.mark.slow @@ -5343,13 +5343,13 @@ def test_de_serialization_none(s, a, b): @gen_cluster() def test_client_repr_closed(s, a, b): - c = yield Client(s.address, asynchronous=True) + c = yield Client(s.address, asynchronous=True, dashboard_address=None) yield c.close() c._repr_html_() def test_client_repr_closed_sync(loop): - with Client(loop=loop, processes=False) as c: + with Client(loop=loop, processes=False, dashboard_address=None) as c: c.close() c._repr_html_() @@ -5498,7 +5498,7 @@ def f(x): assert result == 101 -@gen_cluster(client=True, check_new_threads=False) +@gen_cluster(client=True, clean_kwargs={"threads": False}) def test_profile_bokeh(c, s, a, b): pytest.importorskip("bokeh.plotting") from bokeh.model import Model @@ -5578,7 +5578,7 @@ def test_instances(c, s, a, b): @gen_cluster(client=True) def test_wait_for_workers(c, s, a, b): - future = c.wait_for_workers(n_workers=3) + future = asyncio.ensure_future(c.wait_for_workers(n_workers=3)) yield gen.sleep(0.22) # 2 chances assert not future.done() diff --git a/distributed/tests/test_locks.py b/distributed/tests/test_locks.py index 226feec4faf..9fa9a73787a 100644 --- a/distributed/tests/test_locks.py +++ b/distributed/tests/test_locks.py @@ -13,7 +13,7 @@ @gen_cluster(client=True, nthreads=[("127.0.0.1", 8)] * 2) def test_lock(c, s, a, b): - c.set_metadata("locked", False) + yield c.set_metadata("locked", False) def f(x): client = get_client() diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index 7722476a2c5..579af8dbc2c 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -12,6 +12,7 @@ import pytest from toolz import valmap, first from tornado import gen +from tornado.ioloop import IOLoop import dask from distributed import Nanny, rpc, Scheduler, Worker @@ -19,7 +20,13 @@ from distributed.metrics import time from distributed.protocol.pickle import dumps from distributed.utils import ignoring, tmpfile -from distributed.utils_test import gen_cluster, gen_test, inc, captured_logger +from distributed.utils_test import ( # noqa: F401 + gen_cluster, + gen_test, + inc, + captured_logger, + cleanup, +) @gen_cluster(nthreads=[]) @@ -82,7 +89,7 @@ def test_nanny_process_failure(c, s): pid = n.pid assert pid is not None with ignoring(CommClosedError): - yield c._run(os._exit, 0, workers=[n.worker_address]) + yield c.run(os._exit, 0, workers=[n.worker_address]) start = time() while n.pid == pid: # wait while process dies and comes back @@ -90,6 +97,7 @@ def test_nanny_process_failure(c, s): assert time() - start < 5 start = time() + yield gen.sleep(1) while not n.is_alive(): # wait while process comes back yield gen.sleep(0.01) assert time() - start < 5 @@ -259,7 +267,7 @@ def test_nanny_timeout(c, s, a): Worker=Nanny, worker_kwargs={"memory_limit": 1e8}, timeout=20, - check_new_threads=False, + clean_kwargs={"threads": False}, ) def test_nanny_terminate(c, s, a): from time import sleep @@ -319,7 +327,7 @@ def test_scheduler_address_config(c, s): def test_wait_for_scheduler(): with captured_logger("distributed") as log: w = Nanny("127.0.0.1:44737") - w.start() + IOLoop.current().add_callback(w.start) yield gen.sleep(6) yield w.close() @@ -378,3 +386,15 @@ def pool_worker(world_size): p.map(_noop, range(world_size)) yield c.submit(pool_worker, 4) + + +@pytest.mark.asyncio +async def test_nanny_closes_cleanly(cleanup): + async with Scheduler() as s: + n = await Nanny(s.address) + assert n.process.pid + proc = n.process.process + await n.close() + assert not n.process + assert not proc.is_alive() + assert proc.exitcode == 0 diff --git a/distributed/tests/test_priorities.py b/distributed/tests/test_priorities.py index 6258c4e16a7..ae96517f1ac 100644 --- a/distributed/tests/test_priorities.py +++ b/distributed/tests/test_priorities.py @@ -6,32 +6,34 @@ from dask import delayed, persist from distributed.utils_test import gen_cluster, inc, slowinc, slowdec -from distributed import wait +from distributed import wait, Worker from distributed.utils import tokey -@gen_cluster(client=True) -def test_submit(c, s, a, b): +@gen_cluster(client=True, nthreads=[]) +async def test_submit(c, s): low = c.submit(inc, 1, priority=-1) futures = c.map(slowinc, range(10), delay=0.1) high = c.submit(inc, 2, priority=1) - yield wait(high) - assert all(s.processing.values()) - assert s.tasks[low.key].state == "processing" + async with Worker(s.address, nthreads=1): + await wait(high) + assert all(s.processing.values()) + assert s.tasks[low.key].state == "processing" -@gen_cluster(client=True) -def test_map(c, s, a, b): +@gen_cluster(client=True, nthreads=[]) +async def test_map(c, s): low = c.map(inc, [1, 2, 3], priority=-1) futures = c.map(slowinc, range(10), delay=0.1) high = c.map(inc, [4, 5, 6], priority=1) - yield wait(high) - assert all(s.processing.values()) - assert s.tasks[low[0].key].state == "processing" + async with Worker(s.address, nthreads=1): + await wait(high) + assert all(s.processing.values()) + assert s.tasks[low[0].key].state == "processing" -@gen_cluster(client=True) -def test_compute(c, s, a, b): +@gen_cluster(client=True, nthreads=[]) +async def test_compute(c, s): da = pytest.importorskip("dask.array") x = da.random.random((10, 10), chunks=(5, 5)) y = da.random.random((10, 10), chunks=(5, 5)) @@ -39,13 +41,14 @@ def test_compute(c, s, a, b): low = c.compute(x, priority=-1) futures = c.map(slowinc, range(10), delay=0.1) high = c.compute(y, priority=1) - yield wait(high) - assert all(s.processing.values()) - assert s.tasks[tokey(low.key)].state in ("processing", "waiting") + async with Worker(s.address, nthreads=1): + await wait(high) + assert all(s.processing.values()) + assert s.tasks[tokey(low.key)].state in ("processing", "waiting") -@gen_cluster(client=True) -def test_persist(c, s, a, b): +@gen_cluster(client=True, nthreads=[]) +async def test_persist(c, s): da = pytest.importorskip("dask.array") x = da.random.random((10, 10), chunks=(5, 5)) y = da.random.random((10, 10), chunks=(5, 5)) @@ -53,12 +56,13 @@ def test_persist(c, s, a, b): low = x.persist(priority=-1) futures = c.map(slowinc, range(10), delay=0.1) high = y.persist(priority=1) - yield wait(high) - assert all(s.processing.values()) - assert all( - s.tasks[tokey(k)].state in ("processing", "waiting") - for k in flatten(low.__dask_keys__()) - ) + async with Worker(s.address, nthreads=1): + await wait(high) + assert all(s.processing.values()) + assert all( + s.tasks[tokey(k)].state in ("processing", "waiting") + for k in flatten(low.__dask_keys__()) + ) @gen_cluster(client=True) diff --git a/distributed/tests/test_queues.py b/distributed/tests/test_queues.py index 2e7702171ad..a28d1e29082 100644 --- a/distributed/tests/test_queues.py +++ b/distributed/tests/test_queues.py @@ -223,7 +223,7 @@ def test_Future_knows_status_immediately(c, s, a, b): @gen_cluster(client=True) def test_erred_future(c, s, a, b): future = c.submit(div, 1, 0) - q = Queue() + q = yield Queue() yield q.put(future) yield gen.sleep(0.1) future2 = yield q.get() @@ -236,10 +236,7 @@ def test_erred_future(c, s, a, b): @gen_cluster(client=True) def test_close(c, s, a, b): - q = Queue() - - while q.name not in s.extensions["queues"].queues: - yield gen.sleep(0.01) + q = yield Queue() q.close() q.close() @@ -250,7 +247,7 @@ def test_close(c, s, a, b): @gen_cluster(client=True) def test_timeout(c, s, a, b): - q = Queue("v", maxsize=1) + q = yield Queue("v", maxsize=1) start = time() with pytest.raises(gen.TimeoutError): diff --git a/distributed/tests/test_resources.py b/distributed/tests/test_resources.py index 1985d44e2a3..b3f5db36a76 100644 --- a/distributed/tests/test_resources.py +++ b/distributed/tests/test_resources.py @@ -380,8 +380,8 @@ def test_full_collections(c, s, a, b): def test_collections_get(client, optimize_graph, s, a, b): da = pytest.importorskip("dask.array") - def f(dask_worker): - dask_worker.set_resources(**{"A": 1}) + async def f(dask_worker): + await dask_worker.set_resources(**{"A": 1}) client.run(f, workers=[a["address"]]) diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 20de5e7b7fd..4f1b2808102 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -110,7 +110,7 @@ def test_decide_worker_with_many_independent_leaves(c, s, a, b): @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3) def test_decide_worker_with_restrictions(client, s, a, b, c): x = client.submit(inc, 1, workers=[a.address, b.address]) - yield wait(x) + yield x assert x.key in a.data or x.key in b.data @@ -621,14 +621,6 @@ def test_update_graph_culls(s, a, b): assert "z" not in s.dependencies -@gen_cluster(nthreads=[]) -def test_add_worker_is_idempotent(s): - s.add_worker(address=alice, nthreads=1, resolve_address=False) - nthreads = dict(s.nthreads) - s.add_worker(address=alice, resolve_address=False) - assert s.nthreads == s.nthreads - - def test_io_loop(loop): s = Scheduler(loop=loop, validate=True) assert s.io_loop is loop @@ -956,7 +948,7 @@ def test_worker_breaks_and_returns(c, s, a): yield wait(future) - a.batched_stream.comm.close() + yield a.batched_stream.comm.close() yield gen.sleep(0.1) start = time() @@ -1146,11 +1138,13 @@ def test_scheduler_file(): @pytest.mark.xfail(reason="") @gen_cluster(client=True, nthreads=[]) -def test_non_existent_worker(c, s): +async def test_non_existent_worker(c, s): with dask.config.set({"distributed.comm.timeouts.connect": "100ms"}): - s.add_worker(address="127.0.0.1:5738", nthreads=2, nbytes={}, host_info={}) + await s.add_worker( + address="127.0.0.1:5738", nthreads=2, nbytes={}, host_info={} + ) futures = c.map(inc, range(10)) - yield gen.sleep(0.300) + await gen.sleep(0.300) assert not s.workers assert all(ts.state == "no-worker" for ts in s.tasks.values()) @@ -1317,19 +1311,19 @@ def test_retries(c, s, a, b): @pytest.mark.xfail(reason="second worker also errant for some reason") @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3, timeout=5) -def test_mising_data_errant_worker(c, s, w1, w2, w3): +async def test_mising_data_errant_worker(c, s, w1, w2, w3): with dask.config.set({"distributed.comm.timeouts.connect": "1s"}): np = pytest.importorskip("numpy") x = c.submit(np.random.random, 10000000, workers=w1.address) - yield wait(x) - yield c.replicate(x, workers=[w1.address, w2.address]) + await wait(x) + await c.replicate(x, workers=[w1.address, w2.address]) y = c.submit(len, x, workers=w3.address) while not w3.tasks: - yield gen.sleep(0.001) - w1.close() - yield wait(y) + await gen.sleep(0.001) + await w1.close() + await wait(y) @gen_cluster(client=True) diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index d233fc28388..45a110bbecf 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -234,7 +234,7 @@ def test_dont_steal_host_restrictions(c, s, a, b): yield future futures = c.map(slowinc, range(100), delay=0.1, workers="127.0.0.1") - while len(a.task_state) < 10: + while len(a.task_state) + len(b.task_state) < 100: yield gen.sleep(0.01) assert len(a.task_state) == 100 assert len(b.task_state) == 0 @@ -254,7 +254,7 @@ def test_dont_steal_resource_restrictions(c, s, a, b): yield future futures = c.map(slowinc, range(100), delay=0.1, resources={"A": 1}) - while len(a.task_state) < 10: + while len(a.task_state) + len(b.task_state) < 100: yield gen.sleep(0.01) assert len(a.task_state) == 100 assert len(b.task_state) == 0 diff --git a/distributed/tests/test_stress.py b/distributed/tests/test_stress.py index 6a5dbe72736..b5f51359239 100644 --- a/distributed/tests/test_stress.py +++ b/distributed/tests/test_stress.py @@ -14,7 +14,7 @@ from distributed import Client, wait, Nanny from distributed.config import config from distributed.metrics import time -from distributed.utils import All +from distributed.utils import All, ignoring from distributed.utils_test import ( gen_cluster, cluster, @@ -126,7 +126,7 @@ def test_stress_scatter_death(c, s, *workers): s.allowed_failures = 1000 np = pytest.importorskip("numpy") L = yield c.scatter([np.random.random(10000) for i in range(len(workers))]) - yield c._replicate(L, n=2) + yield c.replicate(L, n=2) adds = [ delayed(slowadd, pure=True)( @@ -166,27 +166,10 @@ def test_stress_scatter_death(c, s, *workers): yield w.close() alive.remove(w) - try: - yield gen.with_timeout(timedelta(seconds=25), c._gather(futures)) - except gen.TimeoutError: - ws = {w.address: w for w in workers if w.status != "closed"} - print(s.processing) - print(ws) - print(futures) - try: - worker = [w for w in ws.values() if w.waiting_for_data][0] - except Exception: - pass - if config.get("log-on-err"): - import pdb - - pdb.set_trace() - else: - raise - except CancelledError: - pass - finally: - futures = None + with ignoring(CancelledError): + yield c.gather(futures) + + futures = None def vsum(*args): diff --git a/distributed/tests/test_tls_functional.py b/distributed/tests/test_tls_functional.py index 7d097e28112..43c8c667bf4 100644 --- a/distributed/tests/test_tls_functional.py +++ b/distributed/tests/test_tls_functional.py @@ -19,8 +19,8 @@ def test_Queue(c, s, a, b): assert s.address.startswith("tls://") - x = Queue("x") - y = Queue("y") + x = yield Queue("x") + y = yield Queue("y") size = yield x.qsize() assert size == 0 diff --git a/distributed/tests/test_utils_test.py b/distributed/tests/test_utils_test.py index eac2ec71529..05b8066c707 100755 --- a/distributed/tests/test_utils_test.py +++ b/distributed/tests/test_utils_test.py @@ -13,6 +13,7 @@ from distributed.core import rpc from distributed.metrics import time from distributed.utils_test import ( # noqa: F401 + cleanup, cluster, gen_cluster, inc, @@ -175,10 +176,10 @@ def test_tls_cluster(tls_client): assert tls_client.security -def test_tls_scheduler(security, loop): - s = yield Scheduler(security=security, loop=loop, host="localhost") - assert s.address.startswith("tls") - yield s.close() +@pytest.mark.asyncio +async def test_tls_scheduler(security, cleanup): + async with Scheduler(security=security, host="localhost") as s: + assert s.address.startswith("tls") if sys.version_info >= (3, 5): diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 2e6af6e0bdc..1c1e70fcba7 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -501,7 +501,8 @@ def test_memory_limit_auto(): assert isinstance(a.memory_limit, Number) assert isinstance(b.memory_limit, Number) - assert a.memory_limit < b.memory_limit + if multiprocessing.cpu_count() > 1: + assert a.memory_limit < b.memory_limit assert c.memory_limit == d.memory_limit @@ -1050,7 +1051,14 @@ def test_statistical_profiling(c, s, a, b): @pytest.mark.slow @nodebug -@gen_cluster(client=True, timeout=20) +@gen_cluster( + client=True, + timeout=30, + config={ + "distributed.worker.profile.interval": "1ms", + "distributed.worker.profile.cycle": "100ms", + }, +) def test_statistical_profiling_2(c, s, a, b): da = pytest.importorskip("dask.array") while True: diff --git a/distributed/tests/test_worker_plugins.py b/distributed/tests/test_worker_plugins.py index bbba39943fb..02db9419d4e 100644 --- a/distributed/tests/test_worker_plugins.py +++ b/distributed/tests/test_worker_plugins.py @@ -27,7 +27,7 @@ def test_create_with_client(c, s): assert worker._my_plugin_status == "setup" assert worker._my_plugin_data == 123 - yield worker._close() + yield worker.close() assert worker._my_plugin_status == "teardown" diff --git a/distributed/utils.py b/distributed/utils.py index 2f4657439cf..6e8769979fb 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -1,5 +1,6 @@ from __future__ import print_function, division, absolute_import +import asyncio import atexit from collections import deque from contextlib import contextmanager @@ -200,8 +201,7 @@ def ignore_exceptions(coroutines, *exceptions): raise gen.Return(results) -@gen.coroutine -def All(args, quiet_exceptions=()): +async def All(args, quiet_exceptions=()): """ Wait on many tasks at the same time Err once any of the tasks err. @@ -214,11 +214,11 @@ def All(args, quiet_exceptions=()): quiet_exceptions: tuple, Exception Exception types to avoid logging if they fail """ - tasks = gen.WaitIterator(*args) + tasks = gen.WaitIterator(*map(asyncio.ensure_future, args)) results = [None for _ in args] while not tasks.done(): try: - result = yield tasks.next() + result = await tasks.next() except Exception: @gen.coroutine @@ -237,13 +237,11 @@ def quiet(): quiet() raise - results[tasks.current_index] = result - raise gen.Return(results) + return results -@gen.coroutine -def Any(args, quiet_exceptions=()): +async def Any(args, quiet_exceptions=()): """ Wait on many tasks at the same time and return when any is finished Err once any of the tasks err. @@ -254,11 +252,11 @@ def Any(args, quiet_exceptions=()): quiet_exceptions: tuple, Exception Exception types to avoid logging if they fail """ - tasks = gen.WaitIterator(*args) + tasks = gen.WaitIterator(*map(asyncio.ensure_future, args)) results = [None for _ in args] while not tasks.done(): try: - result = yield tasks.next() + result = await tasks.next() except Exception: @gen.coroutine @@ -280,7 +278,7 @@ def quiet(): results[tasks.current_index] = result break - raise gen.Return(results) + return results def sync(loop, func, *args, callback_timeout=None, **kwargs): @@ -1397,7 +1395,6 @@ def reset_logger_locks(): ) if not jupyter_event_loop_initialized: - import asyncio import tornado.platform.asyncio asyncio.set_event_loop_policy( @@ -1489,3 +1486,7 @@ def format_dashboard_link(host, port): else: scheme = "http" return template.format(scheme=scheme, host=host, port=port, **os.environ) + + +def is_coroutine_function(f): + return asyncio.iscoroutinefunction(f) or gen.is_coroutine_function(f) diff --git a/distributed/utils_comm.py b/distributed/utils_comm.py index e6f5235afe0..af393cbd79e 100644 --- a/distributed/utils_comm.py +++ b/distributed/utils_comm.py @@ -1,12 +1,10 @@ from __future__ import print_function, division, absolute_import +import asyncio from collections import defaultdict from itertools import cycle import random -from tornado import gen -from tornado.gen import Return - from dask.optimization import SubgraphCallable from toolz import merge, concat, groupby, drop @@ -14,8 +12,7 @@ from .utils import All, tokey -@gen.coroutine -def gather_from_workers(who_has, rpc, close=True, serializers=None, who=None): +async def gather_from_workers(who_has, rpc, close=True, serializers=None, who=None): """ Gather data directly from peers Parameters @@ -59,20 +56,22 @@ def gather_from_workers(who_has, rpc, close=True, serializers=None, who=None): rpcs = {addr: rpc(addr) for addr in d} try: coroutines = { - address: get_data_from_worker( - rpc, - keys, - address, - who=who, - serializers=serializers, - max_connections=False, + address: asyncio.ensure_future( + get_data_from_worker( + rpc, + keys, + address, + who=who, + serializers=serializers, + max_connections=False, + ) ) for address, keys in d.items() } response = {} for worker, c in coroutines.items(): try: - r = yield c + r = await c except EnvironmentError: missing_workers.add(worker) else: @@ -85,7 +84,7 @@ def gather_from_workers(who_has, rpc, close=True, serializers=None, who=None): results.update(response) bad_keys = {k: list(original_who_has[k]) for k in all_bad_keys} - raise Return((results, bad_keys, list(missing_workers))) + return (results, bad_keys, list(missing_workers)) class WrappedKey(object): @@ -109,8 +108,7 @@ def __repr__(self): _round_robin_counter = [0] -@gen.coroutine -def scatter_to_workers(nthreads, data, rpc=rpc, report=True, serializers=None): +async def scatter_to_workers(nthreads, data, rpc=rpc, report=True, serializers=None): """ Scatter data directly to workers This distributes data in a round-robin fashion to a set of workers based on @@ -134,7 +132,7 @@ def scatter_to_workers(nthreads, data, rpc=rpc, report=True, serializers=None): rpcs = {addr: rpc(addr) for addr in d} try: - out = yield All( + out = await All( [ rpcs[address].update_data( data=v, report=report, serializers=serializers @@ -150,7 +148,7 @@ def scatter_to_workers(nthreads, data, rpc=rpc, report=True, serializers=None): who_has = {k: [w for w, _, _ in v] for k, v in groupby(1, L).items()} - raise Return((names, who_has, nbytes)) + return (names, who_has, nbytes) collection_types = (tuple, list, set, frozenset) diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 5f3dff548cf..d2932857c40 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -1,5 +1,6 @@ from __future__ import print_function, division, absolute_import +import asyncio import collections from contextlib import contextmanager import copy @@ -110,10 +111,9 @@ def invalid_python_script(tmpdir_factory): return local_file -@gen.coroutine -def cleanup_global_workers(): +async def cleanup_global_workers(): for worker in Worker._instances: - worker.close(report=False, executor_wait=False) + await worker.close(report=False, executor_wait=False) @pytest.fixture @@ -399,10 +399,9 @@ def apply(func, *args, **kwargs): return apply, list(map(varying, itemslists)) -@gen.coroutine -def geninc(x, delay=0.02): - yield gen.sleep(delay) - raise gen.Return(x + 1) +async def geninc(x, delay=0.02): + await gen.sleep(delay) + return x + 1 def compile_snippet(code, dedent=True): @@ -429,8 +428,7 @@ async def asyncinc(x, delay=0.02): _readone_queues = {} -@gen.coroutine -def readone(comm): +async def readone(comm): """ Read one message at a time from a comm that reads lists of messages. @@ -440,11 +438,10 @@ def readone(comm): except KeyError: q = _readone_queues[comm] = queues.Queue() - @gen.coroutine - def background_read(): + async def background_read(): while True: try: - messages = yield comm.read() + messages = await comm.read() except CommClosedError: break for msg in messages: @@ -454,11 +451,11 @@ def background_read(): background_read() - msg = yield q.get() + msg = await q.get() if msg is None: raise CommClosedError else: - raise gen.Return(msg) + return msg def run_scheduler(q, nputs, port=0, **kwargs): @@ -467,13 +464,17 @@ def run_scheduler(q, nputs, port=0, **kwargs): # On Python 2.7 and Unix, fork() is used to spawn child processes, # so avoid inheriting the parent's IO loop. with pristine_loop() as loop: - scheduler = Scheduler(validate=True, host="127.0.0.1", port=port, **kwargs) - done = scheduler.start() - for i in range(nputs): - q.put(scheduler.address) + async def _(): + scheduler = await Scheduler( + validate=True, host="127.0.0.1", port=port, **kwargs + ) + for i in range(nputs): + q.put(scheduler.address) + await scheduler.finished() + try: - loop.start() + loop.run_sync(_) finally: loop.close(all_fds=True) @@ -485,16 +486,14 @@ def run_worker(q, scheduler_q, **kwargs): with log_errors(): with pristine_loop() as loop: scheduler_addr = scheduler_q.get() - worker = Worker(scheduler_addr, validate=True, **kwargs) - loop.run_sync(worker.start) - q.put(worker.address) - try: - @gen.coroutine - def wait_until_closed(): - yield worker._closed.wait() + async def _(): + worker = await Worker(scheduler_addr, validate=True, **kwargs) + q.put(worker.address) + await worker.finished() - loop.run_sync(wait_until_closed) + try: + loop.run_sync(_) finally: loop.close(all_fds=True) @@ -503,13 +502,15 @@ def run_nanny(q, scheduler_q, **kwargs): with log_errors(): with pristine_loop() as loop: scheduler_addr = scheduler_q.get() - worker = Nanny(scheduler_addr, validate=True, **kwargs) - loop.run_sync(worker.start) - q.put(worker.address) + + async def _(): + worker = await Nanny(scheduler_addr, validate=True, **kwargs) + q.put(worker.address) + await worker.finished() + try: - loop.start() + loop.run_sync(_) finally: - loop.run_sync(worker.close) loop.close(all_fds=True) @@ -533,9 +534,8 @@ def fail(): "some RPCs left active by test: %s" % (set(rpc.active) - active_before) ) - @gen.coroutine - def wait(): - yield async_wait_for( + async def wait(): + await async_wait_for( lambda: len(set(rpc.active) - active_before) == 0, timeout=active_rpc_timeout, fail_func=fail, @@ -738,23 +738,20 @@ def cluster( assert time() < start + 5, ("Workers still around after five seconds", text) -@gen.coroutine -def disconnect(addr, timeout=3, rpc_kwargs=None): +async def disconnect(addr, timeout=3, rpc_kwargs=None): rpc_kwargs = rpc_kwargs or {} - @gen.coroutine - def do_disconnect(): + async def do_disconnect(): with ignoring(EnvironmentError, CommClosedError): with rpc(addr, **rpc_kwargs) as w: - yield w.terminate(close=True) + await w.terminate(close=True) with ignoring(TimeoutError): - yield gen.with_timeout(timedelta(seconds=timeout), do_disconnect()) + await gen.with_timeout(timedelta(seconds=timeout), do_disconnect()) -@gen.coroutine -def disconnect_all(addresses, timeout=3, rpc_kwargs=None): - yield [disconnect(addr, timeout, rpc_kwargs) for addr in addresses] +async def disconnect_all(addresses, timeout=3, rpc_kwargs=None): + await asyncio.gather(*[disconnect(addr, timeout, rpc_kwargs) for addr in addresses]) def gen_test(timeout=10): @@ -783,8 +780,7 @@ def test_func(): from .worker import Worker -@gen.coroutine -def start_cluster( +async def start_cluster( nthreads, scheduler_addr, loop, @@ -793,7 +789,7 @@ def start_cluster( scheduler_kwargs={}, worker_kwargs={}, ): - s = Scheduler( + s = await Scheduler( loop=loop, validate=True, security=security, @@ -801,7 +797,6 @@ def start_cluster( host=scheduler_addr, **scheduler_kwargs ) - done = s.start() workers = [ Worker( s.address, @@ -818,31 +813,29 @@ def start_cluster( # for w in workers: # w.rpc = workers[0].rpc - yield workers + await asyncio.gather(*workers) start = time() while len(s.workers) < len(nthreads) or any( comm.comm is None for comm in s.stream_comms.values() ): - yield gen.sleep(0.01) + await gen.sleep(0.01) if time() - start > 5: - yield [w.close(timeout=1) for w in workers] - yield s.close(fast=True) + await asyncio.gather(*[w.close(timeout=1) for w in workers]) + await s.close(fast=True) raise Exception("Cluster creation timeout") - raise gen.Return((s, workers)) + return s, workers -@gen.coroutine -def end_cluster(s, workers): +async def end_cluster(s, workers): logger.debug("Closing out test cluster") - @gen.coroutine - def end_worker(w): + async def end_worker(w): with ignoring(TimeoutError, CommClosedError, EnvironmentError): - yield w.close(report=False) + await w.close(report=False) - yield [end_worker(w) for w in workers] - yield s.close() # wait until scheduler stops completely + await asyncio.gather(*[end_worker(w) for w in workers]) + await s.close() # wait until scheduler stops completely s.stop() @@ -859,7 +852,7 @@ def gen_cluster( client_kwargs={}, active_rpc_timeout=1, config={}, - check_new_threads=True, + clean_kwargs={}, ): from distributed import Client @@ -874,7 +867,7 @@ def test_foo(scheduler, worker1, worker2): end """ if ncores is not None: - warnings.warn("ncores= has moved to nthreads=") + warnings.warn("ncores= has moved to nthreads=", stacklevel=2) nthreads = ncores worker_kwargs = merge( @@ -888,15 +881,14 @@ def _(func): def test_func(): result = None workers = [] - with clean(threads=check_new_threads, timeout=active_rpc_timeout) as loop: + with clean(timeout=active_rpc_timeout, **clean_kwargs) as loop: - @gen.coroutine - def coro(): + async def coro(): with dask.config.set(config): s = False for i in range(5): try: - s, ws = yield start_cluster( + s, ws = await start_cluster( nthreads, scheduler, loop, @@ -917,7 +909,7 @@ def coro(): if s is False: raise Exception("Could not start cluster") if client: - c = yield Client( + c = await Client( s.address, loop=loop, security=security, @@ -931,36 +923,36 @@ def coro(): future = gen.with_timeout( timedelta(seconds=timeout), future ) - result = yield future + result = await future if s.validate: s.validate_state() finally: if client and c.status not in ("closing", "closed"): - yield c._close(fast=s.status == "closed") - yield end_cluster(s, workers) - yield gen.with_timeout( + await c._close(fast=s.status == "closed") + await end_cluster(s, workers) + await gen.with_timeout( timedelta(seconds=1), cleanup_global_workers() ) try: - c = yield default_client() + c = await default_client() except ValueError: pass else: - yield c._close(fast=True) + await c._close(fast=True) for i in range(5): if all(c.closed() for c in Comm._instances): break else: - yield gen.sleep(0.05) + await gen.sleep(0.05) else: L = [c for c in Comm._instances if not c.closed()] Comm._instances.clear() # raise ValueError("Unclosed Comms", L) print("Unclosed Comms", L) - raise gen.Return(result) + return result result = loop.run_sync( coro, timeout=timeout * 2 if timeout else timeout @@ -1074,11 +1066,10 @@ def wait_for(predicate, timeout, fail_func=None, period=0.001): pytest.fail("condition not reached until %s seconds" % (timeout,)) -@gen.coroutine -def async_wait_for(predicate, timeout, fail_func=None, period=0.001): +async def async_wait_for(predicate, timeout, fail_func=None, period=0.001): deadline = time() + timeout while not predicate(): - yield gen.sleep(period) + await gen.sleep(period) if time() > deadline: if fail_func is not None: fail_func() @@ -1118,20 +1109,18 @@ def requires_ipv6(test_func): requires_ipv6 = pytest.mark.skip("ipv6 required") -@gen.coroutine -def assert_can_connect(addr, timeout=None, connection_args=None): +async def assert_can_connect(addr, timeout=None, connection_args=None): """ Check that it is possible to connect to the distributed *addr* within the given *timeout*. """ if timeout is None: timeout = 0.5 - comm = yield connect(addr, timeout=timeout, connection_args=connection_args) + comm = await connect(addr, timeout=timeout, connection_args=connection_args) comm.abort() -@gen.coroutine -def assert_cannot_connect( +async def assert_cannot_connect( addr, timeout=None, connection_args=None, exception_class=EnvironmentError ): """ @@ -1141,12 +1130,11 @@ def assert_cannot_connect( if timeout is None: timeout = 0.5 with pytest.raises(exception_class): - comm = yield connect(addr, timeout=timeout, connection_args=connection_args) + comm = await connect(addr, timeout=timeout, connection_args=connection_args) comm.abort() -@gen.coroutine -def assert_can_connect_from_everywhere_4_6( +async def assert_can_connect_from_everywhere_4_6( port, timeout=None, connection_args=None, protocol="tcp" ): """ @@ -1162,11 +1150,10 @@ def assert_can_connect_from_everywhere_4_6( assert_can_connect("%s://[::1]:%d" % (protocol, port), *args), assert_can_connect("%s://[%s]:%d" % (protocol, get_ipv6(), port), *args), ] - yield futures + await asyncio.gather(*futures) -@gen.coroutine -def assert_can_connect_from_everywhere_4( +async def assert_can_connect_from_everywhere_4( port, timeout=None, connection_args=None, protocol="tcp" ): """ @@ -1182,11 +1169,10 @@ def assert_can_connect_from_everywhere_4( assert_cannot_connect("%s://[::1]:%d" % (protocol, port), *args), assert_cannot_connect("%s://[%s]:%d" % (protocol, get_ipv6(), port), *args), ] - yield futures + await asyncio.gather(*futures) -@gen.coroutine -def assert_can_connect_locally_4(port, timeout=None, connection_args=None): +async def assert_can_connect_locally_4(port, timeout=None, connection_args=None): """ Check that the local *port* is only reachable from local IPv4 addresses. """ @@ -1199,11 +1185,12 @@ def assert_can_connect_locally_4(port, timeout=None, connection_args=None): assert_cannot_connect("tcp://[::1]:%d" % port, *args), assert_cannot_connect("tcp://[%s]:%d" % (get_ipv6(), port), *args), ] - yield futures + await asyncio.gather(*futures) -@gen.coroutine -def assert_can_connect_from_everywhere_6(port, timeout=None, connection_args=None): +async def assert_can_connect_from_everywhere_6( + port, timeout=None, connection_args=None +): """ Check that the local *port* is reachable from all IPv6 addresses. """ @@ -1215,11 +1202,10 @@ def assert_can_connect_from_everywhere_6(port, timeout=None, connection_args=Non assert_can_connect("tcp://[::1]:%d" % port, *args), assert_can_connect("tcp://[%s]:%d" % (get_ipv6(), port), *args), ] - yield futures + await asyncio.gather(*futures) -@gen.coroutine -def assert_can_connect_locally_6(port, timeout=None, connection_args=None): +async def assert_can_connect_locally_6(port, timeout=None, connection_args=None): """ Check that the local *port* is only reachable from local IPv6 addresses. """ @@ -1232,7 +1218,7 @@ def assert_can_connect_locally_6(port, timeout=None, connection_args=None): ] if get_ipv6() != "::1": # No outside IPv6 connectivity? futures += [assert_cannot_connect("tcp://[%s]:%d" % (get_ipv6(), port), *args)] - yield futures + await asyncio.gather(*futures) @contextmanager @@ -1509,9 +1495,9 @@ def check_instances(): for w in Worker._instances: with ignoring(RuntimeError): # closed IOLoop - w.close(report=False, executor_wait=False) + w.loop.add_callback(w.close, report=False, executor_wait=False) if w.status == "running": - w.close() + w.loop.add_callback(w.close) Worker._instances.clear() for i in range(5): diff --git a/distributed/variable.py b/distributed/variable.py index 7b775d3327a..30ffc5bf72d 100644 --- a/distributed/variable.py +++ b/distributed/variable.py @@ -1,5 +1,6 @@ from __future__ import print_function, division, absolute_import +import asyncio from collections import defaultdict import logging import uuid @@ -58,15 +59,14 @@ def set(self, stream=None, name=None, key=None, data=None, client=None): pass else: if old["type"] == "Future" and old["value"] != key: - self.release(old["value"], name) + asyncio.ensure_future(self.release(old["value"], name)) if name not in self.variables: self.started.notify_all() self.variables[name] = record - @gen.coroutine - def release(self, key, name): + async def release(self, key, name): while self.waiting[key, name]: - yield self.waiting_conditions[name].wait() + await self.waiting_conditions[name].wait() self.scheduler.client_releases_keys(keys=[key], client="variable-%s" % name) del self.waiting[key, name] @@ -76,8 +76,7 @@ def future_release(self, name=None, key=None, token=None, client=None): if not self.waiting[key, name]: self.waiting_conditions[name].notify_all() - @gen.coroutine - def get(self, stream=None, name=None, client=None, timeout=None): + async def get(self, stream=None, name=None, client=None, timeout=None): start = time() while name not in self.variables: if timeout is not None: @@ -86,7 +85,7 @@ def get(self, stream=None, name=None, client=None, timeout=None): left = None if left and left < 0: raise gen.TimeoutError() - yield self.started.wait(timeout=left) + await self.started.wait(timeout=left) record = self.variables[name] if record["type"] == "Future": key = record["value"] @@ -99,10 +98,9 @@ def get(self, stream=None, name=None, client=None, timeout=None): msg["traceback"] = ts.exception_blame.traceback record = merge(record, msg) self.waiting[key, name].add(token) - raise gen.Return(record) + return record - @gen.coroutine - def delete(self, stream=None, name=None, client=None): + async def delete(self, stream=None, name=None, client=None): with log_errors(): try: old = self.variables[name] @@ -110,7 +108,7 @@ def delete(self, stream=None, name=None, client=None): pass else: if old["type"] == "Future": - yield self.release(old["value"], name) + await self.release(old["value"], name) del self.waiting_conditions[name] del self.variables[name] @@ -151,14 +149,13 @@ def __init__(self, name=None, client=None, maxsize=0): self.client = client or _get_global_client() self.name = name or "variable-" + uuid.uuid4().hex - @gen.coroutine - def _set(self, value): + async def _set(self, value): if isinstance(value, Future): - yield self.client.scheduler.variable_set( + await self.client.scheduler.variable_set( key=tokey(value.key), name=self.name ) else: - yield self.client.scheduler.variable_set(data=value, name=self.name) + await self.client.scheduler.variable_set(data=value, name=self.name) def set(self, value, **kwargs): """ Set the value of this variable @@ -170,9 +167,8 @@ def set(self, value, **kwargs): """ return self.client.sync(self._set, value, **kwargs) - @gen.coroutine - def _get(self, timeout=None): - d = yield self.client.scheduler.variable_get( + async def _get(self, timeout=None): + d = await self.client.scheduler.variable_get( timeout=timeout, name=self.name, client=self.client.id ) if d["type"] == "Future": @@ -189,7 +185,7 @@ def _get(self, timeout=None): ) else: value = d["value"] - raise gen.Return(value) + return value def get(self, timeout=None, **kwargs): """ Get the value of this variable """ diff --git a/distributed/worker.py b/distributed/worker.py index b052dd05799..e65c1cbccdc 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1,5 +1,6 @@ from __future__ import print_function, division, absolute_import +import asyncio import bisect from collections import defaultdict, deque from datetime import timedelta @@ -25,7 +26,6 @@ from cytoolz import pluck, partial, merge, first except ImportError: from toolz import pluck, partial, merge, first -from tornado.gen import Return from tornado import gen from tornado.ioloop import IOLoop from tornado.locks import Event @@ -48,9 +48,9 @@ from .sizeof import safe_sizeof as sizeof from .threadpoolexecutor import ThreadPoolExecutor, secede as tpe_secede from .utils import ( + get_ip, funcname, typename, - get_ip, has_arg, _maybe_complex, log_errors, @@ -704,8 +704,7 @@ def identity(self, comm=None): # External Services # ##################### - @gen.coroutine - def _register_with_scheduler(self): + async def _register_with_scheduler(self): self.periodic_callbacks["heartbeat"].stop() start = time() if self.contact_address is None: @@ -717,21 +716,21 @@ def _register_with_scheduler(self): "Timed out when connecting to scheduler '%s'", self.scheduler.address, ) - yield self.close(timeout=1) + await self.close(timeout=1) raise gen.TimeoutError( "Timed out connecting to scheduler '%s'" % self.scheduler.address ) if self.status in ("closed", "closing"): - raise gen.Return + return try: _start = time() types = {k: typename(v) for k, v in self.data.items()} - comm = yield connect( + comm = await connect( self.scheduler.address, connection_args=self.connection_args ) comm.name = "Worker->Scheduler" comm._server = weakref.ref(self) - yield comm.write( + await comm.write( dict( op="register-worker", reply=False, @@ -758,7 +757,7 @@ def _register_with_scheduler(self): if diff < 0: continue future = gen.with_timeout(timedelta(seconds=diff), future) - response = yield future + response = await future _end = time() middle = (_start + _end) / 2 self.latency = (_end - start) * 0.05 + self.latency * 0.95 @@ -767,15 +766,18 @@ def _register_with_scheduler(self): break except EnvironmentError: logger.info("Waiting to connect to: %26s", self.scheduler.address) - yield gen.sleep(0.1) + await gen.sleep(0.1) except gen.TimeoutError: logger.info("Timed out when connecting to scheduler") if response["status"] != "OK": raise ValueError("Unexpected response from register: %r" % (response,)) else: - yield [ - self.plugin_add(plugin=plugin) for plugin in response["worker-plugins"] - ] + await asyncio.gather( + *[ + self.plugin_add(plugin=plugin) + for plugin in response["worker-plugins"] + ] + ) logger.info(" Registered to: %26s", self.scheduler.address) logger.info("-" * 49) @@ -785,21 +787,20 @@ def _register_with_scheduler(self): self.periodic_callbacks["heartbeat"].start() self.loop.add_callback(self.handle_scheduler, comm) - @gen.coroutine - def heartbeat(self): + async def heartbeat(self): if not self.heartbeat_active: self.heartbeat_active = True logger.debug("Heartbeat: %s" % self.address) try: start = time() - response = yield self.scheduler.heartbeat_worker( + response = await self.scheduler.heartbeat_worker( address=self.contact_address, now=time(), metrics=self.get_metrics() ) end = time() middle = (start + end) / 2 if response["status"] == "missing": - yield self._register_with_scheduler() + await self._register_with_scheduler() return self.scheduler_delay = response["time"] - middle self.periodic_callbacks["heartbeat"].callback_time = ( @@ -812,10 +813,9 @@ def heartbeat(self): else: logger.debug("Heartbeat skipped: channel busy") - @gen.coroutine - def handle_scheduler(self, comm): + async def handle_scheduler(self, comm): try: - yield self.handle_stream( + await self.handle_stream( comm, every_cycle=[self.ensure_communicating, self.ensure_computing] ) except Exception as e: @@ -826,7 +826,7 @@ def handle_scheduler(self, comm): logger.info("Connection to scheduler broken. Reconnecting...") self.loop.add_callback(self._register_with_scheduler) else: - yield self.close(report=False) + await self.close(report=False) def start_ipython(self, comm): """Start an IPython kernel @@ -841,8 +841,7 @@ def start_ipython(self, comm): ) return self._ipython_kernel.get_connection_info() - @gen.coroutine - def upload_file(self, comm, filename=None, data=None, load=True): + async def upload_file(self, comm, filename=None, data=None, load=True): out_filename = os.path.join(self.local_dir, filename) def func(data): @@ -856,28 +855,27 @@ def func(data): if len(data) < 10000: data = func(data) else: - data = yield offload(func, data) + data = await offload(func, data) if load: try: import_file(out_filename) except Exception as e: logger.exception(e) - raise gen.Return({"status": "error", "exception": to_serialize(e)}) + return {"status": "error", "exception": to_serialize(e)} - raise gen.Return({"status": "OK", "nbytes": len(data)}) + return {"status": "OK", "nbytes": len(data)} def keys(self, comm=None): return list(self.data) - @gen.coroutine - def gather(self, comm=None, who_has=None): + async def gather(self, comm=None, who_has=None): who_has = { k: [coerce_to_address(addr) for addr in v] for k, v in who_has.items() if k not in self.data } - result, missing_keys, missing_workers = yield gather_from_workers( + result, missing_keys, missing_workers = await gather_from_workers( who_has, rpc=self.rpc, who=self.address ) if missing_keys: @@ -887,18 +885,17 @@ def gather(self, comm=None, who_has=None): missing_workers, who_has, ) - raise Return({"status": "missing-data", "keys": missing_keys}) + return {"status": "missing-data", "keys": missing_keys} else: self.update_data(data=result, report=False) - raise Return({"status": "OK"}) + return {"status": "OK"} ############# # Lifecycle # ############# - @gen.coroutine - def start(self): - assert self.status is None + async def start(self): + assert self.status is None, self.status enable_gc_diagnosis() thread_state.on_event_loop_thread = True @@ -938,10 +935,12 @@ def start(self): setproctitle("dask-worker [%s]" % self.address) - yield [self.plugin_add(plugin=plugin) for plugin in self._pending_plugins] + await asyncio.gather( + *[self.plugin_add(plugin=plugin) for plugin in self._pending_plugins] + ) self._pending_plugins = () - yield self._register_with_scheduler() + await self._register_with_scheduler() self.start_periodic_callbacks() return self @@ -950,11 +949,10 @@ def _close(self, *args, **kwargs): warnings.warn("Worker._close has moved to Worker.close", stacklevel=2) return self.close(*args, **kwargs) - @gen.coroutine - def close(self, report=True, timeout=10, nanny=True, executor_wait=True): + async def close(self, report=True, timeout=10, nanny=True, executor_wait=True): with log_errors(): if self.status in ("closed", "closing"): - yield self.finished() + await self.finished() return disable_gc_diagnosis() @@ -963,25 +961,29 @@ def close(self, report=True, timeout=10, nanny=True, executor_wait=True): logger.info("Stopping worker at %s", self.address) except ValueError: # address not available if already closed logger.info("Stopping worker") + if self.status != "running": + logger.info("Closed worker has not yet started: %s", self.status) self.status = "closing" if nanny and self.nanny: with self.rpc(self.nanny) as r: - yield r.close_gracefully() + await r.close_gracefully() setproctitle("dask-worker [closing]") - yield [ + teardowns = [ plugin.teardown(self) for plugin in self.plugins.values() if hasattr(plugin, "teardown") ] + await asyncio.gather(*[td for td in teardowns if hasattr(td, "__await__")]) + for pc in self.periodic_callbacks.values(): pc.stop() with ignoring(EnvironmentError, gen.TimeoutError): if report: - yield gen.with_timeout( + await gen.with_timeout( timedelta(seconds=timeout), self.scheduler.unregister(address=self.contact_address), ) @@ -1006,25 +1008,23 @@ def close(self, report=True, timeout=10, nanny=True, executor_wait=True): if nanny and self.nanny: with self.rpc(self.nanny) as r: - yield r.terminate() + await r.terminate() self.stop() self.rpc.close() self._closed.set() self.status = "closed" - yield ServerNode.close(self) + await ServerNode.close(self) setproctitle("dask-worker [closed]") - @gen.coroutine - def terminate(self, comm, report=True): - yield self.close(report=report) - raise Return("OK") + async def terminate(self, comm, report=True): + await self.close(report=report) + return "OK" - @gen.coroutine - def wait_until_closed(self): - yield self._closed.wait() + async def wait_until_closed(self): + await self._closed.wait() assert self.status == "closed" ################ @@ -1036,13 +1036,12 @@ def send_to_worker(self, address, msg): bcomm = BatchedSend(interval="1ms", loop=self.loop) self.stream_comms[address] = bcomm - @gen.coroutine - def batched_send_connect(): - comm = yield connect( + async def batched_send_connect(): + comm = await connect( address, connection_args=self.connection_args # TODO, serialization ) comm.name = "Worker->Worker" - yield comm.write({"op": "connection_stream"}) + await comm.write({"op": "connection_stream"}) bcomm.start(comm) @@ -1050,8 +1049,7 @@ def batched_send_connect(): self.stream_comms[address].send(msg) - @gen.coroutine - def get_data( + async def get_data( self, comm, keys=None, who=None, serializers=None, max_connections=None ): start = time() @@ -1071,7 +1069,7 @@ def get_data( max_connections is not False and self.outgoing_current_count > max_connections ): - raise gen.Return({"status": "busy"}) + return {"status": "busy"} self.outgoing_current_count += 1 data = {k: self.data[k] for k in keys if k in self.data} @@ -1091,8 +1089,8 @@ def get_data( start = time() try: - compressed = yield comm.write(msg, serializers=serializers) - response = yield comm.read(deserializers=serializers) + compressed = await comm.write(msg, serializers=serializers) + response = await comm.read(deserializers=serializers) assert response == "OK", response except EnvironmentError: logger.exception( @@ -1124,7 +1122,7 @@ def get_data( } ) - raise gen.Return("dont-reply") + return "dont-reply" ################### # Local Execution # @@ -1152,8 +1150,7 @@ def update_data(self, comm=None, data=None, report=True, serializers=None): info = {"nbytes": {k: sizeof(v) for k, v in data.items()}, "status": "OK"} return info - @gen.coroutine - def delete_data(self, comm=None, keys=None, report=True): + async def delete_data(self, comm=None, keys=None, report=True): if keys: for key in list(keys): self.log.append((key, "delete")) @@ -1167,13 +1164,12 @@ def delete_data(self, comm=None, keys=None, report=True): if report: logger.debug("Reporting loss of keys to scheduler") # TODO: this route seems to not exist? - yield self.scheduler.remove_keys( + await self.scheduler.remove_keys( address=self.contact_address, keys=list(keys) ) - raise Return("OK") + return "OK" - @gen.coroutine - def set_resources(self, **resources): + async def set_resources(self, **resources): for r, quantity in resources.items(): if r in self.total_resources: self.available_resources[r] += quantity - self.total_resources[r] @@ -1181,7 +1177,7 @@ def set_resources(self, **resources): self.available_resources[r] = quantity self.total_resources[r] = quantity - yield self.scheduler.set_resources( + await self.scheduler.set_resources( resources=self.total_resources, worker=self.contact_address ) @@ -1786,8 +1782,7 @@ def select_keys_for_gather(self, worker, dep): return deps, total_bytes - @gen.coroutine - def gather_dep(self, worker, dep, deps, total_nbytes, cause=None): + async def gather_dep(self, worker, dep, deps, total_nbytes, cause=None): if self.status != "running": return with log_errors(): @@ -1804,7 +1799,7 @@ def gather_dep(self, worker, dep, deps, total_nbytes, cause=None): logger.debug("Request %d keys", len(deps)) start = time() - response = yield get_data_from_worker( + response = await get_data_from_worker( self.rpc, deps, worker, who=self.address ) stop = time() @@ -1896,10 +1891,10 @@ def gather_dep(self, worker, dep, deps, total_nbytes, cause=None): else: # Exponential backoff to avoid hammering scheduler/worker self.repetitively_busy += 1 - yield gen.sleep(0.100 * 1.5 ** self.repetitively_busy) + await gen.sleep(0.100 * 1.5 ** self.repetitively_busy) # See if anyone new has the data - yield self.query_who_has(dep) + await self.query_who_has(dep) self.ensure_communicating() def bad_dep(self, dep): @@ -1911,8 +1906,7 @@ def bad_dep(self, dep): self.transition(key, "error") self.release_dep(dep) - @gen.coroutine - def handle_missing_dep(self, *deps, **kwargs): + async def handle_missing_dep(self, *deps, **kwargs): original_deps = list(deps) self.log.append(("handle-missing", deps)) try: @@ -1935,7 +1929,7 @@ def handle_missing_dep(self, *deps, **kwargs): self.suspicious_deps[dep], ) - who_has = yield self.scheduler.who_has(keys=list(deps)) + who_has = await self.scheduler.who_has(keys=list(deps)) who_has = {k: v for k, v in who_has.items() if v} self.update_who_has(who_has) for dep in deps: @@ -1955,7 +1949,7 @@ def handle_missing_dep(self, *deps, **kwargs): retries = kwargs.get("retries", 5) self.log.append(("handle-missing-failed", retries, deps)) if retries > 0: - yield self.handle_missing_dep(self, *deps, retries=retries - 1) + await self.handle_missing_dep(self, *deps, retries=retries - 1) else: raise finally: @@ -1967,12 +1961,11 @@ def handle_missing_dep(self, *deps, **kwargs): self.ensure_communicating() - @gen.coroutine - def query_who_has(self, *deps): + async def query_who_has(self, *deps): with log_errors(): - response = yield self.scheduler.who_has(keys=deps) + response = await self.scheduler.who_has(keys=deps) self.update_who_has(response) - raise gen.Return(response) + return response def update_who_has(self, who_has): try: @@ -2180,8 +2173,7 @@ def run(self, comm, function, args=(), wait=True, kwargs=None): def run_coroutine(self, comm, function, args=(), kwargs=None, wait=True): return run(self, comm, function=function, args=args, kwargs=kwargs, wait=wait) - @gen.coroutine - def plugin_add(self, comm=None, plugin=None, name=None): + async def plugin_add(self, comm=None, plugin=None, name=None): with log_errors(pdb=False): if isinstance(plugin, bytes): plugin = pickle.loads(plugin) @@ -2201,16 +2193,17 @@ def plugin_add(self, comm=None, plugin=None, name=None): logger.info("Starting Worker plugin %s" % name) try: result = plugin.setup(worker=self) - if isinstance(result, gen.Future): - result = yield result + if hasattr(result, "__await__"): + result = await result except Exception as e: msg = error_message(e) return msg else: return {"status": "OK"} - @gen.coroutine - def actor_execute(self, comm=None, actor=None, function=None, args=(), kwargs={}): + async def actor_execute( + self, comm=None, actor=None, function=None, args=(), kwargs={} + ): separate_thread = kwargs.pop("separate_thread", True) key = actor actor = self.actors[key] @@ -2218,9 +2211,9 @@ def actor_execute(self, comm=None, actor=None, function=None, args=(), kwargs={} name = key_split(key) + "." + function if iscoroutinefunction(func): - result = yield func(*args, **kwargs) + result = await func(*args, **kwargs) elif separate_thread: - result = yield self.executor_submit( + result = await self.executor_submit( name, apply_function_actor, args=( @@ -2236,7 +2229,7 @@ def actor_execute(self, comm=None, actor=None, function=None, args=(), kwargs={} ) else: result = func(*args, **kwargs) - raise gen.Return({"status": "OK", "result": to_serialize(result)}) + return {"status": "OK", "result": to_serialize(result)} def actor_attribute(self, comm=None, actor=None, attribute=None): value = getattr(self.actors[actor], attribute) @@ -2277,8 +2270,7 @@ def ensure_computing(self): pdb.set_trace() raise - @gen.coroutine - def execute(self, key, report=False): + async def execute(self, key, report=False): executor_error = None if self.status in ("closing", "closed"): return @@ -2312,7 +2304,7 @@ def execute(self, key, report=False): "Execute key: %s worker: %s", key, self.address ) # TODO: comment out? try: - result = yield self.executor_submit( + result = await self.executor_submit( key, apply_function, args=( @@ -2391,8 +2383,7 @@ def execute(self, key, report=False): # Administrative # ################## - @gen.coroutine - def memory_monitor(self): + async def memory_monitor(self): """ Track this process's memory usage and act accordingly If we rise above 70% memory use, start dumping data to disk. @@ -2458,7 +2449,7 @@ def memory_monitor(self): del k, v total += weight count += 1 - yield gen.moment + await gen.sleep(0) memory = proc.memory_info().rss if total > need and memory > target: # Issue a GC to ensure that the evicted data is actually @@ -2474,7 +2465,7 @@ def memory_monitor(self): ) self._memory_monitoring = False - raise gen.Return(total) + return total def cycle_profile(self): now = time() + self.scheduler_delay @@ -2962,8 +2953,7 @@ def parse_memory_limit(memory_limit, nthreads, total_cores=multiprocessing.cpu_c return memory_limit -@gen.coroutine -def get_data_from_worker( +async def get_data_from_worker( rpc, keys, worker, @@ -2988,10 +2978,10 @@ def get_data_from_worker( if deserializers is None: deserializers = rpc.deserializers - comm = yield rpc.connect(worker) + comm = await rpc.connect(worker) comm.name = "Ephemeral Worker->Worker for gather" try: - response = yield send_recv( + response = await send_recv( comm, serializers=serializers, deserializers=deserializers, @@ -3006,11 +2996,11 @@ def get_data_from_worker( raise ValueError("Unexpected response", response) else: if status == "OK": - yield comm.write("OK") + await comm.write("OK") finally: rpc.reuse(worker, comm) - raise gen.Return(response) + return response job_counter = [0] @@ -3266,8 +3256,7 @@ def weight(k, v): return sizeof(v) -@gen.coroutine -def run(server, comm, function, args=(), kwargs={}, is_coro=None, wait=True): +async def run(server, comm, function, args=(), kwargs={}, is_coro=None, wait=True): function = pickle.loads(function) if is_coro is None: is_coro = iscoroutinefunction(function) @@ -3291,7 +3280,7 @@ def run(server, comm, function, args=(), kwargs={}, is_coro=None, wait=True): result = function(*args, **kwargs) else: if wait: - result = yield function(*args, **kwargs) + result = await function(*args, **kwargs) else: server.loop.add_callback(function, *args, **kwargs) result = None @@ -3308,7 +3297,7 @@ def run(server, comm, function, args=(), kwargs={}, is_coro=None, wait=True): response = error_message(e) else: response = {"status": "OK", "result": to_serialize(result)} - raise Return(response) + return response _global_workers = Worker._instances diff --git a/docs/source/adaptive.rst b/docs/source/adaptive.rst index 0e3a4bdefbc..774ae21e4a6 100644 --- a/docs/source/adaptive.rst +++ b/docs/source/adaptive.rst @@ -60,8 +60,7 @@ the correct times. .. code-block:: python class MyCluster(object): - @gen.coroutine - def scale_up(self, n, **kwargs): + async def scale_up(self, n, **kwargs): """ Bring the total count of workers up to ``n`` @@ -72,8 +71,7 @@ the correct times. """ raise NotImplementedError() - @gen.coroutine - def scale_down(self, workers): + async def scale_down(self, workers): """ Remove ``workers`` from the cluster diff --git a/docs/source/client.rst b/docs/source/client.rst index b2b520d9614..b955d5ab504 100644 --- a/docs/source/client.rst +++ b/docs/source/client.rst @@ -149,8 +149,8 @@ keyword argument. In this case keys are randomly generated (by ``uuid4``.) .. _pure: https://toolz.readthedocs.io/en/latest/purity.html -Tornado Coroutines ------------------- +Async/await Operation +--------------------- If we are operating in an asynchronous environment then the blocking functions listed above become asynchronous equivalents. You must start your client @@ -159,11 +159,10 @@ functions. .. code-block:: python - @gen.coroutine - def f(): - client = yield Client(asynchronous=True) + async def f(): + client = await Client(asynchronous=True) future = client.submit(func, *args) - result = yield future + result = await future return result If you want to reuse the same client in asynchronous and synchronous @@ -174,10 +173,9 @@ call. client = Client() # normal blocking client - @gen.coroutine - def f(): + async def f(): futures = client.map(func, L) - results = yield client.gather(futures, asynchronous=True) + results = await client.gather(futures, asynchronous=True) return results See the :doc:`Asynchronous ` documentation for more information. diff --git a/docs/source/foundations.rst b/docs/source/foundations.rst index 7b351f8f972..62253433763 100644 --- a/docs/source/foundations.rst +++ b/docs/source/foundations.rst @@ -86,25 +86,23 @@ Server Side .. code-block:: python - from tornado import gen - from tornado.ioloop import IOLoop + import asyncio from distributed.core import Server def add(comm, x=None, y=None): # simple handler, just a function return x + y - @gen.coroutine - def stream_data(comm, interval=1): # complex handler, multiple responses + async def stream_data(comm, interval=1): # complex handler, multiple responses data = 0 while True: - yield gen.sleep(interval) + await asyncio.sleep(interval) data += 1 - yield comm.write(data) + await comm.write(data) s = Server({'add': add, 'stream_data': stream_data}) s.listen('tcp://:8888') # listen on TCP port 8888 - IOLoop.current().start() + asyncio.get_event_loop().run_forever() Client Side @@ -112,30 +110,27 @@ Client Side .. code-block:: python - from tornado import gen - from tornado.ioloop import IOLoop + import asyncio from distributed.core import connect - @gen.coroutine - def f(): - comm = yield connect('tcp://127.0.0.1:8888') - yield comm.write({'op': 'add', 'x': 1, 'y': 2}) - result = yield comm.read() - yield comm.close() + async def f(): + comm = await connect('tcp://127.0.0.1:8888') + await comm.write({'op': 'add', 'x': 1, 'y': 2}) + result = await comm.read() + await comm.close() print(result) - >>> IOLoop().run_sync(f) + >>> asyncio.get_event_loop().run_until_complete(g()) 3 - @gen.coroutine - def g(): - comm = yield connect('tcp://127.0.0.1:8888') - yield comm.write({'op': 'stream_data', 'interval': 1}) + async def g(): + comm = await connect('tcp://127.0.0.1:8888') + await comm.write({'op': 'stream_data', 'interval': 1}) while True: - result = yield comm.read() + result = await comm.read() print(result) - >>> IOLoop().run_sync(g) + >>> asyncio.get_event_loop().run_until_complete(g()) 1 2 3 @@ -152,21 +147,17 @@ with the stream data case above. .. code-block:: python - from tornado import gen - from tornado.ioloop import IOLoop + import asyncio from distributed.core import rpc - @gen.coroutine - def f(): - # comm = yield connect('tcp://127.0.0.1', 8888) - # yield comm.write({'op': 'add', 'x': 1, 'y': 2}) - # result = yield comm.read() - r = rpc('tcp://127.0.0.1:8888') - result = yield r.add(x=1, y=2) - r.close_comms() + async def f(): + # comm = await connect('tcp://127.0.0.1', 8888) + # await comm.write({'op': 'add', 'x': 1, 'y': 2}) + # result = await comm.read() + with rpc('tcp://127.0.0.1:8888') as r: + result = await r.add(x=1, y=2) print(result) - >>> IOLoop().run_sync(f) + >>> asyncio.get_event_loop().run_until_complete(f()) 3 - From 0c861360099fae9d352baa92538905b785d70a04 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 24 Jul 2019 15:11:52 -0700 Subject: [PATCH 0368/1550] Forcefully kill all processes before each test (#2882) This should hopefully help with some intermittent testing failures --- distributed/tests/test_client.py | 2 +- distributed/utils_test.py | 11 ++++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 50415971e20..bf05875fdee 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -3312,7 +3312,7 @@ def test_get_foo_lost_keys(c, s, u, v, w): client=True, Worker=Nanny, worker_kwargs={"death_timeout": "500ms"}, - clean_kwargs={"processes": False, "threads": False}, + clean_kwargs={"threads": False}, ) def test_bad_tasks_fail(c, s, a, b): f = c.submit(sys.exit, 0) diff --git a/distributed/utils_test.py b/distributed/utils_test.py index d2932857c40..51f1e907236 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -1456,12 +1456,13 @@ def check_thread_leak(): @contextmanager def check_process_leak(): - start_children = set(mp_context.active_children()) + for proc in mp_context.active_children(): + proc.terminate() yield - for i in range(50): - if not set(mp_context.active_children()) - start_children: + for i in range(100): + if not set(mp_context.active_children()): break else: sleep(0.2) @@ -1524,14 +1525,14 @@ def check_instances(): @contextmanager -def clean(threads=not WINDOWS, processes=True, instances=True, timeout=1): +def clean(threads=not WINDOWS, instances=True, timeout=1): @contextmanager def null(): yield with check_thread_leak() if threads else null(): with pristine_loop() as loop: - with check_process_leak() if processes else null(): + with check_process_leak(): with check_instances() if instances else null(): with check_active_rpc(loop, timeout): reset_config() From 909a943b67b6b472a2d77afa13a8caa61f25f972 Mon Sep 17 00:00:00 2001 From: Jim Crist Date: Thu, 25 Jul 2019 11:11:35 -0500 Subject: [PATCH 0369/1550] Cleanup Security class and configuration (#2873) The previous implementation's constructor was a bit confusing, and wasn't loading from the configuration file properly. This PR: - Simplifies the class implementation - Ensures all fields are set to the appropriate defaults in the config files (previously the defaults were in the code, not in the config). - Documents the security class - Updates tests to ensure parameters are loaded from the appropriate configuration fields. * Add Security to API docs --- distributed/cli/dask_scheduler.py | 10 +- distributed/cli/dask_worker.py | 10 +- distributed/distributed.yaml | 31 ++--- distributed/security.py | 177 ++++++++++++++--------------- distributed/tests/test_security.py | 131 ++++++++++----------- distributed/utils_test.py | 23 ++-- docs/source/tls.rst | 9 ++ 7 files changed, 206 insertions(+), 185 deletions(-) diff --git a/distributed/cli/dask_scheduler.py b/distributed/cli/dask_scheduler.py index d41b98eb310..2e6220b7d81 100755 --- a/distributed/cli/dask_scheduler.py +++ b/distributed/cli/dask_scheduler.py @@ -170,7 +170,15 @@ def main( port = 8786 sec = Security( - tls_ca_file=tls_ca_file, tls_scheduler_cert=tls_cert, tls_scheduler_key=tls_key + **{ + k: v + for k, v in [ + ("tls_ca_file", tls_ca_file), + ("tls_scheduler_cert", tls_cert), + ("tls_scheduler_key", tls_key), + ] + if v is not None + } ) if not host and (tls_ca_file or tls_cert or tls_key): diff --git a/distributed/cli/dask_worker.py b/distributed/cli/dask_worker.py index 8752cd52448..073c7c9c922 100755 --- a/distributed/cli/dask_worker.py +++ b/distributed/cli/dask_worker.py @@ -238,7 +238,15 @@ def main( dashboard = bokeh sec = Security( - tls_ca_file=tls_ca_file, tls_worker_cert=tls_cert, tls_worker_key=tls_key + **{ + k: v + for k, v in [ + ("tls_ca_file", tls_ca_file), + ("tls_worker_cert", tls_cert), + ("tls_worker_key", tls_key), + ] + if v is not None + } ) if nprocs > 1 and worker_port != 0: diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index c6c3e3d1ba2..c3f14f114f1 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -72,21 +72,22 @@ distributed: connect: 10s # time before connecting fails tcp: 30s # time before calling an unresponsive connection dead - # require-encryption: False # whether to require encryption on non-local comms - # - # tls: - # ca-file: xxx.pem - # scheduler: - # key: xxx.pem - # cert: xxx.pem - # worker: - # key: xxx.pem - # cert: xxx.pem - # client: - # key: xxx.pem - # cert: xxx.pem - # ciphers: - # ECDHE-ECDSA-AES128-GCM-SHA256 + require-encryption: False # Whether to require encryption on non-local comms + + tls: + ciphers: null # Allowed ciphers, specified as an OpenSSL cipher string. + ca-file: null # Path to a CA file, in pem format, optional + scheduler: + cert: null # Path to certificate file for scheduler. + key: null # Path to key file for scheduler. Alternatively, the key + # can be appended to the cert file above, and this field + # left blank. + worker: + key: null + cert: null + client: + key: null + cert: null ################### diff --git a/distributed/security.py b/distributed/security.py index e86c0602860..a42cbeef646 100644 --- a/distributed/security.py +++ b/distributed/security.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - try: import ssl except ImportError: @@ -8,80 +6,85 @@ import dask -_roles = ["client", "scheduler", "worker"] - -_tls_per_role_fields = ["key", "cert"] - -_tls_fields = ["ca_file", "ciphers"] - -_misc_fields = ["require_encryption"] - -_fields = set( - _misc_fields - + ["tls_%s" % field for field in _tls_fields] - + ["tls_%s_%s" % (role, field) for role in _roles for field in _tls_per_role_fields] -) - - -def _field_to_config_key(field): - return field.replace("_", "-") +__all__ = ("Security",) class Security(object): - """ - An object to gather and pass around security configuration. - Default values are gathered from the global ``config`` object and - can be overriden by constructor args. - - Supported fields: - - require_encryption - - tls_ca_file - - tls_ciphers - - tls_client_key - - tls_client_cert - - tls_scheduler_key - - tls_scheduler_cert - - tls_worker_key - - tls_worker_cert + """Security configuration for a Dask cluster. + + Default values are loaded from Dask's configuration files, and can be + overridden in the constructor. + + Parameters + ---------- + require_encryption : bool, optional + Whether TLS encryption is required for all connections. + tls_ca_file : str, optional + Path to a CA certificate file encoded in PEM format. + tls_ciphers : str, optional + An OpenSSL cipher string of allowed ciphers. If not provided, the + system defaults will be used. + tls_client_cert : str, optional + Path to a certificate file for the client, encoded in PEM format. + tls_client_key : str, optional + Path to a key file for the client, encoded in PEM format. + Alternatively, the key may be appended to the cert file, and this + parameter be omitted. + tls_scheduler_cert : str, optional + Path to a certificate file for the scheduler, encoded in PEM format. + tls_scheduler_key : str, optional + Path to a key file for the scheduler, encoded in PEM format. + Alternatively, the key may be appended to the cert file, and this + parameter be omitted. + tls_worker_cert : str, optional + Path to a certificate file for a worker, encoded in PEM format. + tls_worker_key : str, optional + Path to a key file for a worker, encoded in PEM format. + Alternatively, the key may be appended to the cert file, and this + parameter be omitted. """ - __slots__ = tuple(_fields) + __slots__ = ( + "require_encryption", + "tls_ca_file", + "tls_ciphers", + "tls_client_key", + "tls_client_cert", + "tls_scheduler_key", + "tls_scheduler_cert", + "tls_worker_key", + "tls_worker_cert", + ) def __init__(self, **kwargs): - self._init_from_dict(dask.config.config) - for k, v in kwargs.items(): - if v is not None: - setattr(self, k, v) - for k in _fields: - if not hasattr(self, k): - setattr(self, k, None) - - def _init_from_dict(self, d): - """ - Initialize Security from nested dict. - """ - self._init_fields_from_dict(d, "", _misc_fields, {}) - self._init_fields_from_dict(d, "tls", _tls_fields, _tls_per_role_fields) + extra = set(kwargs).difference(self.__slots__) + if extra: + raise TypeError("Unknown parameters: %r" % sorted(extra)) + self._set_field( + kwargs, "require_encryption", "distributed.comm.require-encryption" + ) + self._set_field(kwargs, "tls_ciphers", "distributed.comm.tls.ciphers") + self._set_field(kwargs, "tls_ca_file", "distributed.comm.tls.ca-file") + self._set_field(kwargs, "tls_client_key", "distributed.comm.tls.client.key") + self._set_field(kwargs, "tls_client_cert", "distributed.comm.tls.client.cert") + self._set_field( + kwargs, "tls_scheduler_key", "distributed.comm.tls.scheduler.key" + ) + self._set_field( + kwargs, "tls_scheduler_cert", "distributed.comm.tls.scheduler.cert" + ) + self._set_field(kwargs, "tls_worker_key", "distributed.comm.tls.worker.key") + self._set_field(kwargs, "tls_worker_cert", "distributed.comm.tls.worker.cert") - def _init_fields_from_dict(self, d, category, fields, per_role_fields): - if category: - d = d.get(category, {}) - category_prefix = category + "_" + def _set_field(self, kwargs, field, config_name): + if field in kwargs: + out = kwargs[field] else: - category_prefix = "" - for field in fields: - k = _field_to_config_key(field) - if k in d: - setattr(self, "%s%s" % (category_prefix, field), d[k]) - for role in _roles: - dd = d.get(role, {}) - for field in per_role_fields: - k = _field_to_config_key(field) - if k in dd: - setattr(self, "%s%s_%s" % (category_prefix, role, field), dd[k]) + out = dask.config.get(config_name) + setattr(self, field, out) def __repr__(self): - items = sorted((k, getattr(self, k)) for k in _fields) + items = sorted((k, getattr(self, k)) for k in self.__slots__) return ( "Security(" + ", ".join("%s=%r" % (k, v) for k, v in items if v is not None) @@ -92,26 +95,18 @@ def get_tls_config_for_role(self, role): """ Return the TLS configuration for the given role, as a flat dict. """ - return self._get_config_for_role("tls", role, _tls_fields, _tls_per_role_fields) - - def _get_config_for_role(self, category, role, fields, per_role_fields): - if role not in _roles: + if role not in {"client", "scheduler", "worker"}: raise ValueError("unknown role %r" % (role,)) - d = {} - for field in fields: - k = "%s_%s" % (category, field) - d[field] = getattr(self, k) - for field in per_role_fields: - k = "%s_%s_%s" % (category, role, field) - d[field] = getattr(self, k) - return d + return { + "ca_file": self.tls_ca_file, + "ciphers": self.tls_ciphers, + "cert": getattr(self, "tls_%s_cert" % role), + "key": getattr(self, "tls_%s_key" % role), + } def _get_tls_context(self, tls, purpose): if tls.get("ca_file") and tls.get("cert"): - try: - ctx = ssl.create_default_context(purpose=purpose, cafile=tls["ca_file"]) - except AttributeError: - raise RuntimeError("TLS functionality requires Python 2.7.9+") + ctx = ssl.create_default_context(purpose=purpose, cafile=tls["ca_file"]) ctx.verify_mode = ssl.CERT_REQUIRED # We expect a dedicated CA for the cluster and people using # IP addresses rather than hostnames @@ -126,23 +121,19 @@ def get_connection_args(self, role): Get the *connection_args* argument for a connect() call with the given *role*. """ - d = {} tls = self.get_tls_config_for_role(role) - # Ensure backwards compatibility (ssl.Purpose is Python 2.7.9+ only) - purpose = ssl.Purpose.SERVER_AUTH if hasattr(ssl, "Purpose") else None - d["ssl_context"] = self._get_tls_context(tls, purpose) - d["require_encryption"] = self.require_encryption - return d + return { + "ssl_context": self._get_tls_context(tls, ssl.Purpose.SERVER_AUTH), + "require_encryption": self.require_encryption, + } def get_listen_args(self, role): """ Get the *connection_args* argument for a listen() call with the given *role*. """ - d = {} tls = self.get_tls_config_for_role(role) - # Ensure backwards compatibility (ssl.Purpose is Python 2.7.9+ only) - purpose = ssl.Purpose.CLIENT_AUTH if hasattr(ssl, "Purpose") else None - d["ssl_context"] = self._get_tls_context(tls, purpose) - d["require_encryption"] = self.require_encryption - return d + return { + "ssl_context": self._get_tls_context(tls, ssl.Purpose.CLIENT_AUTH), + "require_encryption": self.require_encryption, + } diff --git a/distributed/tests/test_security.py b/distributed/tests/test_security.py index 8e82db1308e..7f144625b04 100644 --- a/distributed/tests/test_security.py +++ b/distributed/tests/test_security.py @@ -13,8 +13,9 @@ from distributed.comm import connect, listen from distributed.security import Security -from distributed.utils_test import new_config, get_cert, gen_test +from distributed.utils_test import get_cert, gen_test +import dask ca_file = get_cert("tls-ca-cert.pem") @@ -35,8 +36,7 @@ def test_defaults(): - with new_config({}): - sec = Security() + sec = Security() assert sec.require_encryption in (None, False) assert sec.tls_ca_file is None assert sec.tls_ciphers is None @@ -48,6 +48,13 @@ def test_defaults(): assert sec.tls_worker_cert is None +def test_constructor_errors(): + with pytest.raises(TypeError) as exc: + Security(unknown_keyword="bar") + + assert "unknown_keyword" in str(exc.value) + + def test_attribute_error(): sec = Security() assert hasattr(sec, "tls_ca_file") @@ -59,16 +66,17 @@ def test_attribute_error(): def test_from_config(): c = { - "tls": { - "ca-file": "ca.pem", - "scheduler": {"key": "skey.pem", "cert": "scert.pem"}, - "worker": {"cert": "wcert.pem"}, - "ciphers": FORCED_CIPHER, - }, - "require-encryption": True, + "distributed.comm.tls.ca-file": "ca.pem", + "distributed.comm.tls.scheduler.key": "skey.pem", + "distributed.comm.tls.scheduler.cert": "scert.pem", + "distributed.comm.tls.worker.cert": "wcert.pem", + "distributed.comm.tls.ciphers": FORCED_CIPHER, + "distributed.comm.require-encryption": True, } - with new_config(c): + + with dask.config.set(c): sec = Security() + assert sec.require_encryption is True assert sec.tls_ca_file == "ca.pem" assert sec.tls_ciphers == FORCED_CIPHER @@ -82,18 +90,16 @@ def test_from_config(): def test_kwargs(): c = { - "tls": { - "ca-file": "ca.pem", - "scheduler": {"key": "skey.pem", "cert": "scert.pem"}, - } + "distributed.comm.tls.ca-file": "ca.pem", + "distributed.comm.tls.scheduler.key": "skey.pem", + "distributed.comm.tls.scheduler.cert": "scert.pem", } - with new_config(c): + with dask.config.set(c): sec = Security( tls_scheduler_cert="newcert.pem", require_encryption=True, tls_ca_file=None ) assert sec.require_encryption is True - # None value didn't override default - assert sec.tls_ca_file == "ca.pem" + assert sec.tls_ca_file is None assert sec.tls_ciphers is None assert sec.tls_client_key is None assert sec.tls_client_cert is None @@ -104,24 +110,22 @@ def test_kwargs(): def test_repr(): - with new_config({}): - sec = Security(tls_ca_file="ca.pem", tls_scheduler_cert="scert.pem") - assert ( - repr(sec) - == "Security(tls_ca_file='ca.pem', tls_scheduler_cert='scert.pem')" - ) + sec = Security(tls_ca_file="ca.pem", tls_scheduler_cert="scert.pem") + assert ( + repr(sec) + == "Security(require_encryption=False, tls_ca_file='ca.pem', tls_scheduler_cert='scert.pem')" + ) def test_tls_config_for_role(): c = { - "tls": { - "ca-file": "ca.pem", - "scheduler": {"key": "skey.pem", "cert": "scert.pem"}, - "worker": {"cert": "wcert.pem"}, - "ciphers": FORCED_CIPHER, - } + "distributed.comm.tls.ca-file": "ca.pem", + "distributed.comm.tls.scheduler.key": "skey.pem", + "distributed.comm.tls.scheduler.cert": "scert.pem", + "distributed.comm.tls.worker.cert": "wcert.pem", + "distributed.comm.tls.ciphers": FORCED_CIPHER, } - with new_config(c): + with dask.config.set(c): sec = Security() t = sec.get_tls_config_for_role("scheduler") assert t == { @@ -158,13 +162,12 @@ def many_ciphers(ctx): assert len(ctx.get_ciphers()) > 2 # Most likely c = { - "tls": { - "ca-file": ca_file, - "scheduler": {"key": key1, "cert": cert1}, - "worker": {"cert": keycert1}, - } + "distributed.comm.tls.ca-file": ca_file, + "distributed.comm.tls.scheduler.key": key1, + "distributed.comm.tls.scheduler.cert": cert1, + "distributed.comm.tls.worker.cert": keycert1, } - with new_config(c): + with dask.config.set(c): sec = Security() d = sec.get_connection_args("scheduler") @@ -183,10 +186,10 @@ def many_ciphers(ctx): assert d.get("ssl_context") is None # With more settings - c["tls"]["ciphers"] = FORCED_CIPHER - c["require-encryption"] = True + c["distributed.comm.tls.ciphers"] = FORCED_CIPHER + c["distributed.comm.require-encryption"] = True - with new_config(c): + with dask.config.set(c): sec = Security() d = sec.get_listen_args("scheduler") @@ -212,13 +215,12 @@ def many_ciphers(ctx): assert len(ctx.get_ciphers()) > 2 # Most likely c = { - "tls": { - "ca-file": ca_file, - "scheduler": {"key": key1, "cert": cert1}, - "worker": {"cert": keycert1}, - } + "distributed.comm.tls.ca-file": ca_file, + "distributed.comm.tls.scheduler.key": key1, + "distributed.comm.tls.scheduler.cert": cert1, + "distributed.comm.tls.worker.cert": keycert1, } - with new_config(c): + with dask.config.set(c): sec = Security() d = sec.get_listen_args("scheduler") @@ -237,10 +239,10 @@ def many_ciphers(ctx): assert d.get("ssl_context") is None # With more settings - c["tls"]["ciphers"] = FORCED_CIPHER - c["require-encryption"] = True + c["distributed.comm.tls.ciphers"] = FORCED_CIPHER + c["distributed.comm.require-encryption"] = True - with new_config(c): + with dask.config.set(c): sec = Security() d = sec.get_listen_args("scheduler") @@ -270,17 +272,16 @@ def handle_comm(comm): yield comm.close() c = { - "tls": { - "ca-file": ca_file, - "scheduler": {"key": key1, "cert": cert1}, - "worker": {"cert": keycert1}, - } + "distributed.comm.tls.ca-file": ca_file, + "distributed.comm.tls.scheduler.key": key1, + "distributed.comm.tls.scheduler.cert": cert1, + "distributed.comm.tls.worker.cert": keycert1, } - with new_config(c): + with dask.config.set(c): sec = Security() - c["tls"]["ciphers"] = FORCED_CIPHER - with new_config(c): + c["distributed.comm.tls.ciphers"] = FORCED_CIPHER + with dask.config.set(c): forced_cipher_sec = Security() with listen( @@ -321,16 +322,16 @@ def handle_comm(comm): comm.abort() c = { - "tls": { - "ca-file": ca_file, - "scheduler": {"key": key1, "cert": cert1}, - "worker": {"cert": keycert1}, - } + "distributed.comm.tls.ca-file": ca_file, + "distributed.comm.tls.scheduler.key": key1, + "distributed.comm.tls.scheduler.cert": cert1, + "distributed.comm.tls.worker.cert": keycert1, } - with new_config(c): + with dask.config.set(c): sec = Security() - c["require-encryption"] = True - with new_config(c): + + c["distributed.comm.require-encryption"] = True + with dask.config.set(c): sec2 = Security() for listen_addr in ["inproc://", "tls://"]: diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 51f1e907236..8055aec3f34 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -1265,10 +1265,10 @@ def new_config(new_config): from .config import defaults config = dask.config.config - orig_config = config.copy() + orig_config = copy.deepcopy(config) try: config.clear() - config.update(defaults.copy()) + config.update(copy.deepcopy(defaults)) dask.config.update(config, new_config) initialize_logging(config) yield @@ -1332,15 +1332,18 @@ def tls_config(): ca_file = get_cert("tls-ca-cert.pem") keycert = get_cert("tls-key-cert.pem") - c = { - "tls": { - "ca-file": ca_file, - "client": {"cert": keycert}, - "scheduler": {"cert": keycert}, - "worker": {"cert": keycert}, + return { + "distributed": { + "comm": { + "tls": { + "ca-file": ca_file, + "client": {"cert": keycert}, + "scheduler": {"cert": keycert}, + "worker": {"cert": keycert}, + } + } } } - return c def tls_only_config(): @@ -1349,7 +1352,7 @@ def tls_only_config(): plain TCP communications. """ c = tls_config() - c["require-encryption"] = True + c["distributed"]["comm"]["require-encryption"] = True return c diff --git a/docs/source/tls.rst b/docs/source/tls.rst index d367dabbf7b..0c635b85761 100644 --- a/docs/source/tls.rst +++ b/docs/source/tls.rst @@ -96,3 +96,12 @@ very large data over very high speed network links. `A study of AES-NI acceleration `_ shows recent x86 CPUs can AES-encrypt more than 1 GB per second on each CPU core. + + +API +--- + +.. currentmodule:: distributed + +.. autoclass:: distributed.security.Security + :members: From 2cfaca1eeff3682239b9869daaa829ff0e002f0e Mon Sep 17 00:00:00 2001 From: Jacob Tomlinson Date: Fri, 26 Jul 2019 16:56:30 +0100 Subject: [PATCH 0370/1550] Remove unused variable in SpecCluster scale down (#2870) --- distributed/deploy/spec.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index 1ba8e7fb213..912ccc79302 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -318,8 +318,6 @@ def new_worker_spec(self): async def scale_down(self, workers): workers = set(workers) - # TODO: this is linear cost. We should be indexing by name or something - to_close = [w for w in self.workers.values() if w.address in workers] for k, v in self.workers.items(): if v.worker_address in workers: del self.worker_spec[k] From 208e0bc2d313809b41e0998b75b0996cb39390b2 Mon Sep 17 00:00:00 2001 From: Jacob Tomlinson Date: Fri, 26 Jul 2019 16:57:09 +0100 Subject: [PATCH 0371/1550] Add ProcessInterface (#2874) Added ProcessInterface class which is an interface for custom schedulers and workers to inherit from for use in SpecCluster. --- distributed/deploy/__init__.py | 2 +- distributed/deploy/spec.py | 40 ++++++++++++++++++- distributed/deploy/ssh2.py | 28 +++++-------- distributed/deploy/tests/test_spec_cluster.py | 12 +++++- 4 files changed, 62 insertions(+), 20 deletions(-) diff --git a/distributed/deploy/__init__.py b/distributed/deploy/__init__.py index 9b5e478c303..24a86e6d6d2 100644 --- a/distributed/deploy/__init__.py +++ b/distributed/deploy/__init__.py @@ -4,7 +4,7 @@ from .cluster import Cluster from .local import LocalCluster -from .spec import SpecCluster +from .spec import SpecCluster, ProcessInterface from .adaptive import Adaptive with ignoring(ImportError): diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index 912ccc79302..a6701807fa5 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -11,6 +11,44 @@ from ..security import Security +class ProcessInterface: + """ An interface for Scheduler and Worker processes for use in SpecCluster + + Parameters + ---------- + loop: + A pointer to the running loop. + + """ + + def __init__(self, loop=None): + self.address = None + self.loop = loop + self.lock = asyncio.Lock() + self.status = "created" + + def __await__(self): + async def _(): + async with self.lock: + if self.status == "created": + await self.start() + assert self.status == "running" + return self + + return _().__await__() + + async def start(self): + """ Start the process. """ + self.status = "running" + + async def close(self): + """ Close the process. """ + self.status = "closed" + + def __repr__(self): + return "<%s: status=%s>" % (type(self).__name__, self.status) + + class SpecCluster(Cluster): """ Cluster that requires a full specification of workers @@ -319,7 +357,7 @@ async def scale_down(self, workers): workers = set(workers) for k, v in self.workers.items(): - if v.worker_address in workers: + if getattr(v, "worker_address", v.address) in workers: del self.worker_spec[k] await self diff --git a/distributed/deploy/ssh2.py b/distributed/deploy/ssh2.py index 0f8823cdab8..bf471f1dee6 100644 --- a/distributed/deploy/ssh2.py +++ b/distributed/deploy/ssh2.py @@ -1,4 +1,3 @@ -import asyncio import logging import sys import warnings @@ -6,7 +5,7 @@ import asyncssh -from .spec import SpecCluster +from .spec import SpecCluster, ProcessInterface logger = logging.getLogger(__name__) @@ -16,7 +15,7 @@ ) -class Process: +class Process(ProcessInterface): """ A superclass for SSH Workers and Nannies See Also @@ -25,27 +24,20 @@ class Process: Scheduler """ - def __init__(self): - self.lock = asyncio.Lock() + def __init__(self, **kwargs): self.connection = None self.proc = None - self.status = "created" + super().__init__(**kwargs) - def __await__(self): - async def _(): - async with self.lock: - if not self.connection: - await self.start() - assert self.connection - weakref.finalize(self, self.proc.terminate) - return self - - return _().__await__() + async def start(self): + assert self.connection + weakref.finalize(self, self.proc.terminate) + await super().start() async def close(self): self.proc.terminate() self.connection.close() - self.status = "closed" + await super().close() def __repr__(self): return "" % (type(self).__name__, self.status) @@ -97,6 +89,7 @@ async def start(self): self.status = "running" break logger.debug("%s", line) + await super().start() class Scheduler(Process): @@ -135,6 +128,7 @@ async def start(self): self.address = line.split("Scheduler at:")[1].strip() break logger.debug("%s", line) + await super().start() def SSHCluster(hosts, connect_kwargs, **kwargs): diff --git a/distributed/deploy/tests/test_spec_cluster.py b/distributed/deploy/tests/test_spec_cluster.py index bb992f8b7c7..bef5a2e554c 100644 --- a/distributed/deploy/tests/test_spec_cluster.py +++ b/distributed/deploy/tests/test_spec_cluster.py @@ -1,5 +1,5 @@ from dask.distributed import SpecCluster, Worker, Client, Scheduler, Nanny -from distributed.deploy.spec import close_clusters +from distributed.deploy.spec import close_clusters, ProcessInterface from distributed.utils_test import loop, cleanup # noqa: F401 import pytest @@ -164,3 +164,13 @@ async def test_nanny_port(): scheduler=scheduler, workers=workers, asynchronous=True ) as cluster: pass + + +@pytest.mark.asyncio +async def test_spec_process(): + proc = ProcessInterface() + assert proc.status == "created" + await proc + assert proc.status == "running" + await proc.close() + assert proc.status == "closed" From 594589e7dd7112c557661a4cb890440cceaacb9c Mon Sep 17 00:00:00 2001 From: Jacob Tomlinson Date: Fri, 26 Jul 2019 19:28:27 +0100 Subject: [PATCH 0372/1550] Add Log(str) and Logs(dict) classes for nice HTML reprs (#2875) --- distributed/tests/test_utils.py | 12 ++++++++++++ distributed/utils.py | 20 ++++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/distributed/tests/test_utils.py b/distributed/tests/test_utils.py index df98bbe59e1..9e3a1d90c4b 100644 --- a/distributed/tests/test_utils.py +++ b/distributed/tests/test_utils.py @@ -7,6 +7,7 @@ import sys from time import sleep import traceback +import xml.etree.ElementTree import numpy as np import pytest @@ -18,6 +19,8 @@ from distributed.metrics import time from distributed.utils import ( All, + Log, + Logs, sync, is_kernel, ensure_ip, @@ -548,3 +551,12 @@ def test_warn_on_duration(): def test_format_bytes_compat(): # moved to dask, but exported here for compatibility from distributed.utils import format_bytes # noqa + + +def test_logs(): + d = Logs({"123": Log("Hello"), "456": Log("World!")}) + text = d._repr_html_() + for line in text.split("\n"): + assert xml.etree.ElementTree.fromstring(line) is not None + assert "Hello" in text + assert "456" in text diff --git a/distributed/utils.py b/distributed/utils.py index 6e8769979fb..227406da76b 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -1490,3 +1490,23 @@ def format_dashboard_link(host, port): def is_coroutine_function(f): return asyncio.iscoroutinefunction(f) or gen.is_coroutine_function(f) + + +class Log(str): + """ A container for logs """ + + def _repr_html_(self): + return "

                  {log}
                  ".format(log=self) + + +class Logs(dict): + """ A container for multiple logs """ + + def _repr_html_(self): + summaries = [ + "
                  {title}{log}
                  ".format( + title=title, log=log._repr_html_() + ) + for title, log in self.items() + ] + return "\n".join(summaries) From 65001f2d1c796f652bf77bbf222cf900bf1062b8 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 26 Jul 2019 12:41:25 -0700 Subject: [PATCH 0373/1550] Pass Client._asynchronous to Cluster._asynchronous (#2890) Previously when starting a client/cluster with `Client()` the underlying cluster would always be started with `asynchronous=True`. This could be troublesome in some cases. Now we pass through the `asynchronous=` value that was originally passed to the Client object. --- distributed/client.py | 4 +++- distributed/deploy/tests/test_local.py | 7 +++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/distributed/client.py b/distributed/client.py index c4b2f51426a..a830b147fb8 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -932,7 +932,9 @@ async def _start(self, timeout=no_default, **kwargs): try: self.cluster = await LocalCluster( - loop=self.loop, asynchronous=True, **self._startup_kwargs + loop=self.loop, + asynchronous=self._asynchronous, + **self._startup_kwargs ) except (OSError, socket.error) as e: if e.errno != errno.EADDRINUSE: diff --git a/distributed/deploy/tests/test_local.py b/distributed/deploy/tests/test_local.py index 1c098c2b4c5..91a792272e8 100644 --- a/distributed/deploy/tests/test_local.py +++ b/distributed/deploy/tests/test_local.py @@ -856,3 +856,10 @@ def test_dont_select_closed_worker(): cluster2.close() c2.close() + + +def test_client_cluster_synchronous(loop): + with clean(threads=False): + with Client(loop=loop, processes=False) as c: + assert not c.asynchronous + assert not c.cluster.asynchronous From 3e50887d3d24fda7ec08364deb14090e72b4d484 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 26 Jul 2019 14:23:57 -0700 Subject: [PATCH 0374/1550] Add default logs method to Spec Cluster (#2889) This gathers logs through the scheduler using existing methods, and then returns them as nicely rendered summary/details outputs. --- distributed/deploy/spec.py | 35 ++++++++++++++++++- distributed/deploy/tests/test_spec_cluster.py | 31 ++++++++++++++++ distributed/scheduler.py | 2 +- 3 files changed, 66 insertions(+), 2 deletions(-) diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index a6701807fa5..747ecc19a41 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -6,7 +6,7 @@ from .cluster import Cluster from ..core import rpc, CommClosedError -from ..utils import LoopRunner, silence_logging, ignoring +from ..utils import LoopRunner, silence_logging, ignoring, Log, Logs from ..scheduler import Scheduler from ..security import Security @@ -371,6 +371,39 @@ def __repr__(self): len(self.workers), ) + async def _logs(self, scheduler=True, workers=True): + logs = Logs() + + if scheduler: + L = await self.scheduler_comm.logs() + logs["Scheduler"] = Log("\n".join(line for level, line in L)) + + if workers: + d = await self.scheduler_comm.worker_logs(workers=workers) + for k, v in d.items(): + logs[k] = Log("\n".join(line for level, line in v)) + + return logs + + def logs(self, scheduler=True, workers=True): + """ Return logs for the scheduler and workers + + Parameters + ---------- + scheduler : boolean + Whether or not to collect logs for the scheduler + workers : boolean or Iterable[str], optional + A list of worker addresses to select. + Defaults to all workers if `True` or no workers if `False` + + Returns + ------- + logs: Dict[str] + A dictionary of logs, with one item for the scheduler and one for + each worker + """ + return self.sync(self._logs, scheduler=scheduler, workers=workers) + @atexit.register def close_clusters(): diff --git a/distributed/deploy/tests/test_spec_cluster.py b/distributed/deploy/tests/test_spec_cluster.py index bef5a2e554c..51094dca2d5 100644 --- a/distributed/deploy/tests/test_spec_cluster.py +++ b/distributed/deploy/tests/test_spec_cluster.py @@ -1,6 +1,7 @@ from dask.distributed import SpecCluster, Worker, Client, Scheduler, Nanny from distributed.deploy.spec import close_clusters, ProcessInterface from distributed.utils_test import loop, cleanup # noqa: F401 +import toolz import pytest @@ -174,3 +175,33 @@ async def test_spec_process(): assert proc.status == "running" await proc.close() assert proc.status == "closed" + + +@pytest.mark.asyncio +async def test_logs(cleanup): + worker = {"cls": Worker, "options": {"nthreads": 1}} + async with SpecCluster( + asynchronous=True, scheduler=scheduler, worker=worker + ) as cluster: + cluster.scale(2) + await cluster + + logs = await cluster.logs() + assert "Scheduler" in logs + for worker in cluster.scheduler.workers: + assert worker in logs + + assert "Registered" in str(logs) + + logs = await cluster.logs(scheduler=True, workers=False) + assert list(logs) == ["Scheduler"] + + logs = await cluster.logs(scheduler=False, workers=False) + assert list(logs) == [] + + logs = await cluster.logs(scheduler=False, workers=True) + assert set(logs) == set(cluster.scheduler.workers) + + w = toolz.first(cluster.scheduler.workers) + logs = await cluster.logs(scheduler=False, workers=[w]) + assert set(logs) == {w} diff --git a/distributed/scheduler.py b/distributed/scheduler.py index ace4d2483d5..98a66bcc55f 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -2631,7 +2631,7 @@ async def broadcast( serializers=None, ): """ Broadcast message to workers, return all results """ - if workers is None: + if workers is None or workers is True: if hosts is None: workers = list(self.workers) else: From c291175a975dfb9376724ececba10fcc9e2e43c0 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 26 Jul 2019 17:15:07 -0700 Subject: [PATCH 0375/1550] Add processes keyword back into clean (#2891) This resolves some intermittent testing failures --- distributed/tests/test_failed_workers.py | 19 +++++++++++-------- distributed/tests/test_scheduler.py | 2 +- distributed/utils_test.py | 4 ++-- 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/distributed/tests/test_failed_workers.py b/distributed/tests/test_failed_workers.py index 5465a7dd5f0..1f27e067058 100644 --- a/distributed/tests/test_failed_workers.py +++ b/distributed/tests/test_failed_workers.py @@ -266,23 +266,26 @@ def test_fast_kill(c, s, a, b): @gen_cluster(Worker=Nanny, timeout=60) def test_multiple_clients_restart(s, a, b): - e1 = yield Client(s.address, asynchronous=True) - e2 = yield Client(s.address, asynchronous=True) + c1 = yield Client(s.address, asynchronous=True) + c2 = yield Client(s.address, asynchronous=True) - x = e1.submit(inc, 1) - y = e2.submit(inc, 2) + x = c1.submit(inc, 1) + y = c2.submit(inc, 2) xx = yield x yy = yield y assert xx == 2 assert yy == 3 - yield e1._restart() + yield c1.restart() assert x.cancelled() - assert y.cancelled() + start = time() + while not y.cancelled(): + yield gen.sleep(0.01) + assert time() < start + 5 - yield e1._close(fast=True) - yield e2._close(fast=True) + yield c1.close() + yield c2.close() @gen_cluster(Worker=Nanny, timeout=60) diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 4f1b2808102..300a1a5b2b0 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -1234,7 +1234,7 @@ def test_cancel_fire_and_forget(c, s, a, b): assert not s.tasks -@gen_cluster(client=True, Worker=Nanny) +@gen_cluster(client=True, Worker=Nanny, clean_kwargs={"processes": False}) def test_log_tasks_during_restart(c, s, a, b): future = c.submit(sys.exit, 0) yield wait(future) diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 8055aec3f34..52cc54b639d 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -1528,14 +1528,14 @@ def check_instances(): @contextmanager -def clean(threads=not WINDOWS, instances=True, timeout=1): +def clean(threads=not WINDOWS, instances=True, timeout=1, processes=True): @contextmanager def null(): yield with check_thread_leak() if threads else null(): with pristine_loop() as loop: - with check_process_leak(): + with check_process_leak() if processes else null(): with check_instances() if instances else null(): with check_active_rpc(loop, timeout): reset_config() From 741ffb60b94b15d2f243fc4ad4a849df76c46092 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sun, 28 Jul 2019 09:06:45 -0700 Subject: [PATCH 0376/1550] Update black (#2901) This helps to resolve `**kwargs,` issues in Python 3.5 --- .pre-commit-config.yaml | 2 +- distributed/comm/tests/test_comms.py | 4 ++-- distributed/protocol/tests/test_numpy.py | 8 ++++---- distributed/queues.py | 2 +- distributed/scheduler.py | 2 +- distributed/tests/test_scheduler.py | 5 ++--- distributed/tests/test_security.py | 2 +- distributed/versions.py | 8 ++++---- distributed/worker.py | 4 ++-- 9 files changed, 18 insertions(+), 19 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5939ad63655..6be2fcaa3bc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/ambv/black - rev: stable + rev: cad4138050b86d1c8570b926883e32f7465c2880 hooks: - id: black language_version: python3.7 diff --git a/distributed/comm/tests/test_comms.py b/distributed/comm/tests/test_comms.py index 5d52b04a137..f2bf7778221 100644 --- a/distributed/comm/tests/test_comms.py +++ b/distributed/comm/tests/test_comms.py @@ -969,7 +969,7 @@ def check_out_false(out_value): assert deserialize(ser.header, ser.frames) == 456 assert isinstance(to_ser, list) - to_ser, = to_ser + (to_ser,) = to_ser # The to_serialize() value could have been actually serialized # or not (it's a transport-specific optimization) if isinstance(to_ser, Serialized): @@ -1021,7 +1021,7 @@ def check_out(deserialize_flag, out_value): assert isinstance(ser, Serialized) assert deserialize(ser.header, ser.frames) == _uncompressible assert isinstance(to_ser, list) - to_ser, = to_ser + (to_ser,) = to_ser # The to_serialize() value could have been actually serialized # or not (it's a transport-specific optimization) if isinstance(to_ser, Serialized): diff --git a/distributed/protocol/tests/test_numpy.py b/distributed/protocol/tests/test_numpy.py index ede0eded3cf..eb39b57c351 100644 --- a/distributed/protocol/tests/test_numpy.py +++ b/distributed/protocol/tests/test_numpy.py @@ -107,7 +107,7 @@ def test_dumps_serialize_numpy(x): ], ) def test_serialize_numpy_ma_masked_array(x): - y, = loads(dumps([to_serialize(x)])) + (y,) = loads(dumps([to_serialize(x)])) assert x.data.dtype == y.data.dtype np.testing.assert_equal(x.data, y.data) np.testing.assert_equal(x.mask, y.mask) @@ -115,7 +115,7 @@ def test_serialize_numpy_ma_masked_array(x): def test_serialize_numpy_ma_masked(): - y, = loads(dumps([to_serialize(np.ma.masked)])) + (y,) = loads(dumps([to_serialize(np.ma.masked)])) assert y is np.ma.masked @@ -126,8 +126,8 @@ def test_dumps_serialize_numpy_custom_dtype(): rational = test_rational.rational try: builtins.rational = ( - rational - ) # Work around https://github.com/numpy/numpy/issues/9160 + rational # Work around https://github.com/numpy/numpy/issues/9160 + ) x = np.array([1], dtype=rational) header, frames = serialize(x) y = deserialize(header, frames) diff --git a/distributed/queues.py b/distributed/queues.py index 12bd15b6318..b97c317ac58 100644 --- a/distributed/queues.py +++ b/distributed/queues.py @@ -79,7 +79,7 @@ async def put( else: record = {"type": "msgpack", "value": data} if timeout is not None: - timeout = datetime.timedelta(seconds=(timeout)) + timeout = datetime.timedelta(seconds=timeout) await self.queues[name].put(record, timeout=timeout) def future_release(self, name=None, key=None, client=None): diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 98a66bcc55f..999d4802730 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1643,7 +1643,7 @@ def update_graph( for key in set(priority) & touched_keys: ts = self.tasks[key] if ts.priority is None: - ts.priority = (-user_priority.get(key, 0), generation, priority[key]) + ts.priority = (-(user_priority.get(key, 0)), generation, priority[key]) # Ensure all runnables have a priority runnables = [ts for ts in touched_tasks if ts.run_spec] diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 300a1a5b2b0..6401cdd4b94 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -433,11 +433,10 @@ def test_filtered_communication(s, a, b): "keys": ["z"], } ) - - msg, = yield c.read() + (msg,) = yield c.read() assert msg["op"] == "key-in-memory" assert msg["key"] == "y" - msg, = yield f.read() + (msg,) = yield f.read() assert msg["op"] == "key-in-memory" assert msg["key"] == "z" diff --git a/distributed/tests/test_security.py b/distributed/tests/test_security.py index 7f144625b04..28438c6f359 100644 --- a/distributed/tests/test_security.py +++ b/distributed/tests/test_security.py @@ -306,7 +306,7 @@ def handle_comm(comm): listener.contact_address, connection_args=forced_cipher_sec.get_connection_args("worker"), ) - cipher, _, _, = comm.extra_info["cipher"] + cipher, _, _ = comm.extra_info["cipher"] assert cipher in [FORCED_CIPHER] + TLS_13_CIPHERS comm.abort() diff --git a/distributed/versions.py b/distributed/versions.py index 2baa47a1d8f..d6a44096796 100644 --- a/distributed/versions.py +++ b/distributed/versions.py @@ -53,10 +53,10 @@ def get_system_info(): host = [ ("python", "%d.%d.%d.%s.%s" % sys.version_info[:]), ("python-bits", struct.calcsize("P") * 8), - ("OS", "%s" % (sysname)), - ("OS-release", "%s" % (release)), - ("machine", "%s" % (machine)), - ("processor", "%s" % (processor)), + ("OS", "%s" % sysname), + ("OS-release", "%s" % release), + ("machine", "%s" % machine), + ("processor", "%s" % processor), ("byteorder", "%s" % sys.byteorder), ("LC_ALL", "%s" % os.environ.get("LC_ALL", "None")), ("LANG", "%s" % os.environ.get("LANG", "None")), diff --git a/distributed/worker.py b/distributed/worker.py index e65c1cbccdc..2b762a4e751 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -396,9 +396,9 @@ def __init__( ("flight", "memory"): self.transition_dep_flight_memory, } - self.incoming_transfer_log = deque(maxlen=(100000)) + self.incoming_transfer_log = deque(maxlen=100000) self.incoming_count = 0 - self.outgoing_transfer_log = deque(maxlen=(100000)) + self.outgoing_transfer_log = deque(maxlen=100000) self.outgoing_count = 0 self.outgoing_current_count = 0 self.repetitively_busy = 0 From 5437975b9c8f5bccb8ee0d39de80526412f83903 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sun, 28 Jul 2019 11:10:32 -0700 Subject: [PATCH 0377/1550] Move Worker.local_dir attribute to Worker.local_directory (#2900) This matches the term elsewhere, including the Scheduler and the dask-worker CLI --- distributed/cli/dask_worker.py | 2 +- distributed/nanny.py | 23 +++++++++++++--- distributed/tests/test_worker.py | 46 ++++++++++++++++---------------- distributed/worker.py | 43 ++++++++++++++++++----------- 4 files changed, 70 insertions(+), 44 deletions(-) diff --git a/distributed/cli/dask_worker.py b/distributed/cli/dask_worker.py index 073c7c9c922..08a1f47d1eb 100755 --- a/distributed/cli/dask_worker.py +++ b/distributed/cli/dask_worker.py @@ -357,7 +357,7 @@ def del_pid_file(): resources=resources, memory_limit=memory_limit, reconnect=reconnect, - local_dir=local_directory, + local_directory=local_directory, death_timeout=death_timeout, preload=preload, preload_argv=preload_argv, diff --git a/distributed/nanny.py b/distributed/nanny.py index 6c859115242..228e37c2839 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -66,7 +66,8 @@ def __init__( nthreads=None, ncores=None, loop=None, - local_dir="dask-worker-space", + local_dir=None, + local_directory="dask-worker-space", services=None, name=None, memory_limit="auto", @@ -135,7 +136,11 @@ def __init__( "distributed.worker.memory.terminate" ) - self.local_dir = local_dir + if local_dir is not None: + warnings.warn("The local_dir keyword has moved to local_directory") + local_directory = local_dir + + self.local_directory = local_directory self.services = services self.name = name @@ -221,6 +226,12 @@ def worker_address(self): def worker_dir(self): return None if self.process is None else self.process.worker_dir + @property + def local_dir(self): + """ For API compatibility with Nanny """ + warnings.warn("The local_dir attribute has moved to local_directory") + return self.local_directory + async def start(self): """ Start nanny, start local process, start watching """ self.listen(self._start_address, listen_args=self.listen_args) @@ -268,7 +279,7 @@ async def instantiate(self, comm=None): worker_kwargs = dict( scheduler_ip=self.scheduler_addr, nthreads=self.nthreads, - local_dir=self.local_dir, + local_directory=self.local_directory, services=self.services, nanny=self.address, name=self.name, @@ -667,7 +678,11 @@ async def run(): else: assert worker.address init_result_q.put( - {"address": worker.address, "dir": worker.local_dir, "uid": uid} + { + "address": worker.address, + "dir": worker.local_directory, + "uid": uid, + } ) init_result_q.close() await worker.wait_until_closed() diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 1c1e70fcba7..96c673bf69a 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -65,7 +65,7 @@ def test_worker_nthreads(): try: assert w.executor._max_workers == multiprocessing.cpu_count() finally: - shutil.rmtree(w.local_dir) + shutil.rmtree(w.local_directory) @gen_cluster() @@ -179,9 +179,9 @@ def dont_test_delete_data_with_missing_worker(c, a, b): @gen_cluster(client=True) def test_upload_file(c, s, a, b): - assert not os.path.exists(os.path.join(a.local_dir, "foobar.py")) - assert not os.path.exists(os.path.join(b.local_dir, "foobar.py")) - assert a.local_dir != b.local_dir + assert not os.path.exists(os.path.join(a.local_directory, "foobar.py")) + assert not os.path.exists(os.path.join(b.local_directory, "foobar.py")) + assert a.local_directory != b.local_directory with rpc(a.address) as aa, rpc(b.address) as bb: yield [ @@ -189,8 +189,8 @@ def test_upload_file(c, s, a, b): bb.upload_file(filename="foobar.py", data="x = 123"), ] - assert os.path.exists(os.path.join(a.local_dir, "foobar.py")) - assert os.path.exists(os.path.join(b.local_dir, "foobar.py")) + assert os.path.exists(os.path.join(a.local_directory, "foobar.py")) + assert os.path.exists(os.path.join(b.local_directory, "foobar.py")) def g(): import foobar @@ -203,7 +203,7 @@ def g(): yield c.close() yield s.close(close_workers=True) - assert not os.path.exists(os.path.join(a.local_dir, "foobar.py")) + assert not os.path.exists(os.path.join(a.local_directory, "foobar.py")) @pytest.mark.skip(reason="don't yet support uploading pyc files") @@ -239,14 +239,14 @@ def g(): def test_upload_egg(c, s, a, b): eggname = "testegg-1.0.0-py3.4.egg" local_file = __file__.replace("test_worker.py", eggname) - assert not os.path.exists(os.path.join(a.local_dir, eggname)) - assert not os.path.exists(os.path.join(b.local_dir, eggname)) - assert a.local_dir != b.local_dir + assert not os.path.exists(os.path.join(a.local_directory, eggname)) + assert not os.path.exists(os.path.join(b.local_directory, eggname)) + assert a.local_directory != b.local_directory yield c.upload_file(filename=local_file) - assert os.path.exists(os.path.join(a.local_dir, eggname)) - assert os.path.exists(os.path.join(b.local_dir, eggname)) + assert os.path.exists(os.path.join(a.local_directory, eggname)) + assert os.path.exists(os.path.join(b.local_directory, eggname)) def g(x): import testegg @@ -261,21 +261,21 @@ def g(x): yield s.close() yield a.close() yield b.close() - assert not os.path.exists(os.path.join(a.local_dir, eggname)) + assert not os.path.exists(os.path.join(a.local_directory, eggname)) @gen_cluster(client=True) def test_upload_pyz(c, s, a, b): pyzname = "mytest.pyz" local_file = __file__.replace("test_worker.py", pyzname) - assert not os.path.exists(os.path.join(a.local_dir, pyzname)) - assert not os.path.exists(os.path.join(b.local_dir, pyzname)) - assert a.local_dir != b.local_dir + assert not os.path.exists(os.path.join(a.local_directory, pyzname)) + assert not os.path.exists(os.path.join(b.local_directory, pyzname)) + assert a.local_directory != b.local_directory yield c.upload_file(filename=local_file) - assert os.path.exists(os.path.join(a.local_dir, pyzname)) - assert os.path.exists(os.path.join(b.local_dir, pyzname)) + assert os.path.exists(os.path.join(a.local_directory, pyzname)) + assert os.path.exists(os.path.join(b.local_directory, pyzname)) def g(x): from mytest import mytest @@ -290,7 +290,7 @@ def g(x): yield s.close() yield a.close() yield b.close() - assert not os.path.exists(os.path.join(a.local_dir, pyzname)) + assert not os.path.exists(os.path.join(a.local_directory, pyzname)) @pytest.mark.xfail(reason="Still lose time to network I/O") @@ -805,7 +805,7 @@ def test_heartbeats(c, s, a, b): def test_worker_dir(worker): with tmpfile() as fn: - @gen_cluster(client=True, worker_kwargs={"local_dir": fn}) + @gen_cluster(client=True, worker_kwargs={"local_directory": fn}) def test_worker_dir(c, s, a, b): directories = [w.local_directory for w in s.workers.values()] assert all(d.startswith(fn) for d in directories) @@ -1414,12 +1414,12 @@ def __init__(self, x, y): @gen_cluster(nthreads=[]) -def test_local_dir(s): +def test_local_directory(s): with tmpfile() as fn: with dask.config.set(temporary_directory=fn): w = yield Worker(s.address) - assert w.local_dir.startswith(fn) - assert "dask-worker-space" in w.local_dir + assert w.local_directory.startswith(fn) + assert "dask-worker-space" in w.local_directory @pytest.mark.skipif( diff --git a/distributed/worker.py b/distributed/worker.py index 2b762a4e751..08927eb741d 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -118,7 +118,7 @@ class Worker(ServerNode): Number of nthreads used by this worker process * **executor:** ``concurrent.futures.ThreadPoolExecutor``: Executor used to perform computation - * **local_dir:** ``path``: + * **local_directory:** ``path``: Path on local machine to store temporary files * **scheduler:** ``rpc``: Location of scheduler. See ``.ip/.port`` attributes. @@ -233,7 +233,7 @@ class Worker(ServerNode): The object to use for storage, builds a disk-backed LRU dict by default nthreads: int, optional loop: tornado.ioloop.IOLoop - local_dir: str, optional + local_directory: str, optional Directory where we place local resources name: str, optional memory_limit: int, float, string @@ -282,6 +282,7 @@ def __init__( nthreads=None, loop=None, local_dir=None, + local_directory=None, services=None, service_ports=None, service_kwargs=None, @@ -457,11 +458,15 @@ def __init__( if silence_logs: silence_logging(level=silence_logs) - if local_dir is None: - local_dir = dask.config.get("temporary-directory") or os.getcwd() - if not os.path.exists(local_dir): - os.mkdir(local_dir) - local_dir = os.path.join(local_dir, "dask-worker-space") + if local_dir is not None: + warnings.warn("The local_dir keyword has moved to local_directory") + local_directory = local_dir + + if local_directory is None: + local_directory = dask.config.get("temporary-directory") or os.getcwd() + if not os.path.exists(local_directory): + os.mkdir(local_directory) + local_directory = os.path.join(local_directory, "dask-worker-space") with warn_on_duration( "1s", @@ -470,9 +475,9 @@ def __init__( "Consider specifying a local-directory to point workers to write " "scratch data to a local disk.", ): - self._workspace = WorkSpace(os.path.abspath(local_dir)) + self._workspace = WorkSpace(os.path.abspath(local_directory)) self._workdir = self._workspace.new_work_dir(prefix="worker-") - self.local_dir = self._workdir.dir_path + self.local_directory = self._workdir.dir_path self.security = security or Security() assert isinstance(self.security, Security) @@ -515,7 +520,7 @@ def __init__( from zict import Buffer, File, Func except ImportError: raise ImportError("Please `pip install zict` for spill-to-disk workers") - path = os.path.join(self.local_dir, "storage") + path = os.path.join(self.local_directory, "storage") storage = Func( partial(serialize_bytelist, on_error="raise"), deserialize_bytes, @@ -545,8 +550,8 @@ def __init__( self.heartbeat_active = False self._ipython_kernel = None - if self.local_dir not in sys.path: - sys.path.insert(0, self.local_dir) + if self.local_directory not in sys.path: + sys.path.insert(0, self.local_directory) self.services = {} self.service_specs = services or {} @@ -678,6 +683,12 @@ def worker_address(self): """ For API compatibility with Nanny """ return self.address + @property + def local_dir(self): + """ For API compatibility with Nanny """ + warnings.warn("The local_dir attribute has moved to local_directory") + return self.local_directory + def get_metrics(self): core = dict( executing=len(self.executing), @@ -743,7 +754,7 @@ async def _register_with_scheduler(self): now=time(), resources=self.total_resources, memory_limit=self.memory_limit, - local_directory=self.local_dir, + local_directory=self.local_directory, services=self.service_ports, nanny=self.nanny, pid=os.getpid(), @@ -842,7 +853,7 @@ def start_ipython(self, comm): return self._ipython_kernel.get_connection_info() async def upload_file(self, comm, filename=None, data=None, load=True): - out_filename = os.path.join(self.local_dir, filename) + out_filename = os.path.join(self.local_directory, filename) def func(data): if isinstance(data, unicode): @@ -909,7 +920,7 @@ async def start(self): preload_modules( self.preload, parameter=self, - file_dir=self.local_dir, + file_dir=self.local_directory, argv=self.preload_argv, ) # Services listen on all addresses @@ -931,7 +942,7 @@ async def start(self): logger.info(" Threads: %26d", self.nthreads) if self.memory_limit: logger.info(" Memory: %26s", format_bytes(self.memory_limit)) - logger.info(" Local Directory: %26s", self.local_dir) + logger.info(" Local Directory: %26s", self.local_directory) setproctitle("dask-worker [%s]" % self.address) From 7810b731664ec179dc082657b6be725505d50225 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sun, 28 Jul 2019 11:11:42 -0700 Subject: [PATCH 0378/1550] Link from TapTools to worker info pages in dashboard (#2894) --- distributed/dashboard/scheduler.py | 26 +++++-------------- .../dashboard/tests/test_scheduler_bokeh.py | 2 +- 2 files changed, 7 insertions(+), 21 deletions(-) diff --git a/distributed/dashboard/scheduler.py b/distributed/dashboard/scheduler.py index 013edb39ace..e41862335cd 100644 --- a/distributed/dashboard/scheduler.py +++ b/distributed/dashboard/scheduler.py @@ -132,8 +132,7 @@ def __init__(self, scheduler, **kwargs): "y": [1, 2], "ms": [1, 2], "color": ["red", "blue"], - "dashboard_port": ["", ""], - "dashboard_host": ["", ""], + "escaped_worker": ["a", "b"], } ) @@ -155,9 +154,7 @@ def __init__(self, scheduler, **kwargs): # fig.xaxis[0].formatter = NumeralTickFormatter(format='0.0s') fig.x_range.start = 0 - tap = TapTool( - callback=OpenURL(url="./proxy/@dashboard_port/@dashboard_host/status") - ) + tap = TapTool(callback=OpenURL(url="./info/worker/@escaped_worker.html")) hover = HoverTool() hover.tooltips = "@worker : @occupancy s." @@ -171,9 +168,6 @@ def update(self): with log_errors(): workers = list(self.scheduler.workers.values()) - dashboard_host = [ws.host for ws in workers] - dashboard_port = [ws.services.get("dashboard", "") for ws in workers] - y = list(range(len(workers))) occupancy = [ws.occupancy for ws in workers] ms = [occ * 1000 for occ in occupancy] @@ -202,8 +196,7 @@ def update(self): "worker": [ws.address for ws in workers], "ms": ms, "color": color, - "dashboard_host": dashboard_host, - "dashboard_port": dashboard_port, + "escaped_worker": [escape.url_escape(ws.address) for ws in workers], "x": x, "y": y, } @@ -321,8 +314,7 @@ def __init__(self, scheduler, width=600, **kwargs): "worker": ["a", "b"], "y": [1, 2], "nbytes-color": ["blue", "blue"], - "dashboard_port": ["", ""], - "dashboard_host": ["", ""], + "escaped_worker": ["a", "b"], } ) @@ -374,9 +366,7 @@ def __init__(self, scheduler, width=600, **kwargs): fig.ygrid.visible = False tap = TapTool( - callback=OpenURL( - url="./proxy/@dashboard_port/@dashboard_host/status" - ) + callback=OpenURL(url="./info/worker/@escaped_worker.html") ) fig.add_tools(tap) @@ -404,9 +394,6 @@ def update(self): with log_errors(): workers = list(self.scheduler.workers.values()) - dashboard_host = [ws.host for ws in workers] - dashboard_port = [ws.services.get("dashboard", "") for ws in workers] - y = list(range(len(workers))) nprocessing = [len(ws.processing) for ws in workers] processing_color = [] @@ -449,9 +436,8 @@ def update(self): "nbytes-half": [nb / 2 for nb in nbytes], "nbytes-color": nbytes_color, "nbytes_text": nbytes_text, - "dashboard_host": dashboard_host, - "dashboard_port": dashboard_port, "worker": [ws.address for ws in workers], + "escaped_worker": [escape.url_escape(ws.address) for ws in workers], "y": y, } diff --git a/distributed/dashboard/tests/test_scheduler_bokeh.py b/distributed/dashboard/tests/test_scheduler_bokeh.py index 3c7f85dc89a..8544e72d9f4 100644 --- a/distributed/dashboard/tests/test_scheduler_bokeh.py +++ b/distributed/dashboard/tests/test_scheduler_bokeh.py @@ -92,7 +92,7 @@ def test_basic(c, s, a, b): data = ss.source.data assert len(first(data.values())) if component is Occupancy: - assert all(addr == "127.0.0.1" for addr in data["dashboard_host"]) + assert all("127.0.0.1" in addr for addr in data["escaped_worker"]) @gen_cluster(client=True) From d1263246298a1978a3dd19f009b806fe24baaf77 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sun, 28 Jul 2019 11:11:56 -0700 Subject: [PATCH 0379/1550] Avoid exception in Client._ensure_connected if closed (#2893) --- distributed/client.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/distributed/client.py b/distributed/client.py index a830b147fb8..c84a8160ad6 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -1026,6 +1026,11 @@ async def _ensure_connected(self, timeout=None): await comm.write( {"op": "register-client", "client": self.id, "reply": False} ) + except Exception as e: + if self.status == "closed": + return + else: + raise finally: self._connecting_to_scheduler = False if timeout is not None: From e5f6d48db6c70080833293050e850afb7079fb73 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sun, 28 Jul 2019 11:13:01 -0700 Subject: [PATCH 0380/1550] Convert Pythonic kwargs to CLI Keywords for SSHCluster (#2898) --- distributed/deploy/ssh2.py | 79 +++++++++++++++++++++------ distributed/deploy/tests/test_ssh2.py | 24 +++++++- distributed/utils.py | 43 ++++++++++++++- 3 files changed, 126 insertions(+), 20 deletions(-) diff --git a/distributed/deploy/ssh2.py b/distributed/deploy/ssh2.py index bf471f1dee6..5f9bb4f9a64 100644 --- a/distributed/deploy/ssh2.py +++ b/distributed/deploy/ssh2.py @@ -6,6 +6,9 @@ import asyncssh from .spec import SpecCluster, ProcessInterface +from ..utils import cli_keywords +from ..scheduler import Scheduler as _Scheduler +from ..worker import Worker as _Worker logger = logging.getLogger(__name__) @@ -54,15 +57,25 @@ class Worker(Process): The hostname where we should run this worker connect_kwargs: dict kwargs to be passed to asyncssh connections - kwargs: - TODO + kwargs: dict + These will be passed through the dask-worker CLI to the + dask.distributed.Worker class """ - def __init__(self, scheduler: str, address: str, connect_kwargs: dict, **kwargs): + def __init__( + self, + scheduler: str, + address: str, + connect_kwargs: dict, + kwargs: dict, + loop=None, + name=None, + ): self.address = address self.scheduler = scheduler self.connect_kwargs = connect_kwargs self.kwargs = kwargs + self.name = name super().__init__() @@ -75,15 +88,19 @@ async def start(self): "-m", "distributed.cli.dask_worker", self.scheduler, - "--name", # we need to have name for SpecCluster - str(self.kwargs["name"]), + "--name", + str(self.name), ] + + cli_keywords(self.kwargs, cls=_Worker) ) ) # We watch stderr in order to get the address, then we return while True: line = await self.proc.stderr.readline() + if not line.strip(): + raise Exception("Worker failed to start") + logger.info(line.strip()) if "worker at" in line: self.address = line.split("worker at:")[1].strip() self.status = "running" @@ -101,11 +118,12 @@ class Scheduler(Process): The hostname where we should run this worker connect_kwargs: dict kwargs to be passed to asyncssh connections - kwargs: - TODO + kwargs: dict + These will be passed through the dask-scheduler CLI to the + dask.distributed.Scheduler class """ - def __init__(self, address: str, connect_kwargs: dict, **kwargs): + def __init__(self, address: str, connect_kwargs: dict, kwargs: dict, loop=None): self.address = address self.kwargs = kwargs self.connect_kwargs = connect_kwargs @@ -118,12 +136,18 @@ async def start(self): self.connection = await asyncssh.connect(self.address, **self.connect_kwargs) self.proc = await self.connection.create_process( - " ".join([sys.executable, "-m", "distributed.cli.dask_scheduler"]) + " ".join( + [sys.executable, "-m", "distributed.cli.dask_scheduler"] + + cli_keywords(self.kwargs, cls=_Scheduler) + ) ) # We watch stderr in order to get the address, then we return while True: line = await self.proc.stderr.readline() + if not line.strip(): + raise Exception("Worker failed to start") + logger.info(line.strip()) if "Scheduler at" in line: self.address = line.split("Scheduler at:")[1].strip() break @@ -131,7 +155,9 @@ async def start(self): await super().start() -def SSHCluster(hosts, connect_kwargs, **kwargs): +def SSHCluster( + hosts, connect_kwargs={}, worker_kwargs={}, scheduler_kwargs={}, **kwargs +): """ Deploy a Dask cluster using SSH Parameters @@ -140,25 +166,44 @@ def SSHCluster(hosts, connect_kwargs, **kwargs): List of hostnames or addresses on which to launch our cluster The first will be used for the scheduler and the rest for workers connect_kwargs: + Keywords to pass through to asyncssh.connect known_hosts: List[str] or None The list of keys which will be used to validate the server host key presented during the SSH handshake. If this is not specified, the keys will be looked up in the file .ssh/known_hosts. If this is explicitly set to None, server host key validation will be disabled. - TODO - kwargs: - TODO - ---- - This doesn't handle any keyword arguments yet. It is a proof of concept + scheduler_kwargs: + Keywords to pass on to dask-scheduler + worker_kwargs: + Keywords to pass on to dask-worker + + Examples + -------- + >>> from dask.distributed import Client + >>> from distributed.deploy.ssh2 import SSHCluster # experimental for now + >>> cluster = SSHCluster( + ... ["localhost"] * 4, + ... connect_kwargs={"known_hosts": None}, + ... worker_kwargs={"nthreads": 2}, + ... scheduler_kwargs={"port": 0, "dashboard_address": ":8797"}) + >>> client = Client(cluster) """ scheduler = { "cls": Scheduler, - "options": {"address": hosts[0], "connect_kwargs": connect_kwargs}, + "options": { + "address": hosts[0], + "connect_kwargs": connect_kwargs, + "kwargs": scheduler_kwargs, + }, } workers = { i: { "cls": Worker, - "options": {"address": host, "connect_kwargs": connect_kwargs}, + "options": { + "address": host, + "connect_kwargs": connect_kwargs, + "kwargs": worker_kwargs, + }, } for i, host in enumerate(hosts[1:]) } diff --git a/distributed/deploy/tests/test_ssh2.py b/distributed/deploy/tests/test_ssh2.py index beb1c6ef91e..07415ed47e6 100644 --- a/distributed/deploy/tests/test_ssh2.py +++ b/distributed/deploy/tests/test_ssh2.py @@ -9,9 +9,31 @@ @pytest.mark.asyncio async def test_basic(): async with SSHCluster( - ["127.0.0.1"] * 3, connect_kwargs=dict(known_hosts=None), asynchronous=True + ["127.0.0.1"] * 3, + connect_kwargs=dict(known_hosts=None), + asynchronous=True, + scheduler_kwargs={"port": 0}, ) as cluster: assert len(cluster.workers) == 2 async with Client(cluster, asynchronous=True) as client: result = await client.submit(lambda x: x + 1, 10) assert result == 11 + + +@pytest.mark.asyncio +async def test_keywords(): + async with SSHCluster( + ["127.0.0.1"] * 3, + connect_kwargs=dict(known_hosts=None), + asynchronous=True, + worker_kwargs={"nthreads": 2, "memory_limit": "2 GiB"}, + scheduler_kwargs={"idle_timeout": "5s", "port": 0}, + ) as cluster: + async with Client(cluster, asynchronous=True) as client: + assert ( + await client.run_on_scheduler( + lambda dask_scheduler: dask_scheduler.idle_timeout + ) + ) == 5 + d = client.scheduler_info()["workers"] + assert all(v["nthreads"] == 2 for v in d.values()) diff --git a/distributed/utils.py b/distributed/utils.py index 227406da76b..71e4a09d2a1 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -803,7 +803,7 @@ def tokey(o): -------- >>> tokey(b'x') - 'x' + b'x' >>> tokey('x') 'x' >>> tokey(1) @@ -1210,7 +1210,7 @@ def parse_timedelta(s, default="seconds"): >>> parse_timedelta('300ms') 0.3 >>> parse_timedelta(timedelta(seconds=3)) # also supports timedeltas - 3 + 3.0 """ if s is None: return None @@ -1510,3 +1510,42 @@ def _repr_html_(self): for title, log in self.items() ] return "\n".join(summaries) + + +def cli_keywords(d: dict, cls=None): + """ Convert a kwargs dictionary into a list of CLI keywords + + Parameters + ---------- + d: dict + The keywords to convert + cls: callable + The callable that consumes these terms to check them for validity + + Examples + -------- + >>> cli_keywords({"x": 123, "save_file": "foo.txt"}) + ['--x', '123', '--save-file', 'foo.txt'] + + >>> from dask.distributed import Worker + >>> cli_keywords({"x": 123}, Worker) + Traceback (most recent call last): + ... + ValueError: Class distributed.worker.Worker does not support keyword x + """ + if cls: + for k in d: + if not has_keyword(cls, k): + raise ValueError( + "Class %s does not support keyword %s" % (typename(cls), k) + ) + + def convert_value(v): + out = str(v) + if " " in out and "'" not in out and '"' not in out: + out = '"' + out + '"' + return out + + return sum( + [["--" + k.replace("_", "-"), convert_value(v)] for k, v in d.items()], [] + ) From 3757cd497ad48cca737af291493eca5e660adcfe Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sun, 28 Jul 2019 15:19:55 -0700 Subject: [PATCH 0381/1550] Use kwargs in CLI (#2899) --- .travis.yml | 2 +- distributed/cli/dask_scheduler.py | 14 ++------------ distributed/cli/dask_worker.py | 29 +++-------------------------- distributed/comm/ucx.py | 4 +++- distributed/protocol/cupy.py | 2 +- distributed/protocol/numba.py | 2 +- 6 files changed, 11 insertions(+), 42 deletions(-) diff --git a/.travis.yml b/.travis.yml index 35f4383748e..1726cffd4f1 100644 --- a/.travis.yml +++ b/.travis.yml @@ -24,7 +24,7 @@ install: script: - if [[ $TESTS == true ]]; then source continuous_integration/travis/run_tests.sh ; fi - if [[ $LINT == true ]]; then pip install flake8 ; flake8 distributed ; fi - - if [[ $LINT == true ]]; then pip install black; black distributed --check; fi + - if [[ $LINT == true ]]; then pip install git+https://github.com/psf/black@cad4138050b86d1c8570b926883e32f7465c2880; black distributed --check; fi after_success: - if [[ $COVERAGE == true ]]; then coverage report; pip install -q coveralls ; coveralls ; fi diff --git a/distributed/cli/dask_scheduler.py b/distributed/cli/dask_scheduler.py index 2e6220b7d81..54ecd69e595 100755 --- a/distributed/cli/dask_scheduler.py +++ b/distributed/cli/dask_scheduler.py @@ -136,17 +136,12 @@ def main( dashboard_prefix, use_xheaders, pid_file, - scheduler_file, - interface, - protocol, local_directory, - preload, - preload_argv, tls_ca_file, tls_cert, tls_key, dashboard_address, - idle_timeout, + **kwargs ): g0, g1, g2 = gc.get_threshold() # https://github.com/dask/distributed/issues/1653 gc.set_threshold(g0 * 3, g1 * 3, g2 * 3) @@ -217,17 +212,12 @@ def del_pid_file(): scheduler = Scheduler( loop=loop, - scheduler_file=scheduler_file, security=sec, host=host, port=port, - interface=interface, - protocol=protocol, dashboard_address=dashboard_address if dashboard else None, service_kwargs={"dashboard": {"prefix": dashboard_prefix}}, - idle_timeout=idle_timeout, - preload=preload, - preload_argv=preload_argv, + **kwargs, ) logger.info("Local Directory: %26s", local_directory) logger.info("-" * 47) diff --git a/distributed/cli/dask_worker.py b/distributed/cli/dask_worker.py index 08a1f47d1eb..084d7b59ccc 100755 --- a/distributed/cli/dask_worker.py +++ b/distributed/cli/dask_worker.py @@ -11,7 +11,6 @@ import click import dask from distributed import Nanny, Worker -from distributed.utils import parse_timedelta from distributed.security import Security from distributed.cli.utils import check_python_3, install_signal_handlers from distributed.comm import get_address_host_port @@ -199,25 +198,18 @@ def main( nprocs, nanny, name, - memory_limit, pid_file, - reconnect, resources, dashboard, bokeh, bokeh_port, - local_directory, scheduler_file, - interface, - protocol, - death_timeout, - preload, - preload_argv, dashboard_prefix, tls_ca_file, tls_cert, tls_key, dashboard_address, + **kwargs ): g0, g1, g2 = gc.get_threshold() # https://github.com/dask/distributed/issues/1653 gc.set_threshold(g0 * 3, g1 * 3, g2 * 3) @@ -314,8 +306,6 @@ def del_pid_file(): atexit.register(del_pid_file) - services = {} - if resources: resources = resources.replace(",", " ").split() resources = dict(pair.split("=") for pair in resources) @@ -326,10 +316,9 @@ def del_pid_file(): loop = IOLoop.current() if nanny: - kwargs = {"worker_port": worker_port, "listen_address": listen_address} + kwargs.update({"worker_port": worker_port, "listen_address": listen_address}) t = Nanny else: - kwargs = {} if nanny_port: kwargs["service_ports"] = {"nanny": nanny_port} t = Worker @@ -344,33 +333,21 @@ def del_pid_file(): "dask-worker SCHEDULER_ADDRESS:8786" ) - if death_timeout is not None: - death_timeout = parse_timedelta(death_timeout, "s") - nannies = [ t( scheduler, scheduler_file=scheduler_file, nthreads=nthreads, - services=services, loop=loop, resources=resources, - memory_limit=memory_limit, - reconnect=reconnect, - local_directory=local_directory, - death_timeout=death_timeout, - preload=preload, - preload_argv=preload_argv, security=sec, contact_address=contact_address, - interface=interface, - protocol=protocol, host=host, port=port, dashboard_address=dashboard_address if dashboard else None, service_kwargs={"dashboard": {"prefix": dashboard_prefix}}, name=name if nprocs == 1 or not name else name + "-" + str(i), - **kwargs + **kwargs, ) for i in range(nprocs) ] diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index 3f3f0bfe943..eb1c7514133 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -121,7 +121,9 @@ async def read(self, deserializers=("cuda", "dask", "pickle", "error")): deserializers = ("cuda", "dask", "pickle", "error") resp = await self.ep.recv_future() obj = ucp.get_obj_from_msg(resp) - nframes, = struct.unpack("Q", obj[:8]) # first eight bytes for number of frames + (nframes,) = struct.unpack( + "Q", obj[:8] + ) # first eight bytes for number of frames gpu_frame_msg = obj[ 8 : 8 + nframes diff --git a/distributed/protocol/cupy.py b/distributed/protocol/cupy.py index 13c0348a821..f8d08ee3a1e 100644 --- a/distributed/protocol/cupy.py +++ b/distributed/protocol/cupy.py @@ -31,7 +31,7 @@ def serialize_cupy_ndarray(x): @cuda_deserialize.register(cupy.ndarray) def deserialize_cupy_array(header, frames): - frame, = frames + (frame,) = frames # TODO: put this in ucx... as a kind of "fixup" try: frame.typestr = header["typestr"] diff --git a/distributed/protocol/numba.py b/distributed/protocol/numba.py index 18405ffebe0..aa56a682b95 100644 --- a/distributed/protocol/numba.py +++ b/distributed/protocol/numba.py @@ -36,7 +36,7 @@ def serialize_numba_ndarray(x): @cuda_deserialize.register(numba.cuda.devicearray.DeviceNDArray) def deserialize_numba_ndarray(header, frames): - frame, = frames + (frame,) = frames # TODO: put this in ucx... as a kind of "fixup" if isinstance(frame, bytes): import numpy as np From ec51220d3ff7b3607bd0899e492e253705a6e94f Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sun, 28 Jul 2019 15:20:25 -0700 Subject: [PATCH 0382/1550] Name SSHClusters by providing name= keyword to SpecCluster (#2903) --- distributed/deploy/spec.py | 6 +++++- distributed/deploy/ssh2.py | 2 +- distributed/deploy/tests/test_ssh2.py | 2 ++ 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index 747ecc19a41..cfafa5e94c2 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -76,6 +76,8 @@ class does handle all of the logic around asynchronously cleanly setting up async/await silence_logs: bool Whether or not we should silence logging when setting up the cluster. + name: str, optional + A name to use when printing out the cluster, defaults to type name Examples -------- @@ -149,6 +151,7 @@ def __init__( loop=None, security=None, silence_logs=False, + name=None, ): self._created = weakref.WeakSet() @@ -170,6 +173,7 @@ def __init__( self.status = "created" self._instances.add(self) self._correct_state_waiting = None + self._name = name or type(self).__name__ if not self.asynchronous: self._loop_runner.start() @@ -366,7 +370,7 @@ async def scale_down(self, workers): def __repr__(self): return "%s(%r, workers=%d)" % ( - type(self).__name__, + self._name, self.scheduler_address, len(self.workers), ) diff --git a/distributed/deploy/ssh2.py b/distributed/deploy/ssh2.py index 5f9bb4f9a64..3e87b6d301b 100644 --- a/distributed/deploy/ssh2.py +++ b/distributed/deploy/ssh2.py @@ -207,4 +207,4 @@ def SSHCluster( } for i, host in enumerate(hosts[1:]) } - return SpecCluster(workers, scheduler, **kwargs) + return SpecCluster(workers, scheduler, name="SSHCluster", **kwargs) diff --git a/distributed/deploy/tests/test_ssh2.py b/distributed/deploy/tests/test_ssh2.py index 07415ed47e6..df90d35cd6e 100644 --- a/distributed/deploy/tests/test_ssh2.py +++ b/distributed/deploy/tests/test_ssh2.py @@ -19,6 +19,8 @@ async def test_basic(): result = await client.submit(lambda x: x + 1, 10) assert result == 11 + assert "SSH" in repr(cluster) + @pytest.mark.asyncio async def test_keywords(): From 50e486b67d01084896b92fb90b9dcf8d5be8eb21 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sun, 28 Jul 2019 18:23:56 -0700 Subject: [PATCH 0383/1550] Request feed of worker information from Scheduler to SpecCluster (#2902) * Don't explicitly provide loop= in SpecCluster This is called from within the event loop, so IOLoop.current should be fine * Ask scheduler for updates on adding and removing workers * implement dashboard_link * Add widgets to SpecCluster * Don't include scaling buttons in SSHCluster --- distributed/deploy/spec.py | 185 +++++++++++++++++- distributed/deploy/ssh2.py | 2 +- distributed/deploy/tests/test_spec_cluster.py | 59 +++++- distributed/deploy/tests/test_ssh2.py | 1 + distributed/scheduler.py | 46 ++++- 5 files changed, 281 insertions(+), 12 deletions(-) diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index cfafa5e94c2..87228d5693e 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -1,12 +1,24 @@ import asyncio import atexit +import copy import weakref from tornado import gen +from dask.utils import format_bytes from .cluster import Cluster +from ..comm import connect from ..core import rpc, CommClosedError -from ..utils import LoopRunner, silence_logging, ignoring, Log, Logs +from ..utils import ( + log_errors, + LoopRunner, + silence_logging, + ignoring, + Log, + Logs, + PeriodicCallback, + format_dashboard_link, +) from ..scheduler import Scheduler from ..security import Security @@ -21,9 +33,8 @@ class ProcessInterface: """ - def __init__(self, loop=None): + def __init__(self): self.address = None - self.loop = loop self.lock = asyncio.Lock() self.status = "created" @@ -155,14 +166,16 @@ def __init__( ): self._created = weakref.WeakSet() - self.scheduler_spec = scheduler - self.worker_spec = workers or {} - self.new_spec = worker + self.scheduler_spec = copy.copy(scheduler) + self.worker_spec = copy.copy(workers) or {} + self.new_spec = copy.copy(worker) self.workers = {} self._i = 0 self._asynchronous = asynchronous self.security = security or Security() self.scheduler_comm = None + self.scheduler_info = {} + self.periodic_callbacks = {} if silence_logs: self._old_logging_level = silence_logging(level=silence_logs) @@ -189,6 +202,8 @@ async def _start(self): if self.status == "closed": raise ValueError("Cluster is closed") + self._lock = asyncio.Lock() + if self.scheduler_spec is None: try: from distributed.dashboard import BokehScheduler @@ -198,18 +213,47 @@ async def _start(self): services = {("dashboard", 8787): BokehScheduler} self.scheduler_spec = {"cls": Scheduler, "options": {"services": services}} self.scheduler = self.scheduler_spec["cls"]( - loop=self.loop, **self.scheduler_spec.get("options", {}) + **self.scheduler_spec.get("options", {}) ) - self._lock = asyncio.Lock() self.status = "starting" self.scheduler = await self.scheduler self.scheduler_comm = rpc( self.scheduler.address, connection_args=self.security.get_connection_args("client"), ) + comm = await connect( + self.scheduler_address, + connection_args=self.security.get_connection_args("client"), + ) + await comm.write({"op": "subscribe_worker_status"}) + self.scheduler_info = await comm.read() + self._watch_worker_status_comm = comm + self._watch_worker_status_task = asyncio.ensure_future( + self._watch_worker_status(comm) + ) self.status = "running" + async def _watch_worker_status(self, comm): + """ Listen to scheduler for updates on adding and removing workers """ + while True: + try: + msgs = await comm.read() + except OSError: + break + + for op, msg in msgs: + if op == "add": + workers = msg.pop("workers") + self.scheduler_info["workers"].update(workers) + self.scheduler_info.update(msg) + elif op == "remove": + del self.scheduler_info["workers"][msg] + else: + raise ValueError("Invalid op", op, msg) + + await comm.close() + def _correct_state(self): if self._correct_state_waiting: # If people call this frequently, we only want to run it once @@ -293,12 +337,17 @@ async def _close(self): return self.status = "closing" + for pc in self.periodic_callbacks.values(): + pc.stop() + self.scale(0) await self._correct_state() async with self._lock: with ignoring(CommClosedError): await self.scheduler_comm.close(close_workers=True) await self.scheduler.close() + await self._watch_worker_status_comm.close() + await self._watch_worker_status_task for w in self._created: assert w.status == "closed" self.scheduler_comm.close_rpc() @@ -357,6 +406,10 @@ def new_worker_spec(self): return self._i, self.new_spec + @property + def _supports_scaling(self): + return not not self.new_spec + async def scale_down(self, workers): workers = set(workers) @@ -408,6 +461,122 @@ def logs(self, scheduler=True, workers=True): """ return self.sync(self._logs, scheduler=scheduler, workers=workers) + @property + def dashboard_link(self): + try: + port = self.scheduler_info["services"]["dashboard"] + except KeyError: + return "" + else: + host = self.scheduler_address.split("://")[1].split(":")[0] + return format_dashboard_link(host, port) + + def _widget_status(self): + workers = len(self.scheduler_info["workers"]) + cores = sum(v["nthreads"] for v in self.scheduler_info["workers"].values()) + memory = sum(v["memory_limit"] for v in self.scheduler_info["workers"].values()) + memory = format_bytes(memory) + text = """ +
                  + +
        + + + +
        Workers %d
        Cores %d
        Memory %s
        + +""" % ( + workers, + cores, + memory, + ) + return text + + def _widget(self): + """ Create IPython widget for display within a notebook """ + try: + return self._cached_widget + except AttributeError: + pass + + from ipywidgets import Layout, VBox, HBox, IntText, Button, HTML, Accordion + + layout = Layout(width="150px") + + if self.dashboard_link: + link = '

        Dashboard: %s

        \n' % ( + self.dashboard_link, + self.dashboard_link, + ) + else: + link = "" + + title = "

        %s

        " % type(self).__name__ + title = HTML(title) + dashboard = HTML(link) + + status = HTML(self._widget_status(), layout=Layout(min_width="150px")) + + if self._supports_scaling: + request = IntText(0, description="Workers", layout=layout) + scale = Button(description="Scale", layout=layout) + + minimum = IntText(0, description="Minimum", layout=layout) + maximum = IntText(0, description="Maximum", layout=layout) + adapt = Button(description="Adapt", layout=layout) + + accordion = Accordion( + [HBox([request, scale]), HBox([minimum, maximum, adapt])], + layout=Layout(min_width="500px"), + ) + accordion.selected_index = None + accordion.set_title(0, "Manual Scaling") + accordion.set_title(1, "Adaptive Scaling") + + def adapt_cb(b): + self.adapt(minimum=minimum.value, maximum=maximum.value) + + adapt.on_click(adapt_cb) + + def scale_cb(b): + with log_errors(): + n = request.value + with ignoring(AttributeError): + self._adaptive.stop() + self.scale(n) + + scale.on_click(scale_cb) + else: + accordion = HTML("") + + box = VBox([title, HBox([status, accordion]), dashboard]) + + self._cached_widget = box + + def update(): + status.value = self._widget_status() + + pc = PeriodicCallback(update, 500, io_loop=self.loop) + self.periodic_callbacks["cluster-repr"] = pc + pc.start() + + return box + + def _ipython_display_(self, **kwargs): + return self._widget()._ipython_display_(**kwargs) + @atexit.register def close_clusters(): diff --git a/distributed/deploy/ssh2.py b/distributed/deploy/ssh2.py index 3e87b6d301b..189a61df5f1 100644 --- a/distributed/deploy/ssh2.py +++ b/distributed/deploy/ssh2.py @@ -123,7 +123,7 @@ class Scheduler(Process): dask.distributed.Scheduler class """ - def __init__(self, address: str, connect_kwargs: dict, kwargs: dict, loop=None): + def __init__(self, address: str, connect_kwargs: dict, kwargs: dict): self.address = address self.kwargs = kwargs self.connect_kwargs = connect_kwargs diff --git a/distributed/deploy/tests/test_spec_cluster.py b/distributed/deploy/tests/test_spec_cluster.py index 51094dca2d5..84c868b2585 100644 --- a/distributed/deploy/tests/test_spec_cluster.py +++ b/distributed/deploy/tests/test_spec_cluster.py @@ -1,3 +1,6 @@ +import asyncio +from time import time + from dask.distributed import SpecCluster, Worker, Client, Scheduler, Nanny from distributed.deploy.spec import close_clusters, ProcessInterface from distributed.utils_test import loop, cleanup # noqa: F401 @@ -30,7 +33,7 @@ async def test_specification(cleanup): async with SpecCluster( workers=worker_spec, scheduler=scheduler, asynchronous=True ) as cluster: - assert cluster.worker_spec is worker_spec + assert cluster.worker_spec == worker_spec assert len(cluster.workers) == 3 assert set(cluster.workers) == set(worker_spec) @@ -57,7 +60,7 @@ def test_spec_sync(loop): "my-worker": {"cls": MyWorker, "options": {"nthreads": 3}}, } with SpecCluster(workers=worker_spec, scheduler=scheduler, loop=loop) as cluster: - assert cluster.worker_spec is worker_spec + assert cluster.worker_spec == worker_spec assert len(cluster.workers) == 3 assert set(cluster.workers) == set(worker_spec) @@ -205,3 +208,55 @@ async def test_logs(cleanup): w = toolz.first(cluster.scheduler.workers) logs = await cluster.logs(scheduler=False, workers=[w]) assert set(logs) == {w} + + +@pytest.mark.asyncio +async def test_scheduler_info(cleanup): + async with SpecCluster( + workers=worker_spec, scheduler=scheduler, asynchronous=True + ) as cluster: + assert ( + cluster.scheduler_info["id"] == cluster.scheduler.id + ) # present at startup + + start = time() # wait for all workers + while len(cluster.scheduler_info["workers"]) < len(cluster.workers): + await asyncio.sleep(0.01) + assert time() < start + 1 + + assert set(cluster.scheduler.identity()["workers"]) == set( + cluster.scheduler_info["workers"] + ) + assert ( + cluster.scheduler.identity()["services"] + == cluster.scheduler_info["services"] + ) + assert len(cluster.scheduler_info["workers"]) == len(cluster.workers) + + +@pytest.mark.asyncio +async def test_dashboard_link(cleanup): + async with SpecCluster( + workers=worker_spec, + scheduler={ + "cls": Scheduler, + "options": {"port": 0, "dashboard_address": ":12345"}, + }, + asynchronous=True, + ) as cluster: + assert "12345" in cluster.dashboard_link + + +@pytest.mark.asyncio +async def test_widget(cleanup): + async with SpecCluster( + workers=worker_spec, scheduler=scheduler, asynchronous=True + ) as cluster: + + start = time() # wait for all workers + while len(cluster.scheduler_info["workers"]) < len(cluster.worker_spec): + await asyncio.sleep(0.01) + assert time() < start + 1 + + assert "3" in cluster._widget_status() + assert "GB" in cluster._widget_status() diff --git a/distributed/deploy/tests/test_ssh2.py b/distributed/deploy/tests/test_ssh2.py index df90d35cd6e..b744d352b8b 100644 --- a/distributed/deploy/tests/test_ssh2.py +++ b/distributed/deploy/tests/test_ssh2.py @@ -18,6 +18,7 @@ async def test_basic(): async with Client(cluster, asynchronous=True) as client: result = await client.submit(lambda x: x + 1, 10) assert result == 11 + assert not cluster._supports_scaling assert "SSH" in repr(cluster) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 999d4802730..b0db6653d2b 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -41,6 +41,7 @@ from .comm.addressing import address_from_user_args from .compatibility import finalize, unicode, Mapping, Set from .core import rpc, connect, send_recv, clean_exception, CommClosedError +from .diagnostics.plugin import SchedulerPlugin from . import profile from .metrics import time from .node import ServerNode @@ -1078,6 +1079,7 @@ def __init__( "register_worker_plugin": self.register_worker_plugin, "adaptive_target": self.adaptive_target, "workers_to_close": self.workers_to_close, + "subscribe_worker_status": self.subscribe_worker_status, } self._transitions = { @@ -1266,7 +1268,7 @@ async def close(self, comm=None, fast=False, close_workers=False): self.periodic_callbacks.clear() self.stop_services() - for ext in self.extensions: + for ext in self.extensions.values(): with ignoring(AttributeError): ext.teardown() logger.info("Scheduler closing all comms") @@ -3232,6 +3234,14 @@ async def feed( if teardown: teardown(self, state) + def subscribe_worker_status(self, comm=None): + WorkerStatusPlugin(self, comm) + ident = self.identity() + for v in ident["workers"].values(): + del v["metrics"] + del v["last_seen"] + return ident + def get_processing(self, comm=None, workers=None): if workers is not None: workers = set(map(self.coerce_address, workers)) @@ -4963,3 +4973,37 @@ def __init__(self, task, last_worker): super(KilledWorker, self).__init__(task, last_worker) self.task = task self.last_worker = last_worker + + +class WorkerStatusPlugin(SchedulerPlugin): + """ + An plugin to share worker status with a remote observer + + This is used in cluster managers to keep updated about the status of the + scheduler. + """ + + def __init__(self, scheduler, comm): + self.bcomm = BatchedSend(interval="5ms") + self.bcomm.start(comm) + + self.scheduler = scheduler + self.scheduler.add_plugin(self) + + def add_worker(self, worker=None, **kwargs): + ident = self.scheduler.workers[worker].identity() + del ident["metrics"] + del ident["last_seen"] + try: + self.bcomm.send(["add", {"workers": {worker: ident}}]) + except CommClosedError: + self.scheduler.remove_plugin(self) + + def remove_worker(self, worker=None, **kwargs): + try: + self.bcomm.send(["remove", worker]) + except CommClosedError: + self.scheduler.remove_plugin(self) + + def teardown(self): + self.bcomm.close() From 157eada32c7c49c3f1b0fe9e020be74592ab27fd Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 29 Jul 2019 07:35:54 -0700 Subject: [PATCH 0384/1550] Clear out compatibillity file (#2896) It has been a suitable time since we've dropped Python 2. This clears out the file of unnecessary functions. --- distributed/actor.py | 5 +- distributed/client.py | 25 +-- distributed/comm/inproc.py | 3 +- distributed/comm/tcp.py | 15 +- distributed/comm/tests/test_comms.py | 7 +- distributed/comm/utils.py | 9 +- distributed/compatibility.py | 251 +---------------------- distributed/core.py | 4 +- distributed/dashboard/utils.py | 4 +- distributed/deploy/cluster.py | 4 +- distributed/diagnostics/progressbar.py | 4 +- distributed/diskutils.py | 6 +- distributed/node.py | 8 +- distributed/process.py | 4 +- distributed/profile.py | 3 +- distributed/protocol/numpy.py | 6 +- distributed/protocol/serialize.py | 3 - distributed/protocol/tests/test_h5py.py | 23 +-- distributed/protocol/tests/test_numpy.py | 3 +- distributed/publish.py | 5 +- distributed/pubsub.py | 8 +- distributed/scheduler.py | 9 +- distributed/submit.py | 10 +- distributed/tests/test_as_completed.py | 13 +- distributed/tests/test_client.py | 11 +- distributed/tests/test_compatibility.py | 38 ---- distributed/tests/test_core.py | 6 +- distributed/tests/test_diskutils.py | 7 +- distributed/tests/test_metrics.py | 35 ++-- distributed/tests/test_profile.py | 6 +- distributed/tests/test_utils.py | 10 +- distributed/tests/test_utils_perf.py | 3 - distributed/tests/test_worker.py | 5 +- distributed/threadpoolexecutor.py | 4 +- distributed/utils.py | 213 ++++++------------- distributed/utils_perf.py | 10 +- distributed/utils_test.py | 15 +- distributed/worker.py | 17 +- 38 files changed, 200 insertions(+), 612 deletions(-) delete mode 100644 distributed/tests/test_compatibility.py diff --git a/distributed/actor.py b/distributed/actor.py index e45f089effd..e7e4afaacf0 100644 --- a/distributed/actor.py +++ b/distributed/actor.py @@ -1,9 +1,10 @@ import asyncio from tornado import gen import functools +import threading +from queue import Queue from .client import Future, default_client -from .compatibility import get_thread_identity, Queue from .protocol import to_serialize from .utils import sync from .utils_comm import WrappedKey @@ -103,7 +104,7 @@ def _asynchronous(self): if self._client: return self._client.asynchronous else: - return get_thread_identity() == self._worker.thread_id + return threading.get_ident() == self._worker.thread_id def _sync(self, func, *args, **kwargs): if self._client: diff --git a/distributed/client.py b/distributed/client.py index c84a8160ad6..e1efda4e0d7 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -1,7 +1,7 @@ from __future__ import print_function, division, absolute_import import atexit -from collections import defaultdict +from collections import defaultdict, Iterator from concurrent.futures import ThreadPoolExecutor, CancelledError from concurrent.futures._base import DoneAndNotDoneFutures from contextlib import contextmanager @@ -9,6 +9,7 @@ from datetime import timedelta import errno from functools import partial +import html import itertools import json import logging @@ -19,6 +20,7 @@ import threading import six import socket +from queue import Queue as pyQueue import warnings import weakref @@ -26,7 +28,7 @@ from dask.base import tokenize, normalize_token, collections_to_dsk from dask.core import flatten, get_dependencies from dask.optimization import SubgraphCallable -from dask.compatibility import apply, unicode +from dask.compatibility import apply from dask.utils import ensure_dict, format_bytes try: @@ -55,13 +57,6 @@ gather_from_workers, ) from .cfexecutor import ClientExecutor -from .compatibility import ( - Queue as pyQueue, - isqueue, - html_escape, - StopAsyncIteration, - Iterator, -) from .core import connect, rpc, clean_exception, CommClosedError, PooledRPCCall from .metrics import time from .node import Node @@ -400,7 +395,7 @@ def __repr__(self): return "" % (self.status, self.key) def _repr_html_(self): - text = "Future: %s " % html_escape(key_split(self.key)) + text = "Future: %s " % html.escape(key_split(self.key)) text += ( 'status: ' '%(status)s, ' @@ -414,7 +409,7 @@ def _repr_html_(self): except AttributeError: typ = str(self.type) text += 'type: %s, ' % typ - text += 'key: %s' % html_escape(str(self.key)) + text += 'key: %s' % html.escape(str(self.key)) return text def __await__(self): @@ -1523,7 +1518,7 @@ def map( if not callable(func): raise TypeError("First input to map must be a callable function") - if all(map(isqueue, iterables)) or all( + if all(isinstance(it, pyQueue) for it in iterables) or all( isinstance(i, Iterator) for i in iterables ): raise TypeError( @@ -1792,7 +1787,7 @@ def gather(self, futures, errors="raise", direct=None, asynchronous=None): -------- Client.scatter: Send data out to cluster """ - if isqueue(futures): + if isinstance(futures, pyQueue): raise TypeError( "Dask no longer supports gathering over Iterators and Queues. " "Consider using a normal for loop and Client.submit/gather" @@ -1829,7 +1824,7 @@ async def _scatter( if isinstance(workers, six.string_types + (Number,)): workers = [workers] if isinstance(data, dict) and not all( - isinstance(k, (bytes, unicode)) for k in data + isinstance(k, (bytes, str)) for k in data ): d = await self._scatter(keymap(tokey, data), workers, broadcast) raise gen.Return({k: d[tokey(k)] for k in data}) @@ -1998,7 +1993,7 @@ def scatter( """ if timeout == no_default: timeout = self._timeout - if isqueue(data) or isinstance(data, Iterator): + if isinstance(data, pyQueue) or isinstance(data, Iterator): raise TypeError( "Dask no longer supports mapping over Iterators or Queues." "Consider using a normal for loop and Client.submit" diff --git a/distributed/comm/inproc.py b/distributed/comm/inproc.py index c9a6dc90281..3a781479bbc 100644 --- a/distributed/comm/inproc.py +++ b/distributed/comm/inproc.py @@ -11,7 +11,6 @@ from tornado.concurrent import Future from tornado.ioloop import IOLoop -from ..compatibility import finalize from ..protocol import nested_deserialize from ..utils import get_ip @@ -161,7 +160,7 @@ def __init__( self._write_loop = write_loop self._closed = False - self._finalizer = finalize(self, self._get_finalizer()) + self._finalizer = weakref.finalize(self, self._get_finalizer()) self._finalizer.atexit = False self._initialized = True diff --git a/distributed/comm/tcp.py b/distributed/comm/tcp.py index 602c9a36253..d23f381857d 100644 --- a/distributed/comm/tcp.py +++ b/distributed/comm/tcp.py @@ -6,6 +6,7 @@ import struct import sys from tornado import gen +import weakref try: import ssl @@ -19,7 +20,6 @@ from tornado.tcpclient import TCPClient from tornado.tcpserver import TCPServer -from ..compatibility import finalize, PY3 from ..threadpoolexecutor import ThreadPoolExecutor from ..utils import ( ensure_bytes, @@ -158,7 +158,7 @@ def __init__(self, stream, local_addr, peer_addr, deserialize=True): self._peer_addr = peer_addr self.stream = stream self.deserialize = deserialize - self._finalizer = finalize(self, self._get_finalizer()) + self._finalizer = weakref.finalize(self, self._get_finalizer()) self._finalizer.atexit = False self._extra = {} @@ -199,7 +199,7 @@ async def read(self, deserializers=None): frames = [] for length in lengths: if length: - if PY3 and self._iostream_has_read_into: + if self._iostream_has_read_into: frame = bytearray(length) n = await stream.read_into(frame) assert n == length, (n, length) @@ -242,7 +242,7 @@ def write(self, msg, serializers=None, on_error="message"): length_bytes = [struct.pack("Q", len(frames))] + [ struct.pack("Q", x) for x in lengths ] - if PY3 and sum(lengths) < 2 ** 17: # 128kiB + if sum(lengths) < 2 ** 17: # 128kiB b = b"".join(length_bytes + frames) # small enough, send in one go stream.write(b) else: @@ -340,11 +340,8 @@ def _check_encryption(self, address, connection_args): class BaseTCPConnector(Connector, RequireEncryptionMixin): - if PY3: # see github PR #2403 discussion for more info - _executor = ThreadPoolExecutor(2, thread_name_prefix="TCP-Executor") - _resolver = netutil.ExecutorResolver(close_executor=False, executor=_executor) - else: - _resolver = None + _executor = ThreadPoolExecutor(2, thread_name_prefix="TCP-Executor") + _resolver = netutil.ExecutorResolver(close_executor=False, executor=_executor) client = TCPClient(resolver=_resolver) async def connect(self, address, deserialize=True, **connection_args): diff --git a/distributed/comm/tests/test_comms.py b/distributed/comm/tests/test_comms.py index f2bf7778221..7fac117027b 100644 --- a/distributed/comm/tests/test_comms.py +++ b/distributed/comm/tests/test_comms.py @@ -11,7 +11,6 @@ from tornado import gen, ioloop, locks, queues from tornado.concurrent import Future -from distributed.compatibility import PY3 from distributed.metrics import time from distributed.utils import get_ip, get_ipv6 from distributed.utils_test import ( @@ -333,8 +332,7 @@ def sleep_for_60ms(): yield connect("tcp://localhost:28400", 0.052) max_thread_count = yield sleep_future # 2 is the number set by BaseTCPConnector.executor (ThreadPoolExecutor) - if PY3: - assert max_thread_count <= 2 + original_thread_count + assert max_thread_count <= 2 + original_thread_count # tcp.TLSConnector() sleep_future = sleep_for_60ms() @@ -345,8 +343,7 @@ def sleep_for_60ms(): connection_args={"ssl_context": get_client_ssl_context()}, ) max_thread_count = yield sleep_future - if PY3: - assert max_thread_count <= 2 + original_thread_count + assert max_thread_count <= 2 + original_thread_count @gen.coroutine diff --git a/distributed/comm/utils.py b/distributed/comm/utils.py index bb6621e2021..1e23b25c46b 100644 --- a/distributed/comm/utils.py +++ b/distributed/comm/utils.py @@ -3,11 +3,11 @@ from concurrent.futures import ThreadPoolExecutor import logging import socket +import weakref from tornado import gen from .. import protocol -from ..compatibility import finalize, PY3 from ..utils import get_ip, get_ipv6, nbytes @@ -25,7 +25,7 @@ ) except TypeError: _offload_executor = ThreadPoolExecutor(max_workers=1) -finalize(_offload_executor, _offload_executor.shutdown) +weakref.finalize(_offload_executor, _offload_executor.shutdown) def offload(fn, *args, **kwargs): @@ -50,10 +50,7 @@ def _to_frames(): logger.exception(e) raise - if PY3: - res = yield offload(_to_frames) - else: # distributed/deploy/tests/test_adaptive.py::test_get_scale_up_kwargs fails on Py27. Don't know why - res = _to_frames() + res = yield offload(_to_frames) raise gen.Return(res) diff --git a/distributed/compatibility.py b/distributed/compatibility.py index f3a85973802..fb79353d24b 100644 --- a/distributed/compatibility.py +++ b/distributed/compatibility.py @@ -1,256 +1,11 @@ from __future__ import print_function, division, absolute_import import logging +import platform import sys -# flake8: noqa - -if sys.version_info[0] == 2: - from Queue import Queue, Empty - from io import BytesIO - from thread import get_ident as get_thread_identity - from inspect import getargspec - from cgi import escape as html_escape - from collections import Iterator, Mapping, Set, MutableMapping - from fractions import gcd - - reload = reload - unicode = unicode - PY2 = True - PY3 = False - ConnectionRefusedError = OSError - FileExistsError = OSError - - class StopAsyncIteration(Exception): - pass - - import gzip - - def gzip_decompress(b): - f = gzip.GzipFile(fileobj=BytesIO(b)) - result = f.read() - f.close() - return result - - def gzip_compress(b): - bio = BytesIO() - f = gzip.GzipFile(fileobj=bio, mode="w") - f.write(b) - f.close() - bio.seek(0) - result = bio.read() - return result - - def isqueue(o): - return ( - hasattr(o, "queue") and hasattr(o, "__module__") and o.__module__ == "Queue" - ) - - def invalidate_caches(): - pass - - def cache_from_source(path): - import os - - name, ext = os.path.splitext(path) - return name + ".pyc" - - logging_names = logging._levelNames - - def iscoroutinefunction(func): - return False - - -if sys.version_info[0] == 3: - from asyncio import iscoroutinefunction - from collections.abc import Iterator, Mapping, Set, MutableMapping - from queue import Queue, Empty - from importlib import reload - from threading import get_ident as get_thread_identity - from importlib import invalidate_caches - from importlib.util import cache_from_source - from inspect import getfullargspec as getargspec - from html import escape as html_escape - from math import gcd - - PY2 = False - PY3 = True - unicode = str - from gzip import decompress as gzip_decompress - from gzip import compress as gzip_compress - - ConnectionRefusedError = ConnectionRefusedError - FileExistsError = FileExistsError - StopAsyncIteration = StopAsyncIteration - - def isqueue(o): - return isinstance(o, Queue) - - logging_names = logging._levelToName.copy() - logging_names.update(logging._nameToLevel) - - -import platform +logging_names = logging._levelToName.copy() +logging_names.update(logging._nameToLevel) PYPY = platform.python_implementation().lower() == "pypy" WINDOWS = sys.platform.startswith("win") - - -try: - from json.decoder import JSONDecodeError -except (ImportError, AttributeError): - JSONDecodeError = ValueError - -try: - from functools import singledispatch -except ImportError: - from singledispatch import singledispatch - -try: - from weakref import finalize -except ImportError: - # Backported from Python 3.6 - import itertools - from weakref import ref - - class finalize(object): - """Class for finalization of weakrefable objects - - finalize(obj, func, *args, **kwargs) returns a callable finalizer - object which will be called when obj is garbage collected. The - first time the finalizer is called it evaluates func(*arg, **kwargs) - and returns the result. After this the finalizer is dead, and - calling it just returns None. - - When the program exits any remaining finalizers for which the - atexit attribute is true will be run in reverse order of creation. - By default atexit is true. - """ - - # Finalizer objects don't have any state of their own. They are - # just used as keys to lookup _Info objects in the registry. This - # ensures that they cannot be part of a ref-cycle. - - __slots__ = () - _registry = {} - _shutdown = False - _index_iter = itertools.count() - _dirty = False - _registered_with_atexit = False - - class _Info: - __slots__ = ("weakref", "func", "args", "kwargs", "atexit", "index") - - def __init__(self, obj, func, *args, **kwargs): - if not self._registered_with_atexit: - # We may register the exit function more than once because - # of a thread race, but that is harmless - import atexit - - atexit.register(self._exitfunc) - finalize._registered_with_atexit = True - info = self._Info() - info.weakref = ref(obj, self) - info.func = func - info.args = args - info.kwargs = kwargs or None - info.atexit = True - info.index = next(self._index_iter) - self._registry[self] = info - finalize._dirty = True - - def __call__(self, _=None): - """If alive then mark as dead and return func(*args, **kwargs); - otherwise return None""" - info = self._registry.pop(self, None) - if info and not self._shutdown: - return info.func(*info.args, **(info.kwargs or {})) - - def detach(self): - """If alive then mark as dead and return (obj, func, args, kwargs); - otherwise return None""" - info = self._registry.get(self) - obj = info and info.weakref() - if obj is not None and self._registry.pop(self, None): - return (obj, info.func, info.args, info.kwargs or {}) - - def peek(self): - """If alive then return (obj, func, args, kwargs); - otherwise return None""" - info = self._registry.get(self) - obj = info and info.weakref() - if obj is not None: - return (obj, info.func, info.args, info.kwargs or {}) - - @property - def alive(self): - """Whether finalizer is alive""" - return self in self._registry - - @property - def atexit(self): - """Whether finalizer should be called at exit""" - info = self._registry.get(self) - return bool(info) and info.atexit - - @atexit.setter - def atexit(self, value): - info = self._registry.get(self) - if info: - info.atexit = bool(value) - - def __repr__(self): - info = self._registry.get(self) - obj = info and info.weakref() - if obj is None: - return "<%s object at %#x; dead>" % (type(self).__name__, id(self)) - else: - return "<%s object at %#x; for %r at %#x>" % ( - type(self).__name__, - id(self), - type(obj).__name__, - id(obj), - ) - - @classmethod - def _select_for_exit(cls): - # Return live finalizers marked for exit, oldest first - L = [(f, i) for (f, i) in cls._registry.items() if i.atexit] - L.sort(key=lambda item: item[1].index) - return [f for (f, i) in L] - - @classmethod - def _exitfunc(cls): - # At shutdown invoke finalizers for which atexit is true. - # This is called once all other non-daemonic threads have been - # joined. - reenable_gc = False - try: - if cls._registry: - import gc - - if gc.isenabled(): - reenable_gc = True - gc.disable() - pending = None - while True: - if pending is None or finalize._dirty: - pending = cls._select_for_exit() - finalize._dirty = False - if not pending: - break - f = pending.pop() - try: - # gc is disabled, so (assuming no daemonic - # threads) the following is the only line in - # this function which might trigger creation - # of a new finalizer - f() - except Exception: - sys.excepthook(*sys.exc_info()) - assert f not in cls._registry - finally: - # prevent any more finalizers from executing during shutdown - finalize._shutdown = True - if reenable_gc: - gc.enable() diff --git a/distributed/core.py b/distributed/core.py index d8a34859359..7db7b3e29e5 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -6,6 +6,7 @@ from functools import partial import logging import six +import threading import traceback import uuid import weakref @@ -17,7 +18,6 @@ from tornado.ioloop import IOLoop from tornado.locks import Event -from .compatibility import get_thread_identity from .comm import ( connect, listen, @@ -207,7 +207,7 @@ def stop(): self.thread_id = 0 def set_thread_ident(): - self.thread_id = get_thread_identity() + self.thread_id = threading.get_ident() self.io_loop.add_callback(set_thread_ident) diff --git a/distributed/dashboard/utils.py b/distributed/dashboard/utils.py index a9b31345ca9..8e6b5ff0b9c 100644 --- a/distributed/dashboard/utils.py +++ b/distributed/dashboard/utils.py @@ -7,13 +7,11 @@ from tornado import web from toolz import partition -from ..compatibility import PY2 - BOKEH_VERSION = LooseVersion(bokeh.__version__) dirname = os.path.dirname(__file__) -if BOKEH_VERSION >= "1.0.0" and not PY2: +if BOKEH_VERSION >= "1.0.0": # This decorator is only available in bokeh >= 1.0.0, and doesn't work for # callbacks in Python 2, since the signature introspection won't line up. from bokeh.core.properties import without_property_validation diff --git a/distributed/deploy/cluster.py b/distributed/deploy/cluster.py index f5d991cd737..58c6ce73644 100644 --- a/distributed/deploy/cluster.py +++ b/distributed/deploy/cluster.py @@ -1,5 +1,6 @@ from datetime import timedelta import logging +import threading from weakref import ref from dask.utils import format_bytes @@ -7,7 +8,6 @@ from .adaptive import Adaptive -from ..compatibility import get_thread_identity from ..utils import ( PeriodicCallback, log_errors, @@ -231,7 +231,7 @@ def asynchronous(self): self._asynchronous or getattr(thread_state, "asynchronous", False) or hasattr(self.loop, "_thread_identity") - and self.loop._thread_identity == get_thread_identity() + and self.loop._thread_identity == threading.get_ident() ) def sync(self, func, *args, asynchronous=None, callback_timeout=None, **kwargs): diff --git a/distributed/diagnostics/progressbar.py b/distributed/diagnostics/progressbar.py index 8d57da779c6..4c9b781f61c 100644 --- a/distributed/diagnostics/progressbar.py +++ b/distributed/diagnostics/progressbar.py @@ -1,6 +1,7 @@ from __future__ import print_function, division, absolute_import import logging +import html from timeit import default_timer import sys import weakref @@ -10,7 +11,6 @@ from .progress import format_time, Progress, MultiProgress -from ..compatibility import html_escape from ..core import connect, coerce_to_address, CommClosedError from ..client import default_client, futures_of from ..protocol.pickle import dumps @@ -334,7 +334,7 @@ def make_widget(self, all): '
        ' - + html_escape(key.decode() if isinstance(key, bytes) else key) + + html.escape(key.decode() if isinstance(key, bytes) else key) + "
        " ) for key in all diff --git a/distributed/diskutils.py b/distributed/diskutils.py index 395f7828505..32e6be35adb 100644 --- a/distributed/diskutils.py +++ b/distributed/diskutils.py @@ -7,11 +7,11 @@ import shutil import stat import tempfile +import weakref import dask from . import locket -from .compatibility import finalize logger = logging.getLogger(__name__) @@ -73,7 +73,7 @@ def __init__(self, workspace, name=None, prefix=None): raise workspace._known_locks.add(self._lock_path) - self._finalizer = finalize( + self._finalizer = weakref.finalize( self, self._finalize, workspace, @@ -82,7 +82,7 @@ def __init__(self, workspace, name=None, prefix=None): self.dir_path, ) else: - self._finalizer = finalize( + self._finalizer = weakref.finalize( self, self._finalize, workspace, None, None, self.dir_path ) diff --git a/distributed/node.py b/distributed/node.py index 323e2c3e49d..cbf2c00d8f7 100644 --- a/distributed/node.py +++ b/distributed/node.py @@ -1,13 +1,13 @@ from __future__ import print_function, division, absolute_import -import warnings import logging +import warnings +import weakref from tornado.ioloop import IOLoop from tornado import gen import dask -from .compatibility import unicode, finalize from .core import Server, ConnectionPool from .versions import get_versions from .utils import DequeHandler @@ -97,7 +97,7 @@ def start_services(self, default_listen_ip): else: port = 0 - if isinstance(port, (str, unicode)): + if isinstance(port, str): port = port.split(":") if isinstance(port, (tuple, list)): @@ -143,7 +143,7 @@ def _setup_logging(self, logger): logging.Formatter(dask.config.get("distributed.admin.log-format")) ) logger.addHandler(self._deque_handler) - finalize(self, logger.removeHandler, self._deque_handler) + weakref.finalize(self, logger.removeHandler, self._deque_handler) def get_logs(self, comm=None, n=None): deque_handler = self._deque_handler diff --git a/distributed/process.py b/distributed/process.py index e716d754db1..b6e50122c36 100644 --- a/distributed/process.py +++ b/distributed/process.py @@ -4,11 +4,11 @@ from datetime import timedelta import logging import os +from queue import Queue as PyQueue import re import threading import weakref -from .compatibility import finalize, Queue as PyQueue from .utils import mp_context from tornado import gen @@ -112,7 +112,7 @@ def stop_thread(q): # We don't join the thread here as a finalizer can be called # asynchronously from anywhere - self._finalizer = finalize(self, stop_thread, q=self._watch_q) + self._finalizer = weakref.finalize(self, stop_thread, q=self._watch_q) self._finalizer.atexit = False def _on_exit(self, exitcode): diff --git a/distributed/profile.py b/distributed/profile.py index e240a872fb4..7f85f46312b 100644 --- a/distributed/profile.py +++ b/distributed/profile.py @@ -37,7 +37,6 @@ from .metrics import time from .utils import format_time, color_of, parse_timedelta -from .compatibility import get_thread_identity def identifier(frame): @@ -304,7 +303,7 @@ def watch( deque """ if thread_id is None: - thread_id = get_thread_identity() + thread_id = threading.get_ident() log = deque(maxlen=maxlen) diff --git a/distributed/protocol/numpy.py b/distributed/protocol/numpy.py index d8da4f204e4..b2375569ef6 100644 --- a/distributed/protocol/numpy.py +++ b/distributed/protocol/numpy.py @@ -1,12 +1,12 @@ from __future__ import print_function, division, absolute_import +import math import numpy as np from .utils import frame_split_size, merge_frames from .serialize import dask_serialize, dask_deserialize from . import pickle -from ..compatibility import gcd from ..utils import log_errors @@ -60,13 +60,13 @@ def serialize_numpy_ndarray(x): data = x.ravel() if data.dtype.fields or data.dtype.itemsize > 8: - data = data.view("u%d" % gcd(x.dtype.itemsize, 8)) + data = data.view("u%d" % math.gcd(x.dtype.itemsize, 8)) try: data = data.data except ValueError: # "ValueError: cannot include dtype 'M' in a buffer" - data = data.view("u%d" % gcd(x.dtype.itemsize, 8)).data + data = data.view("u%d" % math.gcd(x.dtype.itemsize, 8)).data header = {"dtype": dt, "shape": x.shape, "strides": strides} diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index 26129f4e1c5..9e314703072 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -13,7 +13,6 @@ import msgpack from . import pickle -from ..compatibility import PY2 from ..utils import has_keyword, typename from .compression import maybe_compress, decompress from .utils import ( @@ -376,8 +375,6 @@ def serialize_bytelist(x, **kwargs): def serialize_bytes(x, **kwargs): L = serialize_bytelist(x, **kwargs) - if PY2: - L = [bytes(y) for y in L] return b"".join(L) diff --git a/distributed/protocol/tests/test_h5py.py b/distributed/protocol/tests/test_h5py.py index f2f9a6625cb..6bae5b3b8d5 100644 --- a/distributed/protocol/tests/test_h5py.py +++ b/distributed/protocol/tests/test_h5py.py @@ -7,25 +7,22 @@ from distributed.protocol import deserialize, serialize -from distributed.utils import PY3, tmpfile +from distributed.utils import tmpfile def silence_h5py_issue775(func): @functools.wraps(func) def wrapper(): - if PY3: - try: - func() - except RuntimeError as e: - # https://github.com/h5py/h5py/issues/775 - if str(e) != "dictionary changed size during iteration": - raise - tb = traceback.extract_tb(e.__traceback__) - filename, lineno, _, _ = tb[-1] - if not filename.endswith("h5py/_objects.pyx"): - raise - else: + try: func() + except RuntimeError as e: + # https://github.com/h5py/h5py/issues/775 + if str(e) != "dictionary changed size during iteration": + raise + tb = traceback.extract_tb(e.__traceback__) + filename, lineno, _, _ = tb[-1] + if not filename.endswith("h5py/_objects.pyx"): + raise return wrapper diff --git a/distributed/protocol/tests/test_numpy.py b/distributed/protocol/tests/test_numpy.py index eb39b57c351..ed4e32c1137 100644 --- a/distributed/protocol/tests/test_numpy.py +++ b/distributed/protocol/tests/test_numpy.py @@ -6,7 +6,6 @@ import numpy as np import pytest -from distributed.compatibility import PY2 from distributed.protocol import ( serialize, deserialize, @@ -79,7 +78,7 @@ def test_dumps_serialize_numpy(x): header, frames = serialize(x) if "compression" in header: frames = decompress(header, frames) - buffer_interface = buffer if PY2 else memoryview # noqa: F821 + buffer_interface = memoryview for frame in frames: assert isinstance(frame, (bytes, buffer_interface)) y = deserialize(header, frames) diff --git a/distributed/publish.py b/distributed/publish.py index a21f5ef37ed..ea65efb4e74 100644 --- a/distributed/publish.py +++ b/distributed/publish.py @@ -1,5 +1,6 @@ -from distributed.compatibility import MutableMapping -from distributed.utils import log_errors, tokey +from collections import MutableMapping + +from .utils import log_errors, tokey class PublishExtension(object): diff --git a/distributed/pubsub.py b/distributed/pubsub.py index f40c0b15b31..0a4053191eb 100644 --- a/distributed/pubsub.py +++ b/distributed/pubsub.py @@ -1,12 +1,12 @@ from collections import defaultdict, deque import datetime import logging +import threading import weakref import tornado.locks from tornado import gen -from .compatibility import finalize, get_thread_identity from .core import CommClosedError from .utils import sync from .protocol.serialize import to_serialize @@ -306,7 +306,7 @@ def __init__(self, name, worker=None, client=None): if self.worker: pubsub = self.worker.extensions["pubsub"] self.loop.add_callback(pubsub.publishers[name].add, self) - finalize(self, pubsub.trigger_cleanup) + weakref.finalize(self, pubsub.trigger_cleanup) async def _start(self): if self.worker: @@ -385,7 +385,7 @@ def __init__(self, name, worker=None, client=None): else: raise Exception() - finalize(self, pubsub.trigger_cleanup) + weakref.finalize(self, pubsub.trigger_cleanup) async def _get(self, timeout=None): if timeout is not None: @@ -408,7 +408,7 @@ def get(self, timeout=None): """ Get a single message """ if self.client: return self.client.sync(self._get, timeout=timeout) - elif self.worker.thread_id == get_thread_identity(): + elif self.worker.thread_id == threading.get_ident(): return self._get() else: if self.buffer: # fastpath diff --git a/distributed/scheduler.py b/distributed/scheduler.py index b0db6653d2b..8eec8744849 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1,7 +1,7 @@ from __future__ import print_function, division, absolute_import import asyncio -from collections import defaultdict, deque, OrderedDict +from collections import defaultdict, deque, OrderedDict, Mapping, Set from datetime import timedelta from functools import partial import itertools @@ -39,7 +39,6 @@ unparse_host_port, ) from .comm.addressing import address_from_user_args -from .compatibility import finalize, unicode, Mapping, Set from .core import rpc, connect, send_recv, clean_exception, CommClosedError from .diagnostics.plugin import SchedulerPlugin from . import profile @@ -1229,7 +1228,7 @@ def del_scheduler_file(): if os.path.exists(fn): os.remove(fn) - finalize(self, del_scheduler_file) + weakref.finalize(self, del_scheduler_file) preload_modules(self.preload, parameter=self, argv=self.preload_argv) @@ -2124,7 +2123,7 @@ def validate_state(self, allow_overlap=False): raise ValueError("Workers not the same in all collections") for w, ws in self.workers.items(): - assert isinstance(w, (str, unicode)), (type(w), w) + assert isinstance(w, str), (type(w), w) assert isinstance(ws, WorkerState), (type(ws), ws) assert ws.address == w if not ws.processing: @@ -3706,7 +3705,7 @@ def transition_processing_memory( try: ts = self.tasks[key] assert worker - assert isinstance(worker, (str, unicode)) + assert isinstance(worker, str) if self.validate: assert ts.processing_on diff --git a/distributed/submit.py b/distributed/submit.py index bdbe3251a9d..f7e0a2f70aa 100644 --- a/distributed/submit.py +++ b/distributed/submit.py @@ -11,11 +11,9 @@ from tornado.ioloop import IOLoop -from distributed import rpc -from distributed.compatibility import unicode -from distributed.core import Server -from distributed.security import Security -from distributed.utils import get_ip +from .core import rpc, Server +from .security import Security +from .utils import get_ip logger = logging.getLogger("distributed.remote") @@ -62,7 +60,7 @@ def execute(self, stream=None, filename=None): def upload_file(self, stream, filename=None, file_payload=None): out_filename = os.path.join(self.local_dir, filename) - if isinstance(file_payload, unicode): + if isinstance(file_payload, str): file_payload = file_payload.encode() with open(out_filename, "wb") as f: f.write(file_payload) diff --git a/distributed/tests/test_as_completed.py b/distributed/tests/test_as_completed.py index aa53b9b993a..911ff388e06 100644 --- a/distributed/tests/test_as_completed.py +++ b/distributed/tests/test_as_completed.py @@ -1,5 +1,7 @@ from concurrent.futures import CancelledError +from collections import Iterator from operator import add +import queue import random from time import sleep @@ -7,7 +9,6 @@ from tornado import gen from distributed.client import _as_completed, as_completed, _first_completed -from distributed.compatibility import Empty, StopAsyncIteration, Queue, Iterator from distributed.utils_test import gen_cluster, inc, throws from distributed.utils_test import client, cluster_fixture, loop # noqa: F401 @@ -18,11 +19,11 @@ def test__as_completed(c, s, a, b): y = c.submit(inc, 1) z = c.submit(inc, 2) - queue = Queue() - yield _as_completed([x, y, z], queue) + q = queue.Queue() + yield _as_completed([x, y, z], q) - assert queue.qsize() == 3 - assert {queue.get(), queue.get(), queue.get()} == {x, y, z} + assert q.qsize() == 3 + assert {q.get(), q.get(), q.get()} == {x, y, z} result = yield _first_completed([x, y, z]) assert result in [x, y, z] @@ -112,7 +113,7 @@ def test_as_completed_cancel(client): assert next(ac) is x or y assert next(ac) is y or x - with pytest.raises(Empty): + with pytest.raises(queue.Empty): ac.queue.get(timeout=0.1) res = list(as_completed([x, y, x])) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index bf05875fdee..5e2ade3f247 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -51,7 +51,7 @@ futures_of, temp_default_client, ) -from distributed.compatibility import PY3, WINDOWS +from distributed.compatibility import WINDOWS from distributed.metrics import time from distributed.scheduler import Scheduler, KilledWorker @@ -4833,7 +4833,6 @@ def test_bytes_keys(c, s, a, b): @gen_cluster(client=True) def test_unicode_ascii_keys(c, s, a, b): - # cross-version unicode type (py2: unicode, py3: str) uni_type = type(u"") key = u"inc-123" future = c.submit(inc, 1, key=key) @@ -4846,7 +4845,6 @@ def test_unicode_ascii_keys(c, s, a, b): @gen_cluster(client=True) def test_unicode_keys(c, s, a, b): - # cross-version unicode type (py2: unicode, py3: str) uni_type = type(u"") key = u"inc-123\u03bc" future = c.submit(inc, 1, key=key) @@ -5036,12 +5034,7 @@ def test_client_async_before_loop_starts(): @pytest.mark.slow -@gen_cluster( - client=True, - Worker=Nanny if PY3 else Worker, - timeout=60, - nthreads=[("127.0.0.1", 3)] * 2, -) +@gen_cluster(client=True, Worker=Nanny, timeout=60, nthreads=[("127.0.0.1", 3)] * 2) def test_nested_compute(c, s, a, b): def fib(x): assert get_worker().get_current_task() diff --git a/distributed/tests/test_compatibility.py b/distributed/tests/test_compatibility.py deleted file mode 100644 index 42eae448aa1..00000000000 --- a/distributed/tests/test_compatibility.py +++ /dev/null @@ -1,38 +0,0 @@ -from __future__ import print_function, division, absolute_import - -from distributed.compatibility import gzip_compress, gzip_decompress, finalize - - -def test_gzip(): - b = b"Hello, world!" - c = gzip_compress(b) - d = gzip_decompress(c) - assert b == d - - -def test_finalize(): - class C(object): - pass - - l = [] - - def cb(value): - l.append(value) - - o = C() - f = finalize(o, cb, 1) - assert f in f._select_for_exit() - f.atexit = False - assert f not in f._select_for_exit() - assert not l - del o - assert l.pop() == 1 - - o = C() - fin = finalize(o, cb, 2) - assert fin.alive - fin() - assert not fin.alive - assert l.pop() == 2 - del o - assert not l diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index 38a43a1a5c8..82e4c709be5 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -3,13 +3,13 @@ from contextlib import contextmanager import os import socket +import threading import weakref from tornado import gen import pytest import dask -from distributed.compatibility import finalize, get_thread_identity from distributed.core import ( pingpong, Server, @@ -63,7 +63,7 @@ class CountedObject(object): def __new__(cls): cls.n_instances += 1 obj = object.__new__(cls) - finalize(obj, cls._finalize) + weakref.finalize(obj, cls._finalize) return obj @classmethod @@ -702,7 +702,7 @@ def f(): @gen_cluster() def test_thread_id(s, a, b): - assert s.thread_id == a.thread_id == b.thread_id == get_thread_identity() + assert s.thread_id == a.thread_id == b.thread_id == threading.get_ident() @gen_test() diff --git a/distributed/tests/test_diskutils.py b/distributed/tests/test_diskutils.py index a6dcf3497a3..561a4cd408b 100644 --- a/distributed/tests/test_diskutils.py +++ b/distributed/tests/test_diskutils.py @@ -3,6 +3,7 @@ import functools import gc import os +import queue import shutil import subprocess import sys @@ -12,7 +13,7 @@ import pytest import dask -from distributed.compatibility import Empty, WINDOWS +from distributed.compatibility import WINDOWS from distributed.diskutils import WorkSpace from distributed.metrics import time from distributed.utils import mp_context @@ -258,7 +259,7 @@ def _test_workspace_concurrency(tmpdir, timeout, max_procs): # Any errors? try: err = err_q.get_nowait() - except Empty: + except queue.Empty: pass else: raise err @@ -266,7 +267,7 @@ def _test_workspace_concurrency(tmpdir, timeout, max_procs): try: while True: n_purged += purged_q.get_nowait() - except Empty: + except queue.Empty: pass # We attempted to purge most directories at some point assert n_purged >= 0.5 * n_created > 0 diff --git a/distributed/tests/test_metrics.py b/distributed/tests/test_metrics.py index d1eb4a1dad0..cdb4b8ee478 100644 --- a/distributed/tests/test_metrics.py +++ b/distributed/tests/test_metrics.py @@ -5,7 +5,6 @@ import time from distributed import metrics -from distributed.compatibility import PY3 from distributed.utils_test import run_for @@ -37,12 +36,11 @@ def test_process_time(): dt = metrics.process_time() - start assert dt >= 0.05 - if PY3: - # Sleep time not counted - start = metrics.process_time() - time.sleep(0.1) - dt = metrics.process_time() - start - assert dt <= 0.05 + # Sleep time not counted + start = metrics.process_time() + time.sleep(0.1) + dt = metrics.process_time() - start + assert dt <= 0.05 def test_thread_time(): @@ -51,18 +49,17 @@ def test_thread_time(): dt = metrics.thread_time() - start assert 0.03 <= dt <= 0.2 - if PY3: - # Sleep time not counted + # Sleep time not counted + start = metrics.thread_time() + time.sleep(0.1) + dt = metrics.thread_time() - start + assert dt <= 0.05 + + if sys.platform == "linux": + # Always per-thread on Linux + t = threading.Thread(target=run_for, args=(0.1,)) start = metrics.thread_time() - time.sleep(0.1) + t.start() + t.join() dt = metrics.thread_time() - start assert dt <= 0.05 - - if sys.platform == "linux": - # Always per-thread on Linux - t = threading.Thread(target=run_for, args=(0.1,)) - start = metrics.thread_time() - t.start() - t.join() - dt = metrics.thread_time() - start - assert dt <= 0.05 diff --git a/distributed/tests/test_profile.py b/distributed/tests/test_profile.py index ee49f130027..a022600d819 100644 --- a/distributed/tests/test_profile.py +++ b/distributed/tests/test_profile.py @@ -4,7 +4,7 @@ from toolz import first import threading -from distributed.compatibility import get_thread_identity, WINDOWS +from distributed.compatibility import WINDOWS from distributed import metrics from distributed.profile import ( process, @@ -164,7 +164,7 @@ def test_merge_empty(): def test_call_stack(): - frame = sys._current_frames()[get_thread_identity()] + frame = sys._current_frames()[threading.get_ident()] L = call_stack(frame) assert isinstance(L, list) assert all(isinstance(s, str) for s in L) @@ -172,7 +172,7 @@ def test_call_stack(): def test_identifier(): - frame = sys._current_frames()[get_thread_identity()] + frame = sys._current_frames()[threading.get_ident()] assert identifier(frame) == identifier(frame) assert identifier(None) == identifier(None) diff --git a/distributed/tests/test_utils.py b/distributed/tests/test_utils.py index 9e3a1d90c4b..c547834626d 100644 --- a/distributed/tests/test_utils.py +++ b/distributed/tests/test_utils.py @@ -3,6 +3,7 @@ import datetime from functools import partial import io +import queue import socket import sys from time import sleep @@ -15,7 +16,6 @@ from tornado.ioloop import IOLoop import dask -from distributed.compatibility import Queue, Empty, PY2 from distributed.metrics import time from distributed.utils import ( All, @@ -278,8 +278,6 @@ def f(): def test_ensure_bytes(): data = [b"1", "1", memoryview(b"1"), bytearray(b"1")] - if PY2: - data.append(buffer(b"1")) # noqa: F821 for d in data: result = ensure_bytes(d) assert isinstance(result, bytes) @@ -319,7 +317,7 @@ def assert_running(loop): """ Raise if the given IOLoop is not running. """ - q = Queue() + q = queue.Queue() loop.add_callback(q.put, 42) assert q.get(timeout=1) == 42 @@ -328,14 +326,14 @@ def assert_not_running(loop): """ Raise if the given IOLoop is running. """ - q = Queue() + q = queue.Queue() try: loop.add_callback(q.put, 42) except RuntimeError: # On AsyncIOLoop, can't add_callback() after the loop is closed pass else: - with pytest.raises(Empty): + with pytest.raises(queue.Empty): q.get(timeout=0.02) diff --git a/distributed/tests/test_utils_perf.py b/distributed/tests/test_utils_perf.py index 55b250273c0..95fa816a75b 100644 --- a/distributed/tests/test_utils_perf.py +++ b/distributed/tests/test_utils_perf.py @@ -8,7 +8,6 @@ import pytest -from distributed.compatibility import PY2 from distributed.metrics import thread_time from distributed.utils_perf import FractionalTimer, GCDiagnosis, disable_gc_diagnosis from distributed.utils_test import captured_logger, run_for @@ -84,7 +83,6 @@ def enable_gc_diagnosis_and_log(diag, level="INFO"): gc.enable() -@pytest.mark.skipif(PY2, reason="requires Python 3") def test_gc_diagnosis_cpu_time(): diag = GCDiagnosis(warn_over_frac=0.75) diag.N_SAMPLES = 3 # shorten tests @@ -115,7 +113,6 @@ def test_gc_diagnosis_cpu_time(): @pytest.mark.xfail(reason="unknown") -@pytest.mark.skipif(PY2, reason="requires Python 3") def test_gc_diagnosis_rss_win(): diag = GCDiagnosis(info_over_rss_win=10e6) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 96c673bf69a..c7337d36424 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -2,6 +2,7 @@ from concurrent.futures import ThreadPoolExecutor from datetime import timedelta +import importlib import logging import multiprocessing from numbers import Number @@ -32,7 +33,7 @@ Reschedule, wait, ) -from distributed.compatibility import WINDOWS, cache_from_source +from distributed.compatibility import WINDOWS from distributed.core import rpc from distributed.scheduler import Scheduler from distributed.metrics import time @@ -219,7 +220,7 @@ def test_upload_file_pyc(c, s, w): import foo assert foo.f() == 123 - pyc = cache_from_source(os.path.join(dirname, "foo.py")) + pyc = importlib.util.cache_from_source(os.path.join(dirname, "foo.py")) assert os.path.exists(pyc) yield c.upload_file(pyc) diff --git a/distributed/threadpoolexecutor.py b/distributed/threadpoolexecutor.py index d2d4e3b7921..f4cae3fd88e 100644 --- a/distributed/threadpoolexecutor.py +++ b/distributed/threadpoolexecutor.py @@ -23,9 +23,9 @@ from __future__ import print_function, division, absolute_import from . import _concurrent_futures_thread as thread -from .compatibility import Empty import os import logging +import queue import threading import itertools @@ -51,7 +51,7 @@ def _worker(executor, work_queue): break try: task = work_queue.get(timeout=1) - except Empty: + except queue.Empty: continue if task is not None: # sentinel task.run() diff --git a/distributed/utils.py b/distributed/utils.py index 71e4a09d2a1..ab45350bf61 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -12,13 +12,14 @@ import logging import multiprocessing from numbers import Number -import operator import os import re import shutil import socket from time import sleep -from importlib import import_module +import importlib +from importlib.util import cache_from_source +import inspect import sys import tempfile import threading @@ -28,8 +29,6 @@ import six import tblib.pickling_support -from .compatibility import cache_from_source, getargspec, invalidate_caches, reload - try: import resource except ImportError: @@ -50,7 +49,7 @@ except ImportError: PollIOLoop = None # dropped in tornado 6.0 -from .compatibility import PY3, PY2, get_thread_identity, unicode +from .compatibility import PYPY, WINDOWS from .metrics import time @@ -66,7 +65,9 @@ def _initialize_mp_context(): - if PY3 and not sys.platform.startswith("win") and "PyPy" not in sys.version: + if WINDOWS or PYPY: + return multiprocessing + else: method = dask.config.get("distributed.worker.multiprocessing-method") ctx = multiprocessing.get_context(method) # Makes the test suite much faster @@ -74,10 +75,7 @@ def _initialize_mp_context(): if "pkg_resources" in sys.modules: preload.append("pkg_resources") ctx.set_forkserver_preload(preload) - else: - ctx = multiprocessing - - return ctx + return ctx mp_context = _initialize_mp_context() @@ -99,7 +97,7 @@ def has_arg(func, argname): """ while True: try: - if argname in getargspec(func).args: + if argname in inspect.getfullargspec(func).args: return True except TypeError: break @@ -298,14 +296,14 @@ def sync(loop, func, *args, callback_timeout=None, **kwargs): pass e = threading.Event() - main_tid = get_thread_identity() + main_tid = threading.get_ident() result = [None] error = [False] @gen.coroutine def f(): try: - if main_tid == get_thread_identity(): + if main_tid == threading.get_ident(): raise RuntimeError("sync() called from thread of running loop") yield gen.moment thread_state.asynchronous = True @@ -552,6 +550,7 @@ def is_kernel(): hex_pattern = re.compile("[a-f]+") +@functools.lru_cache(100000) def key_split(s): """ >>> key_split('x') @@ -606,102 +605,48 @@ def key_split(s): return "Other" -try: - from functools import lru_cache -except ImportError: - lru_cache = False - pass -else: - key_split = lru_cache(100000)(key_split) - -if PY3: - - def key_split_group(x): - """A more fine-grained version of key_split - - >>> key_split_group('x') - 'x' - >>> key_split_group('x-1') - 'x-1' - >>> key_split_group('x-1-2-3') - 'x-1-2-3' - >>> key_split_group(('x-2', 1)) - 'x-2' - >>> key_split_group("('x-2', 1)") - 'x-2' - >>> key_split_group('hello-world-1') - 'hello-world-1' - >>> key_split_group(b'hello-world-1') - 'hello-world-1' - >>> key_split_group('ae05086432ca935f6eba409a8ecd4896') - 'data' - >>> key_split_group('>> key_split_group(None) - 'Other' - >>> key_split_group('x-abcdefab') # ignores hex - 'x-abcdefab' - """ - typ = type(x) - if typ is tuple: - return x[0] - elif typ is str: - if x[0] == "(": - return x.split(",", 1)[0].strip("()\"'") - elif len(x) == 32 and re.match(r"[a-f0-9]{32}", x): - return "data" - elif x[0] == "<": - return x.strip("<>").split()[0].split(".")[-1] - else: - return x - elif typ is bytes: - return key_split_group(x.decode()) - else: - return "Other" +def key_split_group(x): + """A more fine-grained version of key_split - -else: - - def key_split_group(x): - """A more fine-grained version of key_split - - >>> key_split_group('x') - 'x' - >>> key_split_group('x-1') - 'x-1' - >>> key_split_group('x-1-2-3') - 'x-1-2-3' - >>> key_split_group(('x-2', 1)) - 'x-2' - >>> key_split_group("('x-2', 1)") - 'x-2' - >>> key_split_group('hello-world-1') - 'hello-world-1' - >>> key_split_group(b'hello-world-1') - 'hello-world-1' - >>> key_split_group('ae05086432ca935f6eba409a8ecd4896') - 'data' - >>> key_split_group('>> key_split_group(None) - 'Other' - >>> key_split_group('x-abcdefab') # ignores hex - 'x-abcdefab' - """ - typ = type(x) - if typ is tuple: - return x[0] - elif typ is str or typ is unicode: - if x[0] == "(": - return x.split(",", 1)[0].strip("()\"'") - elif len(x) == 32 and re.match(r"[a-f0-9]{32}", x): - return "data" - elif x[0] == "<": - return x.strip("<>").split()[0].split(".")[-1] - else: - return x + >>> key_split_group('x') + 'x' + >>> key_split_group('x-1') + 'x-1' + >>> key_split_group('x-1-2-3') + 'x-1-2-3' + >>> key_split_group(('x-2', 1)) + 'x-2' + >>> key_split_group("('x-2', 1)") + 'x-2' + >>> key_split_group('hello-world-1') + 'hello-world-1' + >>> key_split_group(b'hello-world-1') + 'hello-world-1' + >>> key_split_group('ae05086432ca935f6eba409a8ecd4896') + 'data' + >>> key_split_group('>> key_split_group(None) + 'Other' + >>> key_split_group('x-abcdefab') # ignores hex + 'x-abcdefab' + """ + typ = type(x) + if typ is tuple: + return x[0] + elif typ is str: + if x[0] == "(": + return x.split(",", 1)[0].strip("()\"'") + elif len(x) == 32 and re.match(r"[a-f0-9]{32}", x): + return "data" + elif x[0] == "<": + return x.strip("<>").split()[0].split(".")[-1] else: - return "Other" + return x + elif typ is bytes: + return key_split_group(x.decode()) + else: + return "Other" @contextmanager @@ -810,7 +755,7 @@ def tokey(o): '1' """ typ = type(o) - if typ is unicode or typ is bytes: + if typ is str or typ is bytes: return o else: return str(o) @@ -820,7 +765,7 @@ def validate_key(k): """Validate a key as received on a stream. """ typ = type(k) - if typ is not unicode and typ is not bytes: + if typ is not str and typ is not bytes: raise TypeError("Unexpected key type %s (value: %r)" % (typ, k)) @@ -970,7 +915,7 @@ def ensure_bytes(s): return s if isinstance(s, memoryview): return s.tobytes() - if isinstance(s, bytearray) or PY2 and isinstance(s, buffer): # noqa: F821 + if isinstance(s, bytearray): # noqa: F821 return bytes(s) if hasattr(s, "encode"): return s.encode() @@ -1075,13 +1020,13 @@ def import_file(path): if not names_to_import: logger.warning("Found nothing to import from %s", filename) else: - invalidate_caches() + importlib.invalidate_caches() if tmp_python_path is not None: sys.path.insert(0, tmp_python_path) try: for name in names_to_import: logger.info("Reload module %s from %s file", name, ext) - loaded.append(reload(import_module(name))) + loaded.append(importlib.reload(importlib.import_module(name))) finally: if tmp_python_path is not None: sys.path.remove(tmp_python_path) @@ -1261,32 +1206,15 @@ def asciitable(columns, rows): return "\n".join([bar, header, bar, data, bar]) -if PY2: - - def nbytes(frame, _bytes_like=(bytes, bytearray, buffer)): # noqa: F821 - """ Number of bytes of a frame or memoryview """ - if isinstance(frame, _bytes_like): - return len(frame) - elif isinstance(frame, memoryview): - if frame.shape is None: - return frame.itemsize - else: - return functools.reduce(operator.mul, frame.shape, frame.itemsize) - else: +def nbytes(frame, _bytes_like=(bytes, bytearray)): + """ Number of bytes of a frame or memoryview """ + if isinstance(frame, _bytes_like): + return len(frame) + else: + try: return frame.nbytes - - -else: - - def nbytes(frame, _bytes_like=(bytes, bytearray)): - """ Number of bytes of a frame or memoryview """ - if isinstance(frame, _bytes_like): + except AttributeError: return len(frame) - else: - try: - return frame.nbytes - except AttributeError: - return len(frame) def PeriodicCallback(callback, callback_time, io_loop=None): @@ -1402,18 +1330,9 @@ def reset_logger_locks(): ) +@functools.lru_cache(1000) def has_keyword(func, keyword): - if PY3: - return keyword in inspect.signature(func).parameters - else: - # https://stackoverflow.com/questions/50100498/determine-keywords-of-a-tornado-coroutine - if gen.is_coroutine_function(func): - func = func.__wrapped__ - return keyword in inspect.getargspec(func).args - - -if lru_cache: - has_keyword = lru_cache(1000)(has_keyword) + return keyword in inspect.signature(func).parameters # from bokeh.palettes import viridis diff --git a/distributed/utils_perf.py b/distributed/utils_perf.py index b1f65256c1e..eb54ea0b381 100644 --- a/distributed/utils_perf.py +++ b/distributed/utils_perf.py @@ -7,7 +7,7 @@ from dask.utils import format_bytes -from .compatibility import PY2, PYPY +from .compatibility import PYPY from .metrics import thread_time @@ -147,7 +147,7 @@ def __init__(self, warn_over_frac=0.1, info_over_rss_win=10 * 1e6): self._enabled = False def enable(self): - if PY2 or PYPY: + if PYPY: return assert not self._enabled self._fractional_timer = FractionalTimer(n_samples=self.N_SAMPLES) @@ -165,7 +165,7 @@ def enable(self): self._enabled = True def disable(self): - if PY2 or PYPY: + if PYPY: return assert self._enabled gc.callbacks.remove(self._gc_callback) @@ -232,7 +232,7 @@ def enable_gc_diagnosis(): """ Ask to enable global GC diagnosis. """ - if PY2 or PYPY: + if PYPY: return global _gc_diagnosis_users with _gc_diagnosis_lock: @@ -247,7 +247,7 @@ def disable_gc_diagnosis(force=False): """ Ask to disable global GC diagnosis. """ - if PY2 or PYPY: + if PYPY: return global _gc_diagnosis_users with _gc_diagnosis_lock: diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 52cc54b639d..0c7c8958a91 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -6,12 +6,12 @@ import copy from datetime import timedelta import functools -import gc from glob import glob import itertools import logging import logging.config import os +import queue import re import shutil import signal @@ -41,7 +41,7 @@ from tornado.ioloop import IOLoop from .client import default_client, _global_clients, Client -from .compatibility import PY3, Empty, WINDOWS +from .compatibility import WINDOWS from .comm import Comm from .comm.utils import offload from .config import initialize_logging @@ -225,11 +225,6 @@ def nodebug(func): A decorator to disable debug facilities during timing-sensitive tests. Warning: this doesn't affect already created IOLoops. """ - if not PY3: - # py.test's runner magic breaks horridly on Python 2 - # when a test function is wrapped, so avoid it - # (incidently, asyncio is irrelevant anyway) - return func @functools.wraps(func) def wrapped(*args, **kwargs): @@ -517,10 +512,6 @@ async def _(): @contextmanager def check_active_rpc(loop, active_rpc_timeout=1): active_before = set(rpc.active) - if active_before and not PY3: - # On Python 2, try to avoid dangling comms before forking workers - gc.collect() - active_before = set(rpc.active) yield # Some streams can take a bit of time to notice their peer # has closed, and keep a coroutine (*) waiting for a CommClosedError @@ -664,7 +655,7 @@ def cluster( try: for worker in workers: worker["address"] = worker["queue"].get(timeout=5) - except Empty: + except queue.Empty: raise pytest.xfail.Exception("Worker failed to start in test") saddr = scheduler_q.get() diff --git a/distributed/worker.py b/distributed/worker.py index 08927eb741d..ae7ee648ff9 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -2,7 +2,7 @@ import asyncio import bisect -from collections import defaultdict, deque +from collections import defaultdict, deque, MutableMapping from datetime import timedelta import heapq import logging @@ -35,7 +35,6 @@ from .comm import get_address_host, connect from .comm.utils import offload from .comm.addressing import address_from_user_args -from .compatibility import unicode, get_thread_identity, MutableMapping from .core import error_message, CommClosedError, send_recv, pingpong, coerce_to_address from .diskutils import WorkSpace from .metrics import time @@ -856,7 +855,7 @@ async def upload_file(self, comm, filename=None, data=None, load=True): out_filename = os.path.join(self.local_directory, filename) def func(data): - if isinstance(data, unicode): + if isinstance(data, str): data = data.encode() with open(out_filename, "wb") as f: f.write(data) @@ -2303,8 +2302,8 @@ async def execute(self, key, report=False): from .actor import Actor # TODO: create local actor data[k] = Actor(type(self.actors[k]), self.address, k, self) - args2 = pack_data(args, data, key_types=(bytes, unicode)) - kwargs2 = pack_data(kwargs, data, key_types=(bytes, unicode)) + args2 = pack_data(args, data, key_types=(bytes, str)) + kwargs2 = pack_data(kwargs, data, key_types=(bytes, str)) stop = time() if stop - start > 0.005: self.startstops[key].append(("disk-read", start, stop)) @@ -2796,7 +2795,7 @@ def get_current_task(self): -------- get_worker """ - return self.active_threads[get_thread_identity()] + return self.active_threads[threading.get_ident()] def get_worker(): @@ -2946,7 +2945,7 @@ def parse_memory_limit(memory_limit, nthreads, total_cores=multiprocessing.cpu_c if isinstance(memory_limit, float) and memory_limit <= 1: memory_limit = int(memory_limit * TOTAL_MEMORY) - if isinstance(memory_limit, (unicode, str)): + if isinstance(memory_limit, str): memory_limit = parse_bytes(memory_limit) else: memory_limit = int(memory_limit) @@ -3146,7 +3145,7 @@ def apply_function( ------- msg: dictionary with status, result/error, timings, etc.. """ - ident = get_thread_identity() + ident = threading.get_ident() with active_threads_lock: active_threads[ident] = key thread_state.start_time = time() @@ -3186,7 +3185,7 @@ def apply_function_actor( ------- msg: dictionary with status, result/error, timings, etc.. """ - ident = get_thread_identity() + ident = threading.get_ident() with active_threads_lock: active_threads[ident] = key From cc4fc7d8fe42c41b328ade1db54dd88da904244a Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 29 Jul 2019 13:12:20 -0700 Subject: [PATCH 0385/1550] Remove future imports (#2897) --- distributed/__init__.py | 2 -- distributed/_ipython_utils.py | 2 -- distributed/batched.py | 2 -- distributed/cfexecutor.py | 2 -- distributed/cli/dask_remote.py | 2 -- distributed/cli/dask_scheduler.py | 2 -- distributed/cli/dask_ssh.py | 2 -- distributed/cli/dask_worker.py | 2 -- distributed/cli/tests/test_dask_scheduler.py | 2 -- distributed/cli/tests/test_dask_worker.py | 2 -- distributed/cli/tests/test_tls_cli.py | 3 --- distributed/cli/utils.py | 2 -- distributed/client.py | 2 -- distributed/comm/__init__.py | 2 -- distributed/comm/addressing.py | 2 -- distributed/comm/core.py | 2 -- distributed/comm/inproc.py | 2 -- distributed/comm/registry.py | 2 -- distributed/comm/tcp.py | 2 -- distributed/comm/tests/test_comms.py | 2 -- distributed/comm/utils.py | 2 -- distributed/compatibility.py | 2 -- distributed/config.py | 2 -- distributed/core.py | 2 -- distributed/counter.py | 2 -- distributed/dashboard/components.py | 2 -- distributed/dashboard/core.py | 2 -- distributed/dashboard/export_tool.py | 2 -- distributed/dashboard/scheduler.py | 2 -- distributed/dashboard/tests/test_components.py | 2 -- distributed/dashboard/tests/test_scheduler_bokeh.py | 2 -- distributed/dashboard/tests/test_scheduler_bokeh_html.py | 2 -- distributed/dashboard/tests/test_worker_bokeh.py | 2 -- distributed/dashboard/utils.py | 2 -- distributed/dashboard/worker.py | 2 -- distributed/deploy/__init__.py | 2 -- distributed/deploy/adaptive.py | 2 -- distributed/deploy/local.py | 2 -- distributed/deploy/ssh.py | 2 -- distributed/deploy/tests/test_adaptive.py | 2 -- distributed/deploy/tests/test_local.py | 2 -- distributed/deploy/tests/test_ssh.py | 2 -- distributed/diagnostics/__init__.py | 2 -- distributed/diagnostics/eventstream.py | 2 -- distributed/diagnostics/graph_layout.py | 2 -- distributed/diagnostics/plugin.py | 2 -- distributed/diagnostics/progress.py | 2 -- distributed/diagnostics/progress_stream.py | 2 -- distributed/diagnostics/progressbar.py | 2 -- distributed/diagnostics/task_stream.py | 2 -- distributed/diagnostics/tests/test_eventstream.py | 2 -- distributed/diagnostics/tests/test_plugin.py | 2 -- distributed/diagnostics/tests/test_progress.py | 2 -- distributed/diagnostics/tests/test_progress_stream.py | 3 --- distributed/diagnostics/tests/test_progressbar.py | 2 -- distributed/diagnostics/tests/test_task_stream.py | 2 -- distributed/diagnostics/tests/test_widgets.py | 2 -- distributed/diskutils.py | 2 -- distributed/lock.py | 2 -- distributed/metrics.py | 2 -- distributed/nanny.py | 2 -- distributed/node.py | 2 -- distributed/process.py | 2 -- distributed/proctitle.py | 2 -- distributed/profile.py | 2 -- distributed/protocol/__init__.py | 2 -- distributed/protocol/arrow.py | 2 -- distributed/protocol/compression.py | 2 -- distributed/protocol/core.py | 2 -- distributed/protocol/h5py.py | 2 -- distributed/protocol/keras.py | 2 -- distributed/protocol/netcdf4.py | 2 -- distributed/protocol/numpy.py | 2 -- distributed/protocol/pickle.py | 2 -- distributed/protocol/serialize.py | 1 - distributed/protocol/sparse.py | 2 -- distributed/protocol/tests/test_numpy.py | 2 -- distributed/protocol/tests/test_pandas.py | 3 --- distributed/protocol/tests/test_protocol.py | 2 -- distributed/protocol/tests/test_protocol_utils.py | 2 -- distributed/protocol/tests/test_serialize.py | 2 -- distributed/protocol/utils.py | 2 -- distributed/pytest_resourceleaks.py | 2 -- distributed/queues.py | 2 -- distributed/recreate_exceptions.py | 2 -- distributed/scheduler.py | 2 -- distributed/sizeof.py | 2 -- distributed/stealing.py | 2 -- distributed/submit.py | 2 -- distributed/system_monitor.py | 2 -- distributed/tests/test_asyncprocess.py | 2 -- distributed/tests/test_client.py | 2 -- distributed/tests/test_client_executor.py | 2 -- distributed/tests/test_collections.py | 3 --- distributed/tests/test_config.py | 2 -- distributed/tests/test_core.py | 2 -- distributed/tests/test_counter.py | 2 -- distributed/tests/test_diskutils.py | 2 -- distributed/tests/test_failed_workers.py | 2 -- distributed/tests/test_ipython.py | 2 -- distributed/tests/test_locks.py | 2 -- distributed/tests/test_metrics.py | 2 -- distributed/tests/test_nanny.py | 2 -- distributed/tests/test_queues.py | 2 -- distributed/tests/test_resources.py | 2 -- distributed/tests/test_scheduler.py | 2 -- distributed/tests/test_security.py | 2 -- distributed/tests/test_steal.py | 6 ++---- distributed/tests/test_stress.py | 2 -- distributed/tests/test_submit_cli.py | 1 - distributed/tests/test_system_monitor.py | 2 -- distributed/tests/test_tls_functional.py | 4 ---- distributed/tests/test_utils.py | 2 -- distributed/tests/test_utils_comm.py | 2 -- distributed/tests/test_utils_perf.py | 2 -- distributed/tests/test_utils_test.py | 2 -- distributed/tests/test_variable.py | 2 -- distributed/tests/test_worker.py | 2 -- distributed/tests/test_worker_client.py | 2 -- distributed/threadpoolexecutor.py | 2 -- distributed/utils.py | 2 -- distributed/utils_comm.py | 2 -- distributed/utils_perf.py | 2 -- distributed/utils_test.py | 2 -- distributed/variable.py | 2 -- distributed/versions.py | 2 -- distributed/worker.py | 2 -- distributed/worker_client.py | 2 -- 128 files changed, 2 insertions(+), 262 deletions(-) diff --git a/distributed/__init__.py b/distributed/__init__.py index 2a632607cf9..ca36613c815 100644 --- a/distributed/__init__.py +++ b/distributed/__init__.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - from . import config from dask.config import config from .actor import Actor, ActorFuture diff --git a/distributed/_ipython_utils.py b/distributed/_ipython_utils.py index 512f8911588..1a999833786 100644 --- a/distributed/_ipython_utils.py +++ b/distributed/_ipython_utils.py @@ -4,8 +4,6 @@ after which we can import them instead of having our own definitions. """ -from __future__ import print_function - import atexit import os diff --git a/distributed/batched.py b/distributed/batched.py index e17d7b1f1bd..a3207b333ef 100644 --- a/distributed/batched.py +++ b/distributed/batched.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - from collections import deque import logging diff --git a/distributed/cfexecutor.py b/distributed/cfexecutor.py index eb7bbf05646..34350462f8b 100644 --- a/distributed/cfexecutor.py +++ b/distributed/cfexecutor.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import concurrent.futures as cf import weakref diff --git a/distributed/cli/dask_remote.py b/distributed/cli/dask_remote.py index 29cc5c3c784..3118da84ae7 100644 --- a/distributed/cli/dask_remote.py +++ b/distributed/cli/dask_remote.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import click from distributed.cli.utils import check_python_3, install_signal_handlers from distributed.submit import _remote diff --git a/distributed/cli/dask_scheduler.py b/distributed/cli/dask_scheduler.py index 54ecd69e595..07e4c98e267 100755 --- a/distributed/cli/dask_scheduler.py +++ b/distributed/cli/dask_scheduler.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import atexit import logging import gc diff --git a/distributed/cli/dask_ssh.py b/distributed/cli/dask_ssh.py index 1d264dc80e5..389e0327688 100755 --- a/distributed/cli/dask_ssh.py +++ b/distributed/cli/dask_ssh.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - from distributed.deploy.ssh import SSHCluster import click diff --git a/distributed/cli/dask_worker.py b/distributed/cli/dask_worker.py index 084d7b59ccc..44931393522 100755 --- a/distributed/cli/dask_worker.py +++ b/distributed/cli/dask_worker.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import atexit import logging import multiprocessing diff --git a/distributed/cli/tests/test_dask_scheduler.py b/distributed/cli/tests/test_dask_scheduler.py index 24737474165..cb6cc306b6c 100644 --- a/distributed/cli/tests/test_dask_scheduler.py +++ b/distributed/cli/tests/test_dask_scheduler.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import pytest pytest.importorskip("requests") diff --git a/distributed/cli/tests/test_dask_worker.py b/distributed/cli/tests/test_dask_worker.py index e268229767d..2dd74737b16 100644 --- a/distributed/cli/tests/test_dask_worker.py +++ b/distributed/cli/tests/test_dask_worker.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import pytest from click.testing import CliRunner diff --git a/distributed/cli/tests/test_tls_cli.py b/distributed/cli/tests/test_tls_cli.py index 37fdc9bb00f..def31bc244d 100644 --- a/distributed/cli/tests/test_tls_cli.py +++ b/distributed/cli/tests/test_tls_cli.py @@ -1,8 +1,5 @@ -from __future__ import print_function, division, absolute_import - from time import sleep - from distributed import Client from distributed.utils_test import ( popen, diff --git a/distributed/cli/utils.py b/distributed/cli/utils.py index 2c2088a7556..4cfb41abe0f 100644 --- a/distributed/cli/utils.py +++ b/distributed/cli/utils.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - from tornado import gen from tornado.ioloop import IOLoop diff --git a/distributed/client.py b/distributed/client.py index e1efda4e0d7..d9f02ba80f8 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import atexit from collections import defaultdict, Iterator from concurrent.futures import ThreadPoolExecutor, CancelledError diff --git a/distributed/comm/__init__.py b/distributed/comm/__init__.py index e0615b38c7a..3537b301573 100644 --- a/distributed/comm/__init__.py +++ b/distributed/comm/__init__.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - from .addressing import ( parse_address, unparse_address, diff --git a/distributed/comm/addressing.py b/distributed/comm/addressing.py index 54e37b77f6b..8480134997c 100644 --- a/distributed/comm/addressing.py +++ b/distributed/comm/addressing.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import six import dask diff --git a/distributed/comm/core.py b/distributed/comm/core.py index 869cb9b377f..602b3161657 100644 --- a/distributed/comm/core.py +++ b/distributed/comm/core.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - from abc import ABCMeta, abstractmethod, abstractproperty from datetime import timedelta import logging diff --git a/distributed/comm/inproc.py b/distributed/comm/inproc.py index 3a781479bbc..5235b7535fd 100644 --- a/distributed/comm/inproc.py +++ b/distributed/comm/inproc.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - from collections import deque, namedtuple import itertools import logging diff --git a/distributed/comm/registry.py b/distributed/comm/registry.py index a646b4d71b9..b7fcca912cd 100644 --- a/distributed/comm/registry.py +++ b/distributed/comm/registry.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - from abc import ABCMeta, abstractmethod from six import with_metaclass diff --git a/distributed/comm/tcp.py b/distributed/comm/tcp.py index d23f381857d..36783102b69 100644 --- a/distributed/comm/tcp.py +++ b/distributed/comm/tcp.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import errno import logging import socket diff --git a/distributed/comm/tests/test_comms.py b/distributed/comm/tests/test_comms.py index 7fac117027b..620d4b89c94 100644 --- a/distributed/comm/tests/test_comms.py +++ b/distributed/comm/tests/test_comms.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - from functools import partial import os import sys diff --git a/distributed/comm/utils.py b/distributed/comm/utils.py index 1e23b25c46b..dcc9e9a8b1a 100644 --- a/distributed/comm/utils.py +++ b/distributed/comm/utils.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - from concurrent.futures import ThreadPoolExecutor import logging import socket diff --git a/distributed/compatibility.py b/distributed/compatibility.py index fb79353d24b..186e66e485c 100644 --- a/distributed/compatibility.py +++ b/distributed/compatibility.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import logging import platform import sys diff --git a/distributed/config.py b/distributed/config.py index 5c71cf570c8..7e6075125fd 100644 --- a/distributed/config.py +++ b/distributed/config.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import logging import logging.config import os diff --git a/distributed/core.py b/distributed/core.py index 7db7b3e29e5..f97d2df382a 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import asyncio from collections import defaultdict, deque from concurrent.futures import CancelledError diff --git a/distributed/counter.py b/distributed/counter.py index d5a3181b112..f41961e87ac 100644 --- a/distributed/counter.py +++ b/distributed/counter.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - from collections import defaultdict from tornado.ioloop import IOLoop diff --git a/distributed/dashboard/components.py b/distributed/dashboard/components.py index 242a617706e..7fb8a6cb022 100644 --- a/distributed/dashboard/components.py +++ b/distributed/dashboard/components.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import asyncio from bisect import bisect from operator import add diff --git a/distributed/dashboard/core.py b/distributed/dashboard/core.py index a85efb3233c..fd6ebef2834 100644 --- a/distributed/dashboard/core.py +++ b/distributed/dashboard/core.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - from distutils.version import LooseVersion import os import warnings diff --git a/distributed/dashboard/export_tool.py b/distributed/dashboard/export_tool.py index 5d8f1c067ae..d93d21b881b 100644 --- a/distributed/dashboard/export_tool.py +++ b/distributed/dashboard/export_tool.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import os from bokeh.core.properties import Int, String diff --git a/distributed/dashboard/scheduler.py b/distributed/dashboard/scheduler.py index e41862335cd..8396bbcb6ae 100644 --- a/distributed/dashboard/scheduler.py +++ b/distributed/dashboard/scheduler.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - from functools import partial import logging import math diff --git a/distributed/dashboard/tests/test_components.py b/distributed/dashboard/tests/test_components.py index b12780f199b..5e96d788e45 100644 --- a/distributed/dashboard/tests/test_components.py +++ b/distributed/dashboard/tests/test_components.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import pytest pytest.importorskip("bokeh") diff --git a/distributed/dashboard/tests/test_scheduler_bokeh.py b/distributed/dashboard/tests/test_scheduler_bokeh.py index 8544e72d9f4..e9ac62aad41 100644 --- a/distributed/dashboard/tests/test_scheduler_bokeh.py +++ b/distributed/dashboard/tests/test_scheduler_bokeh.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import json import re import ssl diff --git a/distributed/dashboard/tests/test_scheduler_bokeh_html.py b/distributed/dashboard/tests/test_scheduler_bokeh_html.py index b66aff02ddc..fc19efb0812 100644 --- a/distributed/dashboard/tests/test_scheduler_bokeh_html.py +++ b/distributed/dashboard/tests/test_scheduler_bokeh_html.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import json import re import xml.etree.ElementTree diff --git a/distributed/dashboard/tests/test_worker_bokeh.py b/distributed/dashboard/tests/test_worker_bokeh.py index d320ea24ee8..c490c825ab4 100644 --- a/distributed/dashboard/tests/test_worker_bokeh.py +++ b/distributed/dashboard/tests/test_worker_bokeh.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - from operator import add, sub import re from time import sleep diff --git a/distributed/dashboard/utils.py b/distributed/dashboard/utils.py index 8e6b5ff0b9c..285f6a5772a 100644 --- a/distributed/dashboard/utils.py +++ b/distributed/dashboard/utils.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - from distutils.version import LooseVersion import os diff --git a/distributed/dashboard/worker.py b/distributed/dashboard/worker.py index d8f8adc1c7d..402d3fd0a70 100644 --- a/distributed/dashboard/worker.py +++ b/distributed/dashboard/worker.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - from functools import partial import logging import math diff --git a/distributed/deploy/__init__.py b/distributed/deploy/__init__.py index 24a86e6d6d2..5a5a9106005 100644 --- a/distributed/deploy/__init__.py +++ b/distributed/deploy/__init__.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - from ..utils import ignoring from .cluster import Cluster diff --git a/distributed/deploy/adaptive.py b/distributed/deploy/adaptive.py index 761a7d300ee..9b1d8511045 100644 --- a/distributed/deploy/adaptive.py +++ b/distributed/deploy/adaptive.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - from collections import deque import logging diff --git a/distributed/deploy/local.py b/distributed/deploy/local.py index 5b0aec4e80c..877a74587e9 100644 --- a/distributed/deploy/local.py +++ b/distributed/deploy/local.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import atexit import logging import math diff --git a/distributed/deploy/ssh.py b/distributed/deploy/ssh.py index ba8ed01d1c7..9390d00a2ab 100644 --- a/distributed/deploy/ssh.py +++ b/distributed/deploy/ssh.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import logging import socket import os diff --git a/distributed/deploy/tests/test_adaptive.py b/distributed/deploy/tests/test_adaptive.py index e0478a9cbdb..2d3d2235e21 100644 --- a/distributed/deploy/tests/test_adaptive.py +++ b/distributed/deploy/tests/test_adaptive.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - from time import sleep import pytest diff --git a/distributed/deploy/tests/test_local.py b/distributed/deploy/tests/test_local.py index 91a792272e8..f434945c3af 100644 --- a/distributed/deploy/tests/test_local.py +++ b/distributed/deploy/tests/test_local.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - from functools import partial import gc import multiprocessing diff --git a/distributed/deploy/tests/test_ssh.py b/distributed/deploy/tests/test_ssh.py index a86a8ddd280..492ee2c792d 100644 --- a/distributed/deploy/tests/test_ssh.py +++ b/distributed/deploy/tests/test_ssh.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - from time import sleep import pytest diff --git a/distributed/diagnostics/__init__.py b/distributed/diagnostics/__init__.py index 9469c3855d1..2ab9fac731f 100644 --- a/distributed/diagnostics/__init__.py +++ b/distributed/diagnostics/__init__.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - from ..utils import ignoring from .graph_layout import GraphLayout diff --git a/distributed/diagnostics/eventstream.py b/distributed/diagnostics/eventstream.py index b9213144d4e..c0fde24470b 100644 --- a/distributed/diagnostics/eventstream.py +++ b/distributed/diagnostics/eventstream.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import logging from .plugin import SchedulerPlugin diff --git a/distributed/diagnostics/graph_layout.py b/distributed/diagnostics/graph_layout.py index 62e115a9ad4..c81c6edcafe 100644 --- a/distributed/diagnostics/graph_layout.py +++ b/distributed/diagnostics/graph_layout.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - from .plugin import SchedulerPlugin diff --git a/distributed/diagnostics/plugin.py b/distributed/diagnostics/plugin.py index e1da4378fd4..cfe5fa42b49 100644 --- a/distributed/diagnostics/plugin.py +++ b/distributed/diagnostics/plugin.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import logging logger = logging.getLogger(__name__) diff --git a/distributed/diagnostics/progress.py b/distributed/diagnostics/progress.py index 50c4cd9fad1..4136fd17a5c 100644 --- a/distributed/diagnostics/progress.py +++ b/distributed/diagnostics/progress.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - from collections import defaultdict import logging from timeit import default_timer diff --git a/distributed/diagnostics/progress_stream.py b/distributed/diagnostics/progress_stream.py index b1e3787bd5a..038237b89e2 100644 --- a/distributed/diagnostics/progress_stream.py +++ b/distributed/diagnostics/progress_stream.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import logging from toolz import valmap, merge diff --git a/distributed/diagnostics/progressbar.py b/distributed/diagnostics/progressbar.py index 4c9b781f61c..01dc9bbea39 100644 --- a/distributed/diagnostics/progressbar.py +++ b/distributed/diagnostics/progressbar.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import logging import html from timeit import default_timer diff --git a/distributed/diagnostics/task_stream.py b/distributed/diagnostics/task_stream.py index 17e62c3045e..2491c8a89c0 100644 --- a/distributed/diagnostics/task_stream.py +++ b/distributed/diagnostics/task_stream.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - from collections import deque import logging diff --git a/distributed/diagnostics/tests/test_eventstream.py b/distributed/diagnostics/tests/test_eventstream.py index 9139f75eab3..a111220b39e 100644 --- a/distributed/diagnostics/tests/test_eventstream.py +++ b/distributed/diagnostics/tests/test_eventstream.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import collections import pytest diff --git a/distributed/diagnostics/tests/test_plugin.py b/distributed/diagnostics/tests/test_plugin.py index 1c9ebd7a1a8..af29e81674d 100644 --- a/distributed/diagnostics/tests/test_plugin.py +++ b/distributed/diagnostics/tests/test_plugin.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - from distributed import Worker from distributed.utils_test import inc, gen_cluster from distributed.diagnostics.plugin import SchedulerPlugin diff --git a/distributed/diagnostics/tests/test_progress.py b/distributed/diagnostics/tests/test_progress.py index 097b2670247..8e3ba1688cc 100644 --- a/distributed/diagnostics/tests/test_progress.py +++ b/distributed/diagnostics/tests/test_progress.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import pytest from tornado import gen diff --git a/distributed/diagnostics/tests/test_progress_stream.py b/distributed/diagnostics/tests/test_progress_stream.py index 9cf89817f34..56da9e974c1 100644 --- a/distributed/diagnostics/tests/test_progress_stream.py +++ b/distributed/diagnostics/tests/test_progress_stream.py @@ -1,6 +1,3 @@ -from __future__ import print_function, division, absolute_import - - import pytest pytest.importorskip("bokeh") diff --git a/distributed/diagnostics/tests/test_progressbar.py b/distributed/diagnostics/tests/test_progressbar.py index 4e6ffe8c7e9..535efd0e9e2 100644 --- a/distributed/diagnostics/tests/test_progressbar.py +++ b/distributed/diagnostics/tests/test_progressbar.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - from time import sleep import pytest diff --git a/distributed/diagnostics/tests/test_task_stream.py b/distributed/diagnostics/tests/test_task_stream.py index ad23ca5ae8c..58f1c4319f6 100644 --- a/distributed/diagnostics/tests/test_task_stream.py +++ b/distributed/diagnostics/tests/test_task_stream.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import, division, print_function - import os from time import sleep diff --git a/distributed/diagnostics/tests/test_widgets.py b/distributed/diagnostics/tests/test_widgets.py index 033d49251cb..03689c88c1d 100644 --- a/distributed/diagnostics/tests/test_widgets.py +++ b/distributed/diagnostics/tests/test_widgets.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import pytest pytest.importorskip("ipywidgets") diff --git a/distributed/diskutils.py b/distributed/diskutils.py index 32e6be35adb..64dcf1dfc12 100644 --- a/distributed/diskutils.py +++ b/distributed/diskutils.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import errno import glob import logging diff --git a/distributed/lock.py b/distributed/lock.py index 6ad6ab607d3..ed3eb4313f2 100644 --- a/distributed/lock.py +++ b/distributed/lock.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - from collections import defaultdict, deque from datetime import timedelta import logging diff --git a/distributed/metrics.py b/distributed/metrics.py index 6c0bdb4dc7e..fefdfeb2e4c 100755 --- a/distributed/metrics.py +++ b/distributed/metrics.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import collections from functools import wraps import sys diff --git a/distributed/nanny.py b/distributed/nanny.py index 228e37c2839..771b2d11d2d 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - from datetime import timedelta import logging from multiprocessing.queues import Empty diff --git a/distributed/node.py b/distributed/node.py index cbf2c00d8f7..8ef610a8481 100644 --- a/distributed/node.py +++ b/distributed/node.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import logging import warnings import weakref diff --git a/distributed/process.py b/distributed/process.py index b6e50122c36..889787fe0bf 100644 --- a/distributed/process.py +++ b/distributed/process.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import atexit from datetime import timedelta import logging diff --git a/distributed/proctitle.py b/distributed/proctitle.py index 50c9859e17e..961c74b91ab 100644 --- a/distributed/proctitle.py +++ b/distributed/proctitle.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import os try: diff --git a/distributed/profile.py b/distributed/profile.py index 7f85f46312b..274dfcd1d20 100644 --- a/distributed/profile.py +++ b/distributed/profile.py @@ -24,8 +24,6 @@ 'children': {...}}} } """ -from __future__ import print_function, division, absolute_import - import bisect from collections import defaultdict, deque import linecache diff --git a/distributed/protocol/__init__.py b/distributed/protocol/__init__.py index 3f98436f4b9..e30786ab4a5 100644 --- a/distributed/protocol/__init__.py +++ b/distributed/protocol/__init__.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - from functools import partial from .compression import compressions, default_compression diff --git a/distributed/protocol/arrow.py b/distributed/protocol/arrow.py index cac146a575c..1f2b4e83e9a 100644 --- a/distributed/protocol/arrow.py +++ b/distributed/protocol/arrow.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - from .serialize import dask_serialize, dask_deserialize import pyarrow diff --git a/distributed/protocol/compression.py b/distributed/protocol/compression.py index f729748acc8..5035b465cee 100644 --- a/distributed/protocol/compression.py +++ b/distributed/protocol/compression.py @@ -3,8 +3,6 @@ Includes utilities for determining whether or not to compress """ -from __future__ import print_function, division, absolute_import - import logging import random diff --git a/distributed/protocol/core.py b/distributed/protocol/core.py index d54dd2e533e..3937c9c2fc8 100644 --- a/distributed/protocol/core.py +++ b/distributed/protocol/core.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import logging import operator diff --git a/distributed/protocol/h5py.py b/distributed/protocol/h5py.py index cf08719e259..e129c166683 100644 --- a/distributed/protocol/h5py.py +++ b/distributed/protocol/h5py.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - from .serialize import dask_serialize, dask_deserialize import h5py diff --git a/distributed/protocol/keras.py b/distributed/protocol/keras.py index 4c6fc4b4d0a..7471a3dbc93 100644 --- a/distributed/protocol/keras.py +++ b/distributed/protocol/keras.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - from .serialize import dask_serialize, dask_deserialize, serialize, deserialize import keras diff --git a/distributed/protocol/netcdf4.py b/distributed/protocol/netcdf4.py index e04864d2b73..eb83461eddc 100644 --- a/distributed/protocol/netcdf4.py +++ b/distributed/protocol/netcdf4.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - from .serialize import dask_serialize, dask_deserialize, serialize, deserialize import netCDF4 diff --git a/distributed/protocol/numpy.py b/distributed/protocol/numpy.py index b2375569ef6..c7e48e63b1a 100644 --- a/distributed/protocol/numpy.py +++ b/distributed/protocol/numpy.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import math import numpy as np diff --git a/distributed/protocol/pickle.py b/distributed/protocol/pickle.py index 080bb9037db..629fb962fbf 100644 --- a/distributed/protocol/pickle.py +++ b/distributed/protocol/pickle.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import logging import sys diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index 9e314703072..0069c6a264d 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -1,4 +1,3 @@ -from __future__ import print_function, division, absolute_import from functools import partial import traceback diff --git a/distributed/protocol/sparse.py b/distributed/protocol/sparse.py index b5a437a32a4..a22d661f849 100644 --- a/distributed/protocol/sparse.py +++ b/distributed/protocol/sparse.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - from .serialize import dask_serialize, dask_deserialize, serialize, deserialize import sparse diff --git a/distributed/protocol/tests/test_numpy.py b/distributed/protocol/tests/test_numpy.py index ed4e32c1137..b334683b661 100644 --- a/distributed/protocol/tests/test_numpy.py +++ b/distributed/protocol/tests/test_numpy.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import sys from zlib import crc32 diff --git a/distributed/protocol/tests/test_pandas.py b/distributed/protocol/tests/test_pandas.py index 8f5827f7896..104151fb55a 100644 --- a/distributed/protocol/tests/test_pandas.py +++ b/distributed/protocol/tests/test_pandas.py @@ -1,6 +1,3 @@ -from __future__ import print_function, division, absolute_import - - import pandas as pd import pandas.util.testing as tm import pytest diff --git a/distributed/protocol/tests/test_protocol.py b/distributed/protocol/tests/test_protocol.py index 2415e01b5f1..395c1ca7b97 100644 --- a/distributed/protocol/tests/test_protocol.py +++ b/distributed/protocol/tests/test_protocol.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import sys import dask diff --git a/distributed/protocol/tests/test_protocol_utils.py b/distributed/protocol/tests/test_protocol_utils.py index f4b98ab0e1d..d4250fb3c05 100644 --- a/distributed/protocol/tests/test_protocol_utils.py +++ b/distributed/protocol/tests/test_protocol_utils.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - from distributed.protocol.utils import merge_frames, pack_frames, unpack_frames from distributed.utils import ensure_bytes diff --git a/distributed/protocol/tests/test_serialize.py b/distributed/protocol/tests/test_serialize.py index 4f72ec9a538..09297793fc3 100644 --- a/distributed/protocol/tests/test_serialize.py +++ b/distributed/protocol/tests/test_serialize.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import copy import pickle diff --git a/distributed/protocol/utils.py b/distributed/protocol/utils.py index caf4bb8833b..68de0bebd32 100644 --- a/distributed/protocol/utils.py +++ b/distributed/protocol/utils.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import struct import msgpack diff --git a/distributed/pytest_resourceleaks.py b/distributed/pytest_resourceleaks.py index bb62d3916d0..0119a425722 100644 --- a/distributed/pytest_resourceleaks.py +++ b/distributed/pytest_resourceleaks.py @@ -2,8 +2,6 @@ """ A pytest plugin to trace resource leaks. """ -from __future__ import print_function, division - import collections import gc import time diff --git a/distributed/queues.py b/distributed/queues.py index b97c317ac58..7174c48a63c 100644 --- a/distributed/queues.py +++ b/distributed/queues.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - from collections import defaultdict import datetime import logging diff --git a/distributed/recreate_exceptions.py b/distributed/recreate_exceptions.py index d02dc4d94f4..9138c1fca5a 100644 --- a/distributed/recreate_exceptions.py +++ b/distributed/recreate_exceptions.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import logging from .client import futures_of, wait from .utils import sync, tokey diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 8eec8744849..4e769ad420c 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import asyncio from collections import defaultdict, deque, OrderedDict, Mapping, Set from datetime import timedelta diff --git a/distributed/sizeof.py b/distributed/sizeof.py index 0bc094e35a7..bc51b3603ae 100644 --- a/distributed/sizeof.py +++ b/distributed/sizeof.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import logging from dask.sizeof import sizeof diff --git a/distributed/stealing.py b/distributed/stealing.py index afcdf2a1cfa..e3537f647bf 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - from collections import defaultdict, deque import logging from math import log diff --git a/distributed/submit.py b/distributed/submit.py index f7e0a2f70aa..4cd7fb197a9 100644 --- a/distributed/submit.py +++ b/distributed/submit.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import logging import os import socket diff --git a/distributed/system_monitor.py b/distributed/system_monitor.py index 30efc3ceb87..5b3bed3f98d 100644 --- a/distributed/system_monitor.py +++ b/distributed/system_monitor.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - from collections import deque import psutil diff --git a/distributed/tests/test_asyncprocess.py b/distributed/tests/test_asyncprocess.py index 3cb3eee14d4..e496b35cb90 100644 --- a/distributed/tests/test_asyncprocess.py +++ b/distributed/tests/test_asyncprocess.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - from datetime import timedelta import gc import os diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 5e2ade3f247..99e626de7fd 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import asyncio from collections import deque from concurrent.futures import CancelledError diff --git a/distributed/tests/test_client_executor.py b/distributed/tests/test_client_executor.py index a7f10491efb..7d08a63c5b2 100644 --- a/distributed/tests/test_client_executor.py +++ b/distributed/tests/test_client_executor.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import random import time diff --git a/distributed/tests/test_collections.py b/distributed/tests/test_collections.py index 7cb509f6ac7..7fe8467b14b 100644 --- a/distributed/tests/test_collections.py +++ b/distributed/tests/test_collections.py @@ -1,6 +1,3 @@ -from __future__ import print_function, division, absolute_import - - import pytest pytest.importorskip("numpy") diff --git a/distributed/tests/test_config.py b/distributed/tests/test_config.py index cdd4070f7bb..2017bb239f7 100644 --- a/distributed/tests/test_config.py +++ b/distributed/tests/test_config.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import logging import subprocess import sys diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index 82e4c709be5..f91b8b64367 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - from contextlib import contextmanager import os import socket diff --git a/distributed/tests/test_counter.py b/distributed/tests/test_counter.py index 956a682920c..bb38a2812e5 100644 --- a/distributed/tests/test_counter.py +++ b/distributed/tests/test_counter.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import pytest from distributed.counter import Counter diff --git a/distributed/tests/test_diskutils.py b/distributed/tests/test_diskutils.py index 561a4cd408b..c5cca9d5824 100644 --- a/distributed/tests/test_diskutils.py +++ b/distributed/tests/test_diskutils.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import functools import gc import os diff --git a/distributed/tests/test_failed_workers.py b/distributed/tests/test_failed_workers.py index 1f27e067058..27bce439da4 100644 --- a/distributed/tests/test_failed_workers.py +++ b/distributed/tests/test_failed_workers.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - from concurrent.futures import CancelledError import os import random diff --git a/distributed/tests/test_ipython.py b/distributed/tests/test_ipython.py index a6f88ec5241..8f2a40e45eb 100644 --- a/distributed/tests/test_ipython.py +++ b/distributed/tests/test_ipython.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import mock import pytest diff --git a/distributed/tests/test_locks.py b/distributed/tests/test_locks.py index 9fa9a73787a..521a9b46114 100644 --- a/distributed/tests/test_locks.py +++ b/distributed/tests/test_locks.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import pickle from time import sleep diff --git a/distributed/tests/test_metrics.py b/distributed/tests/test_metrics.py index cdb4b8ee478..3a27e638ef3 100644 --- a/distributed/tests/test_metrics.py +++ b/distributed/tests/test_metrics.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import sys import threading import time diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index 579af8dbc2c..bd8a284df54 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import gc import logging import os diff --git a/distributed/tests/test_queues.py b/distributed/tests/test_queues.py index a28d1e29082..817bfcbcea5 100644 --- a/distributed/tests/test_queues.py +++ b/distributed/tests/test_queues.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - from datetime import timedelta from time import sleep diff --git a/distributed/tests/test_resources.py b/distributed/tests/test_resources.py index b3f5db36a76..648a191224e 100644 --- a/distributed/tests/test_resources.py +++ b/distributed/tests/test_resources.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - from time import time from dask import delayed diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 6401cdd4b94..80cc04c81b3 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import cloudpickle import pickle from collections import defaultdict diff --git a/distributed/tests/test_security.py b/distributed/tests/test_security.py index 28438c6f359..bfc8358acf1 100644 --- a/distributed/tests/test_security.py +++ b/distributed/tests/test_security.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - from contextlib import contextmanager import sys diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 45a110bbecf..d7c396bb63f 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import itertools from operator import mul import random @@ -114,8 +112,8 @@ def test_worksteal_many_thieves(c, s, *workers): def test_dont_steal_unknown_functions(c, s, a, b): futures = c.map(inc, [1, 2], workers=a.address, allow_other_workers=True) yield wait(futures) - assert len(a.data) == 2 - assert len(b.data) == 0 + assert len(a.data) == 2, [len(a.data), len(b.data)] + assert len(b.data) == 0, [len(a.data), len(b.data)] @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2) diff --git a/distributed/tests/test_stress.py b/distributed/tests/test_stress.py index b5f51359239..db91ec0c004 100644 --- a/distributed/tests/test_stress.py +++ b/distributed/tests/test_stress.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - from concurrent.futures import CancelledError from datetime import timedelta from operator import add diff --git a/distributed/tests/test_submit_cli.py b/distributed/tests/test_submit_cli.py index 04267a28e2b..9273261dc94 100644 --- a/distributed/tests/test_submit_cli.py +++ b/distributed/tests/test_submit_cli.py @@ -1,4 +1,3 @@ -from __future__ import print_function, division, absolute_import from mock import Mock from tornado import gen diff --git a/distributed/tests/test_system_monitor.py b/distributed/tests/test_system_monitor.py index f42fb8e3e08..f615549a686 100644 --- a/distributed/tests/test_system_monitor.py +++ b/distributed/tests/test_system_monitor.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - from time import sleep from distributed.system_monitor import SystemMonitor diff --git a/distributed/tests/test_tls_functional.py b/distributed/tests/test_tls_functional.py index 43c8c667bf4..6d0e64b54e5 100644 --- a/distributed/tests/test_tls_functional.py +++ b/distributed/tests/test_tls_functional.py @@ -2,10 +2,6 @@ Various functional tests for TLS networking. Most are taken from other test files and adapted. """ - -from __future__ import print_function, division, absolute_import - - from tornado import gen from distributed import Nanny, worker_client, Queue diff --git a/distributed/tests/test_utils.py b/distributed/tests/test_utils.py index c547834626d..590f8c877b7 100644 --- a/distributed/tests/test_utils.py +++ b/distributed/tests/test_utils.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import datetime from functools import partial import io diff --git a/distributed/tests/test_utils_comm.py b/distributed/tests/test_utils_comm.py index c9750891dd7..224b4b7f181 100644 --- a/distributed/tests/test_utils_comm.py +++ b/distributed/tests/test_utils_comm.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import pytest from distributed.core import rpc diff --git a/distributed/tests/test_utils_perf.py b/distributed/tests/test_utils_perf.py index 95fa816a75b..4256548900c 100644 --- a/distributed/tests/test_utils_perf.py +++ b/distributed/tests/test_utils_perf.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import contextlib import gc import itertools diff --git a/distributed/tests/test_utils_test.py b/distributed/tests/test_utils_test.py index 05b8066c707..1c6802b5637 100755 --- a/distributed/tests/test_utils_test.py +++ b/distributed/tests/test_utils_test.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - from contextlib import contextmanager import socket import sys diff --git a/distributed/tests/test_variable.py b/distributed/tests/test_variable.py index e734cc3094f..88f96a241b0 100644 --- a/distributed/tests/test_variable.py +++ b/distributed/tests/test_variable.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import random from time import sleep import sys diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index c7337d36424..bacd169d35c 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - from concurrent.futures import ThreadPoolExecutor from datetime import timedelta import importlib diff --git a/distributed/tests/test_worker_client.py b/distributed/tests/test_worker_client.py index fe1d49def6d..14a2d30f7d5 100644 --- a/distributed/tests/test_worker_client.py +++ b/distributed/tests/test_worker_client.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import random import threading from time import sleep diff --git a/distributed/threadpoolexecutor.py b/distributed/threadpoolexecutor.py index f4cae3fd88e..44770900028 100644 --- a/distributed/threadpoolexecutor.py +++ b/distributed/threadpoolexecutor.py @@ -20,8 +20,6 @@ Copyright 2001-2016 Python Software Foundation; All Rights Reserved """ -from __future__ import print_function, division, absolute_import - from . import _concurrent_futures_thread as thread import os import logging diff --git a/distributed/utils.py b/distributed/utils.py index ab45350bf61..5cc5d414d3a 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import asyncio import atexit from collections import deque diff --git a/distributed/utils_comm.py b/distributed/utils_comm.py index af393cbd79e..f6b4ea36e4f 100644 --- a/distributed/utils_comm.py +++ b/distributed/utils_comm.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import asyncio from collections import defaultdict from itertools import cycle diff --git a/distributed/utils_perf.py b/distributed/utils_perf.py index eb54ea0b381..048d9092d49 100644 --- a/distributed/utils_perf.py +++ b/distributed/utils_perf.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - from collections import deque import gc import logging diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 0c7c8958a91..8293bb474e2 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import asyncio import collections from contextlib import contextmanager diff --git a/distributed/variable.py b/distributed/variable.py index 30ffc5bf72d..2169c287f61 100644 --- a/distributed/variable.py +++ b/distributed/variable.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import asyncio from collections import defaultdict import logging diff --git a/distributed/versions.py b/distributed/versions.py index d6a44096796..a769c9ab032 100644 --- a/distributed/versions.py +++ b/distributed/versions.py @@ -1,7 +1,5 @@ """ utilities for package version introspection """ -from __future__ import print_function, division, absolute_import - import platform import struct import os diff --git a/distributed/worker.py b/distributed/worker.py index ae7ee648ff9..b32ab2e52a4 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - import asyncio import bisect from collections import defaultdict, deque, MutableMapping diff --git a/distributed/worker_client.py b/distributed/worker_client.py index ff6294430b5..a45eb891f7d 100644 --- a/distributed/worker_client.py +++ b/distributed/worker_client.py @@ -1,5 +1,3 @@ -from __future__ import print_function, division, absolute_import - from contextlib import contextmanager import warnings From 4260790fc8c685cc306390f533fea8709b159830 Mon Sep 17 00:00:00 2001 From: Christian Hudon Date: Mon, 29 Jul 2019 18:52:30 -0400 Subject: [PATCH 0386/1550] Use click's show_default=True in relevant places (#2838) * Use click's show_default=True in relevant places * Make black happy... by running black * Tweak Click default help text for bool options --- distributed/cli/dask_remote.py | 4 +++- distributed/cli/dask_scheduler.py | 6 +++--- distributed/cli/dask_ssh.py | 16 ++++++++++++---- distributed/cli/dask_worker.py | 11 ++++++----- 4 files changed, 24 insertions(+), 13 deletions(-) diff --git a/distributed/cli/dask_remote.py b/distributed/cli/dask_remote.py index 3118da84ae7..9fcfe7f3763 100644 --- a/distributed/cli/dask_remote.py +++ b/distributed/cli/dask_remote.py @@ -5,7 +5,9 @@ @click.command() @click.option("--host", type=str, default=None, help="IP or hostname of this server") -@click.option("--port", type=int, default=8788, help="Remote Client Port") +@click.option( + "--port", type=int, default=8788, show_default=True, help="Remote Client Port" +) @click.version_option() def main(host, port): _remote(host, port) diff --git a/distributed/cli/dask_scheduler.py b/distributed/cli/dask_scheduler.py index 07e4c98e267..a74f76102b9 100755 --- a/distributed/cli/dask_scheduler.py +++ b/distributed/cli/dask_scheduler.py @@ -65,15 +65,15 @@ "--dashboard-address", type=str, default=":8787", + show_default=True, help="Address on which to listen for diagnostics dashboard", ) @click.option( "--dashboard/--no-dashboard", "dashboard", default=True, - show_default=True, required=False, - help="Launch the Dashboard", + help="Launch the Dashboard [default: --dashboard]", ) @click.option( "--bokeh/--no-bokeh", @@ -82,7 +82,7 @@ required=False, help="Deprecated. See --dashboard/--no-dashboard.", ) -@click.option("--show/--no-show", default=False, help="Show web UI") +@click.option("--show/--no-show", default=False, help="Show web UI [default: --show]") @click.option( "--dashboard-prefix", type=str, default=None, help="Prefix for the dashboard app" ) diff --git a/distributed/cli/dask_ssh.py b/distributed/cli/dask_ssh.py index 389e0327688..97cf91f3519 100755 --- a/distributed/cli/dask_ssh.py +++ b/distributed/cli/dask_ssh.py @@ -20,8 +20,9 @@ @click.option( "--scheduler-port", default=8786, + show_default=True, type=int, - help="Specify scheduler port number. Defaults to port 8786.", + help="Specify scheduler port number.", ) @click.option( "--nthreads", @@ -36,8 +37,9 @@ @click.option( "--nprocs", default=1, + show_default=True, type=int, - help="Number of worker processes per host. Defaults to one.", + help="Number of worker processes per host.", ) @click.argument("hostnames", nargs=-1, type=str) @click.option( @@ -53,7 +55,11 @@ help="Username to use when establishing SSH connections.", ) @click.option( - "--ssh-port", default=22, type=int, help="Port to use for SSH connections." + "--ssh-port", + default=22, + type=int, + show_default=True, + help="Port to use for SSH connections.", ) @click.option( "--ssh-private-key", @@ -77,6 +83,7 @@ @click.option( "--memory-limit", default="auto", + show_default=True, help="Bytes of memory that the worker can use. " "This can be an integer (bytes), " "float (fraction of total system memory), " @@ -95,8 +102,9 @@ @click.option( "--remote-dask-worker", default="distributed.cli.dask_worker", + show_default=True, type=str, - help="Worker to run. Defaults to distributed.cli.dask_worker", + help="Worker to run.", ) @click.pass_context @click.version_option() diff --git a/distributed/cli/dask_worker.py b/distributed/cli/dask_worker.py index 44931393522..953d2e26fab 100755 --- a/distributed/cli/dask_worker.py +++ b/distributed/cli/dask_worker.py @@ -70,9 +70,8 @@ "--dashboard/--no-dashboard", "dashboard", default=True, - show_default=True, required=False, - help="Launch the Dashboard", + help="Launch the Dashboard [default: --dashboard]", ) @click.option( "--bokeh/--no-bokeh", @@ -116,7 +115,8 @@ "--nprocs", type=int, default=1, - help="Number of worker processes to launch. Defaults to one.", + show_default=True, + help="Number of worker processes to launch.", ) @click.option( "--name", @@ -129,6 +129,7 @@ @click.option( "--memory-limit", default="auto", + show_default=True, help="Bytes of memory per process that the worker can use. " "This can be an integer (bytes), " "float (fraction of total system memory), " @@ -138,12 +139,12 @@ @click.option( "--reconnect/--no-reconnect", default=True, - help="Reconnect to scheduler if disconnected", + help="Reconnect to scheduler if disconnected [default: --reconnect]", ) @click.option( "--nanny/--no-nanny", default=True, - help="Start workers in nanny process for management", + help="Start workers in nanny process for management [default: --nanny]", ) @click.option("--pid-file", type=str, default="", help="File to write the process PID") @click.option( From 051a79e05c501dcea57a19715f59d0cd203f6be6 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 30 Jul 2019 11:27:23 -0700 Subject: [PATCH 0387/1550] Close workers more gracefully (#2905) This commit does two things: 1. We wait to shutdown the executor a little longer in case it is still in use 2. The worker no longer asks the Nanny to terminate it. Instead it asks the nanny to shutdown gracefully after it is gone, and then continues closing itself as normal. --- distributed/nanny.py | 3 +++ distributed/tests/test_nanny.py | 17 ++++++++++++++++- distributed/worker.py | 26 ++++++++++++-------------- 3 files changed, 31 insertions(+), 15 deletions(-) diff --git a/distributed/nanny.py b/distributed/nanny.py index 771b2d11d2d..e90b4ff33f8 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -380,6 +380,9 @@ async def _on_exit(self, exitcode): if self.auto_restart: logger.warning("Restarting worker") await self.instantiate() + elif self.status == "closing-gracefully": + await self.close() + except Exception: logger.error( "Failed to restart worker after its process exited", exc_info=True diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index bd8a284df54..6187954e6be 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -13,7 +13,7 @@ from tornado.ioloop import IOLoop import dask -from distributed import Nanny, rpc, Scheduler, Worker +from distributed import Nanny, rpc, Scheduler, Worker, Client from distributed.core import CommClosedError from distributed.metrics import time from distributed.protocol.pickle import dumps @@ -396,3 +396,18 @@ async def test_nanny_closes_cleanly(cleanup): assert not n.process assert not proc.is_alive() assert proc.exitcode == 0 + + +@pytest.mark.asyncio +async def test_nanny_closes_cleanly(cleanup): + async with Scheduler() as s: + async with Nanny(s.address) as n: + async with Client(s.address, asynchronous=True) as client: + with client.rpc(n.worker_address) as w: + IOLoop.current().add_callback(w.terminate) + start = time() + while n.status != "closed": + await gen.sleep(0.01) + assert time() < start + 5 + + assert n.status == "closed" diff --git a/distributed/worker.py b/distributed/worker.py index b32ab2e52a4..67f753ad978 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -575,7 +575,7 @@ def __init__( "get_data": self.get_data, "update_data": self.update_data, "delete_data": self.delete_data, - "terminate": self.terminate, + "terminate": self.close, "ping": pingpong, "upload_file": self.upload_file, "start_ipython": self.start_ipython, @@ -830,7 +830,7 @@ async def handle_scheduler(self, comm): logger.exception(e) raise finally: - if self.reconnect: + if self.reconnect and self.status == "running": logger.info("Connection to scheduler broken. Reconnecting...") self.loop.add_callback(self._register_with_scheduler) else: @@ -996,13 +996,6 @@ async def close(self, report=True, timeout=10, nanny=True, executor_wait=True): self.scheduler.unregister(address=self.contact_address), ) self.scheduler.close_rpc() - self.actor_executor._work_queue.queue.clear() - if isinstance(self.executor, ThreadPoolExecutor): - self.executor._work_queue.queue.clear() - self.executor.shutdown(wait=executor_wait, timeout=timeout) - else: - self.executor.shutdown(wait=False) - self.actor_executor.shutdown(wait=executor_wait, timeout=timeout) self._workdir.release() for k, v in self.services.items(): @@ -1014,9 +1007,13 @@ async def close(self, report=True, timeout=10, nanny=True, executor_wait=True): if self.batched_stream: self.batched_stream.close() - if nanny and self.nanny: - with self.rpc(self.nanny) as r: - await r.terminate() + self.actor_executor._work_queue.queue.clear() + if isinstance(self.executor, ThreadPoolExecutor): + self.executor._work_queue.queue.clear() + self.executor.shutdown(wait=executor_wait, timeout=timeout) + else: + self.executor.shutdown(wait=False) + self.actor_executor.shutdown(wait=executor_wait, timeout=timeout) self.stop() self.rpc.close() @@ -1026,9 +1023,10 @@ async def close(self, report=True, timeout=10, nanny=True, executor_wait=True): await ServerNode.close(self) setproctitle("dask-worker [closed]") + return "OK" - async def terminate(self, comm, report=True): - await self.close(report=report) + async def terminate(self, comm, report=True, **kwargs): + await self.close(report=report, **kwargs) return "OK" async def wait_until_closed(self): From 5f120437a9f6101597b00bd2d5dfba8c851692d3 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 31 Jul 2019 07:37:48 -0700 Subject: [PATCH 0388/1550] Close workers gracefully with --lifetime keywords (#2892) This allows workers to optionally terminate themselves gracefully after a predetermined time. This can be helpful in a few contexts: 1. We receive a SIGINT, and know that we need to clean up quickly (though note that the signal handlers are not implemented as part of this commit 2. We know that we'll be kicked off at a certain time, such as in one hour from now, as is often specified by HPC job schedulers 3. We just want to refresh our workers every once in a while, because we know that our code leaks some memory . Fixes https://github.com/dask/distributed/issues/2861 This is configurable as keywords to the `Worker` or `Nanny` classes, in config values, or with CLI. Here is an example with CLI. ### Restart to clear state ``` dask-worker scheduler:8786 --lifetime 1hr --lifetime-restart --lifetime-stagger 5m ``` This will kill the worker roughly 1 hour from now +- a range of 5 minutes (to avoid killing all of our workers at the same time). It will also allow that worker to be restarted afterwards ### Restart to avoid walltime death ``` dask-worker scheduler:8786 --lifetime 58m ``` Here we don't try to restart the worker (no point) and we choose a time a bit before our 60m walltime. --- distributed/cli/dask_worker.py | 24 ++++++++++- distributed/client.py | 6 ++- distributed/diagnostics/__init__.py | 1 + distributed/distributed.yaml | 5 ++- distributed/nanny.py | 26 +++++++----- distributed/tests/test_client.py | 2 + distributed/tests/test_nanny.py | 25 +++++++++++ distributed/tests/test_worker.py | 39 ++++++++++++++++++ distributed/worker.py | 64 ++++++++++++++++++++++++++--- 9 files changed, 173 insertions(+), 19 deletions(-) diff --git a/distributed/cli/dask_worker.py b/distributed/cli/dask_worker.py index 953d2e26fab..eef1d648d40 100755 --- a/distributed/cli/dask_worker.py +++ b/distributed/cli/dask_worker.py @@ -174,6 +174,28 @@ @click.option( "--dashboard-prefix", type=str, default="", help="Prefix for the dashboard" ) +@click.option( + "--lifetime", + type=str, + default="", + help="If provided, shut down the worker after this duration.", +) +@click.option( + "--lifetime-stagger", + type=str, + default="0 seconds", + show_default=True, + help="Random amount by which to stagger lifetime values", +) +@click.option( + "--lifetime-restart/--no-lifetime-restart", + "lifetime_restart", + default=False, + show_default=True, + required=False, + help="Whether or not to restart the worker after the lifetime lapses. " + "This assumes that you are using the --lifetime and --nanny keywords", +) @click.option( "--preload", type=str, @@ -346,7 +368,7 @@ def del_pid_file(): dashboard_address=dashboard_address if dashboard else None, service_kwargs={"dashboard": {"prefix": dashboard_prefix}}, name=name if nprocs == 1 or not name else name + "-" + str(i), - **kwargs, + **kwargs ) for i in range(nprocs) ] diff --git a/distributed/client.py b/distributed/client.py index d9f02ba80f8..0f7fe3d046b 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -953,7 +953,11 @@ async def _start(self, timeout=no_default, **kwargs): self.scheduler = self.rpc(address) self.scheduler_comm = None - await self._ensure_connected(timeout=timeout) + try: + await self._ensure_connected(timeout=timeout) + except OSError: + await self._close() + raise for pc in self._periodic_callbacks.values(): pc.start() diff --git a/distributed/diagnostics/__init__.py b/distributed/diagnostics/__init__.py index 2ab9fac731f..337f41b7598 100644 --- a/distributed/diagnostics/__init__.py +++ b/distributed/diagnostics/__init__.py @@ -1,5 +1,6 @@ from ..utils import ignoring from .graph_layout import GraphLayout +from .plugin import SchedulerPlugin with ignoring(ImportError): from .progressbar import progress diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index c3f14f114f1..9ad3e365e78 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -33,7 +33,6 @@ distributed: key: null cert: null - worker: blocked-handlers: [] multiprocessing-method: forkserver @@ -44,6 +43,10 @@ distributed: preload: [] preload-argv: [] daemon: True + lifetime: + duration: null # Time after which to gracefully shutdown the worker + stagger: 0 seconds # Random amount by which to stagger lifetimes + restart: False # Do we ressurrect the worker after the lifetime deadline? profile: interval: 10ms # Time between statistical profiling queries diff --git a/distributed/nanny.py b/distributed/nanny.py index e90b4ff33f8..155fde98158 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -677,17 +677,21 @@ async def run(): init_result_q.put({"uid": uid, "exception": e}) init_result_q.close() else: - assert worker.address - init_result_q.put( - { - "address": worker.address, - "dir": worker.local_directory, - "uid": uid, - } - ) - init_result_q.close() - await worker.wait_until_closed() - logger.info("Worker closed") + try: + assert worker.address + except ValueError: + pass + else: + init_result_q.put( + { + "address": worker.address, + "dir": worker.local_directory, + "uid": uid, + } + ) + init_result_q.close() + await worker.wait_until_closed() + logger.info("Worker closed") try: loop.run_sync(run) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 99e626de7fd..5c86727c043 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -3318,6 +3318,7 @@ def test_bad_tasks_fail(c, s, a, b): yield f assert info.value.last_worker.nanny in {a.address, b.address} + yield [a.close(), b.close()] def test_get_processing_sync(c, s, a, b): @@ -5233,6 +5234,7 @@ def test_client_timeout_2(): yield c stop = time() + assert c.status == "closed" yield c.close() assert stop - start < 1 diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index 6187954e6be..dec6bd91b20 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -11,8 +11,10 @@ from toolz import valmap, first from tornado import gen from tornado.ioloop import IOLoop +from tornado.locks import Event import dask +from distributed.diagnostics import SchedulerPlugin from distributed import Nanny, rpc, Scheduler, Worker, Client from distributed.core import CommClosedError from distributed.metrics import time @@ -398,6 +400,29 @@ async def test_nanny_closes_cleanly(cleanup): assert proc.exitcode == 0 +@pytest.mark.slow +@pytest.mark.asyncio +async def test_lifetime(cleanup): + counter = 0 + event = Event() + + class Plugin(SchedulerPlugin): + def add_worker(self, **kwargs): + pass + + def remove_worker(self, **kwargs): + nonlocal counter + counter += 1 + if counter == 2: # wait twice, then trigger closing event + event.set() + + async with Scheduler() as s: + s.add_plugin(Plugin()) + async with Nanny(s.address) as a: + async with Nanny(s.address, lifetime="500 ms", lifetime_restart=True) as b: + await event.wait() + + @pytest.mark.asyncio async def test_nanny_closes_cleanly(cleanup): async with Scheduler() as s: diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index bacd169d35c..52e92d474ce 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1497,3 +1497,42 @@ async def test_worker_listens_on_same_interface_by_default(Worker): assert s.ip in {"127.0.0.1", "localhost"} async with Worker(s.address) as w: assert s.ip == w.ip + + +@gen_cluster(client=True) +async def test_close_gracefully(c, s, a, b): + futures = c.map(slowinc, range(200), delay=0.1) + while not b.data: + await gen.sleep(0.1) + + mem = set(b.data) + proc = set(b.executing) + + await b.close_gracefully() + + assert b.status == "closed" + assert b.address not in s.workers + assert mem.issubset(set(a.data)) + for key in proc: + assert s.tasks[key].state in ("processing", "memory") + + +@pytest.mark.slow +@pytest.mark.asyncio +async def test_lifetime(cleanup): + async with Scheduler() as s: + async with Worker(s.address) as a, Worker(s.address, lifetime="1 seconds") as b: + async with Client(s.address, asynchronous=True) as c: + futures = c.map(slowinc, range(200), delay=0.1) + await gen.sleep(1.5) + assert b.status != "running" + await b.finished() + + assert set(b.data).issubset(a.data) # successfully moved data over + + +@gen_cluster(client=True, worker_kwargs={"lifetime": "10s", "lifetime_stagger": "2s"}) +async def test_lifetime_stagger(c, s, a, b): + assert a.lifetime != b.lifetime + assert 8 <= a.lifetime <= 12 + assert 8 <= b.lifetime <= 12 diff --git a/distributed/worker.py b/distributed/worker.py index 67f753ad978..5648e67e353 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -249,6 +249,16 @@ class Worker(ServerNode): Resources that this worker has like ``{'GPU': 2}`` nanny: str Address on which to contact nanny, if it exists + lifetime: str + Amount of time like "1 hour" after which we gracefully shut down the worker. + This defaults to None, meaning no explicit shutdown time. + lifetime_stagger: str + Amount of time like "5 minutes" to stagger the lifetime value + The actual lifetime will be selected uniformly at random between + lifetime +/- lifetime_stagger + lifetime_restart: bool + Whether or not to restart a worker after it has reached its lifetime + Default False Examples -------- @@ -308,6 +318,9 @@ def __init__( low_level_profiler=dask.config.get("distributed.worker.profile.low-level"), validate=False, profile_cycle_interval=None, + lifetime=None, + lifetime_stagger=None, + lifetime_restart=None, **kwargs ): self.tasks = dict() @@ -653,6 +666,23 @@ def __init__( self.plugins = {} self._pending_plugins = plugins + self.lifetime = lifetime or dask.config.get( + "distributed.worker.lifetime.duration" + ) + lifetime_stagger = lifetime_stagger or dask.config.get( + "distributed.worker.lifetime.stagger" + ) + self.lifetime_restart = lifetime_restart or dask.config.get( + "distributed.worker.lifetime.restart" + ) + if isinstance(self.lifetime, str): + self.lifetime = parse_timedelta(self.lifetime) + if isinstance(lifetime_stagger, str): + lifetime_stagger = parse_timedelta(lifetime_stagger) + if self.lifetime: + self.lifetime += (random.random() * 2 - 1) * lifetime_stagger + self.io_loop.call_later(self.lifetime, self.close_gracefully) + Worker._instances.add(self) ################## @@ -903,6 +933,8 @@ async def gather(self, comm=None, who_has=None): ############# async def start(self): + if self.status and self.status.startswith("clos"): + return assert self.status is None, self.status enable_gc_diagnosis() @@ -957,19 +989,22 @@ def _close(self, *args, **kwargs): warnings.warn("Worker._close has moved to Worker.close", stacklevel=2) return self.close(*args, **kwargs) - async def close(self, report=True, timeout=10, nanny=True, executor_wait=True): + async def close( + self, report=True, timeout=10, nanny=True, executor_wait=True, safe=False + ): with log_errors(): if self.status in ("closed", "closing"): await self.finished() return + self.reconnect = False disable_gc_diagnosis() try: logger.info("Stopping worker at %s", self.address) except ValueError: # address not available if already closed logger.info("Stopping worker") - if self.status != "running": + if self.status not in ("running", "closing-gracefully"): logger.info("Closed worker has not yet started: %s", self.status) self.status = "closing" @@ -993,7 +1028,9 @@ async def close(self, report=True, timeout=10, nanny=True, executor_wait=True): if report: await gen.with_timeout( timedelta(seconds=timeout), - self.scheduler.unregister(address=self.contact_address), + self.scheduler.unregister( + address=self.contact_address, safe=safe + ), ) self.scheduler.close_rpc() self._workdir.release() @@ -1025,6 +1062,23 @@ async def close(self, report=True, timeout=10, nanny=True, executor_wait=True): setproctitle("dask-worker [closed]") return "OK" + async def close_gracefully(self): + """ Gracefully shut down a worker + + This first informs the scheduler that we're shutting down, and asks it + to move our data elsewhere. Afterwards, we close as normal + """ + if self.status.startswith("closing"): + await self.finished() + + if self.status == "closed": + return + + logger.info("Closing worker gracefully: %s", self.address) + self.status = "closing-gracefully" + await self.scheduler.retire_workers(workers=[self.address], remove=False) + await self.close(safe=True, nanny=not self.lifetime_restart) + async def terminate(self, comm, report=True, **kwargs): await self.close(report=report, **kwargs) return "OK" @@ -1541,7 +1595,7 @@ def transition_executing_done(self, key, value=no_value, report=True): if key in self.dep_state: self.transition_dep(key, "memory") - if report and self.batched_stream: + if report and self.batched_stream and self.status == "running": self.send_task_state_to_scheduler(key) else: raise CommClosedError @@ -2278,7 +2332,7 @@ def ensure_computing(self): async def execute(self, key, report=False): executor_error = None - if self.status in ("closing", "closed"): + if self.status in ("closing", "closed", "closing-gracefully"): return try: if key not in self.executing or key not in self.task_state: From f17ae0123735038eb0de355c4ce87cfb24b5221a Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 31 Jul 2019 12:25:32 -0700 Subject: [PATCH 0389/1550] Add closing
      • tags to Client._repr_html_ (#2911) * Add is_valid_xml function * Add closing
      • tags to Client._repr_html_ --- distributed/client.py | 4 ++-- .../dashboard/tests/test_scheduler_bokeh_html.py | 6 +++--- distributed/deploy/spec.py | 6 +++--- distributed/tests/test_client.py | 11 ++++++++++- distributed/tests/test_utils.py | 10 ++++++++-- distributed/utils.py | 5 +++++ 6 files changed, 31 insertions(+), 11 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 0f7fe3d046b..fab5ff0bf6e 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -798,9 +798,9 @@ def _repr_html_(self): '
          \n' ) if scheduler is not None: - text += "
        • Scheduler: %s\n" % scheduler.address + text += "
        • Scheduler: %s
        • \n" % scheduler.address else: - text += "
        • Scheduler: not connected\n" + text += "
        • Scheduler: not connected
        • \n" if info and "dashboard" in info["services"]: protocol, rest = scheduler.address.split("://") diff --git a/distributed/dashboard/tests/test_scheduler_bokeh_html.py b/distributed/dashboard/tests/test_scheduler_bokeh_html.py index fc19efb0812..660602df09a 100644 --- a/distributed/dashboard/tests/test_scheduler_bokeh_html.py +++ b/distributed/dashboard/tests/test_scheduler_bokeh_html.py @@ -1,6 +1,5 @@ import json import re -import xml.etree.ElementTree import pytest @@ -10,6 +9,7 @@ from tornado.httpclient import AsyncHTTPClient from dask.sizeof import sizeof +from distributed.utils import is_valid_xml from distributed.utils_test import gen_cluster, slowinc, inc from distributed.dashboard import BokehScheduler, BokehWorker @@ -45,7 +45,7 @@ def test_connect(c, s, a, b): if suffix.endswith(".json"): json.loads(body) else: - assert xml.etree.ElementTree.fromstring(body) is not None + assert is_valid_xml(body) assert not re.search("href=./", body) # no absolute links @@ -66,7 +66,7 @@ def test_prefix(c, s, a, b): if suffix.endswith(".json"): json.loads(body) else: - assert xml.etree.ElementTree.fromstring(body) is not None + assert is_valid_xml(body) @gen_cluster( diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index 87228d5693e..c9e1aca1a87 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -492,9 +492,9 @@ def _widget_status(self): } - - - + + +
          Workers %d
          Cores %d
          Memory %s
          Workers %d
          Cores %d
          Memory %s
          """ % ( diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 5c86727c043..3eb9f39a2a4 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -54,7 +54,15 @@ from distributed.metrics import time from distributed.scheduler import Scheduler, KilledWorker from distributed.sizeof import sizeof -from distributed.utils import ignoring, mp_context, sync, tmp_text, tokey, tmpfile +from distributed.utils import ( + ignoring, + mp_context, + sync, + tmp_text, + tokey, + tmpfile, + is_valid_xml, +) from distributed.utils_test import ( cluster, slowinc, @@ -1904,6 +1912,7 @@ def test_repr_localcluster(): try: text = client._repr_html_() assert cluster.scheduler.address in text + assert is_valid_xml(client._repr_html_()) finally: yield client.close() yield cluster.close() diff --git a/distributed/tests/test_utils.py b/distributed/tests/test_utils.py index 590f8c877b7..737b015646b 100644 --- a/distributed/tests/test_utils.py +++ b/distributed/tests/test_utils.py @@ -6,7 +6,6 @@ import sys from time import sleep import traceback -import xml.etree.ElementTree import numpy as np import pytest @@ -21,6 +20,7 @@ Logs, sync, is_kernel, + is_valid_xml, ensure_ip, str_graph, truncate_exception, @@ -553,6 +553,12 @@ def test_logs(): d = Logs({"123": Log("Hello"), "456": Log("World!")}) text = d._repr_html_() for line in text.split("\n"): - assert xml.etree.ElementTree.fromstring(line) is not None + assert is_valid_xml(line) assert "Hello" in text assert "456" in text + + +def test_is_valid_xml(): + assert is_valid_xml("foo") + with pytest.raises(Exception): + assert is_valid_xml("foo") diff --git a/distributed/utils.py b/distributed/utils.py index 5cc5d414d3a..a8cddb81498 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -26,6 +26,7 @@ import pkgutil import six import tblib.pickling_support +import xml.etree.ElementTree try: import resource @@ -1466,3 +1467,7 @@ def convert_value(v): return sum( [["--" + k.replace("_", "-"), convert_value(v)] for k, v in d.items()], [] ) + + +def is_valid_xml(text): + return xml.etree.ElementTree.fromstring(text) is not None From a1e7b5212f8e96eaac6ebd4fbf45e24c4f50acc0 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 31 Jul 2019 15:51:14 -0700 Subject: [PATCH 0390/1550] Add endline spacing in Logs._repr_html_ (#2912) --- distributed/deploy/tests/test_spec_cluster.py | 2 ++ distributed/tests/test_utils.py | 3 +-- distributed/utils.py | 6 +++--- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/distributed/deploy/tests/test_spec_cluster.py b/distributed/deploy/tests/test_spec_cluster.py index 84c868b2585..58bbbaef44d 100644 --- a/distributed/deploy/tests/test_spec_cluster.py +++ b/distributed/deploy/tests/test_spec_cluster.py @@ -4,6 +4,7 @@ from dask.distributed import SpecCluster, Worker, Client, Scheduler, Nanny from distributed.deploy.spec import close_clusters, ProcessInterface from distributed.utils_test import loop, cleanup # noqa: F401 +from distributed.utils import is_valid_xml import toolz import pytest @@ -190,6 +191,7 @@ async def test_logs(cleanup): await cluster logs = await cluster.logs() + assert is_valid_xml("
          " + logs._repr_html_() + "
          ") assert "Scheduler" in logs for worker in cluster.scheduler.workers: assert worker in logs diff --git a/distributed/tests/test_utils.py b/distributed/tests/test_utils.py index 737b015646b..e5e18eb393c 100644 --- a/distributed/tests/test_utils.py +++ b/distributed/tests/test_utils.py @@ -552,8 +552,7 @@ def test_format_bytes_compat(): def test_logs(): d = Logs({"123": Log("Hello"), "456": Log("World!")}) text = d._repr_html_() - for line in text.split("\n"): - assert is_valid_xml(line) + assert is_valid_xml("
          " + text + "
          ") assert "Hello" in text assert "456" in text diff --git a/distributed/utils.py b/distributed/utils.py index a8cddb81498..cdc5c4d1ae9 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -1414,7 +1414,7 @@ class Log(str): """ A container for logs """ def _repr_html_(self): - return "
          {log}
          ".format(log=self) + return "
          \n{log}\n
          ".format(log=self.rstrip()) class Logs(dict): @@ -1422,12 +1422,12 @@ class Logs(dict): def _repr_html_(self): summaries = [ - "
          {title}{log}
          ".format( + "
          \n{title}\n{log}\n
          ".format( title=title, log=log._repr_html_() ) for title, log in self.items() ] - return "\n".join(summaries) + return "\n\n".join(summaries) def cli_keywords(d: dict, cls=None): From 58844d01259ca2455a15267ceb2573a827761745 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 31 Jul 2019 16:36:06 -0700 Subject: [PATCH 0391/1550] bump version to 2.2.0 --- docs/source/changelog.rst | 48 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index cb98ab79fc7..b7037d9d6c2 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -1,6 +1,51 @@ Changelog ========= +2.2.0 - 2019-07-31 +------------------ + +- Respect security configuration in LocalCluster (:pr:`2822`) `Russ Bubley`_ +- Add Nanny to worker docs (:pr:`2826`) `Christian Hudon`_ +- Don't make False add-keys report to scheduler (:pr:`2421`) `tjb900`_ +- Include type name in SpecCluster repr (:pr:`2834`) `Jacob Tomlinson`_ +- Extend prometheus metrics endpoint (:pr:`2833`) `Gabriel Sailer`_ +- Add alternative SSHCluster implementation (:pr:`2827`) `Matthew Rocklin`_ +- Dont reuse closed worker in get_worker (:pr:`2841`) `Pierre Glaser`_ +- SpecCluster: move init logic into start (:pr:`2850`) `Jacob Tomlinson`_ +- Document distributed.Reschedule in API docs (:pr:`2860`) `James Bourbeau`_ +- Add fsspec to installation of test builds (:pr:`2859`) `Martin Durant`_ +- Make await/start more consistent across Scheduler/Worker/Nanny (:pr:`2831`) `Matthew Rocklin`_ +- Add cleanup fixture for asyncio tests (:pr:`2866`) `Matthew Rocklin`_ +- Use only remote connection to scheduler in Adaptive (:pr:`2865`) `Matthew Rocklin`_ +- Add Server.finished async function (:pr:`2864`) `Matthew Rocklin`_ +- Align text and remove bullets in Client HTML repr (:pr:`2867`) `Matthew Rocklin`_ +- Test dask-scheduler --idle-timeout flag (:pr:`2862`) `Matthew Rocklin`_ +- Remove ``Client.upload_environment`` (:pr:`2877`) `Jim Crist`_ +- Replace gen.coroutine with async/await in core (:pr:`2871`) `Matthew Rocklin`_ +- Forcefully kill all processes before each test (:pr:`2882`) `Matthew Rocklin`_ +- Cleanup Security class and configuration (:pr:`2873`) `Jim Crist`_ +- Remove unused variable in SpecCluster scale down (:pr:`2870`) `Jacob Tomlinson`_ +- Add SpecCluster ProcessInterface (:pr:`2874`) `Jacob Tomlinson`_ +- Add Log(str) and Logs(dict) classes for nice HTML reprs (:pr:`2875`) `Jacob Tomlinson`_ +- Pass Client._asynchronous to Cluster._asynchronous (:pr:`2890`) `Matthew Rocklin`_ +- Add default logs method to Spec Cluster (:pr:`2889`) `Matthew Rocklin`_ +- Add processes keyword back into clean (:pr:`2891`) `Matthew Rocklin`_ +- Update black (:pr:`2901`) `Matthew Rocklin`_ +- Move Worker.local_dir attribute to Worker.local_directory (:pr:`2900`) `Matthew Rocklin`_ +- Link from TapTools to worker info pages in dashboard (:pr:`2894`) `Matthew Rocklin`_ +- Avoid exception in Client._ensure_connected if closed (:pr:`2893`) `Matthew Rocklin`_ +- Convert Pythonic kwargs to CLI Keywords for SSHCluster (:pr:`2898`) `Matthew Rocklin`_ +- Use kwargs in CLI (:pr:`2899`) `Matthew Rocklin`_ +- Name SSHClusters by providing name= keyword to SpecCluster (:pr:`2903`) `Matthew Rocklin`_ +- Request feed of worker information from Scheduler to SpecCluster (:pr:`2902`) `Matthew Rocklin`_ +- Clear out compatibillity file (:pr:`2896`) `Matthew Rocklin`_ +- Remove future imports (:pr:`2897`) `Matthew Rocklin`_ +- Use click's show_default=True in relevant places (:pr:`2838`) `Christian Hudon`_ +- Close workers more gracefully (:pr:`2905`) `Matthew Rocklin`_ +- Close workers gracefully with --lifetime keywords (:pr:`2892`) `Matthew Rocklin`_ +- Add closing
        • tags to Client._repr_html_ (:pr:`2911`) `Matthew Rocklin`_ +- Add endline spacing in Logs._repr_html_ (:pr:`2912`) `Matthew Rocklin`_ + 2.1.0 - 2019-07-08 ------------------ @@ -1125,3 +1170,6 @@ significantly without many new features. .. _`Magnus Nord`: https://github.com/magnunor .. _`Sam Grayson`: https://github.com/charmoniumQ .. _`Mathieu Dugré`: https://github.com/mathdugre +.. _`Christian Hudon`: https://github.com/chrish42 +.. _`Gabriel Sailer`: https://github.com/sublinus +.. _`Pierre Glaser`: https://github.com/pierreglase From e1e36e437a53a937d14519f06fced6b645cff361 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 1 Aug 2019 12:27:26 -0700 Subject: [PATCH 0392/1550] Call heartbeat rather than reconnect on disconnection (#2906) Fixes https://github.com/dask/distributed/issues/2525 This avoids a subtle race condition. --- distributed/scheduler.py | 9 +++------ distributed/worker.py | 2 +- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 4e769ad420c..acabce22c63 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1325,11 +1325,10 @@ def heartbeat_worker( ): address = self.coerce_address(address, resolve_address) address = normalize_address(address) - host = get_address_host(address) if address not in self.workers: - logger.info("Received heartbeat from removed worker: %s", address) - return + return {"status": "missing"} + host = get_address_host(address) local_now = time() now = now or time() metrics = metrics or {} @@ -1342,9 +1341,7 @@ def heartbeat_worker( except KeyError: pass - ws = self.workers.get(address) - if not ws: - return {"status": "missing"} + ws = self.workers[address] ws.last_seen = time() diff --git a/distributed/worker.py b/distributed/worker.py index 5648e67e353..672358d4460 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -862,7 +862,7 @@ async def handle_scheduler(self, comm): finally: if self.reconnect and self.status == "running": logger.info("Connection to scheduler broken. Reconnecting...") - self.loop.add_callback(self._register_with_scheduler) + self.loop.add_callback(self.heartbeat) else: await self.close(report=False) From ff3437c0a71f151ce1050a4e6a7177ab4a8a3b22 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 1 Aug 2019 13:47:09 -0700 Subject: [PATCH 0393/1550] Rewrite Adaptive/SpecCluster to support slowly arriving workers (#2904) Previously SpecCluster waited until all workers had checked in with the scheduler. This made sense for LocalCluster or SSHCluster because there isn't really a significant delay in starting things that we can't control. However, for other systems like dask-jobqueue or dask-kubernetes workers might not ever start, so we need a different system. Now, SpecCluster still awaits the Worker object that it is passed, but doesn't require that the worker has started in the scheduler. We now expect awaiting to mean *"We have successfully handed control of starting the worker to some other robust system"* Our job at this point is done. We hope that the worker arrives, but from our perspective this local Worker object is awaited and "running". This commit also includes a minimal example of a Worker class, `SlowWorker`, that serves as a nice minimal example for what SpecCluster expects. * Add AdaptiveCore class * Back Adaptive with AdaptiveCore * Include requested workers in widget * Use worker names throughout adaptive --- distributed/deploy/adaptive.py | 178 ++++------------ distributed/deploy/adaptive_core.py | 197 ++++++++++++++++++ distributed/deploy/spec.py | 26 ++- distributed/deploy/tests/test_adaptive.py | 120 +++++------ .../deploy/tests/test_adaptive_core.py | 90 ++++++++ .../deploy/tests/test_slow_adaptive.py | 98 +++++++++ distributed/deploy/tests/test_spec_cluster.py | 8 +- distributed/nanny.py | 2 +- distributed/scheduler.py | 61 ++++-- distributed/tests/test_scheduler.py | 4 +- distributed/worker.py | 6 +- 11 files changed, 555 insertions(+), 235 deletions(-) create mode 100644 distributed/deploy/adaptive_core.py create mode 100644 distributed/deploy/tests/test_adaptive_core.py create mode 100644 distributed/deploy/tests/test_slow_adaptive.py diff --git a/distributed/deploy/adaptive.py b/distributed/deploy/adaptive.py index 9b1d8511045..2efc18dfe0c 100644 --- a/distributed/deploy/adaptive.py +++ b/distributed/deploy/adaptive.py @@ -1,16 +1,14 @@ -from collections import deque import logging +import math -from tornado import gen - -from ..metrics import time -from ..utils import log_errors, PeriodicCallback, parse_timedelta +from .adaptive_core import AdaptiveCore +from ..utils import log_errors, parse_timedelta from ..protocol import pickle logger = logging.getLogger(__name__) -class Adaptive(object): +class Adaptive(AdaptiveCore): ''' Adaptively allocate workers based on scheduler load. A superclass. @@ -23,19 +21,13 @@ class Adaptive(object): Parameters ---------- - scheduler: distributed.Scheduler cluster: object - Must have scale_up and scale_down methods/coroutines - startup_cost : timedelta or str, default "1s" - Estimate of the number of seconds for nnFactor representing how costly it is to start an additional worker. - Affects quickly to adapt to high tasks per worker loads + Must have scale and scale_down methods/coroutines interval : timedelta or str, default "1000 ms" Milliseconds between checks wait_count: int, default 3 Number of consecutive times that a worker should be suggested for removal before we remove it. - scale_factor : int, default 2 - Factor to scale by when it's determined additional workers are needed target_duration: timedelta or str, default "5s" Amount of time we want a computation to take. This affects how aggressively we scale up. @@ -84,45 +76,47 @@ def __init__( self, cluster=None, interval="1s", - startup_cost="1s", - scale_factor=2, minimum=0, - maximum=None, + maximum=math.inf, wait_count=3, target_duration="5s", worker_key=None, **kwargs ): - interval = parse_timedelta(interval, default="ms") - self.worker_key = worker_key self.cluster = cluster - self.startup_cost = parse_timedelta(startup_cost, default="s") - self.scale_factor = scale_factor - if self.cluster: - self._adapt_callback = PeriodicCallback( - self._adapt, interval * 1000, io_loop=self.loop - ) - self.loop.add_callback(self._adapt_callback.start) - self._adapting = False + self.worker_key = worker_key self._workers_to_close_kwargs = kwargs - self.minimum = minimum - self.maximum = maximum - self.log = deque(maxlen=1000) - self.close_counts = {} - self.wait_count = wait_count self.target_duration = parse_timedelta(target_duration) + super().__init__( + minimum=minimum, maximum=maximum, wait_count=wait_count, interval=interval + ) + @property def scheduler(self): return self.cluster.scheduler_comm - def stop(self): - if self.cluster: - self._adapt_callback.stop() - self._adapt_callback = None - del self._adapt_callback + @property + def plan(self): + try: + return set(self.cluster.worker_spec) + except AttributeError: + return set(self.cluster.workers) + + @property + def requested(self): + return set(self.cluster.workers) + + @property + def observed(self): + return {d["name"] for d in self.cluster.scheduler_info["workers"].values()} + + async def target(self): + return await self.scheduler.adaptive_target( + target_duration=self.target_duration + ) - async def workers_to_close(self, **kwargs): + async def workers_to_close(self, target: int): """ Determine which, if any, workers should potentially be removed from the cluster. @@ -140,114 +134,30 @@ async def workers_to_close(self, **kwargs): -------- Scheduler.workers_to_close """ - if len(self.cluster.workers) <= self.minimum: - return [] - - kw = dict(self._workers_to_close_kwargs) - kw.update(kwargs) - - if self.maximum is not None and len(self.cluster.workers) > self.maximum: - kw["n"] = len(self.cluster.workers) - self.maximum - - L = await self.scheduler.workers_to_close(**kw) - if len(self.cluster.workers) - len(L) < self.minimum: - L = L[: len(self.cluster.workers) - self.minimum] - - return L + return await self.scheduler.workers_to_close( + target=target, + key=pickle.dumps(self.worker_key) if self.worker_key else None, + attribute="name", + **self._workers_to_close_kwargs + ) - async def _retire_workers(self, workers=None): - if workers is None: - workers = await self.workers_to_close( - key=pickle.dumps(self.worker_key) if self.worker_key else None, - minimum=self.minimum, - ) + async def scale_down(self, workers): if not workers: - raise gen.Return(workers) + return with log_errors(): + # Ask scheduler to cleanly retire workers await self.scheduler.retire_workers( - workers=workers, remove=True, close_workers=True + names=workers, remove=True, close_workers=True ) + # close workers more forcefully logger.info("Retiring workers %s", workers) f = self.cluster.scale_down(workers) if hasattr(f, "__await__"): await f - return workers - - async def recommendations(self, comm=None): - n = await self.scheduler.adaptive_target(target_duration=self.target_duration) - if self.maximum is not None: - n = min(self.maximum, n) - if self.minimum is not None: - n = max(self.minimum, n) - workers = set( - await self.workers_to_close( - key=pickle.dumps(self.worker_key) if self.worker_key else None, - minimum=self.minimum, - ) - ) - try: - current = len(self.cluster.worker_spec) - except AttributeError: - current = len(self.cluster.workers) - if n > current and workers: - logger.info("Attempting to scale up and scale down simultaneously.") - self.close_counts.clear() - return { - "status": "error", - "msg": "Trying to scale up and down simultaneously", - } - - elif n > current: - self.close_counts.clear() - return {"status": "up", "n": n} - - elif workers: - d = {} - to_close = [] - for w, c in self.close_counts.items(): - if w in workers: - if c >= self.wait_count: - to_close.append(w) - else: - d[w] = c - - for w in workers: - d[w] = d.get(w, 0) + 1 - - self.close_counts = d - - if to_close: - return {"status": "down", "workers": to_close} - else: - self.close_counts.clear() - return None - - async def _adapt(self): - if self._adapting: # Semaphore to avoid overlapping adapt calls - return - - self._adapting = True - try: - recommendations = await self.recommendations() - if not recommendations: - return - status = recommendations.pop("status") - if status == "up": - f = self.cluster.scale(**recommendations) - self.log.append((time(), "up", recommendations)) - if hasattr(f, "__await__"): - await f - - elif status == "down": - self.log.append((time(), "down", recommendations["workers"])) - workers = await self._retire_workers(workers=recommendations["workers"]) - finally: - self._adapting = False - - def adapt(self): - self.loop.add_callback(self._adapt) + async def scale_up(self, n): + self.cluster.scale(n) @property def loop(self): diff --git a/distributed/deploy/adaptive_core.py b/distributed/deploy/adaptive_core.py new file mode 100644 index 00000000000..6732bb20284 --- /dev/null +++ b/distributed/deploy/adaptive_core.py @@ -0,0 +1,197 @@ +import collections +import math + +from tornado.ioloop import IOLoop +import toolz + +from ..metrics import time +from ..utils import parse_timedelta, PeriodicCallback + + +class AdaptiveCore: + """ + The core logic for adaptive deployments, with none of the cluster details + + This class controls our adaptive scaling behavior. It is intended to be + sued as a super-class or mixin. It expects the following state and methods: + + **State** + + plan: set + A set of workers that we think should exist. + Here and below worker is just a token, often an address or name string + + requested: set + A set of workers that the cluster class has successfully requested from + the resource manager. We expect that resource manager to work to make + these exist. + + observed: set + A set of workers that have successfully checked in with the scheduler + + These sets are not necessarily equivalent. Often plan and requested will + be very similar (requesting is usually fast) but there may be a large delay + between requested and observed (often resource managers don't give us what + we want). + + **Functions** + + target : -> int + Returns the target number of workers that should exist. + This is often obtained by querying the scheduler + + workers_to_close : int -> Set[worker] + Given a target number of workers, + returns a set of workers that we should close when we're scaling down + + scale_up : int -> None + Scales the cluster up to a target number of workers, presumably + changing at least ``plan`` and hopefully eventually also ``requested`` + + scale_down : Set[worker] -> None + Closes the provided set of workers + + Parameters + ---------- + minimum: int + The minimum number of allowed workers + maximum: int + The maximum number of allowed workers + wait_count: int + The number of scale-down requests we should receive before actually + scaling down + interval: str + The amount of time, like ``"1s"`` between checks + """ + + def __init__( + self, + minimum: int = 0, + maximum: int = math.inf, + wait_count: int = 3, + interval: str = "1s", + ): + self.minimum = minimum + self.maximum = maximum + self.wait_count = wait_count + self.interval = parse_timedelta(interval, "seconds") if interval else interval + self.periodic_callback = None + + def f(): + self.periodic_callback = PeriodicCallback(self.adapt, self.interval * 1000) + self.periodic_callback.start() + + if self.interval: + try: + self.loop.add_callback(f) + except AttributeError: + IOLoop.current().add_callback(f) + + try: + self.plan = set() + self.requested = set() + self.observed = set() + except Exception: + pass + + # internal state + self.close_counts = collections.defaultdict(int) + self._adapting = False + self.log = collections.deque(maxlen=10000) + + def stop(self): + if self.periodic_callback: + self.periodic_callback.stop() + self.periodic_callback = None + + async def target(self) -> int: + """ The target number of workers that should exist """ + raise NotImplementedError() + + async def workers_to_close(self, target: int) -> list: + """ + Give a list of workers to close that brings us down to target workers + """ + # TODO, improve me with something that thinks about current load + return list(self.observed)[target:] + + async def safe_target(self) -> int: + """ Used internally, like target, but respects minimum/maximum """ + n = await self.target() + if n > self.maximum: + n = self.maximum + + if n < self.minimum: + n = self.minimum + + return n + + async def recommendations(self, target: int) -> dict: + """ + Make scale up/down recommendations based on current state and target + """ + plan = self.plan + requested = self.requested + observed = self.observed + + if target == len(plan): + self.close_counts.clear() + return {"status": "same"} + + elif target > len(plan): + self.close_counts.clear() + return {"status": "up", "n": target} + + elif target < len(plan): + not_yet_arrived = requested - observed + to_close = set() + if not_yet_arrived: + to_close.update((toolz.take(len(plan) - target, not_yet_arrived))) + + if target < len(plan) - len(to_close): + L = await self.workers_to_close(target=target) + to_close.update(L) + + firmly_close = set() + for w in to_close: + self.close_counts[w] += 1 + if self.close_counts[w] >= self.wait_count: + firmly_close.add(w) + + for k in list(self.close_counts): # clear out unseen keys + if k in firmly_close or k not in to_close: + del self.close_counts[k] + + if firmly_close: + return {"status": "down", "workers": list(firmly_close)} + else: + return {"status": "same"} + + async def adapt(self) -> None: + """ + Check the current state, make recommendations, call scale + + This is the main event of the system + """ + if self._adapting: # Semaphore to avoid overlapping adapt calls + return + self._adapting = True + + try: + target = await self.safe_target() + recommendations = await self.recommendations(target) + + if recommendations["status"] != "same": + self.log.append((time(), dict(recommendations))) + + status = recommendations.pop("status") + if status == "same": + return + if status == "up": + await self.scale_up(**recommendations) + if status == "down": + await self.scale_down(**recommendations) + except OSError: + self.stop() + finally: + self._adapting = False diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index c9e1aca1a87..feb0dfe63b5 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -192,7 +192,6 @@ def __init__( self._loop_runner.start() self.sync(self._start) self.sync(self._correct_state) - self.sync(self._wait_for_workers) async def _start(self): while self.status == "starting": @@ -306,7 +305,6 @@ async def _(): await self._correct_state() if self.workers: await asyncio.wait(list(self.workers.values())) # maybe there are more - await self._wait_for_workers() return self return _().__await__() @@ -367,7 +365,6 @@ def __del__(self): def __enter__(self): self.sync(self._correct_state) - self.sync(self._wait_for_workers) assert self.status == "running" return self @@ -376,6 +373,13 @@ def __exit__(self, typ, value, traceback): self._loop_runner.stop() def scale(self, n): + if len(self.worker_spec) > n: + not_yet_launched = set(self.worker_spec) - { + v["name"] for v in self.scheduler_info["workers"].values() + } + while len(self.worker_spec) > n and not_yet_launched: + del self.worker_spec[not_yet_launched.pop()] + while len(self.worker_spec) > n: self.worker_spec.popitem() @@ -411,12 +415,9 @@ def _supports_scaling(self): return not not self.new_spec async def scale_down(self, workers): - workers = set(workers) - - for k, v in self.workers.items(): - if getattr(v, "worker_address", v.address) in workers: - del self.worker_spec[k] - + for w in workers: + if w in self.worker_spec: + del self.worker_spec[w] await self scale_up = scale # backwards compatibility @@ -473,6 +474,7 @@ def dashboard_link(self): def _widget_status(self): workers = len(self.scheduler_info["workers"]) + requested = len(self.worker_spec) cores = sum(v["nthreads"] for v in self.scheduler_info["workers"].values()) memory = sum(v["memory_limit"] for v in self.scheduler_info["workers"].values()) memory = format_bytes(memory) @@ -492,13 +494,13 @@ def _widget_status(self): } - +
          Workers %d
          Workers %s
          Cores %d
          Memory %s
          """ % ( - workers, + workers if workers == requested else "%d / %d" % (workers, requested), cores, memory, ) @@ -547,6 +549,7 @@ def _widget(self): def adapt_cb(b): self.adapt(minimum=minimum.value, maximum=maximum.value) + update() adapt.on_click(adapt_cb) @@ -556,6 +559,7 @@ def scale_cb(b): with ignoring(AttributeError): self._adaptive.stop() self.scale(n) + update() scale.on_click(scale_cb) else: diff --git a/distributed/deploy/tests/test_adaptive.py b/distributed/deploy/tests/test_adaptive.py index 2d3d2235e21..261b4355251 100644 --- a/distributed/deploy/tests/test_adaptive.py +++ b/distributed/deploy/tests/test_adaptive.py @@ -1,7 +1,6 @@ from time import sleep import pytest -from toolz import frequencies, pluck from tornado import gen from tornado.ioloop import IOLoop @@ -68,51 +67,46 @@ def test_adaptive_local_cluster(loop): assert not c.nthreads() -@nodebug -@gen_test(timeout=30) -def test_adaptive_local_cluster_multi_workers(): - cluster = yield LocalCluster( +@pytest.mark.asyncio +async def test_adaptive_local_cluster_multi_workers(cleanup): + async with LocalCluster( 0, scheduler_port=0, silence_logs=False, processes=False, dashboard_address=None, asynchronous=True, - ) - try: - cluster.scheduler.allowed_failures = 1000 - alc = cluster.adapt(interval=100) - c = yield Client(cluster, asynchronous=True) - - futures = c.map(slowinc, range(100), delay=0.01) + ) as cluster: - start = time() - while not cluster.scheduler.workers: - yield gen.sleep(0.01) - assert time() < start + 15, alc.log + cluster.scheduler.allowed_failures = 1000 + adapt = cluster.adapt(interval="100 ms") + async with Client(cluster, asynchronous=True) as c: + futures = c.map(slowinc, range(100), delay=0.01) - yield c.gather(futures) - del futures + start = time() + while not cluster.scheduler.workers: + await gen.sleep(0.01) + assert time() < start + 15, adapt.log - start = time() - # while cluster.workers: - while cluster.scheduler.workers: - yield gen.sleep(0.01) - assert time() < start + 15, alc.log + await c.gather(futures) + del futures - # no workers for a while - for i in range(10): - assert not cluster.scheduler.workers - yield gen.sleep(0.05) + start = time() + # while cluster.workers: + while cluster.scheduler.workers: + await gen.sleep(0.01) + assert time() < start + 15, adapt.log - futures = c.map(slowinc, range(100), delay=0.01) - yield c.gather(futures) + # no workers for a while + for i in range(10): + assert not cluster.scheduler.workers + await gen.sleep(0.05) - finally: - yield c.close() - yield cluster.close() + futures = c.map(slowinc, range(100), delay=0.01) + await c.gather(futures) +@pytest.mark.xfail(reason="changed API") @pytest.mark.asyncio async def test_adaptive_scale_down_override(cleanup): class TestAdaptive(Adaptive): @@ -164,7 +158,7 @@ def test_min_max(): yield gen.sleep(0.2) assert len(cluster.scheduler.workers) == 1 - assert frequencies(pluck(1, adapt.log)) == {"up": 1} + assert len(adapt.log) == 1 and adapt.log[-1][1] == {"status": "up", "n": 1} futures = c.map(slowinc, range(100), delay=0.1) @@ -177,7 +171,7 @@ def test_min_max(): yield gen.sleep(0.5) assert len(cluster.scheduler.workers) == 2 assert len(cluster.workers) == 2 - assert frequencies(pluck(1, adapt.log)) == {"up": 2} + assert len(adapt.log) == 2 and all(d["status"] == "up" for _, d in adapt.log) del futures @@ -185,41 +179,35 @@ def test_min_max(): while len(cluster.scheduler.workers) != 1: yield gen.sleep(0.01) assert time() < start + 2 - assert frequencies(pluck(1, adapt.log)) == {"up": 2, "down": 1} + assert adapt.log[-1][1]["status"] == "down" finally: yield c.close() yield cluster.close() -@gen_test() -def test_avoid_churn(): +@pytest.mark.asyncio +async def test_avoid_churn(cleanup): """ We want to avoid creating and deleting workers frequently Instead we want to wait a few beats before removing a worker in case the user is taking a brief pause between work """ - cluster = yield LocalCluster( + async with LocalCluster( 0, asynchronous=True, processes=False, scheduler_port=0, silence_logs=False, dashboard_address=None, - ) - client = yield Client(cluster, asynchronous=True) - try: - adapt = cluster.adapt(interval="20 ms", wait_count=5) + ) as cluster: + async with Client(cluster, asynchronous=True) as client: + adapt = cluster.adapt(interval="20 ms", wait_count=5) - for i in range(10): - yield client.submit(slowinc, i, delay=0.040) - yield gen.sleep(0.040) + for i in range(10): + await client.submit(slowinc, i, delay=0.040) + await gen.sleep(0.040) - from toolz.curried import pipe, unique, pluck, frequencies - - assert pipe(adapt.log, unique(key=str), pluck(1), frequencies) == {"up": 1} - finally: - yield client.close() - yield cluster.close() + assert len(adapt.log) == 1 @gen_test(timeout=None) @@ -238,7 +226,7 @@ def test_adapt_quickly(): dashboard_address=None, ) client = yield Client(cluster, asynchronous=True) - adapt = cluster.adapt(interval=20, wait_count=5, maximum=10) + adapt = cluster.adapt(interval="20 ms", wait_count=5, maximum=10) try: future = client.submit(slowinc, 1, delay=0.100) yield wait(future) @@ -246,10 +234,10 @@ def test_adapt_quickly(): # Scale up when there is plenty of available work futures = client.map(slowinc, range(1000), delay=0.100) - while frequencies(pluck(1, adapt.log)) == {"up": 1}: + while len(adapt.log) == 1: yield gen.sleep(0.01) assert len(adapt.log) == 2 - assert "up" in adapt.log[-1] + assert adapt.log[-1][1]["status"] == "up" d = [x for x in adapt.log[-1] if isinstance(x, dict)][0] assert 2 < d["n"] <= adapt.maximum @@ -362,7 +350,7 @@ def test_target_duration(): dashboard_address=None, ) client = yield Client(cluster, asynchronous=True) - adaptive = cluster.adapt(interval="20ms", minimum=2, target_duration="5s") + adapt = cluster.adapt(interval="20ms", minimum=2, target_duration="5s") cluster.scheduler.task_duration["slowinc"] = 1 @@ -372,21 +360,21 @@ def test_target_duration(): futures = client.map(slowinc, range(100), delay=0.3) - while len(adaptive.log) < 2: + while len(adapt.log) < 2: yield gen.sleep(0.01) - assert adaptive.log[0][1:] == ("up", {"n": 2}) - assert adaptive.log[1][1:] == ("up", {"n": 20}) + assert adapt.log[0][1] == {"status": "up", "n": 2} + assert adapt.log[1][1] == {"status": "up", "n": 20} finally: yield client.close() yield cluster.close() -@gen_test(timeout=None) -def test_worker_keys(): +@pytest.mark.asyncio +async def test_worker_keys(cleanup): """ Ensure that redefining adapt with a lower maximum removes workers """ - cluster = yield SpecCluster( + async with SpecCluster( workers={ "a-1": {"cls": Worker}, "a-2": {"cls": Worker}, @@ -394,9 +382,7 @@ def test_worker_keys(): "b-2": {"cls": Worker}, }, asynchronous=True, - ) - - try: + ) as cluster: def key(ws): return ws.name.split("-")[0] @@ -404,12 +390,10 @@ def key(ws): cluster._adaptive_options = {"worker_key": key} adaptive = cluster.adapt(minimum=1) - yield adaptive._adapt() + await adaptive.adapt() while len(cluster.scheduler.workers) == 4: - yield gen.sleep(0.01) + await gen.sleep(0.01) names = {ws.name for ws in cluster.scheduler.workers.values()} assert names == {"a-1", "a-2"} or names == {"b-1", "b-2"} - finally: - yield cluster.close() diff --git a/distributed/deploy/tests/test_adaptive_core.py b/distributed/deploy/tests/test_adaptive_core.py new file mode 100644 index 00000000000..a073314223d --- /dev/null +++ b/distributed/deploy/tests/test_adaptive_core.py @@ -0,0 +1,90 @@ +import asyncio +import pytest + +from distributed.deploy.adaptive_core import AdaptiveCore +from distributed.metrics import time + + +class MyAdaptive(AdaptiveCore): + def __init__(self, *args, interval=None, **kwargs): + super().__init__(*args, interval=interval, **kwargs) + self._target = 0 + self._log = [] + + async def target(self): + return self._target + + async def scale_up(self, n=0): + self.plan = self.requested = set(range(n)) + + async def scale_down(self, workers=()): + for collection in [self.plan, self.requested, self.observed]: + for w in workers: + collection.discard(w) + + +@pytest.mark.asyncio +async def test_safe_target(): + adapt = MyAdaptive(minimum=1, maximum=4) + assert await adapt.safe_target() == 1 + adapt._target = 10 + assert await adapt.safe_target() == 4 + + +@pytest.mark.asyncio +async def test_scale_up(): + adapt = MyAdaptive(minimum=1, maximum=4) + await adapt.adapt() + assert adapt.log[-1][1] == {"status": "up", "n": 1} + assert adapt.plan == {0} + + adapt._target = 10 + await adapt.adapt() + assert adapt.log[-1][1] == {"status": "up", "n": 4} + assert adapt.plan == {0, 1, 2, 3} + + +@pytest.mark.asyncio +async def test_scale_down(): + adapt = MyAdaptive(minimum=1, maximum=4, wait_count=2) + adapt._target = 10 + await adapt.adapt() + assert len(adapt.log) == 1 + + adapt.observed = {0, 1, 3} # all but 2 have arrived + + adapt._target = 2 + await adapt.adapt() + assert len(adapt.log) == 1 # no change after only one call + await adapt.adapt() + assert len(adapt.log) == 2 # no change after only one call + assert adapt.log[-1][1]["status"] == "down" + assert 2 in adapt.log[-1][1]["workers"] + assert len(adapt.log[-1][1]["workers"]) == 2 + + old = list(adapt.log) + await adapt.adapt() + await adapt.adapt() + await adapt.adapt() + await adapt.adapt() + assert list(adapt.log) == old + + +@pytest.mark.asyncio +async def test_interval(): + adapt = MyAdaptive(interval="5 ms") + assert not adapt.plan + + for i in [0, 3, 1]: + start = time() + adapt._target = i + while len(adapt.plan) != i: + await asyncio.sleep(0.001) + assert time() < start + 2 + + adapt.stop() + await asyncio.sleep(0.050) + + adapt._target = 10 + await asyncio.sleep(0.020) + assert len(adapt.plan) == 1 # last value from before, unchanged diff --git a/distributed/deploy/tests/test_slow_adaptive.py b/distributed/deploy/tests/test_slow_adaptive.py new file mode 100644 index 00000000000..4f565a78289 --- /dev/null +++ b/distributed/deploy/tests/test_slow_adaptive.py @@ -0,0 +1,98 @@ +import asyncio +import pytest + +from dask.distributed import Worker, Scheduler, SpecCluster, Client +from distributed.utils_test import slowinc, cleanup # noqa: F401 +from distributed.metrics import time + + +class SlowWorker(object): + def __init__(self, *args, delay=0, **kwargs): + self.worker = Worker(*args, **kwargs) + self.delay = delay + self.status = None + + @property + def address(self): + return self.worker.address + + def __await__(self): + async def now(): + if self.status != "running": + self.worker.loop.call_later(self.delay, self.worker.start) + self.status = "running" + return self + + return now().__await__() + + async def close(self): + await self.worker.close() + self.status = "closed" + + +scheduler = {"cls": Scheduler, "options": {"port": 0}} + + +@pytest.mark.asyncio +async def test_startup(cleanup): + start = time() + async with SpecCluster( + scheduler=scheduler, + workers={ + 0: {"cls": Worker, "options": {}}, + 1: {"cls": SlowWorker, "options": {"delay": 5}}, + 2: {"cls": SlowWorker, "options": {"delay": 0}}, + }, + asynchronous=True, + ) as cluster: + assert len(cluster.workers) == len(cluster.worker_spec) == 3 + assert time() < start + 5 + assert 1 <= len(cluster.scheduler_info["workers"]) <= 2 + + async with Client(cluster, asynchronous=True) as client: + await client.wait_for_workers(n_workers=2) + + +@pytest.mark.asyncio +async def test_scale_up_down(cleanup): + start = time() + async with SpecCluster( + scheduler=scheduler, + workers={ + "slow": {"cls": SlowWorker, "options": {"delay": 5}}, + "fast": {"cls": Worker, "options": {}}, + }, + asynchronous=True, + ) as cluster: + cluster.scale(1) # remove a worker, hopefully the one we don't have + await cluster + + assert list(cluster.worker_spec) == ["fast"] + + cluster.scale(0) + await cluster + assert not cluster.worker_spec + + +@pytest.mark.asyncio +async def test_adaptive(cleanup): + start = time() + async with SpecCluster( + scheduler=scheduler, + workers={"fast": {"cls": Worker, "options": {}}}, + worker={"cls": SlowWorker, "options": {"delay": 5}}, + asynchronous=True, + ) as cluster: + cluster.adapt(minimum=1, maximum=4, target_duration="1s", interval="20ms") + async with Client(cluster, asynchronous=True) as client: + futures = client.map(slowinc, range(200), delay=0.1) + + while len(cluster.worker_spec) <= 1: + await asyncio.sleep(0.05) + + del futures + + while len(cluster.worker_spec) > 1: + await asyncio.sleep(0.05) + + assert list(cluster.worker_spec) == ["fast"] diff --git a/distributed/deploy/tests/test_spec_cluster.py b/distributed/deploy/tests/test_spec_cluster.py index 58bbbaef44d..64633428a38 100644 --- a/distributed/deploy/tests/test_spec_cluster.py +++ b/distributed/deploy/tests/test_spec_cluster.py @@ -252,7 +252,10 @@ async def test_dashboard_link(cleanup): @pytest.mark.asyncio async def test_widget(cleanup): async with SpecCluster( - workers=worker_spec, scheduler=scheduler, asynchronous=True + workers=worker_spec, + scheduler=scheduler, + asynchronous=True, + worker={"cls": Worker, "options": {"nthreads": 1}}, ) as cluster: start = time() # wait for all workers @@ -262,3 +265,6 @@ async def test_widget(cleanup): assert "3" in cluster._widget_status() assert "GB" in cluster._widget_status() + + cluster.scale(5) + assert "3 / 5" in cluster._widget_status() diff --git a/distributed/nanny.py b/distributed/nanny.py index 155fde98158..b6d8dadbf9a 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -690,7 +690,7 @@ async def run(): } ) init_result_q.close() - await worker.wait_until_closed() + await worker.finished() logger.info("Worker closed") try: diff --git a/distributed/scheduler.py b/distributed/scheduler.py index acabce22c63..e9fdbe60b51 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -2898,7 +2898,14 @@ def replicate( ) def workers_to_close( - self, comm=None, memory_ratio=None, n=None, key=None, minimum=None + self, + comm=None, + memory_ratio=None, + n=None, + key=None, + minimum=None, + target=None, + attribute="address", ): """ Find workers that we can close with low cost @@ -2925,6 +2932,11 @@ def workers_to_close( An optional callable mapping a WorkerState object to a group affiliation. Groups will be closed together. This is useful when closing workers must be done collectively, such as by hostname. + target: int + Target number of workers to have after we close + attribute : str + The attribute of the WorkerState object to return, like "address" + or "name". Defaults to "address". Examples -------- @@ -2952,6 +2964,13 @@ def workers_to_close( -------- Scheduler.retire_workers """ + if target is not None and n is None: + n = len(self.workers) - target + if n is not None: + if n < 0: + n = 0 + target = len(self.workers) - n + if n is None and memory_ratio is None: memory_ratio = 2 @@ -2976,12 +2995,12 @@ def workers_to_close( limit = sum(limit_bytes.values()) total = sum(group_bytes.values()) - def key(group): + def _key(group): is_idle = not any(ws.processing for ws in groups[group]) bytes = -group_bytes[group] return (is_idle, bytes) - idle = sorted(groups, key=key) + idle = sorted(groups, key=_key) to_close = [] n_remain = len(self.workers) @@ -2996,7 +3015,7 @@ def key(group): limit -= limit_bytes[group] - if (n is not None and len(to_close) < n) or ( + if (n is not None and n_remain - len(groups[group]) >= target) or ( memory_ratio is not None and limit >= memory_ratio * total ): to_close.append(group) @@ -3005,22 +3024,30 @@ def key(group): else: break - result = [ws.address for g in to_close for ws in groups[g]] + result = [getattr(ws, attribute) for g in to_close for ws in groups[g]] if result: logger.info("Suggest closing workers: %s", result) return result async def retire_workers( - self, comm=None, workers=None, remove=True, close_workers=False, **kwargs + self, + comm=None, + workers=None, + remove=True, + close_workers=False, + names=None, + **kwargs ): """ Gracefully retire workers from cluster Parameters ---------- workers: list (optional) - List of worker IDs to retire. + List of worker addresses to retire. If not provided we call ``workers_to_close`` which finds a good set + workers_names: list (optional) + List of worker names to retire. remove: bool (defaults to True) Whether or not to remove the worker metadata immediately or else wait for the worker to contact us @@ -3042,6 +3069,11 @@ async def retire_workers( Scheduler.workers_to_close """ with log_errors(): + if names is not None: + names = set(names) + workers = [ + ws.address for ws in self.workers.values() if ws.name in names + ] if workers is None: while True: try: @@ -3052,17 +3084,16 @@ async def retire_workers( remove=remove, close_workers=close_workers, ) - raise gen.Return(workers) + return workers except KeyError: # keys left during replicate pass - workers = {self.workers[w] for w in workers if w in self.workers} - if len(workers) > 0: - # Keys orphaned by retiring those workers - keys = set.union(*[w.has_what for w in workers]) - keys = {ts.key for ts in keys if ts.who_has.issubset(workers)} - else: - keys = set() + if not workers: + return [] + + # Keys orphaned by retiring those workers + keys = set.union(*[w.has_what for w in workers]) + keys = {ts.key for ts in keys if ts.who_has.issubset(workers)} other_workers = set(self.workers.values()) - workers if keys: diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 80cc04c81b3..0331ac0a972 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -1231,7 +1231,9 @@ def test_cancel_fire_and_forget(c, s, a, b): assert not s.tasks -@gen_cluster(client=True, Worker=Nanny, clean_kwargs={"processes": False}) +@gen_cluster( + client=True, Worker=Nanny, clean_kwargs={"processes": False, "threads": False} +) def test_log_tasks_during_restart(c, s, a, b): future = c.submit(sys.exit, 0) yield wait(future) diff --git a/distributed/worker.py b/distributed/worker.py index 672358d4460..290128c84e1 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -26,7 +26,6 @@ from toolz import pluck, partial, merge, first from tornado import gen from tornado.ioloop import IOLoop -from tornado.locks import Event from . import profile, comm from .batched import BatchedSend @@ -546,7 +545,6 @@ def __init__( self.actors = {} self.loop = loop or IOLoop.current() self.status = None - self._closed = Event() self.reconnect = reconnect self.executor = executor or ThreadPoolExecutor( self.nthreads, thread_name_prefix="Dask-Worker-Threads'" @@ -1054,7 +1052,6 @@ async def close( self.stop() self.rpc.close() - self._closed.set() self.status = "closed" await ServerNode.close(self) @@ -1084,7 +1081,8 @@ async def terminate(self, comm, report=True, **kwargs): return "OK" async def wait_until_closed(self): - await self._closed.wait() + warnings.warn("wait_until_closed has moved to finished()") + await self.finished() assert self.status == "closed" ################ From 4dc3d196baafe5c3c704d894c9cbb7d80f61f6cc Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 2 Aug 2019 09:51:45 -0700 Subject: [PATCH 0394/1550] Add keep-alive message between worker and scheduler (#2907) This is effectively a heartbeat, but much simpler and less frequent than our current heartbeats Fixes #2524 --- distributed/scheduler.py | 1 + distributed/tests/test_client.py | 2 +- distributed/tests/test_core.py | 13 ++++++++++--- distributed/worker.py | 7 +++++++ 4 files changed, 19 insertions(+), 4 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index e9fdbe60b51..8a2fe03c8d1 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1025,6 +1025,7 @@ def __init__( "missing-data": self.handle_missing_data, "long-running": self.handle_long_running, "reschedule": self.reschedule, + "keep-alive": lambda *args, **kwargs: None, } client_handlers = { diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 3eb9f39a2a4..d89562dc1ff 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -3319,7 +3319,7 @@ def test_get_foo_lost_keys(c, s, u, v, w): client=True, Worker=Nanny, worker_kwargs={"death_timeout": "500ms"}, - clean_kwargs={"threads": False}, + clean_kwargs={"threads": False, "processes": False}, ) def test_bad_tasks_fail(c, s, a, b): f = c.submit(sys.exit, 0) diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index f91b8b64367..cad622980df 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -445,9 +445,16 @@ def test_identity_inproc(): def test_ports(loop): - port = 9877 - server = Server({}, io_loop=loop) - server.listen(port) + for port in range(9877, 9887): + server = Server({}, io_loop=loop) + try: + server.listen(port) + except OSError: # port already taken? + pass + else: + break + else: + raise Exception() try: assert server.port == port diff --git a/distributed/worker.py b/distributed/worker.py index 290128c84e1..dfdd6df8e4a 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -820,6 +820,13 @@ async def _register_with_scheduler(self): self.batched_stream = BatchedSend(interval="2ms", loop=self.loop) self.batched_stream.start(comm) + pc = PeriodicCallback( + lambda: self.batched_stream.send({"op": "keep-alive"}), + 60000, + io_loop=self.io_loop, + ) + self.periodic_callbacks["keep-alive"] = pc + pc.start() self.periodic_callbacks["heartbeat"].start() self.loop.add_callback(self.handle_scheduler, comm) From 6caa30896e66501483416812d44c861da75ceab6 Mon Sep 17 00:00:00 2001 From: Martin Durant Date: Fri, 2 Aug 2019 15:11:17 -0400 Subject: [PATCH 0395/1550] Fix docstring [skip ci] (#2917) Fixes #2914 --- distributed/deploy/local.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/deploy/local.py b/distributed/deploy/local.py index 877a74587e9..20476ad8065 100644 --- a/distributed/deploy/local.py +++ b/distributed/deploy/local.py @@ -80,7 +80,7 @@ class LocalCluster(SpecCluster): Pass extra keyword arguments to Bokeh - >>> LocalCluster(service_kwargs={'bokeh': {'prefix': '/foo'}}) # doctest: +SKIP + >>> LocalCluster(service_kwargs={'dashboard': {'prefix': '/foo'}}) # doctest: +SKIP """ def __init__( From 20ba1a7405a6a5eb59e14808abc5f6ff823ab48d Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Sat, 3 Aug 2019 12:25:05 -0400 Subject: [PATCH 0396/1550] Raise informative warning when rescheduling an unknown task (#2916) --- distributed/scheduler.py | 9 ++++++++- distributed/tests/test_scheduler.py | 11 +++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 8a2fe03c8d1..f84b3d1bce4 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4364,7 +4364,14 @@ def reschedule(self, key=None, worker=None): Things may have shifted and this task may now be better suited to run elsewhere """ - ts = self.tasks[key] + try: + ts = self.tasks[key] + except KeyError: + logger.warning( + "Attempting to reschedule task {}, which was not " + "found on the scheduler. Aborting reschedule.".format(key) + ) + return if ts.state != "processing": return if worker and ts.processing_on.address != worker: diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 0331ac0a972..9035fbd8667 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -6,6 +6,7 @@ import operator import sys from time import sleep +import logging import dask from dask import delayed @@ -23,6 +24,7 @@ from distributed.worker import dumps_function, dumps_task from distributed.utils import tmpfile from distributed.utils_test import ( # noqa: F401 + captured_logger, cleanup, inc, dec, @@ -1260,6 +1262,15 @@ def test_reschedule(c, s, a, b): assert sum(future.key in a.data for future in x) <= 1 +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2) +def test_reschedule_warns(c, s, a, b): + with captured_logger(logging.getLogger("distributed.scheduler")) as sched: + s.reschedule(key="__this-key-does-not-exist__") + + assert "not found on the scheduler" in sched.getvalue() + assert "Aborting reschedule" in sched.getvalue() + + @gen_cluster(client=True) def test_get_task_status(c, s, a, b): future = c.submit(inc, 1) From fb4e48fa5cb83bd0fe32d4c1fe23645d30798c21 Mon Sep 17 00:00:00 2001 From: Martin Durant Date: Sat, 3 Aug 2019 17:33:00 -0400 Subject: [PATCH 0397/1550] Give 404 when requesting nonexistent tasks or workers (#2921) --- distributed/dashboard/scheduler_html.py | 6 ++++++ .../tests/test_scheduler_bokeh_html.py | 21 ++++++++++++++++++- 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/distributed/dashboard/scheduler_html.py b/distributed/dashboard/scheduler_html.py index 3087f323b5f..1377b037173 100644 --- a/distributed/dashboard/scheduler_html.py +++ b/distributed/dashboard/scheduler_html.py @@ -27,6 +27,9 @@ def get(self): class Worker(RequestHandler): def get(self, worker): worker = escape.url_unescape(worker) + if worker not in self.server.workers: + self.send_error(404) + return with log_errors(): self.render( "worker.html", @@ -40,6 +43,9 @@ def get(self, worker): class Task(RequestHandler): def get(self, task): task = escape.url_unescape(task) + if task not in self.server.tasks: + self.send_error(404) + return with log_errors(): self.render( "task.html", diff --git a/distributed/dashboard/tests/test_scheduler_bokeh_html.py b/distributed/dashboard/tests/test_scheduler_bokeh_html.py index 660602df09a..f2a2c880a94 100644 --- a/distributed/dashboard/tests/test_scheduler_bokeh_html.py +++ b/distributed/dashboard/tests/test_scheduler_bokeh_html.py @@ -6,7 +6,7 @@ pytest.importorskip("bokeh") from tornado.escape import url_escape -from tornado.httpclient import AsyncHTTPClient +from tornado.httpclient import AsyncHTTPClient, HTTPClientError from dask.sizeof import sizeof from distributed.utils import is_valid_xml @@ -49,6 +49,25 @@ def test_connect(c, s, a, b): assert not re.search("href=./", body) # no absolute links +@gen_cluster( + client=True, + nthreads=[], + scheduler_kwargs={"services": {("dashboard", 0): BokehScheduler}}, +) +def test_worker_404(c, s): + http_client = AsyncHTTPClient() + with pytest.raises(HTTPClientError) as err: + yield http_client.fetch( + "http://localhost:%d/info/worker/unknown" % s.services["dashboard"].port + ) + assert err.value.code == 404 + with pytest.raises(HTTPClientError) as err: + yield http_client.fetch( + "http://localhost:%d/info/task/unknown" % s.services["dashboard"].port + ) + assert err.value.code == 404 + + @gen_cluster( client=True, scheduler_kwargs={ From b68660cb873e8efb9d9b47a92e39ab78f2fd7573 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sat, 3 Aug 2019 19:20:03 -0700 Subject: [PATCH 0398/1550] Cleanup async warnings in tests (#2920) --- distributed/client.py | 5 ++- distributed/tests/test_as_completed.py | 2 +- distributed/tests/test_client.py | 51 +++++++++++++++++++------- distributed/tests/test_core.py | 2 +- distributed/tests/test_variable.py | 17 +++++---- distributed/utils_test.py | 2 +- distributed/worker.py | 4 +- 7 files changed, 55 insertions(+), 28 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index fab5ff0bf6e..93501ae2077 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -2860,7 +2860,7 @@ async def _upload_large_file(self, local_filename, remote_filename=None): def dump_to_file(dask_worker=None): if not os.path.isabs(remote_filename): - fn = os.path.join(dask_worker.local_dir, remote_filename) + fn = os.path.join(dask_worker.local_directory, remote_filename) else: fn = remote_filename with open(fn, "wb") as f: @@ -3267,7 +3267,8 @@ def scheduler_info(self, **kwargs): 'stored': 0, 'time-delay': 0.0061032772064208984}}} """ - self.sync(self._update_scheduler_info) + if not self.asynchronous: + self.sync(self._update_scheduler_info) return self._scheduler_identity def write_scheduler_file(self, scheduler_file): diff --git a/distributed/tests/test_as_completed.py b/distributed/tests/test_as_completed.py index 911ff388e06..d74d033c64a 100644 --- a/distributed/tests/test_as_completed.py +++ b/distributed/tests/test_as_completed.py @@ -204,7 +204,7 @@ def test_as_completed_with_results_async(c, s, a, b): z = c.submit(inc, 1) ac = as_completed([x, y, z], with_results=True) - y.cancel() + yield y.cancel() with pytest.raises(RuntimeError) as exc: first = yield ac.__anext__() second = yield ac.__anext__() diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index d89562dc1ff..7f0036e2a7a 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -6,6 +6,7 @@ from operator import add import os import pickle +import psutil import random import subprocess import sys @@ -1529,9 +1530,9 @@ def g(): return package_1.a, package_2.b # c.upload_file tells each worker to - # - put this file in their local_dir + # - put this file in their local_directory # - modify their sys.path to include it - # we don't care about the local_dir + # we don't care about the local_directory # but we do care about restoring the path with save_sys_modules(): @@ -1581,19 +1582,19 @@ def g(): @gen_cluster(client=True) def test_upload_large_file(c, s, a, b): - assert a.local_dir - assert b.local_dir + assert a.local_directory + assert b.local_directory with tmp_text("myfile", "abc") as fn: with tmp_text("myfile2", "def") as fn2: yield c._upload_large_file(fn, remote_filename="x") yield c._upload_large_file(fn2) for w in [a, b]: - assert os.path.exists(os.path.join(w.local_dir, "x")) - assert os.path.exists(os.path.join(w.local_dir, "myfile2")) - with open(os.path.join(w.local_dir, "x")) as f: + assert os.path.exists(os.path.join(w.local_directory, "x")) + assert os.path.exists(os.path.join(w.local_directory, "myfile2")) + with open(os.path.join(w.local_directory, "x")) as f: assert f.read() == "abc" - with open(os.path.join(w.local_dir, "myfile2")) as f: + with open(os.path.join(w.local_directory, "myfile2")) as f: assert f.read() == "def" @@ -4568,7 +4569,7 @@ def test_quiet_client_close(loop): @pytest.mark.slow def test_quiet_client_close_when_cluster_is_closed_before_client(loop): with captured_logger(logging.getLogger("tornado.application")) as logger: - cluster = LocalCluster(loop=loop, n_workers=1) + cluster = LocalCluster(loop=loop, n_workers=1, dashboard_address=":0") client = Client(cluster, loop=loop) cluster.close() client.close() @@ -5179,7 +5180,7 @@ def test_scatter_direct(s, a, b): yield gen.sleep(0.10) assert time() < start + 5 - yield c._close() + yield c.close() @pytest.mark.skipif(sys.version_info[0] < 3, reason="cloudpickle Py27 issue") @@ -5196,7 +5197,7 @@ def test_client_name(s, a, b): c = yield Client(s.address, asynchronous=True) assert any("hello-world" in name for name in list(s.clients)) - yield c._close() + yield c.close() def test_client_doesnt_close_given_loop(loop, s, a, b): @@ -5301,7 +5302,7 @@ def test(s, a, b): with pytest.raises(TypeError): yield c.run_on_scheduler(lambda: inc) finally: - yield c._close() + yield c.close() test() @@ -5324,7 +5325,7 @@ def test_de_serialization(s, a, b): with pytest.raises(TypeError): result = yield future finally: - yield c._close() + yield c.close() @gen_cluster() @@ -5340,7 +5341,7 @@ def test_de_serialization_none(s, a, b): with pytest.raises(TypeError): result = yield future finally: - yield c._close() + yield c.close() @gen_cluster() @@ -5591,5 +5592,27 @@ def test_wait_for_workers(c, s, a, b): yield w.close() +@pytest.mark.skipif(WINDOWS, reason="num_fds not supported on windows") +@pytest.mark.asyncio +@pytest.mark.parametrize("Worker", [Worker, Nanny]) +async def test_file_descriptors_dont_leak(Worker): + pytest.importorskip("pandas") + df = dask.datasets.timeseries(freq="10s", dtypes={"x": int, "y": float}) + + proc = psutil.Process() + start = proc.num_fds() + async with Scheduler(port=0, dashboard_address=":0") as s: + async with Worker(s.address, nthreads=2) as a, Worker( + s.address, nthreads=2 + ) as b: + async with Client(s.address, asynchronous=True) as c: + await df.sum().persist() + + begin = time() + while proc.num_fds() > begin: + await asyncio.sleep(0.01) + assert time() < begin + 5, (start, proc.num_fds()) + + if sys.version_info >= (3, 5): from distributed.tests.py3_test_client import * # noqa F401 diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index cad622980df..e41866d6741 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -126,7 +126,7 @@ def f(): assert isinstance(msg["exception"], ValueError) assert "'ping' handler has been explicitly disallowed" in repr(msg["exception"]) - comm.close() + yield comm.close() server.stop() res = loop.run_sync(f) diff --git a/distributed/tests/test_variable.py b/distributed/tests/test_variable.py index 88f96a241b0..6dcca9c9cf4 100644 --- a/distributed/tests/test_variable.py +++ b/distributed/tests/test_variable.py @@ -1,3 +1,4 @@ +import asyncio import random from time import sleep import sys @@ -99,7 +100,7 @@ def test_timeout_sync(client): @gen_cluster(client=True) -def test_cleanup(c, s, a, b): +async def test_cleanup(c, s, a, b): v = Variable("v") vv = Variable("v") @@ -107,17 +108,17 @@ def test_cleanup(c, s, a, b): y = c.submit(lambda x: x + 1, 20) x_key = x.key - yield v.set(x) + await v.set(x) del x - yield gen.sleep(0.1) + await gen.sleep(0.1) - t_future = xx = vv._get() - yield gen.moment - v._set(y) + t_future = xx = asyncio.ensure_future(vv._get()) + await gen.sleep(0) + asyncio.ensure_future(v.set(y)) - future = yield t_future + future = await t_future assert future.key == x_key - result = yield future + result = await future assert result == 11 diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 8293bb474e2..13639f05fa5 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -636,7 +636,7 @@ def cluster( q = mp_context.Queue() fn = "_test_worker-%s" % uuid.uuid4() kwargs = merge( - {"nthreads": 1, "local_dir": fn, "memory_limit": TOTAL_MEMORY}, + {"nthreads": 1, "local_directory": fn, "memory_limit": TOTAL_MEMORY}, worker_kwargs, ) proc = mp_context.Process( diff --git a/distributed/worker.py b/distributed/worker.py index dfdd6df8e4a..795cb93e9f3 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -711,7 +711,9 @@ def worker_address(self): @property def local_dir(self): """ For API compatibility with Nanny """ - warnings.warn("The local_dir attribute has moved to local_directory") + warnings.warn( + "The local_dir attribute has moved to local_directory", stacklevel=2 + ) return self.local_directory def get_metrics(self): From 2428cc822a51bb12832c2bcc2bea2fa001e40d30 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sun, 4 Aug 2019 12:22:08 -0700 Subject: [PATCH 0399/1550] Add documentation around spec.ProcessInterface (#2923) --- distributed/deploy/spec.py | 29 +++++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index feb0dfe63b5..af4f4a3f23d 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -24,13 +24,12 @@ class ProcessInterface: - """ An interface for Scheduler and Worker processes for use in SpecCluster - - Parameters - ---------- - loop: - A pointer to the running loop. + """ + An interface for Scheduler and Worker processes for use in SpecCluster + This interface is responsible to submit a worker or scheduler process to a + resource manager like Kubernetes, Yarn, or SLURM/PBS/SGE/... + It should implement the methods below, like ``start`` and ``close`` """ def __init__(self): @@ -49,11 +48,25 @@ async def _(): return _().__await__() async def start(self): - """ Start the process. """ + """ Submit the process to the resource manager + + For workers this doesn't have to wait until the process actually starts, + but can return once the resource manager has the request, and will work + to make the job exist in the future + + For the scheduler we will expect the scheduler's ``.address`` attribute + to be avaialble after this completes. + """ self.status = "running" async def close(self): - """ Close the process. """ + """ Close the process + + This will be called by the Cluster object when we scale down a node, + but only after we ask the Scheduler to close the worker gracefully. + This method should kill the process a bit more forcefully and does not + need to worry about shutting down gracefully + """ self.status = "closed" def __repr__(self): From 29389702a88253f6504cca0b8304a86a94c2e677 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sun, 4 Aug 2019 14:31:00 -0700 Subject: [PATCH 0400/1550] Add timeouts to processes in SSH tests (#2925) It turns out that OpenSSH doesn't pass through terminate/kill signals, so we had some zombie processes hanging around sending signals around where they shouldn't. Now we place idle and death timeouts on the launched processes to keep them in check. See https://github.com/ronf/asyncssh/issues/112 for more information on the underlying issue. --- distributed/deploy/ssh2.py | 6 ++++-- distributed/deploy/tests/test_ssh2.py | 5 +++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/distributed/deploy/ssh2.py b/distributed/deploy/ssh2.py index 189a61df5f1..064e580b111 100644 --- a/distributed/deploy/ssh2.py +++ b/distributed/deploy/ssh2.py @@ -34,11 +34,13 @@ def __init__(self, **kwargs): async def start(self): assert self.connection - weakref.finalize(self, self.proc.terminate) + weakref.finalize( + self, self.proc.kill + ) # https://github.com/ronf/asyncssh/issues/112 await super().start() async def close(self): - self.proc.terminate() + self.proc.kill() # https://github.com/ronf/asyncssh/issues/112 self.connection.close() await super().close() diff --git a/distributed/deploy/tests/test_ssh2.py b/distributed/deploy/tests/test_ssh2.py index b744d352b8b..076711bb841 100644 --- a/distributed/deploy/tests/test_ssh2.py +++ b/distributed/deploy/tests/test_ssh2.py @@ -12,7 +12,8 @@ async def test_basic(): ["127.0.0.1"] * 3, connect_kwargs=dict(known_hosts=None), asynchronous=True, - scheduler_kwargs={"port": 0}, + scheduler_kwargs={"port": 0, "idle_timeout": "5s"}, + worker_kwargs={"death_timeout": "5s"}, ) as cluster: assert len(cluster.workers) == 2 async with Client(cluster, asynchronous=True) as client: @@ -29,7 +30,7 @@ async def test_keywords(): ["127.0.0.1"] * 3, connect_kwargs=dict(known_hosts=None), asynchronous=True, - worker_kwargs={"nthreads": 2, "memory_limit": "2 GiB"}, + worker_kwargs={"nthreads": 2, "memory_limit": "2 GiB", "death_timeout": "5s"}, scheduler_kwargs={"idle_timeout": "5s", "port": 0}, ) as cluster: async with Client(cluster, asynchronous=True) as client: From e02cc4409e352e40e4128fc542b8aaed51b5a01f Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sun, 4 Aug 2019 14:31:32 -0700 Subject: [PATCH 0401/1550] Always kill processes in clean tests, even if we don't check (#2924) Also allow ValueErrors when collecting data from workers --- distributed/utils_comm.py | 8 ++++++++ distributed/utils_test.py | 19 +++++++++++-------- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/distributed/utils_comm.py b/distributed/utils_comm.py index f6b4ea36e4f..53504d11939 100644 --- a/distributed/utils_comm.py +++ b/distributed/utils_comm.py @@ -1,6 +1,7 @@ import asyncio from collections import defaultdict from itertools import cycle +import logging import random from dask.optimization import SubgraphCallable @@ -9,6 +10,8 @@ from .core import rpc from .utils import All, tokey +logger = logging.getLogger(__name__) + async def gather_from_workers(who_has, rpc, close=True, serializers=None, who=None): """ Gather data directly from peers @@ -72,6 +75,11 @@ async def gather_from_workers(who_has, rpc, close=True, serializers=None, who=No r = await c except EnvironmentError: missing_workers.add(worker) + except ValueError as e: + logger.info( + "Got an unexpected error while collecting from workers: %s", e + ) + missing_workers.add(worker) else: response.update(r["data"]) finally: diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 13639f05fa5..505d269cae9 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -1447,21 +1447,24 @@ def check_thread_leak(): @contextmanager -def check_process_leak(): +def check_process_leak(check=True): for proc in mp_context.active_children(): proc.terminate() yield - for i in range(100): - if not set(mp_context.active_children()): - break + if check: + for i in range(100): + if not set(mp_context.active_children()): + break + else: + sleep(0.2) else: - sleep(0.2) - else: - assert not mp_context.active_children() + assert not mp_context.active_children() _cleanup_dangling() + for proc in mp_context.active_children(): + proc.terminate() @contextmanager @@ -1524,7 +1527,7 @@ def null(): with check_thread_leak() if threads else null(): with pristine_loop() as loop: - with check_process_leak() if processes else null(): + with check_process_leak(check=processes): with check_instances() if instances else null(): with check_active_rpc(loop, timeout): reset_config() From be88537c4c6040e171c0644f507dae2a3b1e1ead Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 5 Aug 2019 14:54:37 -0700 Subject: [PATCH 0402/1550] Add real-time CPU utilization plot to dashboard (#2922) This matches the styling of the nprocessing and memory use plots --- distributed/dashboard/scheduler.py | 50 ++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/distributed/dashboard/scheduler.py b/distributed/dashboard/scheduler.py index 8396bbcb6ae..3332f4fc27b 100644 --- a/distributed/dashboard/scheduler.py +++ b/distributed/dashboard/scheduler.py @@ -18,6 +18,7 @@ Range1d, Plot, Quad, + Span, value, LinearAxis, NumeralTickFormatter, @@ -309,6 +310,8 @@ def __init__(self, scheduler, width=600, **kwargs): "nbytes": [1, 2], "nbytes-half": [0.5, 1], "nbytes_text": ["1B", "2B"], + "cpu": [1, 2], + "cpu-half": [0.5, 1], "worker": ["a", "b"], "y": [1, 2], "nbytes-color": ["blue", "blue"], @@ -353,6 +356,32 @@ def __init__(self, scheduler, width=600, **kwargs): ) rect.nonselection_glyph = None + cpu = figure( + title="CPU Utilization", + tools="", + id="bk-cpu-worker-plot", + width=int(width / 2), + name="cpu_hist", + **kwargs + ) + rect = cpu.rect( + source=self.source, + x="cpu-half", + y="y", + width="cpu", + height=1, + color="blue", + ) + rect.nonselection_glyph = None + hundred_span = Span( + location=100, + dimension="height", + line_color="gray", + line_dash="dashed", + line_width=3, + ) + cpu.add_layout(hundred_span) + nbytes.axis[0].ticker = BasicTicker(mantissas=[1, 256, 512], base=1024) nbytes.xaxis[0].formatter = NumeralTickFormatter(format="0.0 b") nbytes.xaxis.major_label_orientation = -math.pi / 12 @@ -382,10 +411,17 @@ def __init__(self, scheduler, width=600, **kwargs): hover.point_policy = "follow_mouse" nbytes.add_tools(hover) + hover = HoverTool() + hover.tooltips = "@worker : @cpu %" + hover.point_policy = "follow_mouse" + cpu.add_tools(hover) + self.processing_figure = processing self.nbytes_figure = nbytes + self.cpu_figure = cpu processing.y_range = nbytes.y_range + cpu.y_range = nbytes.y_range @without_property_validation def update(self): @@ -393,6 +429,9 @@ def update(self): workers = list(self.scheduler.workers.values()) y = list(range(len(workers))) + + cpu = [int(ws.metrics["cpu"]) for ws in workers] + nprocessing = [len(ws.processing) for ws in workers] processing_color = [] for ws in workers: @@ -427,6 +466,8 @@ def update(self): if any(nprocessing) or self.last + 1 < now: self.last = now result = { + "cpu": cpu, + "cpu-half": [c / 2 for c in cpu], "nprocessing": nprocessing, "nprocessing-half": [np / 2 for np in nprocessing], "nprocessing-color": processing_color, @@ -1495,6 +1536,14 @@ def individual_nbytes_doc(scheduler, extra, doc): doc.theme = BOKEH_THEME +def individual_cpu_doc(scheduler, extra, doc): + current_load = CurrentLoad(scheduler, sizing_mode="stretch_both") + current_load.update() + add_periodic_callback(doc, current_load, 100) + doc.add_root(current_load.cpu_figure) + doc.theme = BOKEH_THEME + + def individual_nprocessing_doc(scheduler, extra, doc): current_load = CurrentLoad(scheduler, sizing_mode="stretch_both") current_load.update() @@ -1619,6 +1668,7 @@ def __init__(self, scheduler, io_loop=None, prefix="", **kwargs): "/individual-profile": individual_profile_doc, "/individual-profile-server": individual_profile_server_doc, "/individual-nbytes": individual_nbytes_doc, + "/individual-cpu": individual_cpu_doc, "/individual-nprocessing": individual_nprocessing_doc, "/individual-workers": individual_workers_doc, } From 17889a976df6c9891de5ffcca4715b6b8adfb76b Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 5 Aug 2019 17:17:32 -0700 Subject: [PATCH 0403/1550] Add aenter/aexit protocols to ProcessInterface (#2927) --- distributed/deploy/spec.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index af4f4a3f23d..c2de8b9e2f8 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -72,6 +72,13 @@ async def close(self): def __repr__(self): return "<%s: status=%s>" % (type(self).__name__, self.status) + async def __aenter__(self): + await self + return self + + async def __aexit__(self, *args, **kwargs): + await self.close() + class SpecCluster(Cluster): """ Cluster that requires a full specification of workers From f6c8818a39d8163de8152b8cd7bef8f034404d8d Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 6 Aug 2019 12:26:09 -0700 Subject: [PATCH 0404/1550] Move core functionality from SpecCluster to Cluster (#2913) This moves standard functionality from SpecClsuter to the Cluster superclass. It also removes the assumption that the Scheduler will be local to the Cluster class. --- distributed/deploy/cluster.py | 316 ++++++++++++++++++++++------------ distributed/deploy/spec.py | 227 +----------------------- 2 files changed, 208 insertions(+), 335 deletions(-) diff --git a/distributed/deploy/cluster.py b/distributed/deploy/cluster.py index 58c6ce73644..e85ea2bc3dd 100644 --- a/distributed/deploy/cluster.py +++ b/distributed/deploy/cluster.py @@ -1,7 +1,7 @@ +import asyncio from datetime import timedelta import logging import threading -from weakref import ref from dask.utils import format_bytes from tornado import gen @@ -13,6 +13,8 @@ log_errors, ignoring, sync, + Log, + Logs, thread_state, format_dashboard_link, ) @@ -24,45 +26,88 @@ class Cluster(object): """ Superclass for cluster objects - This expects a local Scheduler defined on the object. It provides - common methods and an IPython widget display. + This class contains common functionality for Dask Cluster manager classes. - Clusters inheriting from this class should provide the following: + To implement this class, you must provide - 1. A local ``Scheduler`` object at ``.scheduler`` - 2. scale_up and scale_down methods as defined below:: + 1. A ``scheduler_comm`` attribute, which is a connection to the scheduler + following the ``distributed.core.rpc`` API. + 2. Implement ``scale``, which takes an integer and scales the cluster to + that many workers, or else set ``_supports_scaling`` to False - def scale_up(self, n: int): - ''' Brings total worker count up to ``n`` ''' + For that, should should get the following: - def scale_down(self, workers: List[str]): - ''' Close the workers with the given addresses ''' - - This will provide a general ``scale`` method as well as an IPython widget - for display. + 1. A standard ``__repr__`` + 2. A live IPython widget + 3. Adaptive scaling + 4. Integration with dask-labextension + 5. A ``scheduler_info`` attribute which contains an up-to-date copy of + ``Scheduler.identity()``, which is used for much of the above + 6. Methods to gather logs + """ - Examples - -------- + _supports_scaling = True - >>> from distributed.deploy import Cluster - >>> class MyCluster(cluster): - ... def scale_up(self, n): - ... ''' Bring the total worker count up to n ''' - ... pass - ... def scale_down(self, workers): - ... ''' Close the workers with the given addresses ''' - ... pass + def __init__(self, asynchronous): + self.scheduler_info = {} + self.periodic_callbacks = {} + self._asynchronous = asynchronous - >>> cluster = MyCluster() - >>> cluster.scale(5) # scale manually - >>> cluster.adapt(minimum=1, maximum=100) # scale automatically + self.status = "created" - See Also - -------- - LocalCluster: a simple implementation with local workers - """ - - def adapt(self, Adaptive=Adaptive, **kwargs): + async def _start(self): + comm = await self.scheduler_comm.live_comm() + await comm.write({"op": "subscribe_worker_status"}) + self.scheduler_info = await comm.read() + self._watch_worker_status_comm = comm + self._watch_worker_status_task = asyncio.ensure_future( + self._watch_worker_status(comm) + ) + self.status = "running" + + async def _close(self): + if self.status == "closed": + return + + await self._watch_worker_status_comm.close() + await self._watch_worker_status_task + + for pc in self.periodic_callbacks.values(): + pc.stop() + self.scheduler_comm.close_rpc() + + self.status = "closed" + + def close(self, timeout=None): + with ignoring(RuntimeError): # loop closed during process shutdown + return self.sync(self._close, callback_timeout=timeout) + + def __del__(self): + if self.status != "closed": + with ignoring(AttributeError, RuntimeError): # during closing + self.loop.add_callback(self.close) + + async def _watch_worker_status(self, comm): + """ Listen to scheduler for updates on adding and removing workers """ + while True: + try: + msgs = await comm.read() + except OSError: + break + + for op, msg in msgs: + if op == "add": + workers = msg.pop("workers") + self.scheduler_info["workers"].update(workers) + self.scheduler_info.update(msg) + elif op == "remove": + del self.scheduler_info["workers"][msg] + else: + raise ValueError("Invalid op", op, msg) + + await comm.close() + + def adapt(self, Adaptive=Adaptive, **kwargs) -> Adaptive: """ Turn on adaptivity For keyword arguments see dask.distributed.Adaptive @@ -79,17 +124,7 @@ def adapt(self, Adaptive=Adaptive, **kwargs): self._adaptive = Adaptive(self, **self._adaptive_options) return self._adaptive - @property - def scheduler_address(self): - return self.scheduler.address - - @property - def dashboard_link(self): - host = self.scheduler.address.split("://")[1].split(":")[0] - port = self.scheduler.services["dashboard"].port - return format_dashboard_link(host, port) - - def scale(self, n): + def scale(self, n: int) -> None: """ Scale cluster to n workers Parameters @@ -100,29 +135,81 @@ def scale(self, n): Example ------- >>> cluster.scale(10) # scale cluster to ten workers + """ + raise NotImplementedError() - See Also - -------- - Cluster.scale_up - Cluster.scale_down + @property + def asynchronous(self): + return ( + self._asynchronous + or getattr(thread_state, "asynchronous", False) + or hasattr(self.loop, "_thread_identity") + and self.loop._thread_identity == threading.get_ident() + ) + + def sync(self, func, *args, asynchronous=None, callback_timeout=None, **kwargs): + asynchronous = asynchronous or self.asynchronous + if asynchronous: + future = func(*args, **kwargs) + if callback_timeout is not None: + future = gen.with_timeout(timedelta(seconds=callback_timeout), future) + return future + else: + return sync(self.loop, func, *args, **kwargs) + + async def _logs(self, scheduler=True, workers=True): + logs = Logs() + + if scheduler: + L = await self.scheduler_comm.logs() + logs["Scheduler"] = Log("\n".join(line for level, line in L)) + + if workers: + d = await self.scheduler_comm.worker_logs(workers=workers) + for k, v in d.items(): + logs[k] = Log("\n".join(line for level, line in v)) + + return logs + + def logs(self, scheduler=True, workers=True): + """ Return logs for the scheduler and workers + + Parameters + ---------- + scheduler : boolean + Whether or not to collect logs for the scheduler + workers : boolean or Iterable[str], optional + A list of worker addresses to select. + Defaults to all workers if `True` or no workers if `False` + + Returns + ------- + logs: Dict[str] + A dictionary of logs, with one item for the scheduler and one for + each worker """ - with log_errors(): - if n >= len(self.scheduler.workers): - self.scheduler.loop.add_callback(self.scale_up, n) - else: - to_close = self.scheduler.workers_to_close( - n=len(self.scheduler.workers) - n - ) - logger.debug("Closing workers: %s", to_close) - self.scheduler.loop.add_callback( - self.scheduler.retire_workers, workers=to_close - ) - self.scheduler.loop.add_callback(self.scale_down, to_close) + return self.sync(self._logs, scheduler=scheduler, workers=workers) + + @property + def dashboard_link(self): + try: + port = self.scheduler_info["services"]["dashboard"] + except KeyError: + return "" + else: + host = self.scheduler_address.split("://")[1].split(":")[0] + return format_dashboard_link(host, port) def _widget_status(self): - workers = len(self.scheduler.workers) - cores = sum(ws.nthreads for ws in self.scheduler.workers.values()) - memory = sum(ws.memory_limit for ws in self.scheduler.workers.values()) + workers = len(self.scheduler_info["workers"]) + if hasattr(self, "worker_spec"): + requested = len(self.worker_spec) + elif hasattr(self, "workers"): + requested = len(self.workers) + else: + requested = workers + cores = sum(v["nthreads"] for v in self.scheduler_info["workers"].values()) + memory = sum(v["memory_limit"] for v in self.scheduler_info["workers"].values()) memory = format_bytes(memory) text = """
          @@ -140,13 +227,13 @@ def _widget_status(self): } - - - + + +
          Workers %d
          Cores %d
          Memory %s
          Workers %s
          Cores %d
          Memory %s
          """ % ( - workers, + workers if workers == requested else "%d / %d" % (workers, requested), cores, memory, ) @@ -163,11 +250,10 @@ def _widget(self): layout = Layout(width="150px") - if "dashboard" in self.scheduler.services: - link = self.dashboard_link + if self.dashboard_link: link = '

          Dashboard: %s

          \n' % ( - link, - link, + self.dashboard_link, + self.dashboard_link, ) else: link = "" @@ -178,46 +264,49 @@ def _widget(self): status = HTML(self._widget_status(), layout=Layout(min_width="150px")) - request = IntText(0, description="Workers", layout=layout) - scale = Button(description="Scale", layout=layout) + if self._supports_scaling: + request = IntText(0, description="Workers", layout=layout) + scale = Button(description="Scale", layout=layout) - minimum = IntText(0, description="Minimum", layout=layout) - maximum = IntText(0, description="Maximum", layout=layout) - adapt = Button(description="Adapt", layout=layout) + minimum = IntText(0, description="Minimum", layout=layout) + maximum = IntText(0, description="Maximum", layout=layout) + adapt = Button(description="Adapt", layout=layout) - accordion = Accordion( - [HBox([request, scale]), HBox([minimum, maximum, adapt])], - layout=Layout(min_width="500px"), - ) - accordion.selected_index = None - accordion.set_title(0, "Manual Scaling") - accordion.set_title(1, "Adaptive Scaling") + accordion = Accordion( + [HBox([request, scale]), HBox([minimum, maximum, adapt])], + layout=Layout(min_width="500px"), + ) + accordion.selected_index = None + accordion.set_title(0, "Manual Scaling") + accordion.set_title(1, "Adaptive Scaling") - box = VBox([title, HBox([status, accordion]), dashboard]) + def adapt_cb(b): + self.adapt(minimum=minimum.value, maximum=maximum.value) + update() - self._cached_widget = box + adapt.on_click(adapt_cb) - def adapt_cb(b): - self.adapt(minimum=minimum.value, maximum=maximum.value) + def scale_cb(b): + with log_errors(): + n = request.value + with ignoring(AttributeError): + self._adaptive.stop() + self.scale(n) + update() - adapt.on_click(adapt_cb) - - def scale_cb(b): - with log_errors(): - n = request.value - with ignoring(AttributeError): - self._adaptive.stop() - self.scale(n) + scale.on_click(scale_cb) + else: + accordion = HTML("") - scale.on_click(scale_cb) + box = VBox([title, HBox([status, accordion]), dashboard]) - scheduler_ref = ref(self.scheduler) + self._cached_widget = box def update(): status.value = self._widget_status() - pc = PeriodicCallback(update, 500, io_loop=self.scheduler.loop) - self.scheduler.periodic_callbacks["cluster-repr"] = pc + pc = PeriodicCallback(update, 500, io_loop=self.loop) + self.periodic_callbacks["cluster-repr"] = pc pc.start() return box @@ -225,21 +314,20 @@ def update(): def _ipython_display_(self, **kwargs): return self._widget()._ipython_display_(**kwargs) - @property - def asynchronous(self): - return ( - self._asynchronous - or getattr(thread_state, "asynchronous", False) - or hasattr(self.loop, "_thread_identity") - and self.loop._thread_identity == threading.get_ident() + def __repr__(self): + return "%s(%r, workers=%d)" % ( + type(self).__name__, + self.scheduler_address, + len(self.scheduler_info["workers"]), ) - def sync(self, func, *args, asynchronous=None, callback_timeout=None, **kwargs): - asynchronous = asynchronous or self.asynchronous - if asynchronous: - future = func(*args, **kwargs) - if callback_timeout is not None: - future = gen.with_timeout(timedelta(seconds=callback_timeout), future) - return future - else: - return sync(self.loop, func, *args, **kwargs) + async def __aenter__(self): + await self + return self + + async def __aexit__(self, typ, value, traceback): + await self.close() + + @property + def scheduler_address(self): + return self.scheduler_comm.address diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index c2de8b9e2f8..70a413fe1c0 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -4,21 +4,10 @@ import weakref from tornado import gen -from dask.utils import format_bytes from .cluster import Cluster -from ..comm import connect from ..core import rpc, CommClosedError -from ..utils import ( - log_errors, - LoopRunner, - silence_logging, - ignoring, - Log, - Logs, - PeriodicCallback, - format_dashboard_link, -) +from ..utils import LoopRunner, silence_logging, ignoring from ..scheduler import Scheduler from ..security import Security @@ -191,11 +180,8 @@ def __init__( self.new_spec = copy.copy(worker) self.workers = {} self._i = 0 - self._asynchronous = asynchronous self.security = security or Security() self.scheduler_comm = None - self.scheduler_info = {} - self.periodic_callbacks = {} if silence_logs: self._old_logging_level = silence_logging(level=silence_logs) @@ -203,11 +189,12 @@ def __init__( self._loop_runner = LoopRunner(loop=loop, asynchronous=asynchronous) self.loop = self._loop_runner.loop - self.status = "created" self._instances.add(self) self._correct_state_waiting = None self._name = name or type(self).__name__ + super().__init__(asynchronous=asynchronous) + if not self.asynchronous: self._loop_runner.start() self.sync(self._start) @@ -241,37 +228,7 @@ async def _start(self): self.scheduler.address, connection_args=self.security.get_connection_args("client"), ) - comm = await connect( - self.scheduler_address, - connection_args=self.security.get_connection_args("client"), - ) - await comm.write({"op": "subscribe_worker_status"}) - self.scheduler_info = await comm.read() - self._watch_worker_status_comm = comm - self._watch_worker_status_task = asyncio.ensure_future( - self._watch_worker_status(comm) - ) - self.status = "running" - - async def _watch_worker_status(self, comm): - """ Listen to scheduler for updates on adding and removing workers """ - while True: - try: - msgs = await comm.read() - except OSError: - break - - for op, msg in msgs: - if op == "add": - workers = msg.pop("workers") - self.scheduler_info["workers"].update(workers) - self.scheduler_info.update(msg) - elif op == "remove": - del self.scheduler_info["workers"][msg] - else: - raise ValueError("Invalid op", op, msg) - - await comm.close() + await super()._start() def _correct_state(self): if self._correct_state_waiting: @@ -341,13 +298,6 @@ async def _wait_for_workers(self): raise gen.TimeoutError("Worker unexpectedly closed") await asyncio.sleep(0.1) - async def __aenter__(self): - await self - return self - - async def __aexit__(self, typ, value, traceback): - await self.close() - async def _close(self): while self.status == "closing": await asyncio.sleep(0.1) @@ -355,33 +305,20 @@ async def _close(self): return self.status = "closing" - for pc in self.periodic_callbacks.values(): - pc.stop() - self.scale(0) await self._correct_state() async with self._lock: with ignoring(CommClosedError): await self.scheduler_comm.close(close_workers=True) + await self.scheduler.close() - await self._watch_worker_status_comm.close() - await self._watch_worker_status_task for w in self._created: assert w.status == "closed" - self.scheduler_comm.close_rpc() if hasattr(self, "_old_logging_level"): silence_logging(self._old_logging_level) - self.status = "closed" - - def close(self, timeout=None): - with ignoring(RuntimeError): # loop closed during process shutdown - return self.sync(self._close, callback_timeout=timeout) - - def __del__(self): - if self.status != "closed": - self.loop.add_callback(self.close) + await super()._close() def __enter__(self): self.sync(self._correct_state) @@ -449,158 +386,6 @@ def __repr__(self): len(self.workers), ) - async def _logs(self, scheduler=True, workers=True): - logs = Logs() - - if scheduler: - L = await self.scheduler_comm.logs() - logs["Scheduler"] = Log("\n".join(line for level, line in L)) - - if workers: - d = await self.scheduler_comm.worker_logs(workers=workers) - for k, v in d.items(): - logs[k] = Log("\n".join(line for level, line in v)) - - return logs - - def logs(self, scheduler=True, workers=True): - """ Return logs for the scheduler and workers - - Parameters - ---------- - scheduler : boolean - Whether or not to collect logs for the scheduler - workers : boolean or Iterable[str], optional - A list of worker addresses to select. - Defaults to all workers if `True` or no workers if `False` - - Returns - ------- - logs: Dict[str] - A dictionary of logs, with one item for the scheduler and one for - each worker - """ - return self.sync(self._logs, scheduler=scheduler, workers=workers) - - @property - def dashboard_link(self): - try: - port = self.scheduler_info["services"]["dashboard"] - except KeyError: - return "" - else: - host = self.scheduler_address.split("://")[1].split(":")[0] - return format_dashboard_link(host, port) - - def _widget_status(self): - workers = len(self.scheduler_info["workers"]) - requested = len(self.worker_spec) - cores = sum(v["nthreads"] for v in self.scheduler_info["workers"].values()) - memory = sum(v["memory_limit"] for v in self.scheduler_info["workers"].values()) - memory = format_bytes(memory) - text = """ -
          - - - - - -
          Workers %s
          Cores %d
          Memory %s
          -
          -""" % ( - workers if workers == requested else "%d / %d" % (workers, requested), - cores, - memory, - ) - return text - - def _widget(self): - """ Create IPython widget for display within a notebook """ - try: - return self._cached_widget - except AttributeError: - pass - - from ipywidgets import Layout, VBox, HBox, IntText, Button, HTML, Accordion - - layout = Layout(width="150px") - - if self.dashboard_link: - link = '

          Dashboard: %s

          \n' % ( - self.dashboard_link, - self.dashboard_link, - ) - else: - link = "" - - title = "

          %s

          " % type(self).__name__ - title = HTML(title) - dashboard = HTML(link) - - status = HTML(self._widget_status(), layout=Layout(min_width="150px")) - - if self._supports_scaling: - request = IntText(0, description="Workers", layout=layout) - scale = Button(description="Scale", layout=layout) - - minimum = IntText(0, description="Minimum", layout=layout) - maximum = IntText(0, description="Maximum", layout=layout) - adapt = Button(description="Adapt", layout=layout) - - accordion = Accordion( - [HBox([request, scale]), HBox([minimum, maximum, adapt])], - layout=Layout(min_width="500px"), - ) - accordion.selected_index = None - accordion.set_title(0, "Manual Scaling") - accordion.set_title(1, "Adaptive Scaling") - - def adapt_cb(b): - self.adapt(minimum=minimum.value, maximum=maximum.value) - update() - - adapt.on_click(adapt_cb) - - def scale_cb(b): - with log_errors(): - n = request.value - with ignoring(AttributeError): - self._adaptive.stop() - self.scale(n) - update() - - scale.on_click(scale_cb) - else: - accordion = HTML("") - - box = VBox([title, HBox([status, accordion]), dashboard]) - - self._cached_widget = box - - def update(): - status.value = self._widget_status() - - pc = PeriodicCallback(update, 500, io_loop=self.loop) - self.periodic_callbacks["cluster-repr"] = pc - pc.start() - - return box - - def _ipython_display_(self, **kwargs): - return self._widget()._ipython_display_(**kwargs) - @atexit.register def close_clusters(): From cf10db7b6a4fd091c2e1385162e3d36ab59c8f6e Mon Sep 17 00:00:00 2001 From: Jim Crist Date: Tue, 6 Aug 2019 15:02:16 -0500 Subject: [PATCH 0405/1550] Fixup black string normalization (#2929) After running black, several places in our codebase were rewritten from something like ``` raise ValueError("part one of the message " "part two") ``` to ``` raise ValueError("part one of the message " "part two") ``` This fixes those cases, removing the unnecessary two-part string. --- distributed/cli/dask_scheduler.py | 2 +- distributed/cli/dask_worker.py | 2 +- distributed/client.py | 10 ++++------ distributed/comm/addressing.py | 2 +- distributed/comm/tests/test_ucx.py | 2 +- distributed/diskutils.py | 2 +- distributed/protocol/keras.py | 2 +- distributed/scheduler.py | 6 ++---- distributed/tests/test_worker.py | 2 +- distributed/utils_perf.py | 2 +- distributed/worker.py | 2 +- 11 files changed, 15 insertions(+), 19 deletions(-) diff --git a/distributed/cli/dask_scheduler.py b/distributed/cli/dask_scheduler.py index a74f76102b9..29de26d7b4d 100755 --- a/distributed/cli/dask_scheduler.py +++ b/distributed/cli/dask_scheduler.py @@ -215,7 +215,7 @@ def del_pid_file(): port=port, dashboard_address=dashboard_address if dashboard else None, service_kwargs={"dashboard": {"prefix": dashboard_prefix}}, - **kwargs, + **kwargs ) logger.info("Local Directory: %26s", local_directory) logger.info("-" * 47) diff --git a/distributed/cli/dask_worker.py b/distributed/cli/dask_worker.py index eef1d648d40..790b8b3a9ab 100755 --- a/distributed/cli/dask_worker.py +++ b/distributed/cli/dask_worker.py @@ -84,7 +84,7 @@ "--listen-address", type=str, default=None, - help="The address to which the worker binds. " "Example: tcp://0.0.0.0:9000", + help="The address to which the worker binds. Example: tcp://0.0.0.0:9000", ) @click.option( "--contact-address", diff --git a/distributed/client.py b/distributed/client.py index 93501ae2077..9f59582b019 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -905,7 +905,7 @@ async def _start(self, timeout=no_default, **kwargs): pass except Exception: logger.info( - "Tried to start cluster and received an error. " "Proceeding.", + "Tried to start cluster and received an error. Proceeding.", exc_info=True, ) address = self.cluster.scheduler_address @@ -2383,9 +2383,7 @@ def _graph_to_futures( dsk3 = {k: v for k, v in dsk2.items() if k is not v} for future in extra_futures: if future.client is not self: - msg = ( - "Inputs contain futures that were created by " "another client." - ) + msg = "Inputs contain futures that were created by another client." raise ValueError(msg) if restrictions: @@ -3485,7 +3483,7 @@ def to_packages(d): errs.append("%s\n%s" % (pkg, asciitable(["", "version"], rows))) raise ValueError( - "Mismatched versions found\n" "\n" "%s" % ("\n\n".join(errs)) + "Mismatched versions found\n\n%s" % ("\n\n".join(errs)) ) return result @@ -3967,7 +3965,7 @@ async def _wait(fs, timeout=None, return_when=ALL_COMPLETED): wait_for = Any else: raise NotImplementedError( - "Only return_when='ALL_COMPLETED' and 'FIRST_COMPLETED' are " "supported" + "Only return_when='ALL_COMPLETED' and 'FIRST_COMPLETED' are supported" ) future = wait_for({f._state.wait() for f in fs}) diff --git a/distributed/comm/addressing.py b/distributed/comm/addressing.py index 8480134997c..21a23e1ef6e 100644 --- a/distributed/comm/addressing.py +++ b/distributed/comm/addressing.py @@ -123,7 +123,7 @@ def get_address_host_port(addr, strict=False): return backend.get_address_host_port(loc) except NotImplementedError: raise ValueError( - "don't know how to extract host and port " "for address %r" % (addr,) + "don't know how to extract host and port for address %r" % (addr,) ) diff --git a/distributed/comm/tests/test_ucx.py b/distributed/comm/tests/test_ucx.py index 8a0e8927cf6..4bb4a341552 100644 --- a/distributed/comm/tests/test_ucx.py +++ b/distributed/comm/tests/test_ucx.py @@ -261,7 +261,7 @@ def test_ucx_localcluster(loop, processes): threads_per_worker=1, processes=processes, loop=loop, - **kwargs, + **kwargs ) as cluster: with Client(cluster) as client: x = client.submit(inc, 1) diff --git a/distributed/diskutils.py b/distributed/diskutils.py index 64dcf1dfc12..075ec7750c8 100644 --- a/distributed/diskutils.py +++ b/distributed/diskutils.py @@ -237,7 +237,7 @@ def new_work_dir(self, **kwargs): self._purge_leftovers() except OSError: logger.error( - "Failed to clean up lingering worker directories " "in path: %s ", + "Failed to clean up lingering worker directories in path: %s ", exc_info=True, ) return WorkDir(self, **kwargs) diff --git a/distributed/protocol/keras.py b/distributed/protocol/keras.py index 7471a3dbc93..020ce1cae3b 100644 --- a/distributed/protocol/keras.py +++ b/distributed/protocol/keras.py @@ -9,7 +9,7 @@ def serialize_keras_model(model): if keras.__version__ < "1.2.0": raise ImportError( - "Need Keras >= 1.2.0. " "Try pip install keras --upgrade --no-deps" + "Need Keras >= 1.2.0. Try pip install keras --upgrade --no-deps" ) header = model._updated_config() diff --git a/distributed/scheduler.py b/distributed/scheduler.py index f84b3d1bce4..65e93d3e59c 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -2366,9 +2366,7 @@ def handle_long_running(self, key=None, worker=None, compute_duration=None): ws = ts.processing_on if ws is None: - logger.debug( - "Received long-running signal from duplicate task. " "Ignoring." - ) + logger.debug("Received long-running signal from duplicate task. Ignoring.") return if compute_duration: @@ -4730,7 +4728,7 @@ def check_worker_ttl(self): for ws in self.workers.values(): if ws.last_seen < now - self.worker_ttl: logger.warning( - "Worker failed to heartbeat within %s seconds. " "Closing: %s", + "Worker failed to heartbeat within %s seconds. Closing: %s", self.worker_ttl, ws, ) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 52e92d474ce..64e774f582a 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -464,7 +464,7 @@ def test_Executor(c, s): @pytest.mark.skip( - reason="Other tests leak memory, so process-level checks" "trigger immediately" + reason="Other tests leak memory, so process-level checks trigger immediately" ) @gen_cluster( client=True, diff --git a/distributed/utils_perf.py b/distributed/utils_perf.py index 048d9092d49..c2257f38fb0 100644 --- a/distributed/utils_perf.py +++ b/distributed/utils_perf.py @@ -42,7 +42,7 @@ def collect(self): elapsed = max(collect_start - self.last_collect, MIN_RUNTIME) if self.last_gc_duration / elapsed < self.max_in_gc_frac: self.logger.debug( - "Calling gc.collect(). %0.3fs elapsed since " "previous call.", elapsed + "Calling gc.collect(). %0.3fs elapsed since previous call.", elapsed ) gc.collect() self.last_collect = collect_start diff --git a/distributed/worker.py b/distributed/worker.py index 795cb93e9f3..63d3ed92a21 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -3354,7 +3354,7 @@ async def run(server, comm, function, args=(), kwargs={}, is_coro=None, wait=Tru except Exception as e: logger.warning( - " Run Failed\n" "Function: %s\n" "args: %s\n" "kwargs: %s\n", + "Run Failed\nFunction: %s\nargs: %s\nkwargs: %s\n", str(funcname(function))[:1000], convert_args_to_str(args, max_len=1000), convert_kwargs_to_str(kwargs, max_len=1000), From b1ba71a28aae83e4b9668d471b87a4aee762e2bf Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 6 Aug 2019 14:52:27 -0700 Subject: [PATCH 0406/1550] Change TCP.close to a coroutine to avoid task pending warning (#2930) Previously this triggered an intermittent error --- distributed/comm/tcp.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/distributed/comm/tcp.py b/distributed/comm/tcp.py index 36783102b69..bd76d0e6946 100644 --- a/distributed/comm/tcp.py +++ b/distributed/comm/tcp.py @@ -268,13 +268,17 @@ def write(self, msg, serializers=None, on_error="message"): return sum(map(nbytes, frames)) - async def close(self): + @gen.coroutine + def close(self): + # We use gen.coroutine here rather than async def to avoid errors like + # Task was destroyed but it is pending! + # Triggered by distributed.deploy.tests.test_local::test_silent_startup stream, self.stream = self.stream, None if stream is not None and not stream.closed(): try: # Flush the stream's write buffer by waiting for a last write. if stream.writing(): - await stream.write(b"") + yield stream.write(b"") stream.socket.shutdown(socket.SHUT_RDWR) except EnvironmentError: pass From 8d7e1664127e3967f12b8fbbd9348d118a7b4b8b Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 6 Aug 2019 17:45:29 -0700 Subject: [PATCH 0407/1550] Wrap offload in gen.coroutine (#2934) Previously we would return the bare concurrent.future.Future which was not awaitable. Now we rely on Tornado's gen.coroutine logic to handle this. Fixes https://github.com/dask/distributed/issues/2928 --- distributed/comm/utils.py | 16 +--------------- distributed/utils.py | 16 ++++++++++++++++ distributed/utils_test.py | 4 ++-- distributed/worker.py | 2 +- 4 files changed, 20 insertions(+), 18 deletions(-) diff --git a/distributed/comm/utils.py b/distributed/comm/utils.py index dcc9e9a8b1a..70cd2b4cd27 100644 --- a/distributed/comm/utils.py +++ b/distributed/comm/utils.py @@ -1,12 +1,10 @@ -from concurrent.futures import ThreadPoolExecutor import logging import socket -import weakref from tornado import gen from .. import protocol -from ..utils import get_ip, get_ipv6, nbytes +from ..utils import get_ip, get_ipv6, nbytes, offload logger = logging.getLogger(__name__) @@ -17,18 +15,6 @@ FRAME_OFFLOAD_THRESHOLD = 10 * 1024 ** 2 # 10 MB -try: - _offload_executor = ThreadPoolExecutor( - max_workers=1, thread_name_prefix="Dask-Offload" - ) -except TypeError: - _offload_executor = ThreadPoolExecutor(max_workers=1) -weakref.finalize(_offload_executor, _offload_executor.shutdown) - - -def offload(fn, *args, **kwargs): - return _offload_executor.submit(fn, *args, **kwargs) - @gen.coroutine def to_frames(msg, serializers=None, on_error="message", context=None): diff --git a/distributed/utils.py b/distributed/utils.py index cdc5c4d1ae9..c8ea8d648eb 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -1,6 +1,7 @@ import asyncio import atexit from collections import deque +from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager from datetime import timedelta import functools @@ -1471,3 +1472,18 @@ def convert_value(v): def is_valid_xml(text): return xml.etree.ElementTree.fromstring(text) is not None + + +try: + _offload_executor = ThreadPoolExecutor( + max_workers=1, thread_name_prefix="Dask-Offload" + ) +except TypeError: + _offload_executor = ThreadPoolExecutor(max_workers=1) + +weakref.finalize(_offload_executor, _offload_executor.shutdown) + + +@gen.coroutine +def offload(fn, *args, **kwargs): + return (yield _offload_executor.submit(fn, *args, **kwargs)) diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 505d269cae9..e6b11ce2898 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -41,7 +41,6 @@ from .client import default_client, _global_clients, Client from .compatibility import WINDOWS from .comm import Comm -from .comm.utils import offload from .config import initialize_logging from .core import connect, rpc, CommClosedError from .deploy import SpecCluster @@ -60,6 +59,7 @@ sync, iscoroutinefunction, thread_state, + _offload_executor, ) from .worker import Worker, TOTAL_MEMORY from .nanny import Nanny @@ -80,7 +80,7 @@ } -offload(lambda: None).result() # create thread during import +_offload_executor.submit(lambda: None).result() # create thread during import @pytest.fixture(scope="session") diff --git a/distributed/worker.py b/distributed/worker.py index 63d3ed92a21..9f686cfa508 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -30,7 +30,6 @@ from . import profile, comm from .batched import BatchedSend from .comm import get_address_host, connect -from .comm.utils import offload from .comm.addressing import address_from_user_args from .core import error_message, CommClosedError, send_recv, pingpong, coerce_to_address from .diskutils import WorkSpace @@ -56,6 +55,7 @@ thread_state, json_load_robust, key_split, + offload, PeriodicCallback, parse_bytes, parse_timedelta, From bafc34fdc82949b404f10dcd4dcdbbb32d74c47d Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 7 Aug 2019 10:54:07 -0700 Subject: [PATCH 0408/1550] Redirect setup docs to docs.dask.org [skip ci] (#2936) --- docs/source/conf.py | 3 +- docs/source/index.rst | 4 +- docs/source/quickstart.rst | 2 +- docs/source/related-work.rst | 2 +- docs/source/setup.rst | 332 ----------------------------------- 5 files changed, 6 insertions(+), 337 deletions(-) delete mode 100644 docs/source/setup.rst diff --git a/docs/source/conf.py b/docs/source/conf.py index c8ffc0ae50d..afa33400fdc 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -389,7 +389,8 @@ # https://tech.signavio.com/2017/managing-sphinx-redirects redirect_files = [ # old html, new html - ("joblib.html", "https://ml.dask.org/joblib.html") + ("joblib.html", "https://ml.dask.org/joblib.html"), + ("setup.html", "https://docs.dask.org/en/latest/setup.html"), ] diff --git a/docs/source/index.rst b/docs/source/index.rst index cd27e9b4123..09257be3f58 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -28,7 +28,7 @@ In particular it meets the following needs: Python standard library. Compatible with `dask`_ API for parallel algorithms * **Easy Setup:** As a Pure Python package distributed is ``pip`` installable - and easy to :doc:`set up ` on your own cluster. + and easy to :doc:`set up `_ on your own cluster. .. _`concurrent.futures`: https://www.python.org/dev/peps/pep-3148/ .. _`dask`: https://dask.org @@ -77,7 +77,7 @@ Contents install quickstart - setup + Setup client api faq diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index 3d1e326f528..4437f77a1ea 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -48,7 +48,7 @@ Launch a Client and point it to the IP/port of the scheduler. >>> from dask.distributed import Client >>> client = Client('127.0.0.1:8786') -See :doc:`setup ` for advanced use. +See `setup documentation `_ for advanced use. Map and Submit Functions diff --git a/docs/source/related-work.rst b/docs/source/related-work.rst index 4de458bad66..07b08c29787 100644 --- a/docs/source/related-work.rst +++ b/docs/source/related-work.rst @@ -185,7 +185,7 @@ IPython Parallel has the following advantages over ``distributed`` serve in system administration tasks. 3. Deployment help: IPython Parallel has mechanisms built-in to aid deployment on SGE, MPI, etc.. Distributed does not have any such sugar, - though is fairly simple to :doc:`set up ` by hand. + though is fairly simple to `set up `_ by hand. 4. Various other advantages: Over the years IPython parallel has accrued a variety of helpful features like IPython interaction magics, ``@parallel`` decorators, etc.. diff --git a/docs/source/setup.rst b/docs/source/setup.rst deleted file mode 100644 index f1158712901..00000000000 --- a/docs/source/setup.rst +++ /dev/null @@ -1,332 +0,0 @@ -Setup Network -============= - -A ``dask.distributed`` network consists of one ``Scheduler`` node and several -``Worker`` nodes. One can set these up in a variety of ways - - -Using the Command Line ----------------------- - -We launch the ``dask-scheduler`` executable in one process and the -``dask-worker`` executable in several processes, possibly on different -machines. - -Launch ``dask-scheduler`` on one node:: - - $ dask-scheduler - Start scheduler at 192.168.0.1:8786 - -Then launch ``dask-worker`` on the rest of the nodes, providing the address to the -node that hosts ``dask-scheduler``:: - - $ dask-worker 192.168.0.1:8786 - Start worker at: 192.168.0.2:12345 - Registered with Scheduler at: 192.168.0.1:8786 - - $ dask-worker 192.168.0.1:8786 - Start worker at: 192.168.0.3:12346 - Registered with Scheduler at: 192.168.0.1:8786 - - $ dask-worker 192.168.0.1:8786 - Start worker at: 192.168.0.4:12347 - Registered with Scheduler at: 192.168.0.1:8786 - -There are various mechanisms to deploy these executables on a cluster, ranging -from manualy SSH-ing into all of the nodes to more automated systems like -SGE/SLURM/Torque or Yarn/Mesos. Additionally, cluster SSH tools exist to -send the same commands to many machines. One example is `tmux-cssh`__. - -.. note:: - - - The scheduler and worker both need to accept TCP connections. By default - the scheduler uses port 8786 and the worker binds to a random open port. - If you are behind a firewall then you may have to open particular ports or - tell Dask to use particular ports with the ``--port`` and ``-worker-port`` - keywords. Other ports like 8787, 8788, and 8789 are also useful to keep - open for the diagnostic web interfaces. - - More information about relevant ports is available by looking at the help - pages with ``dask-scheduler --help`` and ``dask-worker --help`` - -__ https://github.com/dennishafemann/tmux-cssh - - -Using SSH ---------- - -The convenience script ``dask-ssh`` opens several SSH connections to your -target computers and initializes the network accordingly. You can -give it a list of hostnames or IP addresses:: - - $ dask-ssh 192.168.0.1 192.168.0.2 192.168.0.3 192.168.0.4 - -Or you can use normal UNIX grouping:: - - $ dask-ssh 192.168.0.{1,2,3,4} - -Or you can specify a hostfile that includes a list of hosts:: - - $ cat hostfile.txt - 192.168.0.1 - 192.168.0.2 - 192.168.0.3 - 192.168.0.4 - - $ dask-ssh --hostfile hostfile.txt - -The ``dask-ssh`` utility depends on the ``paramiko``:: - - pip install paramiko - - -Using a Shared Network File System and a Job Scheduler ------------------------------------------------------- - -Some clusters benefit from a shared network file system (NFS) and can use this -to communicate the scheduler location to the workers:: - - dask-scheduler --scheduler-file /path/to/scheduler.json - - dask-worker --scheduler-file /path/to/scheduler.json - dask-worker --scheduler-file /path/to/scheduler.json - -.. code-block:: python - - >>> client = Client(scheduler_file='/path/to/scheduler.json') - -This can be particularly useful when deploying ``dask-scheduler`` and -``dask-worker`` processes using a job scheduler like -``SGE/SLURM/Torque/etc..`` Here is an example using SGE's ``qsub`` command:: - - # Start a dask-scheduler somewhere and write connection information to file - qsub -b y /path/to/dask-scheduler --scheduler-file /path/to/scheduler.json - - # Start 100 dask-worker processes in an array job pointing to the same file - qsub -b y -t 1-100 /path/to/dask-worker --scheduler-file /path/to/scheduler.json - -Note, the ``--scheduler-file`` option is *only* valuable if your scheduler and -workers share a standard POSIX file system. - - -Using MPI ---------- - -You can launch a Dask network using ``mpirun`` or ``mpiexec`` and the -``dask-mpi`` command line executable. - -.. code-block:: bash - - mpirun --np 4 dask-mpi --scheduler-file /path/to/scheduler.json - -.. code-block:: python - - from dask.distributed import Client - client = Client(scheduler_file='/path/to/scheduler.json') - -This depends on the `mpi4py `_ library. It only -uses MPI to start the Dask cluster, and not for inter-node communication. You -may want to specify a high-bandwidth network interface like infiniband using -the ``--interface`` keyword - -.. code-block:: bash - - mpirun --np 4 dask-mpi --nthreads 1 \ - --interface ib0 \ - --scheduler-file /path/to/scheduler.json - -Using the Python API --------------------- - -Alternatively you can start up the ``distributed.scheduler.Scheduler`` and -``distributed.worker.Worker`` objects within a Python session manually. - -Start the Scheduler, provide the listening port (defaults to 8786) and Tornado -IOLoop (defaults to ``IOLoop.current()``) - -.. code-block:: python - - from distributed import Scheduler - from tornado.ioloop import IOLoop - from threading import Thread - - loop = IOLoop.current() - t = Thread(target=loop.start, daemon=True) - t.start() - - s = Scheduler(loop=loop) - s.start('tcp://:8786') # Listen on TCP port 8786 - -On other nodes start worker processes that point to the URL of the scheduler. - -.. code-block:: python - - from distributed import Worker - from tornado.ioloop import IOLoop - from threading import Thread - - loop = IOLoop.current() - t = Thread(target=loop.start, daemon=True) - t.start() - - w = Worker('tcp://127.0.0.1:8786', loop=loop) - w.start() # choose randomly assigned port - -Alternatively, replace ``Worker`` with ``Nanny`` if you want your workers to be -managed in a separate process by a local nanny process. This allows workers to -restart themselves in case of failure, provides some additional monitoring, and -is useful when coordinating many workers that should live in different -processes to avoid the GIL_. - -.. _GIL: https://docs.python.org/3/glossary.html#term-gil - - -Using LocalCluster ------------------- - -You can do the work above easily using :doc:`LocalCluster`. - -.. code-block:: python - - from distributed import LocalCluster - c = LocalCluster(processes=False) - -A scheduler will be available under ``c.scheduler`` and a list of workers under -``c.workers``. There is an IOLoop running in a background thread. - - -Using AWS ---------- - -See `Cloud Deployments`_ for the latest information on deploying to Amazon -cloud. - -.. _`Cloud Deployments`: https://docs.dask.org/en/latest/setup/cloud.html - - -Using Google Cloud ------------------- - -See the dask-kubernetes_ project to easily launch clusters on `Google Kubernetes -Engine`_. - -.. _dask-kubernetes: https://github.com/dask/dask-kubernetes -.. _`Google Kubernetes Engine`: https://cloud.google.com/kubernetes-engine/ - -Cluster Resource Managers -------------------------- - -Dask.distributed has been deployed on dozens of different cluster resource -managers. This section contains links to some external projects, scripts, and -instructions that may serve as useful starting points. - -Kubernetes -~~~~~~~~~~ - -* https://github.com/martindurant/dask-kubernetes -* https://github.com/ogrisel/docker-distributed -* https://github.com/hammerlab/dask-distributed-on-kubernetes/ - -Marathon -~~~~~~~~ - -* https://github.com/mrocklin/dask-marathon - -DRMAA (SGE, SLURM, Torque, etc..) -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -* https://github.com/dask/dask-drmaa -* https://github.com/mfouesneau/dasksge - -YARN -~~~~ - -* https://github.com/dask/dask-yarn -* https://knit.readthedocs.io/en/latest/ - - -Software Environment --------------------- - -The workers and clients should all share the same software environment. That -means that they should all have access to the same libraries and that those -libraries should be the same version. Dask generally assumes that it can call -a function on any worker with the same outcome (unless explicitly told -otherwise.) - -This is typically enforced through external means, such as by having a network -file system (NFS) mount for libraries, by starting the ``dask-worker`` -processes in equivalent Docker_ containers, using Conda_ environments, or -through any of the other means typically employed by cluster administrators. - -.. _Docker: https://www.docker.com/ -.. _Conda: http://conda.pydata.org/docs/ - - -Windows -~~~~~~~ - -.. note:: - - - Running a ``dask-scheduler`` on Windows architectures is supported for only a - limited number of workers (roughly 100). This is a detail of the underlying tcp server - implementation and is discussed `here`__. - - - Running ``dask-worker`` processes on Windows is well supported, performant, and without limit. - -If you wish to run in a primarily Windows environment, it is recommended -to run a ``dask-scheduler`` on a linux or MacOSX environment, with ``dask-worker`` workers -on the Windows boxes. This works because the scheduler environment is de-coupled from that of -the workers. - -__ https://github.com/jfisteus/ztreamy/issues/26 - - -Customizing initialization --------------------------- - -Both ``dask-scheduler`` and ``dask-worker`` support a ``--preload`` option that -allows custom initialization of each scheduler/worker respectively. A module or -python file passed as a ``--preload`` value is guaranteed to be imported before -establishing any connection. A ``dask_setup(service)`` function is called if -found, with a ``Scheduler`` or ``Worker`` instance as the argument. As the -service stops, ``dask_teardown(service)`` is called if present. - -To support additional configuration a single ``--preload`` module may register -additional command-line arguments by exposing ``dask_setup`` as a Click_ -command. This command will be used to parse additional arguments provided to -``dask-worker`` or ``dask-scheduler`` and will be called before service -initialization. - -.. _Click: http://click.pocoo.org/ - - -As an example, consider the following file that creates a -:doc:`scheduler plugin ` and registers it with the scheduler - -.. code-block:: python - - # scheduler-setup.py - import click - - from distributed.diagnostics.plugin import SchedulerPlugin - - class MyPlugin(SchedulerPlugin): - def __init__(self, print_count): - self.print_count = print_count - SchedulerPlugin.__init__(self) - - def add_worker(self, scheduler=None, worker=None, **kwargs): - print("Added a new worker at:", worker) - if self.print_count and scheduler is not None: - print("Total workers:", len(scheduler.workers)) - - @click.command() - @click.option("--print-count/--no-print-count", default=False) - def dask_setup(scheduler, print_count): - plugin = MyPlugin(print_count) - scheduler.add_plugin(plugin) - -We can then run this preload script by referring to its filename (or module name -if it is on the path) when we start the scheduler:: - - dask-scheduler --preload scheduler-setup.py --print-count From 8ee867be3b0de454467f736963bf7e5501aa3815 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Wed, 7 Aug 2019 14:29:21 -0400 Subject: [PATCH 0409/1550] Fixes Worker docstring formatting [skip ci] (#2939) --- distributed/worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/worker.py b/distributed/worker.py index 9f686cfa508..f9b6348ebac 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -97,7 +97,7 @@ class Worker(ServerNode): $ dask-worker scheduler-ip:port - Use the ``--help`` flag to see more options + Use the ``--help`` flag to see more options:: $ dask-worker --help From 4d98bb5d6e8b45270c81e45a5662f1b1d4edee24 Mon Sep 17 00:00:00 2001 From: Jim Crist Date: Thu, 8 Aug 2019 10:29:21 -0500 Subject: [PATCH 0410/1550] Import from collections.abc (#2938) Silences deprecation warnings about importing from collections instead of collections.abc. --- distributed/client.py | 3 ++- distributed/publish.py | 2 +- distributed/scheduler.py | 3 ++- distributed/tests/test_as_completed.py | 2 +- distributed/worker.py | 3 ++- 5 files changed, 8 insertions(+), 5 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 9f59582b019..678c3b4dbbe 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -1,5 +1,6 @@ import atexit -from collections import defaultdict, Iterator +from collections import defaultdict +from collections.abc import Iterator from concurrent.futures import ThreadPoolExecutor, CancelledError from concurrent.futures._base import DoneAndNotDoneFutures from contextlib import contextmanager diff --git a/distributed/publish.py b/distributed/publish.py index ea65efb4e74..c899b9fbaaa 100644 --- a/distributed/publish.py +++ b/distributed/publish.py @@ -1,4 +1,4 @@ -from collections import MutableMapping +from collections.abc import MutableMapping from .utils import log_errors, tokey diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 65e93d3e59c..809a5bd303e 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1,5 +1,6 @@ import asyncio -from collections import defaultdict, deque, OrderedDict, Mapping, Set +from collections import defaultdict, deque, OrderedDict +from collections.abc import Mapping, Set from datetime import timedelta from functools import partial import itertools diff --git a/distributed/tests/test_as_completed.py b/distributed/tests/test_as_completed.py index d74d033c64a..45833b302e1 100644 --- a/distributed/tests/test_as_completed.py +++ b/distributed/tests/test_as_completed.py @@ -1,5 +1,5 @@ from concurrent.futures import CancelledError -from collections import Iterator +from collections.abc import Iterator from operator import add import queue import random diff --git a/distributed/worker.py b/distributed/worker.py index f9b6348ebac..dce00f706f7 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1,6 +1,7 @@ import asyncio import bisect -from collections import defaultdict, deque, MutableMapping +from collections import defaultdict, deque +from collections.abc import MutableMapping from datetime import timedelta import heapq import logging From a55515569d4c5da734e5b14ae414cd342c37ed7b Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 8 Aug 2019 21:24:01 -0400 Subject: [PATCH 0411/1550] Pass GPU diagnostics from worker to scheduler (#2932) This does a few things: 1. Use `pynvml` to collect information about any CUDA GPUs present 2. Optionally add those metrics to the worker's initial handshake and heartbeats 3. Collect that information in the scheduler in the WorkerState object For now these just hang out in the scheduler information, but in the future they might be used for dashboards, or possibly scheduling decisions in the future. I believe that everything gpu-specific here is fairly well separated and generalized (others should be able to follow this pattern to add more diagnostics relatively easily) but it would be good to hear from others on if this is out of scope. --- distributed/diagnostics/nvml.py | 20 ++++++++++++ distributed/scheduler.py | 8 +++++ distributed/tests/test_worker.py | 12 ++++++++ distributed/worker.py | 53 +++++++++++++++++++++++++++++--- 4 files changed, 88 insertions(+), 5 deletions(-) create mode 100644 distributed/diagnostics/nvml.py diff --git a/distributed/diagnostics/nvml.py b/distributed/diagnostics/nvml.py new file mode 100644 index 00000000000..25a11cde6b0 --- /dev/null +++ b/distributed/diagnostics/nvml.py @@ -0,0 +1,20 @@ +import pynvml + +pynvml.nvmlInit() +count = pynvml.nvmlDeviceGetCount() + +handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in range(count)] + + +def real_time(): + return { + "utilization": [pynvml.nvmlDeviceGetUtilizationRates(h).gpu for h in handles], + "memory-used": [pynvml.nvmlDeviceGetMemoryInfo(h).used for h in handles], + } + + +def one_time(): + return { + "memory-total": [pynvml.nvmlDeviceGetMemoryInfo(h).total for h in handles], + "name": [pynvml.nvmlDeviceGetName(h).decode() for h in handles], + } diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 809a5bd303e..8fc4a828bdf 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -211,6 +211,7 @@ class WorkerState(object): __slots__ = ( "actors", "address", + "extra", "has_what", "last_seen", "local_directory", @@ -240,6 +241,7 @@ def __init__( local_directory=None, services=None, nanny=None, + extra=None, ): self.address = address self.pid = pid @@ -263,6 +265,8 @@ def __init__( self.resources = {} self.used_resources = {} + self.extra = extra or {} + @property def host(self): return get_address_host(self.address) @@ -278,6 +282,7 @@ def clean(self): local_directory=self.local_directory, services=self.services, nanny=self.nanny, + extra=self.extra, ) ws.processing = {ts.key for ts in self.processing} return ws @@ -306,6 +311,7 @@ def identity(self): "services": self.services, "metrics": self.metrics, "nanny": self.nanny, + **self.extra, } @property @@ -1386,6 +1392,7 @@ async def add_worker( services=None, local_directory=None, nanny=None, + extra=None, ): """ Add a new worker to the cluster """ with log_errors(): @@ -1406,6 +1413,7 @@ async def add_worker( local_directory=local_directory, services=services, nanny=nanny, + extra=extra, ) if name in self.aliases: diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 64e774f582a..4dab232487f 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1536,3 +1536,15 @@ async def test_lifetime_stagger(c, s, a, b): assert a.lifetime != b.lifetime assert 8 <= a.lifetime <= 12 assert 8 <= b.lifetime <= 12 + + +@gen_cluster() +async def test_gpu_metrics(s, a, b): + pytest.importorskip("pynvml") + from distributed.diagnostics.nvml import count + + assert "gpu" in a.metrics + assert len(s.workers[a.address].metrics["gpu"]["memory-used"]) == count + + assert "gpu" in a.startup_information + assert len(s.workers[a.address].extra["gpu"]["name"]) == count diff --git a/distributed/worker.py b/distributed/worker.py index dce00f706f7..17d56aec79b 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -82,6 +82,10 @@ DEFAULT_EXTENSIONS = [PubSubWorkerExtension] +DEFAULT_METRICS = {} + +DEFAULT_STARTUP_INFORMATION = {} + class Worker(ServerNode): """ Worker node in a Dask distributed cluster @@ -306,7 +310,8 @@ def __init__( contact_address=None, memory_monitor_interval="200ms", extensions=None, - metrics=None, + metrics=DEFAULT_METRICS, + startup_information=DEFAULT_STARTUP_INFORMATION, data=None, interface=None, host=None, @@ -577,6 +582,9 @@ def __init__( ) self.metrics = dict(metrics) if metrics else {} + self.startup_information = ( + dict(startup_information) if startup_information else {} + ) self.low_level_profiler = low_level_profiler @@ -717,7 +725,7 @@ def local_dir(self): ) return self.local_directory - def get_metrics(self): + async def get_metrics(self): core = dict( executing=len(self.executing), in_memory=len(self.data), @@ -725,10 +733,24 @@ def get_metrics(self): in_flight=len(self.in_flight_tasks), bandwidth=self.bandwidth, ) - custom = {k: metric(self) for k, metric in self.metrics.items()} + custom = {} + for k, metric in self.metrics.items(): + result = metric(self) + if hasattr(result, "__await__"): + result = await result + custom[k] = result return merge(custom, self.monitor.recent(), core) + async def get_startup_information(self): + result = {} + for k, f in self.startup_information.items(): + v = f(self) + if hasattr(v, "__await__"): + v = await v + result[k] = v + return result + def identity(self, comm=None): return { "type": type(self).__name__, @@ -786,7 +808,8 @@ async def _register_with_scheduler(self): services=self.service_ports, nanny=self.nanny, pid=os.getpid(), - metrics=self.get_metrics(), + metrics=await self.get_metrics(), + extra=await self.get_startup_information(), ), serializers=["msgpack"], ) @@ -840,7 +863,9 @@ async def heartbeat(self): try: start = time() response = await self.scheduler.heartbeat_worker( - address=self.contact_address, now=time(), metrics=self.get_metrics() + address=self.contact_address, + now=time(), + metrics=await self.get_metrics(), ) end = time() middle = (start + end) / 2 @@ -3369,3 +3394,21 @@ async def run(server, comm, function, args=(), kwargs={}, is_coro=None, wait=Tru _global_workers = Worker._instances + +try: + from .diagnostics import nvml +except ImportError: + pass +else: + + @gen.coroutine + def gpu_metric(worker): + result = yield offload(nvml.real_time) + return result + + DEFAULT_METRICS["gpu"] = gpu_metric + + def gpu_startup(worker): + return nvml.one_time() + + DEFAULT_STARTUP_INFORMATION["gpu"] = gpu_startup From b27f7b7ff4215c92b2e8f9e1dad5c3b2165c61ef Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 12 Aug 2019 08:31:58 -0400 Subject: [PATCH 0412/1550] Add GPUCurrentLoad dashboard plots (#2944) --- distributed/dashboard/nvml.py | 187 ++++++++++++++++++ distributed/dashboard/scheduler.py | 53 ++--- .../dashboard/tests/test_scheduler_bokeh.py | 19 +- 3 files changed, 221 insertions(+), 38 deletions(-) create mode 100644 distributed/dashboard/nvml.py diff --git a/distributed/dashboard/nvml.py b/distributed/dashboard/nvml.py new file mode 100644 index 00000000000..7fd628dd469 --- /dev/null +++ b/distributed/dashboard/nvml.py @@ -0,0 +1,187 @@ +import math + +from .components import DashboardComponent, add_periodic_callback + +from bokeh.plotting import figure +from bokeh.models import ( + ColumnDataSource, + BasicTicker, + NumeralTickFormatter, + TapTool, + OpenURL, + HoverTool, +) +from tornado import escape +from dask.utils import format_bytes +from ..utils import log_errors +from .scheduler import update, applications, BOKEH_THEME +from .utils import without_property_validation + + +class GPUCurrentLoad(DashboardComponent): + """ How many tasks are on each worker """ + + def __init__(self, scheduler, width=600, **kwargs): + with log_errors(): + self.last = 0 + self.scheduler = scheduler + self.source = ColumnDataSource( + { + "memory": [1, 2], + "memory-half": [0.5, 1], + "memory_text": ["1B", "2B"], + "utilization": [1, 2], + "utilization-half": [0.5, 1], + "worker": ["a", "b"], + "gpu-index": [0, 0], + "y": [1, 2], + "escaped_worker": ["a", "b"], + } + ) + + memory = figure( + title="GPU Memory", + tools="", + id="bk-gpu-memory-worker-plot", + width=int(width / 2), + name="gpu_memory_histogram", + **kwargs + ) + rect = memory.rect( + source=self.source, + x="memory-half", + y="y", + width="memory", + height=1, + color="#76B900", + ) + rect.nonselection_glyph = None + + utilization = figure( + title="GPU Utilization", + tools="", + id="bk-gpu-utilization-worker-plot", + width=int(width / 2), + name="gpu_utilization_histogram", + **kwargs + ) + rect = utilization.rect( + source=self.source, + x="utilization-half", + y="y", + width="utilization", + height=1, + color="#76B900", + ) + rect.nonselection_glyph = None + + memory.axis[0].ticker = BasicTicker(mantissas=[1, 256, 512], base=1024) + memory.xaxis[0].formatter = NumeralTickFormatter(format="0.0 b") + memory.xaxis.major_label_orientation = -math.pi / 12 + memory.x_range.start = 0 + + for fig in [memory, utilization]: + fig.xaxis.minor_tick_line_alpha = 0 + fig.yaxis.visible = False + fig.ygrid.visible = False + + tap = TapTool( + callback=OpenURL(url="./info/worker/@escaped_worker.html") + ) + fig.add_tools(tap) + + fig.toolbar.logo = None + fig.toolbar_location = None + fig.yaxis.visible = False + + hover = HoverTool() + hover.tooltips = "@worker : @utilization %" + hover.point_policy = "follow_mouse" + utilization.add_tools(hover) + + hover = HoverTool() + hover.tooltips = "@worker : @memory_text" + hover.point_policy = "follow_mouse" + memory.add_tools(hover) + + self.memory_figure = memory + self.utilization_figure = utilization + + self.utilization_figure.y_range = memory.y_range + self.utilization_figure.x_range.start = 0 + self.utilization_figure.x_range.end = 100 + + @without_property_validation + def update(self): + with log_errors(): + workers = list(self.scheduler.workers.values()) + + utilization = [] + memory = [] + gpu_index = [] + y = [] + memory_total = 0 + memory_max = 0 + worker = [] + i = 0 + + for ws in workers: + info = ws.extra["gpu"] + metrics = ws.metrics["gpu"] + for j, (u, mem_used, mem_total) in enumerate( + zip( + metrics["utilization"], + metrics["memory-used"], + info["memory-total"], + ) + ): + memory_max = max(memory_max, mem_total) + memory_total += mem_total + utilization.append(int(u)) + memory.append(mem_used) + worker.append(ws.address) + gpu_index.append(j) + y.append(i) + i += 1 + + memory_text = [format_bytes(m) for m in memory] + + result = { + "memory": memory, + "memory-half": [m / 2 for m in memory], + "memory_text": memory_text, + "utilization": utilization, + "utilization-half": [u / 2 for u in utilization], + "worker": worker, + "gpu-index": gpu_index, + "y": y, + "escaped_worker": [escape.url_escape(w) for w in worker], + } + + self.memory_figure.title.text = "GPU Memory: %s / %s" % ( + format_bytes(sum(memory)), + format_bytes(memory_total), + ) + self.memory_figure.x_range.end = memory_max + + update(self.source, result) + + +def gpu_memory_doc(scheduler, extra, doc): + gpu_load = GPUCurrentLoad(scheduler, sizing_mode="stretch_both") + gpu_load.update() + add_periodic_callback(doc, gpu_load, 100) + doc.add_root(gpu_load.memory_figure) + doc.theme = BOKEH_THEME + + +def gpu_utilization_doc(scheduler, extra, doc): + gpu_load = GPUCurrentLoad(scheduler, sizing_mode="stretch_both") + gpu_load.update() + add_periodic_callback(doc, gpu_load, 100) + doc.add_root(gpu_load.utilization_figure) + doc.theme = BOKEH_THEME + + +applications["/individual-gpu-memory"] = gpu_memory_doc +applications["/individual-gpu-utilization"] = gpu_utilization_doc diff --git a/distributed/dashboard/scheduler.py b/distributed/dashboard/scheduler.py index 3332f4fc27b..3dd108f5775 100644 --- a/distributed/dashboard/scheduler.py +++ b/distributed/dashboard/scheduler.py @@ -1651,28 +1651,7 @@ def __init__(self, scheduler, io_loop=None, prefix="", **kwargs): self.server_kwargs["prefix"] = prefix or None - self.apps = { - "/system": systemmonitor_doc, - "/stealing": stealing_doc, - "/workers": workers_doc, - "/events": events_doc, - "/counters": counters_doc, - "/tasks": tasks_doc, - "/status": status_doc, - "/profile": profile_doc, - "/profile-server": profile_server_doc, - "/graph": graph_doc, - "/individual-task-stream": individual_task_stream_doc, - "/individual-progress": individual_progress_doc, - "/individual-graph": individual_graph_doc, - "/individual-profile": individual_profile_doc, - "/individual-profile-server": individual_profile_server_doc, - "/individual-nbytes": individual_nbytes_doc, - "/individual-cpu": individual_cpu_doc, - "/individual-nprocessing": individual_nprocessing_doc, - "/individual-workers": individual_workers_doc, - } - + self.apps = applications self.apps = {k: partial(v, scheduler, self.extra) for k, v in self.apps.items()} self.loop = io_loop or scheduler.loop @@ -1701,3 +1680,33 @@ def listen(self, *args, **kwargs): ] self.server._tornado.add_handlers(r".*", handlers) + + +applications = { + "/system": systemmonitor_doc, + "/stealing": stealing_doc, + "/workers": workers_doc, + "/events": events_doc, + "/counters": counters_doc, + "/tasks": tasks_doc, + "/status": status_doc, + "/profile": profile_doc, + "/profile-server": profile_server_doc, + "/graph": graph_doc, + "/individual-task-stream": individual_task_stream_doc, + "/individual-progress": individual_progress_doc, + "/individual-graph": individual_graph_doc, + "/individual-profile": individual_profile_doc, + "/individual-profile-server": individual_profile_server_doc, + "/individual-nbytes": individual_nbytes_doc, + "/individual-cpu": individual_cpu_doc, + "/individual-nprocessing": individual_nprocessing_doc, + "/individual-workers": individual_workers_doc, +} + +try: + import pynvml # noqa: 1708 +except ImportError: + pass +else: + from . import nvml # noqa: 1708 diff --git a/distributed/dashboard/tests/test_scheduler_bokeh.py b/distributed/dashboard/tests/test_scheduler_bokeh.py index e9ac62aad41..2dc29572ea8 100644 --- a/distributed/dashboard/tests/test_scheduler_bokeh.py +++ b/distributed/dashboard/tests/test_scheduler_bokeh.py @@ -18,6 +18,7 @@ from distributed.utils_test import gen_cluster, inc, dec, slowinc, div, get_cert from distributed.dashboard.worker import Counters, BokehWorker from distributed.dashboard.scheduler import ( + applications, BokehScheduler, SystemMonitor, Occupancy, @@ -54,22 +55,8 @@ def test_simple(c, s, a, b): yield gen.sleep(0.1) http_client = AsyncHTTPClient() - for suffix in [ - "system", - "counters", - "workers", - "status", - "tasks", - "stealing", - "graph", - "individual-task-stream", - "individual-progress", - "individual-graph", - "individual-nbytes", - "individual-nprocessing", - "individual-profile", - ]: - response = yield http_client.fetch("http://localhost:%d/%s" % (port, suffix)) + for suffix in applications: + response = yield http_client.fetch("http://localhost:%d%s" % (port, suffix)) body = response.body.decode() assert "bokeh" in body.lower() assert not re.search("href=./", body) # no absolute links From b83edbef74a1718d62e51a9cee0379b7617048e1 Mon Sep 17 00:00:00 2001 From: Shayan Amani Date: Mon, 12 Aug 2019 15:37:53 -0400 Subject: [PATCH 0413/1550] Update client.py (#2951) --- distributed/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/client.py b/distributed/client.py index 678c3b4dbbe..9d05cde049e 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -4265,7 +4265,7 @@ def default_client(c=None): else: raise ValueError( "No clients found\n" - "Start an client and point it to the scheduler address\n" + "Start a client and point it to the scheduler address\n" " from distributed import Client\n" " client = Client('ip-addr-of-scheduler:8786')\n" ) From 28f300af7136b362a92afc0a93a9de23b374b0c8 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 13 Aug 2019 09:15:47 -0400 Subject: [PATCH 0414/1550] Normalize names with str in retire_workers (#2949) This supports cases where names are passed through a CLI and become strings --- distributed/scheduler.py | 8 ++++++-- distributed/tests/test_scheduler.py | 14 ++++++++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 8fc4a828bdf..ff9560767c0 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -3078,9 +3078,11 @@ async def retire_workers( """ with log_errors(): if names is not None: - names = set(names) + if names: + logger.info("Retire worker names %s", names) + names = set(map(str, names)) workers = [ - ws.address for ws in self.workers.values() if ws.name in names + ws.address for ws in self.workers.values() if str(ws.name) in names ] if workers is None: while True: @@ -3098,6 +3100,7 @@ async def retire_workers( workers = {self.workers[w] for w in workers if w in self.workers} if not workers: return [] + logger.info("Retire workers %s", workers) # Keys orphaned by retiring those workers keys = set.union(*[w.has_what for w in workers]) @@ -3106,6 +3109,7 @@ async def retire_workers( other_workers = set(self.workers.values()) - workers if keys: if other_workers: + logger.info("Moving %d keys to other workers", len(keys)) await self.replicate( keys=keys, workers=[ws.address for ws in other_workers], diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 9035fbd8667..943daffde58 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -1627,3 +1627,17 @@ async def test_finished(): await s.finished() await w.finished() + + +@pytest.mark.asyncio +async def test_retire_names_str(cleanup): + async with Scheduler(port=0) as s: + async with Worker(s.address, name="0") as a: + async with Worker(s.address, name="1") as b: + async with Client(s.address, asynchronous=True) as c: + futures = c.map(inc, range(10)) + await wait(futures) + assert a.data and b.data + await s.retire_workers(names=[0]) + assert all(f.done() for f in futures) + assert len(b.data) == 10 From 3bab3aaabefa5f958ef44894599d71fffaa59b03 Mon Sep 17 00:00:00 2001 From: Shayan Amani Date: Tue, 13 Aug 2019 10:40:58 -0400 Subject: [PATCH 0415/1550] Update utils_perf.py (#2954) --- distributed/utils_perf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/utils_perf.py b/distributed/utils_perf.py index c2257f38fb0..f21e96d7353 100644 --- a/distributed/utils_perf.py +++ b/distributed/utils_perf.py @@ -50,7 +50,7 @@ def collect(self): if self.last_gc_duration > self.warn_if_longer: self.logger.warning( "gc.collect() took %0.3fs. This is usually" - " a sign that the some tasks handle too" + " a sign that some tasks handle too" " many Python objects at the same time." " Rechunking the work into smaller tasks" " might help.", From 6302175056742d101b8af52336bc6bbf4227da35 Mon Sep 17 00:00:00 2001 From: Bruce Merry Date: Thu, 15 Aug 2019 15:17:35 +0200 Subject: [PATCH 0416/1550] Allow server_kwargs to override defaults in dashboard (#2955) Fixes #2915 Add a unit test for overriding allow_websocket_origin --- distributed/dashboard/core.py | 6 ++--- .../tests/test_scheduler_bokeh_html.py | 26 ++++++++++++++++++- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/distributed/dashboard/core.py b/distributed/dashboard/core.py index fd6ebef2834..41e7c289c17 100644 --- a/distributed/dashboard/core.py +++ b/distributed/dashboard/core.py @@ -28,16 +28,16 @@ def listen(self, addr): ip = None for i in range(5): try: - self.server = Server( - self.apps, + server_kwargs = dict( port=port, address=ip, check_unused_sessions_milliseconds=500, allow_websocket_origin=["*"], use_index=False, extra_patterns=[(r"/", web.RedirectHandler, {"url": "/status"})], - **self.server_kwargs ) + server_kwargs.update(self.server_kwargs) + self.server = Server(self.apps, **server_kwargs) self.server.start() handlers = [ diff --git a/distributed/dashboard/tests/test_scheduler_bokeh_html.py b/distributed/dashboard/tests/test_scheduler_bokeh_html.py index f2a2c880a94..55e4b797b4e 100644 --- a/distributed/dashboard/tests/test_scheduler_bokeh_html.py +++ b/distributed/dashboard/tests/test_scheduler_bokeh_html.py @@ -6,7 +6,8 @@ pytest.importorskip("bokeh") from tornado.escape import url_escape -from tornado.httpclient import AsyncHTTPClient, HTTPClientError +from tornado.httpclient import AsyncHTTPClient, HTTPClientError, HTTPRequest +from tornado.websocket import websocket_connect from dask.sizeof import sizeof from distributed.utils import is_valid_xml @@ -153,3 +154,26 @@ def test_task_page(c, s, a, b): assert "int" in body assert a.address in body assert "memory" in body + + +@gen_cluster( + client=True, + scheduler_kwargs={ + "services": { + ("dashboard", 0): ( + BokehScheduler, + {"allow_websocket_origin": ["good.invalid"]}, + ) + } + }, +) +def test_allow_websocket_origin(c, s, a, b): + url = ( + "ws://localhost:%d/status/ws?bokeh-protocol-version=1.0&bokeh-session-id=1" + % s.services["dashboard"].port + ) + with pytest.raises(HTTPClientError) as err: + yield websocket_connect( + HTTPRequest(url, headers={"Origin": "http://evil.invalid"}) + ) + assert err.value.code == 403 From 2a4bc72a7384b5ed357cef890e7be9eb1a80acfc Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 16 Aug 2019 14:47:23 -0400 Subject: [PATCH 0417/1550] Use pytest.warning(Warning) rather than Exception (#2958) --- distributed/deploy/tests/test_local.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/deploy/tests/test_local.py b/distributed/deploy/tests/test_local.py index f434945c3af..31493967d4c 100644 --- a/distributed/deploy/tests/test_local.py +++ b/distributed/deploy/tests/test_local.py @@ -195,7 +195,7 @@ def test_Client_solo(loop): def test_duplicate_clients(): pytest.importorskip("bokeh") c1 = yield Client(processes=False, silence_logs=False, dashboard_address=9876) - with pytest.warns(Exception) as info: + with pytest.warns(Warning) as info: c2 = yield Client(processes=False, silence_logs=False, dashboard_address=9876) assert "dashboard" in c1.cluster.scheduler.services From 41a4d41d174c7762d63e17067327518cb8d313e9 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Fri, 16 Aug 2019 21:51:17 +0200 Subject: [PATCH 0418/1550] Pass serialization down through small base collections (#2948) This PR adds support for serialization of collections using objects' native types, rather than pickling the entire collection --- distributed/protocol/serialize.py | 80 +++++++++++++++++++ distributed/protocol/tests/test_collection.py | 50 ++++++++++++ .../protocol/tests/test_collection_cuda.py | 66 +++++++++++++++ distributed/protocol/tests/test_serialize.py | 4 +- distributed/tests/test_publish.py | 2 +- 5 files changed, 199 insertions(+), 3 deletions(-) create mode 100644 distributed/protocol/tests/test_collection.py create mode 100644 distributed/protocol/tests/test_collection_cuda.py diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index 0069c6a264d..8d1d37a283e 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -135,6 +135,54 @@ def serialize(x, serializers=None, on_error="message", context=None): if isinstance(x, Serialized): return x.header, x.frames + # Determine whether keys are safe to be serialized with msgpack + if type(x) is dict and len(x) <= 5: + try: + msgpack.dumps(list(x.keys())) + except Exception: + dict_safe = False + else: + dict_safe = True + + if ( + type(x) in (list, set, tuple) + and len(x) <= 5 + or type(x) is dict + and len(x) <= 5 + and dict_safe + ): + if isinstance(x, dict): + headers_frames = [] + for k, v in x.items(): + _header, _frames = serialize( + v, serializers=serializers, on_error=on_error, context=context + ) + _header["key"] = k + headers_frames.append((_header, _frames)) + else: + headers_frames = [ + serialize( + obj, serializers=serializers, on_error=on_error, context=context + ) + for obj in x + ] + + frames = [] + lengths = [] + for _header, _frames in headers_frames: + frames.extend(_frames) + length = len(_frames) + lengths.append(length) + + headers = [obj[0] for obj in headers_frames] + headers = { + "sub-headers": headers, + "is-collection": True, + "frame-lengths": lengths, + "type-serialized": type(x).__name__, + } + return headers, frames + tb = "" for name in serializers: @@ -178,6 +226,38 @@ def deserialize(header, frames, deserializers=None): -------- serialize """ + if "is-collection" in header: + headers = header["sub-headers"] + lengths = header["frame-lengths"] + cls = {"tuple": tuple, "list": list, "set": set, "dict": dict}[ + header["type-serialized"] + ] + + start = 0 + if cls is dict: + d = {} + for _header, _length in zip(headers, lengths): + k = _header.pop("key") + d[k] = deserialize( + _header, + frames[start : start + _length], + deserializers=deserializers, + ) + start += _length + return d + else: + lst = [] + for _header, _length in zip(headers, lengths): + lst.append( + deserialize( + _header, + frames[start : start + _length], + deserializers=deserializers, + ) + ) + start += _length + return cls(lst) + name = header.get("serializer") if deserializers is not None and name not in deserializers: raise TypeError( diff --git a/distributed/protocol/tests/test_collection.py b/distributed/protocol/tests/test_collection.py new file mode 100644 index 00000000000..ddb8a44bd44 --- /dev/null +++ b/distributed/protocol/tests/test_collection.py @@ -0,0 +1,50 @@ +import pytest +from distributed.protocol import serialize, deserialize +import pandas as pd +import numpy as np + + +@pytest.mark.parametrize("collection", [tuple, dict, list]) +@pytest.mark.parametrize( + "y,y_serializer", + [ + (np.arange(50), "dask"), + (pd.DataFrame({"C": ["a", "b", None], "D": [2.5, 3.5, 4.5]}), "pickle"), + (None, "pickle"), + ], +) +def test_serialize_collection(collection, y, y_serializer): + x = np.arange(100) + if issubclass(collection, dict): + header, frames = serialize({"x": x, "y": y}, serializers=("dask", "pickle")) + else: + header, frames = serialize(collection((x, y)), serializers=("dask", "pickle")) + t = deserialize(header, frames, deserializers=("dask", "pickle", "error")) + assert isinstance(t, collection) + + assert header["is-collection"] is True + sub_headers = header["sub-headers"] + + if collection is not dict: + assert sub_headers[0]["serializer"] == "dask" + assert sub_headers[1]["serializer"] == y_serializer + + if collection is dict: + assert (t["x"] == x).all() + assert str(t["y"]) == str(y) + else: + assert (t[0] == x).all() + assert str(t[1]) == str(y) + + +def test_large_collections_serialize_simply(): + header, frames = serialize(tuple(range(1000))) + assert len(frames) == 1 + + +def test_nested_types(): + x = np.ones(5) + header, frames = serialize([[[x]]]) + assert "dask" in str(header) + assert len(frames) == 1 + assert x.data in frames diff --git a/distributed/protocol/tests/test_collection_cuda.py b/distributed/protocol/tests/test_collection_cuda.py new file mode 100644 index 00000000000..e2602795782 --- /dev/null +++ b/distributed/protocol/tests/test_collection_cuda.py @@ -0,0 +1,66 @@ +import pytest + +from distributed.protocol import serialize, deserialize +from dask.dataframe.utils import assert_eq +import pandas as pd + + +@pytest.mark.parametrize("collection", [tuple, dict]) +@pytest.mark.parametrize("y,y_serializer", [(50, "cuda"), (None, "pickle")]) +def test_serialize_cupy(collection, y, y_serializer): + cupy = pytest.importorskip("cupy") + + x = cupy.arange(100) + if y is not None: + y = cupy.arange(y) + if issubclass(collection, dict): + header, frames = serialize( + {"x": x, "y": y}, serializers=("cuda", "dask", "pickle") + ) + else: + header, frames = serialize((x, y), serializers=("cuda", "dask", "pickle")) + t = deserialize(header, frames, deserializers=("cuda", "dask", "pickle", "error")) + + assert header["is-collection"] is True + sub_headers = header["sub-headers"] + assert sub_headers[0]["serializer"] == "cuda" + assert sub_headers[1]["serializer"] == y_serializer + assert isinstance(t, collection) + + assert ((t["x"] if isinstance(t, dict) else t[0]) == x).all() + if y is None: + assert (t["y"] if isinstance(t, dict) else t[1]) is None + else: + assert ((t["y"] if isinstance(t, dict) else t[1]) == y).all() + + +@pytest.mark.parametrize("collection", [tuple, dict]) +@pytest.mark.parametrize( + "df2,df2_serializer", + [(pd.DataFrame({"C": [3, 4, 5], "D": [2.5, 3.5, 4.5]}), "cuda"), (None, "pickle")], +) +def test_serialize_pandas_pandas(collection, df2, df2_serializer): + cudf = pytest.importorskip("cudf") + + df1 = cudf.DataFrame({"A": [1, 2, None], "B": [1.0, 2.0, None]}) + if df2 is not None: + df2 = cudf.from_pandas(df2) + if issubclass(collection, dict): + header, frames = serialize( + {"df1": df1, "df2": df2}, serializers=("cuda", "dask", "pickle") + ) + else: + header, frames = serialize((df1, df2), serializers=("cuda", "dask", "pickle")) + t = deserialize(header, frames, deserializers=("cuda", "dask", "pickle")) + + assert header["is-collection"] is True + sub_headers = header["sub-headers"] + assert sub_headers[0]["serializer"] == "cuda" + assert sub_headers[1]["serializer"] == df2_serializer + assert isinstance(t, collection) + + assert_eq(t["df1"] if isinstance(t, dict) else t[0], df1) + if df2 is None: + assert (t["df2"] if isinstance(t, dict) else t[1]) is None + else: + assert_eq(t["df2"] if isinstance(t, dict) else t[1], df2) diff --git a/distributed/protocol/tests/test_serialize.py b/distributed/protocol/tests/test_serialize.py index 09297793fc3..6ba70f676a1 100644 --- a/distributed/protocol/tests/test_serialize.py +++ b/distributed/protocol/tests/test_serialize.py @@ -46,7 +46,7 @@ def deserialize_myobj(header, frames): def test_dumps_serialize(): - for x in [123, [1, 2, 3]]: + for x in [123, [1, 2, 3, 4, 5, 6]]: header, frames = serialize(x) assert header["serializer"] == "pickle" assert len(frames) == 1 @@ -235,7 +235,7 @@ def __getstate__(self): def test_errors(): - msg = {"data": {"foo": to_serialize(inc)}} + msg = {"data": {"foo": to_serialize(inc)}, "a": 1, "b": 2, "c": 3, "d": 4, "e": 5} header, frames = serialize(msg, serializers=["msgpack", "pickle"]) assert header["serializer"] == "pickle" diff --git a/distributed/tests/test_publish.py b/distributed/tests/test_publish.py index 7c0fd0db6d2..32b2974a738 100644 --- a/distributed/tests/test_publish.py +++ b/distributed/tests/test_publish.py @@ -228,7 +228,7 @@ def test_pickle_safe(c, s, a, b): try: yield c2.publish_dataset(x=[1, 2, 3]) result = yield c2.get_dataset("x") - assert result == (1, 2, 3) + assert result == [1, 2, 3] with pytest.raises(TypeError): yield c2.publish_dataset(y=lambda x: x) From 3a2c83534cbacab0ebc6215e9bf5c85a3574255d Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 16 Aug 2019 16:53:29 -0400 Subject: [PATCH 0419/1550] Except all exceptions when checking pynvml (#2961) pynvml uses a home-grown exception for this --- distributed/worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/worker.py b/distributed/worker.py index 17d56aec79b..3dafe1e14df 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -3397,7 +3397,7 @@ async def run(server, comm, function, args=(), kwargs={}, is_coro=None, wait=Tru try: from .diagnostics import nvml -except ImportError: +except Exception: pass else: From e16837c63c35e05764946563e6a397e3ab9597e8 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Fri, 16 Aug 2019 16:59:57 -0500 Subject: [PATCH 0420/1550] bump version to 2.3.0 --- docs/source/changelog.rst | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index b7037d9d6c2..e6c75875765 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -1,6 +1,39 @@ Changelog ========= +2.3.0 - 2019-08-16 +------------------ + +- Except all exceptions when checking ``pynvml`` (:pr:`2961`) `Matthew Rocklin`_ +- Pass serialization down through small base collections (:pr:`2948`) `Peter Andreas Entschev`_ +- Use ``pytest.warning(Warning)`` rather than ``Exception`` (:pr:`2958`) `Matthew Rocklin`_ +- Allow ``server_kwargs`` to override defaults in dashboard (:pr:`2955`) `Bruce Merry`_ +- Update ``utils_perf.py`` (:pr:`2954`) `Shayan Amani`_ +- Normalize names with ``str`` in ``retire_workers`` (:pr:`2949`) `Matthew Rocklin`_ +- Update ``client.py`` (:pr:`2951`) `Shayan Amani`_ +- Add ``GPUCurrentLoad`` dashboard plots (:pr:`2944`) `Matthew Rocklin`_ +- Pass GPU diagnostics from worker to scheduler (:pr:`2932`) `Matthew Rocklin`_ +- Import from ``collections.abc`` (:pr:`2938`) `Jim Crist`_ +- Fixes Worker docstring formatting (:pr:`2939`) `James Bourbeau`_ +- Redirect setup docs to docs.dask.org (:pr:`2936`) `Matthew Rocklin`_ +- Wrap offload in ``gen.coroutine`` (:pr:`2934`) `Matthew Rocklin`_ +- Change ``TCP.close`` to a coroutine to avoid task pending warning (:pr:`2930`) `Matthew Rocklin`_ +- Fixup black string normalization (:pr:`2929`) `Jim Crist`_ +- Move core functionality from ``SpecCluster`` to ``Cluster`` (:pr:`2913`) `Matthew Rocklin`_ +- Add aenter/aexit protocols to ``ProcessInterface`` (:pr:`2927`) `Matthew Rocklin`_ +- Add real-time CPU utilization plot to dashboard (:pr:`2922`) `Matthew Rocklin`_ +- Always kill processes in clean tests, even if we don't check (:pr:`2924`) `Matthew Rocklin`_ +- Add timeouts to processes in SSH tests (:pr:`2925`) `Matthew Rocklin`_ +- Add documentation around ``spec.ProcessInterface`` (:pr:`2923`) `Matthew Rocklin`_ +- Cleanup async warnings in tests (:pr:`2920`) `Matthew Rocklin`_ +- Give 404 when requesting nonexistent tasks or workers (:pr:`2921`) `Martin Durant`_ +- Raise informative warning when rescheduling an unknown task (:pr:`2916`) `James Bourbeau`_ +- Fix docstring (:pr:`2917`) `Martin Durant`_ +- Add keep-alive message between worker and scheduler (:pr:`2907`) `Matthew Rocklin`_ +- Rewrite ``Adaptive``/``SpecCluster`` to support slowly arriving workers (:pr:`2904`) `Matthew Rocklin`_ +- Call heartbeat rather than reconnect on disconnection (:pr:`2906`) `Matthew Rocklin`_ + + 2.2.0 - 2019-07-31 ------------------ @@ -1173,3 +1206,4 @@ significantly without many new features. .. _`Christian Hudon`: https://github.com/chrish42 .. _`Gabriel Sailer`: https://github.com/sublinus .. _`Pierre Glaser`: https://github.com/pierreglase +.. _`Shayan Amani`: https://github.com/SHi-ON From 31e775447a5642644cb5f5d9001b56000c4a0536 Mon Sep 17 00:00:00 2001 From: Jacob Tomlinson Date: Sat, 17 Aug 2019 16:09:51 +0100 Subject: [PATCH 0421/1550] Add support for separate external address for SpecCluster scheduler (#2963) --- distributed/deploy/spec.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index 70a413fe1c0..feae250512b 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -23,6 +23,7 @@ class ProcessInterface: def __init__(self): self.address = None + self.external_address = None self.lock = asyncio.Lock() self.status = "created" @@ -225,7 +226,7 @@ async def _start(self): self.status = "starting" self.scheduler = await self.scheduler self.scheduler_comm = rpc( - self.scheduler.address, + getattr(self.scheduler, "external_address", None) or self.scheduler.address, connection_args=self.security.get_connection_args("client"), ) await super()._start() From 2bff61d9bee59e0bf655937922d9d4c37e49820a Mon Sep 17 00:00:00 2001 From: Benjamin Zaitlen Date: Sun, 18 Aug 2019 09:45:36 -0400 Subject: [PATCH 0422/1550] Defer cudf serialization/deserialization to that library (#2881) Fixes #2830 Also log errors in UCX comm --- distributed/comm/tests/test_ucx.py | 32 ++++++++-- distributed/comm/ucx.py | 98 +++++++++++++++--------------- distributed/protocol/cuda.py | 2 - distributed/protocol/cudf.py | 82 +++++-------------------- 4 files changed, 91 insertions(+), 123 deletions(-) diff --git a/distributed/comm/tests/test_ucx.py b/distributed/comm/tests/test_ucx.py index 4bb4a341552..afc0eee0676 100644 --- a/distributed/comm/tests/test_ucx.py +++ b/distributed/comm/tests/test_ucx.py @@ -1,5 +1,4 @@ import asyncio - import pytest ucp = pytest.importorskip("ucp") @@ -10,6 +9,7 @@ from distributed.comm import ucx, parse_address from distributed.protocol import to_serialize from distributed.deploy.local import LocalCluster +from dask.dataframe.utils import assert_eq from distributed.utils_test import gen_test, loop, inc # noqa: 401 from .test_comms import check_deserialize @@ -35,7 +35,7 @@ async def handle_comm(comm): # Workaround for hanging test in # pytest distributed/comm/tests/test_ucx.py::test_comm_objs -vs --count=2 # on the second time through. - ucp._libs.ucp_py.reader_added = 0 + # ucp._libs.ucp_py.reader_added = 0 listener = listen(listen_addr, handle_comm, connection_args=listen_args, **kwargs) with listener: @@ -164,21 +164,41 @@ def test_ucx_deserialize(): @pytest.mark.asyncio -async def test_ping_pong_cudf(): +@pytest.mark.parametrize( + "g", + [ + lambda cudf: cudf.Series([1, 2, 3]), + lambda cudf: cudf.Series([]), + lambda cudf: cudf.DataFrame([]), + lambda cudf: cudf.DataFrame([1]).head(0), + lambda cudf: cudf.DataFrame([1.0]).head(0), + lambda cudf: cudf.DataFrame({"a": []}), + lambda cudf: cudf.DataFrame({"a": ["a"]}).head(0), + lambda cudf: cudf.DataFrame({"a": [1.0]}).head(0), + lambda cudf: cudf.DataFrame({"a": [1]}).head(0), + lambda cudf: cudf.DataFrame({"a": [1, 2, None], "b": [1.0, 2.0, None]}), + ], +) +async def test_ping_pong_cudf(g): # if this test appears after cupy an import error arises # *** ImportError: /usr/lib/x86_64-linux-gnu/libstdc++.so.6: version `CXXABI_1.3.11' # not found (required by python3.7/site-packages/pyarrow/../../../libarrow.so.12) cudf = pytest.importorskip("cudf") - df = cudf.DataFrame({"A": [1, 2, None], "B": [1.0, 2.0, None]}) + cudf_obj = g(cudf) com, serv_com = await get_comm_pair() - msg = {"op": "ping", "data": to_serialize(df)} + msg = {"op": "ping", "data": to_serialize(cudf_obj)} await com.write(msg) result = await serv_com.read() - data2 = result.pop("data") + + cudf_obj_2 = result.pop("data") assert result["op"] == "ping" + assert_eq(cudf_obj, cudf_obj_2) + + await com.close() + await serv_com.close() @pytest.mark.asyncio diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index eb1c7514133..434c16c35ee 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -13,7 +13,7 @@ from .core import Comm, Connector, Listener, CommClosedError from .registry import Backend, backends from .utils import ensure_concrete_host, to_frames, from_frames -from ..utils import ensure_ip, get_ip, get_ipv6, nbytes +from ..utils import ensure_ip, get_ip, get_ipv6, nbytes, log_errors import ucp @@ -94,60 +94,62 @@ async def write( serializers=("cuda", "dask", "pickle", "error"), on_error: str = "message", ): - if serializers is None: - serializers = ("cuda", "dask", "pickle", "error") - # msg can also be a list of dicts when sending batched messages - frames = await to_frames(msg, serializers=serializers, on_error=on_error) - is_gpus = b"".join( - [ - struct.pack("?", hasattr(frame, "__cuda_array_interface__")) - for frame in frames - ] - ) - sizes = b"".join([struct.pack("Q", nbytes(frame)) for frame in frames]) + with log_errors(): + if serializers is None: + serializers = ("cuda", "dask", "pickle", "error") + # msg can also be a list of dicts when sending batched messages + frames = await to_frames(msg, serializers=serializers, on_error=on_error) + is_gpus = b"".join( + [ + struct.pack("?", hasattr(frame, "__cuda_array_interface__")) + for frame in frames + ] + ) + sizes = b"".join([struct.pack("Q", nbytes(frame)) for frame in frames]) - nframes = struct.pack("Q", len(frames)) + nframes = struct.pack("Q", len(frames)) - meta = b"".join([nframes, is_gpus, sizes]) + meta = b"".join([nframes, is_gpus, sizes]) - await self.ep.send_obj(meta) + await self.ep.send_obj(meta) - for frame in frames: - await self.ep.send_obj(frame) - return sum(map(nbytes, frames)) + for frame in frames: + await self.ep.send_obj(frame) + return sum(map(nbytes, frames)) async def read(self, deserializers=("cuda", "dask", "pickle", "error")): - if deserializers is None: - deserializers = ("cuda", "dask", "pickle", "error") - resp = await self.ep.recv_future() - obj = ucp.get_obj_from_msg(resp) - (nframes,) = struct.unpack( - "Q", obj[:8] - ) # first eight bytes for number of frames - - gpu_frame_msg = obj[ - 8 : 8 + nframes - ] # next nframes bytes for if they're GPU frames - is_gpus = struct.unpack("{}?".format(nframes), gpu_frame_msg) - - sized_frame_msg = obj[8 + nframes :] # then the rest for frame sizes - sizes = struct.unpack("{}Q".format(nframes), sized_frame_msg) - - frames = [] - - for i, (is_gpu, size) in enumerate(zip(is_gpus, sizes)): - if size > 0: - resp = await self.ep.recv_obj(size, cuda=is_gpu) - else: - resp = await self.ep.recv_future() - frame = ucp.get_obj_from_msg(resp) - frames.append(frame) - - msg = await from_frames( - frames, deserialize=self.deserialize, deserializers=deserializers - ) + with log_errors(): + if deserializers is None: + deserializers = ("cuda", "dask", "pickle", "error") + resp = await self.ep.recv_future() + obj = ucp.get_obj_from_msg(resp) + (nframes,) = struct.unpack( + "Q", obj[:8] + ) # first eight bytes for number of frames + + gpu_frame_msg = obj[ + 8 : 8 + nframes + ] # next nframes bytes for if they're GPU frames + is_gpus = struct.unpack("{}?".format(nframes), gpu_frame_msg) + + sized_frame_msg = obj[8 + nframes :] # then the rest for frame sizes + sizes = struct.unpack("{}Q".format(nframes), sized_frame_msg) + + frames = [] + + for i, (is_gpu, size) in enumerate(zip(is_gpus, sizes)): + if size > 0: + resp = await self.ep.recv_obj(size, cuda=is_gpu) + else: + resp = await self.ep.recv_future() + frame = ucp.get_obj_from_msg(resp) + frames.append(frame) + + msg = await from_frames( + frames, deserialize=self.deserialize, deserializers=deserializers + ) - return msg + return msg def abort(self): if self._ep: diff --git a/distributed/protocol/cuda.py b/distributed/protocol/cuda.py index 13be1d75bb8..51cb3ea42fa 100644 --- a/distributed/protocol/cuda.py +++ b/distributed/protocol/cuda.py @@ -16,8 +16,6 @@ def cuda_dumps(x): raise NotImplementedError(type_name) header, frames = dumps(x) - - header["type"] = type_name header["type-serialized"] = pickle.dumps(type(x)) header["serializer"] = "cuda" header["compression"] = (None,) * len(frames) # no compression for gpu data diff --git a/distributed/protocol/cudf.py b/distributed/protocol/cudf.py index 018596b1560..e072570fe58 100644 --- a/distributed/protocol/cudf.py +++ b/distributed/protocol/cudf.py @@ -1,74 +1,22 @@ +import pickle import cudf +import cudf.groupby.groupby from .cuda import cuda_serialize, cuda_deserialize -from .numba import serialize_numba_ndarray, deserialize_numba_ndarray +from ..utils import log_errors - -# TODO: -# 1. Just use positions -# a. Fixes duplicate columns -# b. Fixes non-msgpack-serializable names -# 2. cudf.Series -# 3. Serialize the index - - -@cuda_serialize.register(cudf.DataFrame) +# all (de-)serializtion code lives in the cudf codebase +# here we ammend the returned headers with `is_gpu` for +# UCX buffer consumption +@cuda_serialize.register((cudf.DataFrame, cudf.Series, cudf.groupby.groupby._Groupby)) def serialize_cudf_dataframe(x): - sub_headers = [] - arrays = [] - null_masks = [] - null_headers = [] - null_counts = {} - - for label, col in x.iteritems(): - header, [frame] = serialize_numba_ndarray(col.data.mem) - header["name"] = label - sub_headers.append(header) - arrays.append(frame) - if col.null_count: - header, [frame] = serialize_numba_ndarray(col.nullmask.mem) - header["name"] = label - null_headers.append(header) - null_masks.append(frame) - null_counts[label] = col.null_count - - arrays.extend(null_masks) - - header = { - "is_cuda": len(arrays), - "subheaders": sub_headers, - # TODO: the header must be msgpack (de)serializable. - # See if we can avoid names, and just use integer positions. - "columns": x.columns.tolist(), - "null_counts": null_counts, - "null_subheaders": null_headers, - } + with log_errors(): + header, frames = x.serialize() + return header, frames - return header, arrays - -@cuda_deserialize.register(cudf.DataFrame) +@cuda_deserialize.register((cudf.DataFrame, cudf.Series, cudf.groupby.groupby._Groupby)) def serialize_cudf_dataframe(header, frames): - columns = header["columns"] - n_columns = len(header["columns"]) - n_masks = len(header["null_subheaders"]) - - masks = {} - pairs = [] - - for i in range(n_masks): - subheader = header["null_subheaders"][i] - frame = frames[n_columns + i] - mask = deserialize_numba_ndarray(subheader, [frame]) - masks[subheader["name"]] = mask - - for subheader, frame in zip(header["subheaders"], frames[:n_columns]): - name = subheader["name"] - array = deserialize_numba_ndarray(subheader, [frame]) - - if name in masks: - series = cudf.Series.from_masked_array(array, masks[name]) - else: - series = cudf.Series(array) - pairs.append((name, series)) - - return cudf.DataFrame(pairs) + with log_errors(): + cudf_typ = pickle.loads(header["type"]) + cudf_obj = cudf_typ.deserialize(header, frames) + return cudf_obj From b083b10d64763b38e559096127d6e3e0c0638c31 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Tue, 20 Aug 2019 19:39:33 +0200 Subject: [PATCH 0423/1550] Workaround for hanging test now calls ucp.fin() (#2967) --- distributed/comm/tests/test_ucx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/comm/tests/test_ucx.py b/distributed/comm/tests/test_ucx.py index afc0eee0676..1355daf95b8 100644 --- a/distributed/comm/tests/test_ucx.py +++ b/distributed/comm/tests/test_ucx.py @@ -35,7 +35,7 @@ async def handle_comm(comm): # Workaround for hanging test in # pytest distributed/comm/tests/test_ucx.py::test_comm_objs -vs --count=2 # on the second time through. - # ucp._libs.ucp_py.reader_added = 0 + ucp._libs.ucp_py.fin() listener = listen(listen_addr, handle_comm, connection_args=listen_args, **kwargs) with listener: From de2d529fe113f61979b47f1bc913f9791123c257 Mon Sep 17 00:00:00 2001 From: Pav A Date: Wed, 21 Aug 2019 19:05:10 +0100 Subject: [PATCH 0424/1550] [DOC] Remove unnecessary bullet point (#2972) --- docs/source/limitations.rst | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/source/limitations.rst b/docs/source/limitations.rst index 3c64d4458f2..e272359e701 100644 --- a/docs/source/limitations.rst +++ b/docs/source/limitations.rst @@ -35,7 +35,6 @@ Dask assumes the following about your functions and your data: - Dask may run your functions multiple times, such as if a worker holding an intermediate result dies. Any side effects should be `idempotent `_. -- Security -------- From 6e0fecfe1bccdebd4db8bde71f5058d5534f3bea Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 21 Aug 2019 14:14:08 -0700 Subject: [PATCH 0425/1550] Directly import progress from diagnostics.progressbar (#2975) For some reason the implicit import behavior has changed. Fixes #2973 --- distributed/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/__init__.py b/distributed/__init__.py index ca36613c815..d79993dfef7 100644 --- a/distributed/__init__.py +++ b/distributed/__init__.py @@ -3,7 +3,7 @@ from .actor import Actor, ActorFuture from .core import connect, rpc from .deploy import LocalCluster, Adaptive, SpecCluster -from .diagnostics import progress +from .diagnostics.progressbar import progress from .client import ( Client, Executor, From 4ae027155af116048440f08327a290863b3f5e0e Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Wed, 21 Aug 2019 18:42:09 -0500 Subject: [PATCH 0426/1550] Handle buffer protocol objects in ensure_bytes (#2969) --- distributed/protocol/tests/test_arrow.py | 13 +++++++-- distributed/tests/test_utils.py | 15 +++++++++- distributed/utils.py | 37 +++++++++++++++++++----- 3 files changed, 54 insertions(+), 11 deletions(-) diff --git a/distributed/protocol/tests/test_arrow.py b/distributed/protocol/tests/test_arrow.py index a363ee9511e..37aff3a2644 100644 --- a/distributed/protocol/tests/test_arrow.py +++ b/distributed/protocol/tests/test_arrow.py @@ -3,9 +3,9 @@ pa = pytest.importorskip("pyarrow") +import distributed from distributed.utils_test import gen_cluster -from distributed.protocol import deserialize, serialize - +from distributed.protocol import deserialize, serialize, to_serialize df = pd.DataFrame({"A": list("abc"), "B": [1, 2, 3]}) tbl = pa.Table.from_pandas(df, preserve_index=False) @@ -35,3 +35,12 @@ def run_test(client, scheduler, worker1, worker2): assert obj.equals(result) run_test() + + +def test_dumps_compression(): + # https://github.com/dask/distributed/issues/2966 + # large enough to trigger compression + t = pa.Table.from_pandas(pd.DataFrame({"A": [1] * 10000})) + msg = {"op": "update", "data": to_serialize(t)} + result = distributed.protocol.loads(distributed.protocol.dumps(msg)) + assert result["data"].equals(t) diff --git a/distributed/tests/test_utils.py b/distributed/tests/test_utils.py index e5e18eb393c..81541686baf 100644 --- a/distributed/tests/test_utils.py +++ b/distributed/tests/test_utils.py @@ -1,3 +1,4 @@ +import array import datetime from functools import partial import io @@ -275,13 +276,25 @@ def f(): def test_ensure_bytes(): - data = [b"1", "1", memoryview(b"1"), bytearray(b"1")] + data = [b"1", "1", memoryview(b"1"), bytearray(b"1"), array.array("b", [49])] for d in data: result = ensure_bytes(d) assert isinstance(result, bytes) assert result == b"1" +def test_ensure_bytes_ndarray(): + result = ensure_bytes(np.arange(12)) + assert isinstance(result, bytes) + + +def test_ensure_bytes_pyarrow_buffer(): + pa = pytest.importorskip("pyarrow") + buf = pa.py_buffer(b"123") + result = ensure_bytes(buf) + assert isinstance(result, bytes) + + def test_nbytes(): def check(obj, expected): assert nbytes(obj) == expected diff --git a/distributed/utils.py b/distributed/utils.py index c8ea8d648eb..65f5c188257 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -904,22 +904,43 @@ def tmpfile(extension=""): def ensure_bytes(s): - """ Turn string or bytes to bytes + """Attempt to turn `s` into bytes. + + Parameters + ---------- + s : Any + The object to be converted. Will correctly handled + + * str + * bytes + * objects implementing the buffer protocol (memoryview, ndarray, etc.) + + Returns + ------- + b : bytes + + Raises + ------ + TypeError + When `s` cannot be converted + + Examples + -------- >>> ensure_bytes('123') b'123' >>> ensure_bytes(b'123') b'123' """ - if isinstance(s, bytes): - return s - if isinstance(s, memoryview): - return s.tobytes() - if isinstance(s, bytearray): # noqa: F821 - return bytes(s) if hasattr(s, "encode"): return s.encode() - raise TypeError("Object %s is neither a bytes object nor has an encode method" % s) + else: + try: + return bytes(s) + except Exception as e: + raise TypeError( + "Object %s is neither a bytes object nor has an encode method" % s + ) from e def divide_n_among_bins(n, bins): From 88fd0d23b1f1f89348e98baa62f6ddc30b75ab09 Mon Sep 17 00:00:00 2001 From: Pav A Date: Thu, 22 Aug 2019 00:43:03 +0100 Subject: [PATCH 0427/1550] Fix documentatation syntax and tree (#2981) --- docs/source/client.rst | 2 +- docs/source/conf.py | 1 + docs/source/ec2.rst | 1 - docs/source/index.rst | 3 ++- docs/source/worker.rst | 2 +- 5 files changed, 5 insertions(+), 4 deletions(-) delete mode 100644 docs/source/ec2.rst diff --git a/docs/source/client.rst b/docs/source/client.rst index b955d5ab504..444b651681e 100644 --- a/docs/source/client.rst +++ b/docs/source/client.rst @@ -3,7 +3,7 @@ Client The Client is the primary entry point for users of ``dask.distributed``. -After we :doc:`setup a cluster `, we initialize a ``Client`` by pointing +After we `setup a cluster `_, we initialize a ``Client`` by pointing it to the address of a ``Scheduler``: .. code-block:: python diff --git a/docs/source/conf.py b/docs/source/conf.py index afa33400fdc..bb3361851b8 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -391,6 +391,7 @@ # old html, new html ("joblib.html", "https://ml.dask.org/joblib.html"), ("setup.html", "https://docs.dask.org/en/latest/setup.html"), + ("ec2.html", "https://dask.pydata.org/en/latest/setup/cloud.html"), ] diff --git a/docs/source/ec2.rst b/docs/source/ec2.rst deleted file mode 100644 index 71747ba28dc..00000000000 --- a/docs/source/ec2.rst +++ /dev/null @@ -1 +0,0 @@ -See `Dask's cloud deployment documentation `_ for up-to-date documentation for deployment on Amazon's Cloud. diff --git a/docs/source/index.rst b/docs/source/index.rst index 09257be3f58..732c234a53b 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -28,7 +28,7 @@ In particular it meets the following needs: Python standard library. Compatible with `dask`_ API for parallel algorithms * **Easy Setup:** As a Pure Python package distributed is ``pip`` installable - and easy to :doc:`set up `_ on your own cluster. + and easy to `set up `_ on your own cluster. .. _`concurrent.futures`: https://www.python.org/dev/peps/pep-3148/ .. _`dask`: https://dask.org @@ -80,6 +80,7 @@ Contents Setup client api + examples-overview faq .. toctree:: diff --git a/docs/source/worker.rst b/docs/source/worker.rst index be288ccf68c..5ff66b613a6 100644 --- a/docs/source/worker.rst +++ b/docs/source/worker.rst @@ -239,7 +239,7 @@ Nanny Dask workers are by default launched, monitored, and managed by a small Nanny process. -.. autoclass:: distributed.worker.Nanny +.. autoclass:: distributed.nanny.Nanny API Documentation From 58b3abe6d2e5b9898344ba87f93c20648247507e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Est=C3=A8ve?= Date: Thu, 22 Aug 2019 01:45:29 +0200 Subject: [PATCH 0428/1550] Improve get_ip_interface error message when interface does not exist (#2964) --- distributed/tests/test_utils.py | 12 ++++++++++-- distributed/utils.py | 11 ++++++++++- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/distributed/tests/test_utils.py b/distributed/tests/test_utils.py index 81541686baf..bf2d8456681 100644 --- a/distributed/tests/test_utils.py +++ b/distributed/tests/test_utils.py @@ -150,8 +150,16 @@ def test_get_ip_interface(): assert get_ip_interface("lo") == "127.0.0.1" else: pytest.skip("test needs to be enhanced for platform %r" % (sys.platform,)) - with pytest.raises(KeyError): - get_ip_interface("__non-existent-interface") + + non_existent_interface = "__non-existent-interface" + expected_error_message = "{!r}.+network interface.+".format(non_existent_interface) + + if sys.platform == "darwin": + expected_error_message += "'lo0'" + elif sys.platform.startswith("linux"): + expected_error_message += "'lo'" + with pytest.raises(ValueError, match=expected_error_message): + get_ip_interface(non_existent_interface) def test_truncate_exception(): diff --git a/distributed/utils.py b/distributed/utils.py index 65f5c188257..b83a2c1f2cf 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -169,7 +169,16 @@ def get_ip_interface(ifname): """ import psutil - for info in psutil.net_if_addrs()[ifname]: + net_if_addrs = psutil.net_if_addrs() + + if ifname not in net_if_addrs: + allowed_ifnames = list(net_if_addrs.keys()) + raise ValueError( + "{!r} is not a valid network interface. " + "Valid network interfaces are: {}".format(ifname, allowed_ifnames) + ) + + for info in net_if_addrs[ifname]: if info.family == socket.AF_INET: return info.address raise ValueError("interface %r doesn't have an IPv4 address" % (ifname,)) From 799b5a0a0c95dc757a69e440e1d70a5e3afdd56f Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 22 Aug 2019 05:45:37 -0700 Subject: [PATCH 0429/1550] Add cores= and memory= keywords to scale (#2974) Dask-Jobqueue did this internally. It seems like a decent idea to pull upstream. --- distributed/deploy/spec.py | 25 +++++++++++++++++++++++-- distributed/deploy/tests/test_local.py | 22 ++++++++++++++++++++++ distributed/protocol/cudf.py | 1 + 3 files changed, 46 insertions(+), 2 deletions(-) diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index feae250512b..22fa6692d65 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -1,13 +1,14 @@ import asyncio import atexit import copy +import math import weakref from tornado import gen from .cluster import Cluster from ..core import rpc, CommClosedError -from ..utils import LoopRunner, silence_logging, ignoring +from ..utils import LoopRunner, silence_logging, ignoring, parse_bytes from ..scheduler import Scheduler from ..security import Security @@ -330,7 +331,27 @@ def __exit__(self, typ, value, traceback): self.close() self._loop_runner.stop() - def scale(self, n): + def scale(self, n=0, memory=None, cores=None): + if memory is not None: + try: + limit = self.new_spec["options"]["memory_limit"] + except KeyError: + raise ValueError( + "to use scale(memory=...) your worker definition must include a memory_limit definition" + ) + else: + n = max(n, int(math.ceil(parse_bytes(memory) / parse_bytes(limit)))) + + if cores is not None: + try: + threads_per_worker = self.new_spec["options"]["nthreads"] + except KeyError: + raise ValueError( + "to use scale(cores=...) your worker definition must include an nthreads= definition" + ) + else: + n = max(n, int(math.ceil(cores / threads_per_worker))) + if len(self.worker_spec) > n: not_yet_launched = set(self.worker_spec) - { v["name"] for v in self.scheduler_info["workers"].values() diff --git a/distributed/deploy/tests/test_local.py b/distributed/deploy/tests/test_local.py index 31493967d4c..ad00e908d61 100644 --- a/distributed/deploy/tests/test_local.py +++ b/distributed/deploy/tests/test_local.py @@ -861,3 +861,25 @@ def test_client_cluster_synchronous(loop): with Client(loop=loop, processes=False) as c: assert not c.asynchronous assert not c.cluster.asynchronous + + +@pytest.mark.asyncio +async def test_scale_memory_cores(cleanup): + async with LocalCluster( + n_workers=0, + processes=False, + threads_per_worker=2, + memory_limit="2GB", + asynchronous=True, + ) as cluster: + cluster.scale(cores=4) + assert len(cluster.worker_spec) == 2 + + cluster.scale(memory="6GB") + assert len(cluster.worker_spec) == 3 + + cluster.scale(cores=1) + assert len(cluster.worker_spec) == 1 + + cluster.scale(memory="7GB") + assert len(cluster.worker_spec) == 4 diff --git a/distributed/protocol/cudf.py b/distributed/protocol/cudf.py index e072570fe58..985314f3f2e 100644 --- a/distributed/protocol/cudf.py +++ b/distributed/protocol/cudf.py @@ -4,6 +4,7 @@ from .cuda import cuda_serialize, cuda_deserialize from ..utils import log_errors + # all (de-)serializtion code lives in the cudf codebase # here we ammend the returned headers with `is_gpu` for # UCX buffer consumption From edc094348055e27d8ae45fe1ee397317a92b912f Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 22 Aug 2019 09:00:43 -0700 Subject: [PATCH 0430/1550] Make workers robust to bad custom metrics (#2984) --- .pre-commit-config.yaml | 2 +- distributed/tests/test_worker.py | 10 ++++++++++ distributed/worker.py | 11 +++++++---- 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6be2fcaa3bc..2c72a38ce93 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,6 +5,6 @@ repos: - id: black language_version: python3.7 - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v1.2.3 + rev: v2.3.0 hooks: - id: flake8 diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 4dab232487f..3d0844e66e6 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1548,3 +1548,13 @@ async def test_gpu_metrics(s, a, b): assert "gpu" in a.startup_information assert len(s.workers[a.address].extra["gpu"]["name"]) == count + + +@pytest.mark.asyncio +async def test_bad_metrics(cleanup): + def bad_metric(w): + raise Exception("Hello") + + async with Scheduler() as s: + async with Worker(s.address, metrics={"bad": bad_metric}) as w: + assert "bad" not in s.workers[w.address].metrics diff --git a/distributed/worker.py b/distributed/worker.py index 3dafe1e14df..b20f52161fa 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -735,10 +735,13 @@ async def get_metrics(self): ) custom = {} for k, metric in self.metrics.items(): - result = metric(self) - if hasattr(result, "__await__"): - result = await result - custom[k] = result + try: + result = metric(self) + if hasattr(result, "__await__"): + result = await result + custom[k] = result + except Exception: # TODO: log error once + pass return merge(custom, self.monitor.recent(), core) From 4e9d5ecb373f03c4e71928c79cbabdfb3f257d26 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 22 Aug 2019 09:39:30 -0700 Subject: [PATCH 0431/1550] bump version to 2.3.1 --- docs/source/changelog.rst | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index e6c75875765..07e7b8387f5 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -1,6 +1,21 @@ Changelog ========= +2.3.1 - 2019-08-22 +------------------ + +- Add support for separate external address for SpecCluster scheduler (:pr:`2963`) `Jacob Tomlinson`_ +- Defer cudf serialization/deserialization to that library (:pr:`2881`) `Benjamin Zaitlen`_ +- Workaround for hanging test now calls ucp.fin() (:pr:`2967`) `Mads R. B. Kristensen`_ +- Remove unnecessary bullet point (:pr:`2972`) `Pav A`_ +- Directly import progress from diagnostics.progressbar (:pr:`2975`) `Matthew Rocklin`_ +- Handle buffer protocol objects in ensure_bytes (:pr:`2969`) `Tom Augspurger`_ +- Fix documentatation syntax and tree (:pr:`2981`) `Pav A`_ +- Improve get_ip_interface error message when interface does not exist (:pr:`2964`) `Loïc Estève`_ +- Add cores= and memory= keywords to scale (:pr:`2974`) `Matthew Rocklin`_ +- Make workers robust to bad custom metrics (:pr:`2984`) `Matthew Rocklin`_ + + 2.3.0 - 2019-08-16 ------------------ @@ -1199,6 +1214,7 @@ significantly without many new features. .. _`Michael Spiegel`: https://github.com/Spiegel0 .. _`Caleb`: https://github.com/calebho .. _`Ben Zaitlen`: https://github.com/quasiben +.. _`Benjamin Zaitlen`: https://github.com/quasiben .. _`Manuel Garrido`: https://github.com/manugarri .. _`Magnus Nord`: https://github.com/magnunor .. _`Sam Grayson`: https://github.com/charmoniumQ @@ -1207,3 +1223,5 @@ significantly without many new features. .. _`Gabriel Sailer`: https://github.com/sublinus .. _`Pierre Glaser`: https://github.com/pierreglase .. _`Shayan Amani`: https://github.com/SHi-ON +.. _`Pav A`: https://github.com/rs2 +.. _`Mads R. B. Kristensen`: https://github.com/madsbk From d3b075e7117a2e9a744641e4a9ea9deaca3681ff Mon Sep 17 00:00:00 2001 From: Jacob Tomlinson Date: Fri, 23 Aug 2019 16:43:47 +0100 Subject: [PATCH 0432/1550] Skip exceptions in startup information (#2991) --- distributed/tests/test_worker.py | 12 ++++++++++++ distributed/worker.py | 12 ++++++++---- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 3d0844e66e6..13dd92c00e0 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1558,3 +1558,15 @@ def bad_metric(w): async with Scheduler() as s: async with Worker(s.address, metrics={"bad": bad_metric}) as w: assert "bad" not in s.workers[w.address].metrics + + +@pytest.mark.asyncio +async def test_bad_startup(cleanup): + def bad_startup(w): + raise Exception("Hello") + + async with Scheduler() as s: + try: + w = await Worker(s.address, startup_information={"bad": bad_startup}) + except Exception: + pytest.fail("Startup exception was raised") diff --git a/distributed/worker.py b/distributed/worker.py index b20f52161fa..ca4f4121af3 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -748,10 +748,14 @@ async def get_metrics(self): async def get_startup_information(self): result = {} for k, f in self.startup_information.items(): - v = f(self) - if hasattr(v, "__await__"): - v = await v - result[k] = v + try: + v = f(self) + if hasattr(v, "__await__"): + v = await v + result[k] = v + except Exception: # TODO: log error once + pass + return result def identity(self, comm=None): From a0d68d42b81fe0215098d7f87783e2d57b2aa703 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 23 Aug 2019 09:13:17 -0700 Subject: [PATCH 0433/1550] bump version to 2.3.2 --- docs/source/changelog.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index 07e7b8387f5..7385567467e 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -1,6 +1,11 @@ Changelog ========= +2.3.2 - 2019-08-23 +------------------ + +- Skip exceptions in startup information (:pr:`2991`) `Jacob Tomlinson`_ + 2.3.1 - 2019-08-22 ------------------ From 512729206255154f511acbb3d054e0439f7e07c6 Mon Sep 17 00:00:00 2001 From: "Richard (Rick) Zamora" Date: Fri, 23 Aug 2019 15:47:03 -0500 Subject: [PATCH 0434/1550] Fix PyNVML initialization (#2993) --- distributed/diagnostics/nvml.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/distributed/diagnostics/nvml.py b/distributed/diagnostics/nvml.py index 25a11cde6b0..a96a5547598 100644 --- a/distributed/diagnostics/nvml.py +++ b/distributed/diagnostics/nvml.py @@ -1,12 +1,19 @@ import pynvml -pynvml.nvmlInit() -count = pynvml.nvmlDeviceGetCount() +handles = None -handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in range(count)] + +def _pynvml_handles(): + global handles + if handles is None: + pynvml.nvmlInit() + count = pynvml.nvmlDeviceGetCount() + handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in range(count)] + return handles def real_time(): + handles = _pynvml_handles() return { "utilization": [pynvml.nvmlDeviceGetUtilizationRates(h).gpu for h in handles], "memory-used": [pynvml.nvmlDeviceGetMemoryInfo(h).used for h in handles], @@ -14,6 +21,7 @@ def real_time(): def one_time(): + handles = _pynvml_handles() return { "memory-total": [pynvml.nvmlDeviceGetMemoryInfo(h).total for h in handles], "name": [pynvml.nvmlDeviceGetName(h).decode() for h in handles], From ad0c7c23c384e981e0622c36e7fb132429becc34 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 26 Aug 2019 09:40:31 -0700 Subject: [PATCH 0435/1550] Add threads= and memory= to Cluster and Client reprs (#2995) --- distributed/client.py | 8 +++++++- distributed/deploy/cluster.py | 22 +++++++++++++++------- distributed/deploy/local.py | 8 -------- distributed/deploy/spec.py | 7 ------- distributed/deploy/tests/test_local.py | 21 +++++++++++++++++++++ distributed/tests/test_client.py | 8 ++++++-- 6 files changed, 49 insertions(+), 25 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 9d05cde049e..c11257ee74a 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -761,12 +761,18 @@ def __repr__(self): workers = info.get("workers", {}) nworkers = len(workers) nthreads = sum(w["nthreads"] for w in workers.values()) - return "<%s: scheduler=%r processes=%d cores=%d>" % ( + text = "<%s: %r processes=%d threads=%d" % ( self.__class__.__name__, addr, nworkers, nthreads, ) + memory = [w["memory_limit"] for w in workers.values()] + if all(memory): + text += ", memory=" + format_bytes(sum(memory)) + text += ">" + return text + elif self.scheduler is not None: return "<%s: scheduler=%r>" % ( self.__class__.__name__, diff --git a/distributed/deploy/cluster.py b/distributed/deploy/cluster.py index e85ea2bc3dd..0c1a0364405 100644 --- a/distributed/deploy/cluster.py +++ b/distributed/deploy/cluster.py @@ -314,13 +314,6 @@ def update(): def _ipython_display_(self, **kwargs): return self._widget()._ipython_display_(**kwargs) - def __repr__(self): - return "%s(%r, workers=%d)" % ( - type(self).__name__, - self.scheduler_address, - len(self.scheduler_info["workers"]), - ) - async def __aenter__(self): await self return self @@ -331,3 +324,18 @@ async def __aexit__(self, typ, value, traceback): @property def scheduler_address(self): return self.scheduler_comm.address + + def __repr__(self): + text = "%s(%r, workers=%d, threads=%d" % ( + getattr(self, "_name", type(self).__name__), + self.scheduler_address, + len(self.workers), + sum(w["nthreads"] for w in self.scheduler_info["workers"].values()), + ) + + memory = [w["memory_limit"] for w in self.scheduler_info["workers"].values()] + if all(memory): + text += ", memory=" + format_bytes(sum(memory)) + + text += ")" + return text diff --git a/distributed/deploy/local.py b/distributed/deploy/local.py index 20476ad8065..efe5ed03098 100644 --- a/distributed/deploy/local.py +++ b/distributed/deploy/local.py @@ -201,14 +201,6 @@ def __init__( security=security, ) - def __repr__(self): - return "%s(%r, workers=%d, nthreads=%d)" % ( - type(self).__name__, - self.scheduler_address, - len(self.workers), - sum(w.nthreads for w in self.workers.values()), - ) - def start_worker(self, *args, **kwargs): raise NotImplementedError( "The `cluster.start_worker` function has been removed. " diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index 22fa6692d65..8ea03a8371d 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -401,13 +401,6 @@ async def scale_down(self, workers): scale_up = scale # backwards compatibility - def __repr__(self): - return "%s(%r, workers=%d)" % ( - self._name, - self.scheduler_address, - len(self.workers), - ) - @atexit.register def close_clusters(): diff --git a/distributed/deploy/tests/test_local.py b/distributed/deploy/tests/test_local.py index ad00e908d61..59c75d545d8 100644 --- a/distributed/deploy/tests/test_local.py +++ b/distributed/deploy/tests/test_local.py @@ -883,3 +883,24 @@ async def test_scale_memory_cores(cleanup): cluster.scale(memory="7GB") assert len(cluster.worker_spec) == 4 + + +@pytest.mark.asyncio +async def test_repr(cleanup): + async with LocalCluster( + n_workers=2, + processes=False, + threads_per_worker=2, + memory_limit="2GB", + asynchronous=True, + ) as cluster: + text = repr(cluster) + assert "workers=2" in text + assert cluster.scheduler_address in text + assert "cores=4" in text or "threads=4" in text + assert "GB" in text and "4" in text + + async with LocalCluster( + n_workers=2, processes=False, memory_limit=None, asynchronous=True + ) as cluster: + assert "memory" not in repr(cluster) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 7f0036e2a7a..a4e472c882a 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -1882,12 +1882,16 @@ def test_badly_serialized_input_stderr(capsys, c): def test_repr(loop): funcs = [str, repr, lambda x: x._repr_html_()] - with cluster(nworkers=3) as (s, [a, b, c]): + with cluster(nworkers=3, worker_kwargs={"memory_limit": "2 GB"}) as (s, [a, b, c]): with Client(s["address"], loop=loop) as c: for func in funcs: text = func(c) assert c.scheduler.address in text assert "3" in text + assert "6" in text + assert "GB" in text + if " Date: Mon, 26 Aug 2019 12:41:36 -0400 Subject: [PATCH 0436/1550] Add cuda_ipc to UCX environment for NVLink (#2996) --- distributed/comm/ucx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index 434c16c35ee..8631bb18229 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -21,7 +21,7 @@ os.environ.setdefault("UCX_RNDV_SCHEME", "put_zcopy") os.environ.setdefault("UCX_MEMTYPE_CACHE", "n") -os.environ.setdefault("UCX_TLS", "rc,cuda_copy") +os.environ.setdefault("UCX_TLS", "rc,cuda_copy,cuda_ipc") logger = logging.getLogger(__name__) MAX_MSG_LOG = 23 From 71dda4c326cf5ce523a025e7c299e424343a20f2 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 26 Aug 2019 13:39:45 -0700 Subject: [PATCH 0437/1550] Permit more keyword options when scaling with cores and memory (#2997) Some cluster managers use options like nthreads or cores. We should be robust to a few common choices. --- distributed/deploy/spec.py | 28 ++++++++++++------- distributed/deploy/tests/test_spec_cluster.py | 15 ++++++++++ 2 files changed, 33 insertions(+), 10 deletions(-) diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index 8ea03a8371d..08a993def97 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -333,24 +333,32 @@ def __exit__(self, typ, value, traceback): def scale(self, n=0, memory=None, cores=None): if memory is not None: - try: - limit = self.new_spec["options"]["memory_limit"] - except KeyError: + for name in ["memory_limit", "memory"]: + try: + limit = self.new_spec["options"][name] + except KeyError: + pass + else: + n = max(n, int(math.ceil(parse_bytes(memory) / parse_bytes(limit)))) + break + else: raise ValueError( "to use scale(memory=...) your worker definition must include a memory_limit definition" ) - else: - n = max(n, int(math.ceil(parse_bytes(memory) / parse_bytes(limit)))) if cores is not None: - try: - threads_per_worker = self.new_spec["options"]["nthreads"] - except KeyError: + for name in ["nthreads", "ncores", "threads", "cores"]: + try: + threads_per_worker = self.new_spec["options"][name] + except KeyError: + pass + else: + n = max(n, int(math.ceil(cores / threads_per_worker))) + break + else: raise ValueError( "to use scale(cores=...) your worker definition must include an nthreads= definition" ) - else: - n = max(n, int(math.ceil(cores / threads_per_worker))) if len(self.worker_spec) > n: not_yet_launched = set(self.worker_spec) - { diff --git a/distributed/deploy/tests/test_spec_cluster.py b/distributed/deploy/tests/test_spec_cluster.py index 64633428a38..27ed20c9f20 100644 --- a/distributed/deploy/tests/test_spec_cluster.py +++ b/distributed/deploy/tests/test_spec_cluster.py @@ -268,3 +268,18 @@ async def test_widget(cleanup): cluster.scale(5) assert "3 / 5" in cluster._widget_status() + + +@pytest.mark.asyncio +async def test_scale_cores_memory(cleanup): + async with SpecCluster( + scheduler=scheduler, + worker={"cls": Worker, "options": {"nthreads": 1}}, + asynchronous=True, + ) as cluster: + cluster.scale(cores=2) + assert len(cluster.worker_spec) == 2 + with pytest.raises(ValueError) as info: + cluster.scale(memory="5GB") + + assert "memory" in str(info.value) From 83a844e9144e5310b6c9d140c2c644dcd56b8db1 Mon Sep 17 00:00:00 2001 From: Mohammad Noor Date: Wed, 28 Aug 2019 09:14:26 -0500 Subject: [PATCH 0438/1550] Fix minor typo in documentation (#3002) --- docs/source/manage-computation.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/manage-computation.rst b/docs/source/manage-computation.rst index 20c01fe2c24..e4d2d4298d3 100644 --- a/docs/source/manage-computation.rst +++ b/docs/source/manage-computation.rst @@ -128,7 +128,7 @@ tasks directly to the cluster with ``client.scatter``, ``client.submit`` or ``cl .. code-block:: python futures = client.scatter(args) # Send data - future = client.submit(function, *args, **kwrags) # Send single task + future = client.submit(function, *args, **kwargs) # Send single task futures = client.map(function, sequence, **kwargs) # Send many tasks In this case ``*args`` or ``**kwargs`` can be normal Python objects, like ``1`` From 52323b910b203d496849f61463c48fca8be92ac3 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 28 Aug 2019 08:29:05 -0700 Subject: [PATCH 0439/1550] Return dictionaries from new_worker_spec rather than name/worker pairs (#3000) This allows for larger collections of workers, such as when launching many workers in one HPC job. Fixes https://github.com/dask/distributed/issues/2999 --- distributed/deploy/spec.py | 8 +++----- distributed/deploy/tests/test_spec_cluster.py | 2 +- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index 08a993def97..8b450c6af79 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -375,8 +375,7 @@ def scale(self, n=0, memory=None, cores=None): return while len(self.worker_spec) < n: - k, spec = self.new_worker_spec() - self.worker_spec[k] = spec + self.worker_spec.update(self.new_worker_spec()) self.loop.add_callback(self._correct_state) @@ -385,8 +384,7 @@ def new_worker_spec(self): Returns ------- - name: identifier for worker - spec: dict + d: dict mapping names to worker specs See Also -------- @@ -395,7 +393,7 @@ def new_worker_spec(self): while self._i in self.worker_spec: self._i += 1 - return self._i, self.new_spec + return {self._i: self.new_spec} @property def _supports_scaling(self): diff --git a/distributed/deploy/tests/test_spec_cluster.py b/distributed/deploy/tests/test_spec_cluster.py index 27ed20c9f20..2debf9dea68 100644 --- a/distributed/deploy/tests/test_spec_cluster.py +++ b/distributed/deploy/tests/test_spec_cluster.py @@ -152,7 +152,7 @@ async def test_new_worker_spec(cleanup): class MyCluster(SpecCluster): def new_worker_spec(self): i = len(self.worker_spec) - return i, {"cls": Worker, "options": {"nthreads": i + 1}} + return {i: {"cls": Worker, "options": {"nthreads": i + 1}}} async with MyCluster(asynchronous=True, scheduler=scheduler) as cluster: cluster.scale(3) From 7a48850938c006ad16a2f99dfbcf44fa5df3ce00 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 28 Aug 2019 09:31:23 -0700 Subject: [PATCH 0440/1550] Make spec.ProcessInterface a valid no-op worker (#3004) This allows it to be placed into specs as-is as a placeholder for testing. --- distributed/deploy/spec.py | 2 +- distributed/deploy/tests/test_spec_cluster.py | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index 8b450c6af79..acb7e72368a 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -22,7 +22,7 @@ class ProcessInterface: It should implement the methods below, like ``start`` and ``close`` """ - def __init__(self): + def __init__(self, scheduler=None, name=None): self.address = None self.external_address = None self.lock = asyncio.Lock() diff --git a/distributed/deploy/tests/test_spec_cluster.py b/distributed/deploy/tests/test_spec_cluster.py index 2debf9dea68..ea0afb0488d 100644 --- a/distributed/deploy/tests/test_spec_cluster.py +++ b/distributed/deploy/tests/test_spec_cluster.py @@ -283,3 +283,17 @@ async def test_scale_cores_memory(cleanup): cluster.scale(memory="5GB") assert "memory" in str(info.value) + + +@pytest.mark.asyncio +async def test_ProcessInterfaceValid(cleanup): + async with SpecCluster( + scheduler=scheduler, worker={"cls": ProcessInterface}, asynchronous=True + ) as cluster: + cluster.scale(2) + await cluster + assert len(cluster.worker_spec) == len(cluster.workers) == 2 + + cluster.scale(1) + await cluster + assert len(cluster.worker_spec) == len(cluster.workers) == 1 From bc6f4a6395cbb28b493ff7984b941aa1e0dc4fd5 Mon Sep 17 00:00:00 2001 From: Benjamin Zaitlen Date: Wed, 28 Aug 2019 15:57:03 -0400 Subject: [PATCH 0441/1550] better name for cudf deserialization function name (#3008) --- distributed/protocol/cudf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/protocol/cudf.py b/distributed/protocol/cudf.py index 985314f3f2e..f236a6c1f0c 100644 --- a/distributed/protocol/cudf.py +++ b/distributed/protocol/cudf.py @@ -16,7 +16,7 @@ def serialize_cudf_dataframe(x): @cuda_deserialize.register((cudf.DataFrame, cudf.Series, cudf.groupby.groupby._Groupby)) -def serialize_cudf_dataframe(header, frames): +def deserialize_cudf_dataframe(header, frames): with log_errors(): cudf_typ = pickle.loads(header["type"]) cudf_obj = cudf_typ.deserialize(header, frames) From f5f2aa3c313fc5002be8cb2f7f2af64d9043a6e0 Mon Sep 17 00:00:00 2001 From: Jim Crist Date: Thu, 29 Aug 2019 08:47:01 -0500 Subject: [PATCH 0442/1550] Tweak `Logs` styling (#3012) - Not all browsers render unstyled `summary` elements with the "carrot" indicating they can be expanded. We style these so they always look the same across browsers. - We sort log elements when rendering. --- distributed/utils.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/distributed/utils.py b/distributed/utils.py index b83a2c1f2cf..9f14d58de0b 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -1453,12 +1453,13 @@ class Logs(dict): def _repr_html_(self): summaries = [ - "
          \n{title}\n{log}\n
          ".format( - title=title, log=log._repr_html_() - ) - for title, log in self.items() + "
          \n" + "{title}\n" + "{log}\n" + "
          ".format(title=title, log=log._repr_html_()) + for title, log in sorted(self.items()) ] - return "\n\n".join(summaries) + return "\n".join(summaries) def cli_keywords(d: dict, cls=None): From cf26e1a559e2c89c4a4b14b6622111eaf0954f12 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 29 Aug 2019 10:04:18 -0700 Subject: [PATCH 0443/1550] Support Spec jobs that generate multiple workers (#3013) Sometimes a single entry in the worker_spec will generate multiple Dask workers. We add an entry, "group", to the spec that shows how worker_spec entries map to dask-workers that connect to the scheduler --- distributed/deploy/adaptive.py | 9 +-- distributed/deploy/adaptive_core.py | 2 +- distributed/deploy/cluster.py | 12 ++++ distributed/deploy/spec.py | 65 +++++++++++++---- distributed/deploy/tests/test_spec_cluster.py | 69 +++++++++++++++++++ 5 files changed, 138 insertions(+), 19 deletions(-) diff --git a/distributed/deploy/adaptive.py b/distributed/deploy/adaptive.py index 2efc18dfe0c..b8c3429a505 100644 --- a/distributed/deploy/adaptive.py +++ b/distributed/deploy/adaptive.py @@ -98,18 +98,15 @@ def scheduler(self): @property def plan(self): - try: - return set(self.cluster.worker_spec) - except AttributeError: - return set(self.cluster.workers) + return self.cluster.plan @property def requested(self): - return set(self.cluster.workers) + return self.cluster.requested @property def observed(self): - return {d["name"] for d in self.cluster.scheduler_info["workers"].values()} + return self.cluster.observed async def target(self): return await self.scheduler.adaptive_target( diff --git a/distributed/deploy/adaptive_core.py b/distributed/deploy/adaptive_core.py index 6732bb20284..db50f109ce3 100644 --- a/distributed/deploy/adaptive_core.py +++ b/distributed/deploy/adaptive_core.py @@ -48,7 +48,7 @@ class AdaptiveCore: Scales the cluster up to a target number of workers, presumably changing at least ``plan`` and hopefully eventually also ``requested`` - scale_down : Set[worker] -> None + scale_down : Set[worker] -> None Closes the provided set of workers Parameters diff --git a/distributed/deploy/cluster.py b/distributed/deploy/cluster.py index 0c1a0364405..5e86c39ce8e 100644 --- a/distributed/deploy/cluster.py +++ b/distributed/deploy/cluster.py @@ -339,3 +339,15 @@ def __repr__(self): text += ")" return text + + @property + def plan(self): + return set(self.workers) + + @property + def requested(self): + return set(self.workers) + + @property + def observed(self): + return {d["name"] for d in self.scheduler_info["workers"].values()} diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index acb7e72368a..d897dc6b7df 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -160,6 +160,23 @@ class does handle all of the logic around asynchronously cleanly setting up Also note that uniformity of the specification is not required. Other API could be added externally (in subclasses) that adds workers of different specifications into the same dictionary. + + If a single entry in the spec will generate multiple dask workers then + please provide a `"group"` element to the spec, that includes the suffixes + that will be added to each name (this should be handled by your worker + class). + + >>> cluster.worker_spec + { + 0: {"cls": MultiWorker, "options": {"processes": 3}, "group": ["-0", "-1", -2"]} + 1: {"cls": MultiWorker, "options": {"processes": 2}, "group": ["-0", "-1"]} + } + + These suffixes should correspond to the names used by the workers when + they deploy. + + >>> [ws.name for ws in cluster.scheduler.workers.values()] + ["0-0", "0-1", "0-2", "1-0", "1-1"] """ _instances = weakref.WeakSet() @@ -288,18 +305,6 @@ async def _(): return _().__await__() - async def _wait_for_workers(self): - while { - str(d["name"]) - for d in (await self.scheduler_comm.identity())["workers"].values() - } != set(map(str, self.workers)): - if ( - any(w.status == "closed" for w in self.workers.values()) - and self.scheduler.status == "running" - ): - raise gen.TimeoutError("Worker unexpectedly closed") - await asyncio.sleep(0.1) - async def _close(self): while self.status == "closing": await asyncio.sleep(0.1) @@ -400,6 +405,18 @@ def _supports_scaling(self): return not not self.new_spec async def scale_down(self, workers): + # We may have groups, if so, map worker addresses to job names + if not all(w in self.worker_spec for w in workers): + mapping = {} + for name, spec in self.worker_spec.items(): + if "group" in spec: + for suffix in spec["group"]: + mapping[str(name) + suffix] = name + else: + mapping[name] = name + + workers = {mapping.get(w, w) for w in workers} + for w in workers: if w in self.worker_spec: del self.worker_spec[w] @@ -407,6 +424,30 @@ async def scale_down(self, workers): scale_up = scale # backwards compatibility + @property + def plan(self): + out = set() + for name, spec in self.worker_spec.items(): + if "group" in spec: + out.update({str(name) + suffix for suffix in spec["group"]}) + else: + out.add(name) + return out + + @property + def requested(self): + out = set() + for name in self.workers: + try: + spec = self.worker_spec[name] + except KeyError: + continue + if "group" in spec: + out.update({str(name) + suffix for suffix in spec["group"]}) + else: + out.add(name) + return out + @atexit.register def close_clusters(): diff --git a/distributed/deploy/tests/test_spec_cluster.py b/distributed/deploy/tests/test_spec_cluster.py index ea0afb0488d..efc231ca030 100644 --- a/distributed/deploy/tests/test_spec_cluster.py +++ b/distributed/deploy/tests/test_spec_cluster.py @@ -297,3 +297,72 @@ async def test_ProcessInterfaceValid(cleanup): cluster.scale(1) await cluster assert len(cluster.worker_spec) == len(cluster.workers) == 1 + + +class MultiWorker(Worker, ProcessInterface): + def __init__(self, *args, n=1, name=None, nthreads=None, **kwargs): + self.workers = [ + Worker( + *args, name=str(name) + "-" + str(i), nthreads=nthreads // n, **kwargs + ) + for i in range(n) + ] + + @property + def status(self): + return self.workers[0].status + + def __str__(self): + return "" % len(self.workers) + + __repr__ = __str__ + + async def start(self): + await asyncio.gather(*self.workers) + + async def close(self): + await asyncio.gather(*[w.close() for w in self.workers]) + + +@pytest.mark.asyncio +async def test_MultiWorker(cleanup): + async with SpecCluster( + scheduler=scheduler, + worker={ + "cls": MultiWorker, + "options": {"n": 2, "nthreads": 4, "memory_limit": "4 GB"}, + "group": ["-0", "-1"], + }, + asynchronous=True, + ) as cluster: + s = cluster.scheduler + async with Client(cluster, asynchronous=True) as client: + cluster.scale(2) + await cluster + assert len(cluster.worker_spec) == 2 + await client.wait_for_workers(4) + + cluster.scale(1) + await cluster + assert len(s.workers) == 2 + + cluster.scale(memory="6GB") + await cluster + assert len(cluster.worker_spec) == 2 + assert len(s.workers) == 4 + assert cluster.plan == {ws.name for ws in s.workers.values()} + + cluster.scale(cores=10) + await cluster + assert len(cluster.workers) == 3 + + adapt = cluster.adapt(minimum=0, maximum=4) + + for i in range(adapt.wait_count): # relax down to 0 workers + await adapt.adapt() + await cluster + assert not s.workers + + future = client.submit(lambda x: x + 1, 10) + await future + assert len(cluster.workers) == 1 From 1ef4f70dc7f8fe048d11104ee8daad04b82227e9 Mon Sep 17 00:00:00 2001 From: byjott Date: Fri, 30 Aug 2019 18:18:03 +0200 Subject: [PATCH 0444/1550] Fix ConnectionPool limit handling (#3005) Fixes #3001 * convert test_core to 'async def' style --- distributed/core.py | 52 ++++++++----- distributed/tests/test_core.py | 138 ++++++++++++++++++--------------- 2 files changed, 109 insertions(+), 81 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index f97d2df382a..6bda9c9e0be 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -826,7 +826,9 @@ def __init__( self.deserializers = deserializers if deserializers is not None else serializers self.connection_args = connection_args self.timeout = timeout - self.event = Event() + self._n_connecting = 0 + # Invariant: semaphore._value == limit - open - _n_connecting + self.semaphore = asyncio.Semaphore(self.limit) self.server = weakref.ref(server) if server else None self._created = weakref.WeakSet() self._instances.add(self) @@ -840,7 +842,11 @@ def open(self): return self.active + sum(map(len, self.available.values())) def __repr__(self): - return "" % (self.open, self.active) + return "" % ( + self.open, + self.active, + self._n_connecting, + ) def __call__(self, addr=None, ip=None, port=None): """ Cached rpc objects """ @@ -861,10 +867,11 @@ async def connect(self, addr, timeout=None): occupied.add(comm) return comm - while self.open >= self.limit: - self.event.clear() + if self.semaphore.locked(): self.collect() - await self.event.wait() + + self._n_connecting += 1 + await self.semaphore.acquire() try: comm = await connect( @@ -877,11 +884,12 @@ async def connect(self, addr, timeout=None): comm._pool = weakref.ref(self) self._created.add(comm) except Exception: + self.semaphore.release() raise - occupied.add(comm) + finally: + self._n_connecting -= 1 - if self.open >= self.limit: - self.event.clear() + occupied.add(comm) return comm @@ -889,30 +897,34 @@ def reuse(self, addr, comm): """ Reuse an open communication to the given address. For internal use. """ - try: - self.occupied[addr].remove(comm) - except KeyError: - pass + # if the pool is asked to re-use a comm it does not know about, ignore + # this comm: just close it. + if comm not in self.occupied[addr]: + IOLoop.current().add_callback(comm.close) else: + self.occupied[addr].remove(comm) if comm.closed(): - if self.open < self.limit: - self.event.set() + self.semaphore.release() else: self.available[addr].add(comm) + if self.semaphore.locked() and self._n_connecting > 0: + self.collect() def collect(self): """ Collect open but unused communications, to allow opening other ones. """ logger.info( - "Collecting unused comms. open: %d, active: %d", self.open, self.active + "Collecting unused comms. open: %d, active: %d, connecting: %d", + self.open, + self.active, + self._n_connecting, ) for addr, comms in self.available.items(): for comm in comms: IOLoop.current().add_callback(comm.close) + self.semaphore.release() comms.clear() - if self.open < self.limit: - self.event.set() def remove(self, addr): """ @@ -923,12 +935,12 @@ def remove(self, addr): comms = self.available.pop(addr) for comm in comms: IOLoop.current().add_callback(comm.close) + self.semaphore.release() if addr in self.occupied: comms = self.occupied.pop(addr) for comm in comms: IOLoop.current().add_callback(comm.close) - if self.open < self.limit: - self.event.set() + self.semaphore.release() def close(self): """ @@ -937,8 +949,10 @@ def close(self): for comms in self.available.values(): for comm in comms: comm.abort() + self.semaphore.release() for comms in self.occupied.values(): for comm in comms: + self.semaphore.release() comm.abort() for comm in self._created: diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index e41866d6741..cbde7ac240b 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -1,10 +1,10 @@ +import asyncio from contextlib import contextmanager import os import socket import threading import weakref -from tornado import gen import pytest import dask @@ -82,8 +82,7 @@ def test_server(loop): Simple Server test. """ - @gen.coroutine - def f(): + async def f(): server = Server({"ping": pingpong}) with pytest.raises(ValueError): server.port @@ -92,20 +91,20 @@ def f(): assert server.address == ("tcp://%s:8881" % get_ip()) for addr in ("127.0.0.1:8881", "tcp://127.0.0.1:8881", server.address): - comm = yield connect(addr) + comm = await connect(addr) - n = yield comm.write({"op": "ping"}) + n = await comm.write({"op": "ping"}) assert isinstance(n, int) assert 4 <= n <= 1000 - response = yield comm.read() + response = await comm.read() assert response == b"pong" - yield comm.write({"op": "ping", "close": True}) - response = yield comm.read() + await comm.write({"op": "ping", "close": True}) + response = await comm.read() assert response == b"pong" - yield comm.close() + await comm.close() server.stop() @@ -113,20 +112,19 @@ def f(): def test_server_raises_on_blocked_handlers(loop): - @gen.coroutine - def f(): + async def f(): server = Server({"ping": pingpong}, blocked_handlers=["ping"]) server.listen(8881) - comm = yield connect(server.address) - yield comm.write({"op": "ping"}) - msg = yield comm.read() + comm = await connect(server.address) + await comm.write({"op": "ping"}) + msg = await comm.read() assert "exception" in msg assert isinstance(msg["exception"], ValueError) assert "'ping' handler has been explicitly disallowed" in repr(msg["exception"]) - yield comm.close() + await comm.close() server.stop() res = loop.run_sync(f) @@ -251,21 +249,20 @@ def listen_on(cls, *args, **kwargs): yield assert_cannot_connect(inproc_addr2) -@gen.coroutine -def check_rpc(listen_addr, rpc_addr=None, listen_args=None, connection_args=None): +async def check_rpc(listen_addr, rpc_addr=None, listen_args=None, connection_args=None): server = Server({"ping": pingpong}) server.listen(listen_addr, listen_args=listen_args) if rpc_addr is None: rpc_addr = server.address with rpc(rpc_addr, connection_args=connection_args) as remote: - response = yield remote.ping() + response = await remote.ping() assert response == b"pong" assert remote.comms - response = yield remote.ping(close=True) + response = await remote.ping(close=True) assert response == b"pong" - response = yield remote.ping() + response = await remote.ping() assert response == b"pong" assert not remote.comms @@ -311,8 +308,7 @@ def test_rpc_inputs(): r.close_rpc() -@gen.coroutine -def check_rpc_message_lifetime(*listen_args): +async def check_rpc_message_lifetime(*listen_args): # Issue #956: rpc arguments and result shouldn't be kept alive longer # than necessary server = Server({"echo": echo_serialize}) @@ -324,15 +320,15 @@ def check_rpc_message_lifetime(*listen_args): del obj start = time() while CountedObject.n_instances != 0: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 1 with rpc(server.address) as remote: obj = CountedObject() - res = yield remote.echo(x=to_serialize(obj)) + res = await remote.echo(x=to_serialize(obj)) assert isinstance(res["result"], CountedObject) # Make sure resource cleanup code in coroutines runs - yield gen.sleep(0.05) + await asyncio.sleep(0.05) w1 = weakref.ref(obj) w2 = weakref.ref(res["result"]) @@ -361,18 +357,17 @@ def test_rpc_message_lifetime_inproc(): yield check_rpc_message_lifetime("inproc://") -@gen.coroutine -def check_rpc_with_many_connections(listen_arg): - @gen.coroutine - def g(): +async def check_rpc_with_many_connections(listen_arg): + async def g(): for i in range(10): - yield remote.ping() + await remote.ping() server = Server({"ping": pingpong}) server.listen(listen_arg) with rpc(server.address) as remote: - yield [g() for i in range(10)] + for i in range(10): + await g() server.stop() @@ -390,19 +385,18 @@ def test_rpc_with_many_connections_inproc(): yield check_rpc_with_many_connections("inproc://") -@gen.coroutine -def check_large_packets(listen_arg): +async def check_large_packets(listen_arg): """ tornado has a 100MB cap by default """ server = Server({"echo": echo}) server.listen(listen_arg) data = b"0" * int(200e6) # slightly more than 100MB conn = rpc(server.address) - result = yield conn.echo(x=data) + result = await conn.echo(x=data) assert result == data d = {"x": data} - result = yield conn.echo(x=d) + result = await conn.echo(x=d) assert result == d conn.close_comms() @@ -420,14 +414,13 @@ def test_large_packets_inproc(): yield check_large_packets("inproc://") -@gen.coroutine -def check_identity(listen_arg): +async def check_identity(listen_arg): server = Server({}) server.listen(listen_arg) with rpc(server.address) as remote: - a = yield remote.identity() - b = yield remote.identity() + a = await remote.identity() + b = await remote.identity() assert a["type"] == "Server" assert a["id"] == b["id"] @@ -489,7 +482,7 @@ def test_errors(): @gen_test() def test_connect_raises(): - with pytest.raises((gen.TimeoutError, IOError)): + with pytest.raises(IOError): yield connect("127.0.0.1:58259", timeout=0.01) @@ -519,10 +512,9 @@ def test_coerce_to_address(): @gen_test() def test_connection_pool(): - @gen.coroutine - def ping(comm, delay=0.1): - yield gen.sleep(delay) - raise gen.Return("pong") + async def ping(comm, delay=0.1): + await asyncio.sleep(delay) + return "pong" servers = [Server({"ping": ping}) for i in range(10)] for server in servers: @@ -553,12 +545,35 @@ def ping(comm, delay=0.1): rpc.collect() start = time() while any(rpc.available.values()): - yield gen.sleep(0.01) + yield asyncio.sleep(0.01) assert time() < start + 2 rpc.close() +@gen_test() +def test_connection_pool_respects_limit(): + + limit = 5 + + async def ping(comm, delay=0.01): + await asyncio.sleep(delay) + return "pong" + + async def do_ping(pool, port): + assert pool.open <= limit + await pool(ip="127.0.0.1", port=port).ping() + assert pool.open <= limit + + servers = [Server({"ping": ping}) for i in range(10)] + for server in servers: + server.listen(0) + + pool = ConnectionPool(limit=limit) + + yield [do_ping(pool, s.port) for s in servers] + + @gen_test() def test_connection_pool_tls(): """ @@ -568,10 +583,9 @@ def test_connection_pool_tls(): connection_args = sec.get_connection_args("client") listen_args = sec.get_listen_args("scheduler") - @gen.coroutine - def ping(comm, delay=0.01): - yield gen.sleep(delay) - raise gen.Return("pong") + async def ping(comm, delay=0.01): + await asyncio.sleep(delay) + return "pong" servers = [Server({"ping": ping}) for i in range(10)] for server in servers: @@ -589,10 +603,9 @@ def ping(comm, delay=0.01): @gen_test() def test_connection_pool_remove(): - @gen.coroutine - def ping(comm, delay=0.01): - yield gen.sleep(delay) - raise gen.Return("pong") + async def ping(comm, delay=0.01): + await asyncio.sleep(delay) + return "pong" servers = [Server({"ping": ping}) for i in range(5)] for server in servers: @@ -617,6 +630,9 @@ def ping(comm, delay=0.01): assert rpc.open == 4 rpc.collect() + + # this pattern of calls (esp. `reuse` after `remove`) + # can happen in case of worker failures: comm = yield rpc.connect(serv.address) rpc.remove(serv.address) rpc.reuse(serv.address, comm) @@ -642,7 +658,7 @@ def test_counters(): @gen_cluster() def test_ticks(s, a, b): pytest.importorskip("crick") - yield gen.sleep(0.1) + yield asyncio.sleep(0.1) c = s.digests["tick-duration"] assert c.size() assert 0.01 < c.components[0].quantile(0.5) < 0.5 @@ -657,7 +673,7 @@ def test_tick_logging(s, a, b): core.tick_maximum_delay = 0.001 try: with captured_logger("distributed.core") as sio: - yield gen.sleep(0.1) + yield asyncio.sleep(0.1) text = sio.getvalue() assert "unresponsive" in text @@ -671,14 +687,13 @@ def test_tick_logging(s, a, b): def test_compression(compression, serialize, loop): with dask.config.set(compression=compression): - @gen.coroutine - def f(): + async def f(): server = Server({"echo": serialize}) server.listen("tcp://") with rpc(server.address) as r: data = b"1" * 1000000 - result = yield r.echo(x=to_serialize(data)) + result = await r.echo(x=to_serialize(data)) assert result == {"result": data} server.stop() @@ -687,17 +702,16 @@ def f(): def test_rpc_serialization(loop): - @gen.coroutine - def f(): + async def f(): server = Server({"echo": echo_serialize}) server.listen("tcp://") with rpc(server.address, serializers=["msgpack"]) as r: with pytest.raises(TypeError): - yield r.echo(x=to_serialize(inc)) + await r.echo(x=to_serialize(inc)) with rpc(server.address, serializers=["msgpack", "pickle"]) as r: - result = yield r.echo(x=to_serialize(inc)) + result = await r.echo(x=to_serialize(inc)) assert result == {"result": inc} server.stop() From 2c148521abef4d0d5d46bd8f6e50da6096eb359b Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 3 Sep 2019 08:12:41 -0700 Subject: [PATCH 0445/1550] Avoid collision when using os.environ in dashboard_link (#3021) Fixes #3016 --- distributed/tests/test_utils.py | 16 ++++++++++++++++ distributed/utils.py | 4 +++- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/distributed/tests/test_utils.py b/distributed/tests/test_utils.py index bf2d8456681..ff733a1ad8c 100644 --- a/distributed/tests/test_utils.py +++ b/distributed/tests/test_utils.py @@ -2,6 +2,7 @@ import datetime from functools import partial import io +import os import queue import socket import sys @@ -40,6 +41,7 @@ parse_bytes, parse_timedelta, warn_on_duration, + format_dashboard_link, ) from distributed.utils_test import loop, loop_in_thread # noqa: F401 from distributed.utils_test import div, has_ipv6, inc, throws, gen_test, captured_logger @@ -582,3 +584,17 @@ def test_is_valid_xml(): assert is_valid_xml("foo") with pytest.raises(Exception): assert is_valid_xml("foo") + + +def test_format_dashboard_link(): + with dask.config.set({"distributed.dashboard.link": "foo"}): + assert format_dashboard_link("host", 1234) == "foo" + + assert "host" in format_dashboard_link("host", 1234) + assert "1234" in format_dashboard_link("host", 1234) + + try: + os.environ["host"] = "hello" + assert "hello" not in format_dashboard_link("host", 1234) + finally: + del os.environ["host"] diff --git a/distributed/utils.py b/distributed/utils.py index 9f14d58de0b..b7f6631ce93 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -1434,7 +1434,9 @@ def format_dashboard_link(host, port): scheme = "https" else: scheme = "http" - return template.format(scheme=scheme, host=host, port=port, **os.environ) + return template.format( + **toolz.merge(os.environ, dict(scheme=scheme, host=host, port=port)) + ) def is_coroutine_function(f): From 04a1bb0140d181d36eb27a284cfad2e289b8cda0 Mon Sep 17 00:00:00 2001 From: Abael He Date: Wed, 4 Sep 2019 00:05:45 +0800 Subject: [PATCH 0446/1550] Add support for zstandard compression to comms (#2970) Adds support for zstandard compression to dask comms. --- .../setup_conda_environment.cmd | 1 + continuous_integration/travis/install.sh | 2 +- distributed/distributed.yaml | 4 +++ distributed/protocol/compression.py | 20 +++++++++++ distributed/protocol/tests/test_protocol.py | 36 ++++++++----------- 5 files changed, 40 insertions(+), 23 deletions(-) diff --git a/continuous_integration/setup_conda_environment.cmd b/continuous_integration/setup_conda_environment.cmd index 6fff1a5ca6a..d09846faedd 100644 --- a/continuous_integration/setup_conda_environment.cmd +++ b/continuous_integration/setup_conda_environment.cmd @@ -18,6 +18,7 @@ call deactivate @rem Create test environment @rem (note: no cytoolz as it seems to prevent faulthandler tracebacks on crash) %CONDA% create -n %CONDA_ENV% -q -y ^ + zstandard ^ bokeh ^ click ^ cloudpickle ^ diff --git a/continuous_integration/travis/install.sh b/continuous_integration/travis/install.sh index 82993032e0b..ead4dbc3002 100644 --- a/continuous_integration/travis/install.sh +++ b/continuous_integration/travis/install.sh @@ -56,7 +56,7 @@ conda install -q \ # For low-level profiler, install libunwind and stacktrace from conda-forge # For stacktrace we use --no-deps to avoid upgrade of python -conda install -c defaults -c conda-forge libunwind +conda install -c defaults -c conda-forge libunwind zstandard conda install --no-deps -c defaults -c numba -c conda-forge stacktrace pip install -q "pytest>=4" pytest-repeat pytest-faulthandler pytest-asyncio diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index 9ad3e365e78..a0e801f26ac 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -71,6 +71,10 @@ distributed: socket-backlog: 2048 recent-messages-log-length: 0 # number of messages to keep for debugging + zstd: + level: 3 # Compression level, between 1 and 22. + threads: 0 # Threads to use. 0 for single-threaded, -1 to infer from cpu count. + timeouts: connect: 10s # time before connecting fails tcp: 30s # time before calling an unresponsive connection dead diff --git a/distributed/protocol/compression.py b/distributed/protocol/compression.py index 5035b465cee..5e81cdbaf1f 100644 --- a/distributed/protocol/compression.py +++ b/distributed/protocol/compression.py @@ -93,6 +93,26 @@ def _fixed_lz4_decompress(data): } default_compression = "lz4" + +with ignoring(ImportError): + import zstandard + + zstd_compressor = zstandard.ZstdCompressor( + level=dask.config.get("distributed.comm.zstd.level"), + threads=dask.config.get("distributed.comm.zstd.threads"), + ) + + zstd_decompressor = zstandard.ZstdDecompressor() + + def zstd_compress(data): + return zstd_compressor.compress(data) + + def zstd_decompress(data): + return zstd_decompressor.decompress(data) + + compressions["zstd"] = {"compress": zstd_compress, "decompress": zstd_decompress} + + with ignoring(ImportError): import blosc diff --git a/distributed/protocol/tests/test_protocol.py b/distributed/protocol/tests/test_protocol.py index 395c1ca7b97..3dd11ecc4d1 100644 --- a/distributed/protocol/tests/test_protocol.py +++ b/distributed/protocol/tests/test_protocol.py @@ -57,34 +57,26 @@ def test_small_and_big(): # assert loads([big_header, big]) == {'y': d['y']} -def test_maybe_compress(): - pass +@pytest.mark.parametrize( + "lib,compression", + [(None, None), ("zlib", "zlib"), ("lz4", "lz4"), ("zstandard", "zstd")], +) +def test_maybe_compress(lib, compression): + if lib: + pytest.importorskip(lib) try_converters = [bytes, memoryview] - try_compressions = ["zlib", "lz4"] - payload = b"123" - - with dask.config.set({"distributed.comm.compression": None}): + with dask.config.set({"distributed.comm.compression": compression}): for f in try_converters: + payload = b"123" assert maybe_compress(f(payload)) == (None, payload) - for compression in try_compressions: - try: - __import__(compression) - except ImportError: - continue - - with dask.config.set({"distributed.comm.compression": compression}): - for f in try_converters: - payload = b"123" - assert maybe_compress(f(payload)) == (None, payload) - - payload = b"0" * 10000 - rc, rd = maybe_compress(f(payload)) - # For some reason compressing memoryviews can force blosc... - assert rc in (compression, "blosc") - assert compressions[rc]["decompress"](rd) == payload + payload = b"0" * 10000 + rc, rd = maybe_compress(f(payload)) + # For some reason compressing memoryviews can force blosc... + assert rc in (compression, "blosc") + assert compressions[rc]["decompress"](rd) == payload def test_maybe_compress_sample(): From 7a1a369270557b912cd6fde3f96cce8196672f23 Mon Sep 17 00:00:00 2001 From: Jim Crist Date: Tue, 3 Sep 2019 14:02:16 -0500 Subject: [PATCH 0447/1550] Add fallback html repr for Cluster (#3023) This PR fixes two things: - Cluster objects don't error when repr'd in notebooks without ipywidgets installed. - If ipywidgets isn't installed, a fallback HTML repr is used instead of the default `repr` string. A test is added to test both of these. --- distributed/deploy/cluster.py | 40 +++++++++++++++++++++++--- distributed/deploy/tests/test_local.py | 25 ++++++++++++++++ 2 files changed, 61 insertions(+), 4 deletions(-) diff --git a/distributed/deploy/cluster.py b/distributed/deploy/cluster.py index 5e86c39ce8e..393be849b88 100644 --- a/distributed/deploy/cluster.py +++ b/distributed/deploy/cluster.py @@ -246,7 +246,11 @@ def _widget(self): except AttributeError: pass - from ipywidgets import Layout, VBox, HBox, IntText, Button, HTML, Accordion + try: + from ipywidgets import Layout, VBox, HBox, IntText, Button, HTML, Accordion + except ImportError: + self._cached_widget = None + return None layout = Layout(width="150px") @@ -258,7 +262,7 @@ def _widget(self): else: link = "" - title = "

          %s

          " % type(self).__name__ + title = "

          %s

          " % self._cluster_class_name title = HTML(title) dashboard = HTML(link) @@ -311,8 +315,32 @@ def update(): return box + def _repr_html_(self): + if self.dashboard_link: + dashboard = "
          {0}".format( + self.dashboard_link + ) + else: + dashboard = "Not Available" + return ( + "
          \n" + "

          {cls}

          \n" + "
            \n" + "
          • Dashboard: {dashboard}\n" + "
          \n" + "
          \n" + ).format(cls=self._cluster_class_name, dashboard=dashboard) + def _ipython_display_(self, **kwargs): - return self._widget()._ipython_display_(**kwargs) + widget = self._widget() + if widget is not None: + return widget._ipython_display_(**kwargs) + else: + from IPython.display import display + + data = {"text/plain": repr(self), "text/html": self._repr_html_()} + display(data, raw=True) async def __aenter__(self): await self @@ -325,9 +353,13 @@ async def __aexit__(self, typ, value, traceback): def scheduler_address(self): return self.scheduler_comm.address + @property + def _cluster_class_name(self): + return getattr(self, "_name", type(self).__name__) + def __repr__(self): text = "%s(%r, workers=%d, threads=%d" % ( - getattr(self, "_name", type(self).__name__), + self._cluster_class_name, self.scheduler_address, len(self.workers), sum(w["nthreads"] for w in self.scheduler_info["workers"].values()), diff --git a/distributed/deploy/tests/test_local.py b/distributed/deploy/tests/test_local.py index 59c75d545d8..2ad46035425 100644 --- a/distributed/deploy/tests/test_local.py +++ b/distributed/deploy/tests/test_local.py @@ -530,6 +530,31 @@ def test_ipywidgets(loop): assert isinstance(box, ipywidgets.Widget) +def test_no_ipywidgets(loop, monkeypatch): + from unittest.mock import MagicMock + + mock_display = MagicMock() + + monkeypatch.setitem(sys.modules, "ipywidgets", None) + monkeypatch.setitem(sys.modules, "IPython.display", mock_display) + + with LocalCluster( + n_workers=0, + scheduler_port=0, + silence_logs=False, + loop=loop, + dashboard_address=False, + processes=False, + ) as cluster: + cluster._ipython_display_() + args, kwargs = mock_display.display.call_args + res = args[0] + assert kwargs == {"raw": True} + assert isinstance(res, dict) + assert "text/plain" in res + assert "text/html" in res + + def test_scale(loop): """ Directly calling scale both up and down works as expected """ with LocalCluster( From 8d07bf162b1f6434fc982f36360c5cedbac369a6 Mon Sep 17 00:00:00 2001 From: Benjamin Zaitlen Date: Tue, 3 Sep 2019 20:47:44 -0400 Subject: [PATCH 0448/1550] Rely on cudf codebase for cudf serialization (#2998) * reinstate cudf serialization dispatch in distributed to support older versions of cudf -- there may be a distributed release *before* the next release of cudf --- distributed/protocol/__init__.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/distributed/protocol/__init__.py b/distributed/protocol/__init__.py index e30786ab4a5..ef8b5564bbb 100644 --- a/distributed/protocol/__init__.py +++ b/distributed/protocol/__init__.py @@ -1,4 +1,5 @@ from functools import partial +from distutils.version import LooseVersion from .compression import compressions, default_compression from .core import dumps, loads, maybe_compress, decompress, msgpack @@ -82,4 +83,9 @@ def _register_numba(): @cuda_serialize.register_lazy("cudf") @cuda_deserialize.register_lazy("cudf") def _register_cudf(): - from . import cudf + import cudf + + if LooseVersion(cudf.__version__) > "0.9": + from cudf.comm import serialize + else: + from . import cudf From adf247c15c32831b10c76a163cf0f769c4339483 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 6 Sep 2019 13:00:51 -0700 Subject: [PATCH 0449/1550] Set the x_range limit of the Meory utilization plot to memory-limit (#3034) This should give a better sense for how close we are to filling up memory, which seems to be the biggest use of the memory plot. --- distributed/dashboard/scheduler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/distributed/dashboard/scheduler.py b/distributed/dashboard/scheduler.py index 3dd108f5775..7f172c879d1 100644 --- a/distributed/dashboard/scheduler.py +++ b/distributed/dashboard/scheduler.py @@ -483,6 +483,7 @@ def update(self): self.nbytes_figure.title.text = "Bytes stored: " + format_bytes( sum(nbytes) ) + self.nbytes_figure.x_range.end = max_limit update(self.source, result) From 98822d4ab13081a49a3643f216dcf26009ff9496 Mon Sep 17 00:00:00 2001 From: Mikhail Akimov Date: Sat, 7 Sep 2019 00:30:46 +0300 Subject: [PATCH 0450/1550] Replace print statement in Queue.__init__ with debug message (#3035) --- distributed/queues.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/queues.py b/distributed/queues.py index 7174c48a63c..1d0c2c0bdd3 100644 --- a/distributed/queues.py +++ b/distributed/queues.py @@ -47,7 +47,7 @@ def __init__(self, scheduler): self.scheduler.extensions["queues"] = self def create(self, stream=None, name=None, client=None, maxsize=0): - print("name", name) + logger.debug("Queue name: {}".format(name)) if name not in self.queues: self.queues[name] = tornado.queues.Queue(maxsize=maxsize) self.client_refcount[name] = 1 From 8e544142602e05dab57331cc59f86c43550459db Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sat, 7 Sep 2019 08:37:01 -0700 Subject: [PATCH 0451/1550] Clean up test_local.py::test_defaults (#3017) --- distributed/deploy/tests/test_local.py | 112 +++++++++++++++---------- 1 file changed, 70 insertions(+), 42 deletions(-) diff --git a/distributed/deploy/tests/test_local.py b/distributed/deploy/tests/test_local.py index 2ad46035425..5459574cccf 100644 --- a/distributed/deploy/tests/test_local.py +++ b/distributed/deploy/tests/test_local.py @@ -225,79 +225,112 @@ def test_Client_twice(loop): assert c.cluster.scheduler.port != f.cluster.scheduler.port -@pytest.mark.skipif("sys.version_info[0] == 2", reason="fork issues") -def test_defaults(): - _nthreads = multiprocessing.cpu_count() - - with LocalCluster( - scheduler_port=0, silence_logs=False, dashboard_address=None +@pytest.mark.asyncio +async def test_defaults(cleanup): + async with LocalCluster( + scheduler_port=0, silence_logs=False, dashboard_address=None, asynchronous=True ) as c: - assert sum(w.nthreads for w in c.workers.values()) == _nthreads + assert ( + sum(w.nthreads for w in c.workers.values()) == multiprocessing.cpu_count() + ) assert all(isinstance(w, Nanny) for w in c.workers.values()) - with LocalCluster( - processes=False, scheduler_port=0, silence_logs=False, dashboard_address=None + +@pytest.mark.asyncio +async def test_defaults_2(cleanup): + async with LocalCluster( + processes=False, + scheduler_port=0, + silence_logs=False, + dashboard_address=None, + asynchronous=True, ) as c: - assert sum(w.nthreads for w in c.workers.values()) == _nthreads + assert ( + sum(w.nthreads for w in c.workers.values()) == multiprocessing.cpu_count() + ) assert all(isinstance(w, Worker) for w in c.workers.values()) assert len(c.workers) == 1 - with LocalCluster( - n_workers=2, scheduler_port=0, silence_logs=False, dashboard_address=None + +@pytest.mark.asyncio +async def test_defaults_3(cleanup): + async with LocalCluster( + n_workers=2, + scheduler_port=0, + silence_logs=False, + dashboard_address=None, + asynchronous=True, ) as c: - if _nthreads % 2 == 0: - expected_total_threads = max(2, _nthreads) + if multiprocessing.cpu_count() % 2 == 0: + expected_total_threads = max(2, multiprocessing.cpu_count()) else: # n_workers not a divisor of _nthreads => threads are overcommitted - expected_total_threads = max(2, _nthreads + 1) + expected_total_threads = max(2, multiprocessing.cpu_count() + 1) assert sum(w.nthreads for w in c.workers.values()) == expected_total_threads - with LocalCluster( - threads_per_worker=_nthreads * 2, + +@pytest.mark.asyncio +async def test_defaults_4(cleanup): + async with LocalCluster( + threads_per_worker=multiprocessing.cpu_count() * 2, scheduler_port=0, silence_logs=False, dashboard_address=None, + asynchronous=True, ) as c: assert len(c.workers) == 1 - with LocalCluster( - n_workers=_nthreads * 2, + +@pytest.mark.asyncio +async def test_defaults_5(cleanup): + async with LocalCluster( + n_workers=multiprocessing.cpu_count() * 2, scheduler_port=0, silence_logs=False, dashboard_address=None, + asynchronous=True, ) as c: assert all(w.nthreads == 1 for w in c.workers.values()) - with LocalCluster( + + +@pytest.mark.asyncio +async def test_defaults_6(cleanup): + async with LocalCluster( threads_per_worker=2, n_workers=3, scheduler_port=0, silence_logs=False, dashboard_address=None, + asynchronous=True, ) as c: assert len(c.workers) == 3 assert all(w.nthreads == 2 for w in c.workers.values()) -def test_worker_params(): - with LocalCluster( +@pytest.mark.asyncio +async def test_worker_params(cleanup): + async with LocalCluster( processes=False, n_workers=2, scheduler_port=0, silence_logs=False, dashboard_address=None, memory_limit=500, + asynchronous=True, ) as c: assert [w.memory_limit for w in c.workers.values()] == [500] * 2 -def test_memory_limit_none(): - with LocalCluster( +@pytest.mark.asyncio +async def test_memory_limit_none(cleanup): + async with LocalCluster( n_workers=2, scheduler_port=0, silence_logs=False, processes=False, dashboard_address=None, memory_limit=None, + asynchronous=True, ) as c: w = c.workers[0] assert type(w.data) is dict @@ -364,34 +397,29 @@ def test_blocks_until_full(loop): assert len(c.nthreads()) > 0 -@gen_test() -def test_scale_up_and_down(): - loop = IOLoop.current() - cluster = yield LocalCluster( +@pytest.mark.asyncio +async def test_scale_up_and_down(): + async with LocalCluster( 0, scheduler_port=0, processes=False, silence_logs=False, dashboard_address=None, - loop=loop, asynchronous=True, - ) - c = yield Client(cluster, asynchronous=True) - - assert not cluster.workers + ) as cluster: + async with Client(cluster, asynchronous=True) as c: - cluster.scale(2) - yield cluster - assert len(cluster.workers) == 2 - assert len(cluster.scheduler.nthreads) == 2 + assert not cluster.workers - cluster.scale(1) - yield cluster + cluster.scale(2) + await cluster + assert len(cluster.workers) == 2 + assert len(cluster.scheduler.nthreads) == 2 - assert len(cluster.workers) == 1 + cluster.scale(1) + await cluster - yield c.close() - yield cluster.close() + assert len(cluster.workers) == 1 def test_silent_startup(): From 5735238e6fcd30a64cc16a3c60ba16b045966beb Mon Sep 17 00:00:00 2001 From: Guillaume Eynard-Bontemps Date: Sat, 7 Sep 2019 17:37:41 +0200 Subject: [PATCH 0452/1550] Remove lost workers from SpecCluster.workers (#2990) * Detect workers end in SpecCluster * Make worker deletion checks robust to missing worker * Convert integer names in dask-worker CLI * lint * remove logged warning (this happens under normal operation) * Make sure that we close workers before we delete them * Add a moderate delay before closing worker job --- distributed/cli/dask_worker.py | 6 +++- distributed/cli/tests/test_dask_worker.py | 15 ++++++-- distributed/deploy/cluster.py | 22 +++++++----- distributed/deploy/spec.py | 33 +++++++++++++++--- distributed/deploy/tests/test_spec_cluster.py | 34 ++++++++++++++++++- distributed/distributed.yaml | 3 ++ 6 files changed, 96 insertions(+), 17 deletions(-) diff --git a/distributed/cli/dask_worker.py b/distributed/cli/dask_worker.py index 790b8b3a9ab..952ba90984a 100755 --- a/distributed/cli/dask_worker.py +++ b/distributed/cli/dask_worker.py @@ -8,6 +8,7 @@ import click import dask +from dask.utils import ignoring from distributed import Nanny, Worker from distributed.security import Security from distributed.cli.utils import check_python_3, install_signal_handlers @@ -354,6 +355,9 @@ def del_pid_file(): "dask-worker SCHEDULER_ADDRESS:8786" ) + with ignoring(TypeError, ValueError): + name = int(name) + nannies = [ t( scheduler, @@ -367,7 +371,7 @@ def del_pid_file(): port=port, dashboard_address=dashboard_address if dashboard else None, service_kwargs={"dashboard": {"prefix": dashboard_prefix}}, - name=name if nprocs == 1 or not name else name + "-" + str(i), + name=name if nprocs == 1 or not name else str(name) + "-" + str(i), **kwargs ) for i in range(nprocs) diff --git a/distributed/cli/tests/test_dask_worker.py b/distributed/cli/tests/test_dask_worker.py index 2dd74737b16..0e871cf1b60 100644 --- a/distributed/cli/tests/test_dask_worker.py +++ b/distributed/cli/tests/test_dask_worker.py @@ -1,3 +1,4 @@ +import asyncio import pytest from click.testing import CliRunner @@ -9,11 +10,11 @@ from time import sleep import distributed.cli.dask_worker -from distributed import Client +from distributed import Client, Scheduler from distributed.metrics import time from distributed.utils import sync, tmpfile from distributed.utils_test import popen, terminate_process, wait_for_port -from distributed.utils_test import loop # noqa: F401 +from distributed.utils_test import loop, cleanup # noqa: F401 def test_nanny_worker_ports(loop): @@ -330,3 +331,13 @@ def test_bokeh_deprecation(): except ValueError: # didn't pass scheduler pass + + +@pytest.mark.asyncio +async def test_integer_names(cleanup): + async with Scheduler(port=0) as s: + with popen(["dask-worker", s.address, "--name", "123"]) as worker: + while not s.workers: + await asyncio.sleep(0.01) + [ws] = s.workers.values() + assert ws.name == 123 diff --git a/distributed/deploy/cluster.py b/distributed/deploy/cluster.py index 393be849b88..32ceedc47bd 100644 --- a/distributed/deploy/cluster.py +++ b/distributed/deploy/cluster.py @@ -95,18 +95,22 @@ async def _watch_worker_status(self, comm): except OSError: break - for op, msg in msgs: - if op == "add": - workers = msg.pop("workers") - self.scheduler_info["workers"].update(workers) - self.scheduler_info.update(msg) - elif op == "remove": - del self.scheduler_info["workers"][msg] - else: - raise ValueError("Invalid op", op, msg) + with log_errors(): + for op, msg in msgs: + self._update_worker_status(op, msg) await comm.close() + def _update_worker_status(self, op, msg): + if op == "add": + workers = msg.pop("workers") + self.scheduler_info["workers"].update(workers) + self.scheduler_info.update(msg) + elif op == "remove": + del self.scheduler_info["workers"][msg] + else: + raise ValueError("Invalid op", op, msg) + def adapt(self, Adaptive=Adaptive, **kwargs) -> Adaptive: """ Turn on adaptivity diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index d897dc6b7df..487e5192e17 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -1,18 +1,23 @@ import asyncio import atexit import copy +import logging import math import weakref +import dask from tornado import gen from .cluster import Cluster from ..core import rpc, CommClosedError -from ..utils import LoopRunner, silence_logging, ignoring, parse_bytes +from ..utils import LoopRunner, silence_logging, ignoring, parse_bytes, parse_timedelta from ..scheduler import Scheduler from ..security import Security +logger = logging.getLogger(__name__) + + class ProcessInterface: """ An interface for Scheduler and Worker processes for use in SpecCluster @@ -201,6 +206,7 @@ def __init__( self._i = 0 self.security = security or Security() self.scheduler_comm = None + self._futures = set() if silence_logs: self._old_logging_level = silence_logging(level=silence_logs) @@ -267,13 +273,14 @@ async def _correct_state_internal(self): if to_close: if self.scheduler.status == "running": await self.scheduler_comm.retire_workers(workers=list(to_close)) - tasks = [self.workers[w].close() for w in to_close] + tasks = [self.workers[w].close() for w in to_close if w in self.workers] await asyncio.wait(tasks) for task in tasks: # for tornado gen.coroutine support with ignoring(RuntimeError): await task for name in to_close: - del self.workers[name] + if name in self.workers: + del self.workers[name] to_open = set(self.worker_spec) - set(self.workers) workers = [] @@ -293,6 +300,22 @@ async def _correct_state_internal(self): await w # for tornado gen.coroutine support self.workers.update(dict(zip(to_open, workers))) + def _update_worker_status(self, op, msg): + if op == "remove": + name = self.scheduler_info["workers"][msg]["name"] + + def f(): + if name in self.workers and msg not in self.scheduler_info: + self._futures.add(asyncio.ensure_future(self.workers[name].close())) + del self.workers[name] + + delay = parse_timedelta( + dask.config.get("distributed.deploy.lost-worker-timeout") + ) + + asyncio.get_event_loop().call_later(delay, f) + super()._update_worker_status(op, msg) + def __await__(self): async def _(): if self.status == "created": @@ -314,13 +337,15 @@ async def _close(self): self.scale(0) await self._correct_state() + for future in self._futures: + await future async with self._lock: with ignoring(CommClosedError): await self.scheduler_comm.close(close_workers=True) await self.scheduler.close() for w in self._created: - assert w.status == "closed" + assert w.status == "closed", w.status if hasattr(self, "_old_logging_level"): silence_logging(self._old_logging_level) diff --git a/distributed/deploy/tests/test_spec_cluster.py b/distributed/deploy/tests/test_spec_cluster.py index efc231ca030..679adc0fd60 100644 --- a/distributed/deploy/tests/test_spec_cluster.py +++ b/distributed/deploy/tests/test_spec_cluster.py @@ -1,8 +1,9 @@ import asyncio -from time import time +import dask from dask.distributed import SpecCluster, Worker, Client, Scheduler, Nanny from distributed.deploy.spec import close_clusters, ProcessInterface +from distributed.metrics import time from distributed.utils_test import loop, cleanup # noqa: F401 from distributed.utils import is_valid_xml import toolz @@ -124,6 +125,37 @@ async def test_scale(cleanup): assert len(cluster.workers) == 1 +@pytest.mark.asyncio +async def test_unexpected_closed_worker(cleanup): + worker = {"cls": Worker, "options": {"nthreads": 1}} + with dask.config.set({"distributed.deploy.lost-worker-timeout": "10ms"}): + async with SpecCluster( + asynchronous=True, scheduler=scheduler, worker=worker + ) as cluster: + assert not cluster.workers + assert not cluster.worker_spec + + # Scale up + cluster.scale(2) + assert not cluster.workers + assert cluster.worker_spec + + await cluster + assert len(cluster.workers) == 2 + + # Close one + await list(cluster.workers.values())[0].close() + start = time() + while len(cluster.workers) > 1: # wait for messages to flow around + await asyncio.sleep(0.01) + assert time() < start + 2 + assert len(cluster.workers) == 1 + assert len(cluster.worker_spec) == 2 + + await cluster + assert len(cluster.workers) == 2 + + @pytest.mark.asyncio async def test_broken_worker(): with pytest.raises(Exception) as info: diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index a0e801f26ac..f277eb2f90b 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -65,6 +65,9 @@ distributed: client: heartbeat: 5s # time between client heartbeats + deploy: + lost-worker-timeout: 15s # Interval after which to hard-close a lost worker job + comm: compression: auto default-scheme: tcp From f0ccd366124f45da2bc86802d50684464b14fcb8 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sun, 8 Sep 2019 15:59:30 -0700 Subject: [PATCH 0453/1550] Support --name 0 and --nprocs keywords in dask-worker cli (#3037) Previously a test for `not name` would incorrectly pass if the user provided `--name 0` --- distributed/cli/dask_worker.py | 4 +++- distributed/cli/tests/test_dask_worker.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/distributed/cli/dask_worker.py b/distributed/cli/dask_worker.py index 952ba90984a..badecc4dad8 100755 --- a/distributed/cli/dask_worker.py +++ b/distributed/cli/dask_worker.py @@ -371,7 +371,9 @@ def del_pid_file(): port=port, dashboard_address=dashboard_address if dashboard else None, service_kwargs={"dashboard": {"prefix": dashboard_prefix}}, - name=name if nprocs == 1 or not name else str(name) + "-" + str(i), + name=name + if nprocs == 1 or name is None or name == "" + else str(name) + "-" + str(i), **kwargs ) for i in range(nprocs) diff --git a/distributed/cli/tests/test_dask_worker.py b/distributed/cli/tests/test_dask_worker.py index 0e871cf1b60..01327d64291 100644 --- a/distributed/cli/tests/test_dask_worker.py +++ b/distributed/cli/tests/test_dask_worker.py @@ -181,7 +181,7 @@ def test_nprocs_requires_nanny(loop): def test_nprocs_expands_name(loop): with popen(["dask-scheduler", "--no-dashboard"]) as sched: with popen( - ["dask-worker", "127.0.0.1:8786", "--nprocs", "2", "--name", "foo"] + ["dask-worker", "127.0.0.1:8786", "--nprocs", "2", "--name", "0"] ) as worker: with popen(["dask-worker", "127.0.0.1:8786", "--nprocs", "2"]) as worker: with Client("tcp://127.0.0.1:8786", loop=loop) as c: @@ -192,7 +192,7 @@ def test_nprocs_expands_name(loop): info = c.scheduler_info() names = [d["name"] for d in info["workers"].values()] - foos = [n for n in names if n.startswith("foo")] + foos = [n for n in names if n.startswith("0-")] assert len(foos) == 2 assert len(set(names)) == 4 From fb733478bd30d71f0889119ec3859993dfa0315f Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sun, 8 Sep 2019 17:42:00 -0700 Subject: [PATCH 0454/1550] Redirect configuration doc page (#3038) --- docs/source/conf.py | 3 +- docs/source/configuration.rst | 179 ---------------------------------- 2 files changed, 2 insertions(+), 180 deletions(-) delete mode 100644 docs/source/configuration.rst diff --git a/docs/source/conf.py b/docs/source/conf.py index bb3361851b8..9bd0ce6867e 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -391,7 +391,8 @@ # old html, new html ("joblib.html", "https://ml.dask.org/joblib.html"), ("setup.html", "https://docs.dask.org/en/latest/setup.html"), - ("ec2.html", "https://dask.pydata.org/en/latest/setup/cloud.html"), + ("ec2.html", "https://docs.dask.org/en/latest/setup/cloud.html"), + ("configuration.html", "https://docs.dask.org/en/latest/configuration.html"), ] diff --git a/docs/source/configuration.rst b/docs/source/configuration.rst deleted file mode 100644 index 8967255f526..00000000000 --- a/docs/source/configuration.rst +++ /dev/null @@ -1,179 +0,0 @@ -.. _configuration: - -============= -Configuration -============= - -As with any distributed computation system, taking full advantage of -Dask distributed sometimes requires configuration. Some options can be -passed as :ref:`API ` parameters and/or command line options to the -various Dask executables. However, some options can also be entered in -the Dask configuration file. - - -User-wide configuration -======================= - -Dask accepts some configuration options in a configuration file, which by -default is a ``.dask/config.yaml`` file located in your home directory. -The file path can be overriden using the ``DASK_CONFIG`` environment variable. -In order to parse this configuration file, the ``pyyaml`` module needs to be -installed. If the ``pyyaml`` module is not installed, the configuration file -is ignored. - -The file is written in the YAML format, which allows for a human-readable -hierarchical key-value configuration. All keys in the configuration file -are optional, though Dask will create a default configuration file for you -on its first launch. - -Here is a synopsis of the configuration file: - -.. code-block:: yaml - - logging: - distributed: info - distributed.client: warning - bokeh: critical - - # Scheduler options - bandwidth: 100000000 # 100 MB/s estimated worker-worker bandwidth - allowed-failures: 3 # number of retries before a task is considered bad - pdb-on-err: False # enter debug mode on scheduling error - transition-log-length: 100000 - - # Worker options - multiprocessing-method: forkserver - - # Communication options - compression: auto - tcp-timeout: 30 # seconds delay before calling an unresponsive connection dead - default-scheme: tcp - require-encryption: False # whether to require encryption on non-local comms - tls: - ca-file: myca.pem - scheduler: - cert: mycert.pem - key: mykey.pem - worker: - cert: mycert.pem - key: mykey.pem - client: - cert: mycert.pem - key: mykey.pem - #ciphers: - #ECDHE-ECDSA-AES128-GCM-SHA256 - - # Bokeh web dashboard - bokeh-export-tool: False - - -We will review some of those options hereafter. - - -Communication options ---------------------- - -``compression`` -""""""""""""""" - -This key configures the desired compression scheme when transferring data -over the network. The default value, "auto", applies heuristics to try and -select the best compression scheme for each piece of data. - - -``default-scheme`` -"""""""""""""""""" - -The :ref:`communication ` scheme used by default. You can -override the default ("tcp") here, but it is recommended to use explicit URIs -for the various endpoints instead (for example ``tls://`` if you want to -enable :ref:`TLS ` communications). - - -``require-encryption`` -"""""""""""""""""""""" - -Whether to require that all non-local communications be encrypted. If true, -then Dask will refuse establishing any clear-text communications (for example -over TCP without TLS), forcing you to use a secure transport such as -:ref:`TLS `. - - -``tcp-timeout`` -""""""""""""""" - -The default "timeout" on TCP sockets. If a remote endpoint is unresponsive -(at the TCP layer, not at the distributed layer) for at least the specified -number of seconds, the communication is considered closed. This helps detect -endpoints that have been killed or have disconnected abruptly. - - -``tls`` -""""""" - -This key configures :ref:`TLS ` communications. Several sub-keys are -recognized: - -* ``ca-file`` configures the CA certificate file used to authenticate - and authorize all endpoints. -* ``ciphers`` restricts allowed ciphers on TLS communications. - -Each kind of endpoint has a dedicated endpoint sub-key: ``scheduler``, -``worker`` and ``client``. Each endpoint sub-key also supports several -sub-keys: - -* ``cert`` configures the certificate file for the endpoint. -* ``key`` configures the private key file for the endpoint. - - -Scheduler options ------------------ - -``allowed-failures`` -"""""""""""""""""""" - -The number of retries before a "suspicious" task is considered bad. -A task is considered "suspicious" if the worker died while executing it. - - -``bandwidth`` -""""""""""""" - -The estimated network bandwidth, in bytes per second, from worker to worker. -This value is used to estimate the time it takes to ship data from one node -to another, and balance tasks and data accordingly. - - -Misc options ------------- - -``logging`` -""""""""""" - -This key configures the logging settings. There are two possible formats. -The simple, recommended format configures the desired verbosity level -for each logger. It also sets default values for several loggers such -as ``distributed`` unless explicitly configured. - -A more extended format is possible following the :mod:`logging` module's -`Configuration dictionary schema `_. -To enable this extended format, there must be a ``version`` sub-key as -mandated by the schema. The extended format does not set any default values. - -.. note:: - Python's :mod:`logging` module uses a hierarchical logger tree. - For example, configuring the logging level for the ``distributed`` - logger will also affect its children such as ``distributed.scheduler``, - unless explicitly overriden. - - -``logging-file-config`` -""""""""""""""""""""""" - -As an alternative to the two logging settings formats discussed above, -you can specify a logging config file. -Its format adheres to the :mod:`logging` module's -`Configuration file format `_. - -.. note:: - The configuration options `logging-file-config` and `logging` are mutually exclusive. \ No newline at end of file From d419d41eebdc64d05336284ade8104a49d503e55 Mon Sep 17 00:00:00 2001 From: jakirkham Date: Mon, 9 Sep 2019 22:51:21 +0200 Subject: [PATCH 0455/1550] Drop joblib shim module in distributed (#3040) This already just raises an `ImportError` and has done so for a while. Seems reasonable to just drop the module and let Python raise the `ImportError` that it would anyways. Likely everyone has migrated to this convention by now. --- distributed/joblib.py | 19 ------------------- 1 file changed, 19 deletions(-) delete mode 100644 distributed/joblib.py diff --git a/distributed/joblib.py b/distributed/joblib.py deleted file mode 100644 index fd81a8b078e..00000000000 --- a/distributed/joblib.py +++ /dev/null @@ -1,19 +0,0 @@ -msg = """ It is no longer necessary to `import dask_ml.joblib` or -`import distributed.joblib`. - -This functionality has moved into the core Joblib codebase. - -To use Joblib's Dask backend with Scikit-Learn >= 0.20.0 - - from dask.distributed import Client - client = Client() - - from sklearn.externals import joblib - - with joblib.parallel_backend('dask'): - # your scikit-learn code - -See http://ml.dask.org/joblib.html for more information.""" - - -raise ImportError(msg) From 810808e9ce4b3efeff1afc99e8fd8a9430f08a8c Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Mon, 9 Sep 2019 18:04:08 -0500 Subject: [PATCH 0456/1550] Move task deserialization to immediately before task execution (#3015) This commit moves task deserialization on workers to just before the task is marked for execution on the worker executor. This has a couple of benefits: - In the case that a task is stolen, there won't be any unnecessary deserialization - In the case of single-threaded workers, this will help reduce issues with task deserialization interfering with concurrently running tasks on the executor (ref #2965). This is because tasks are only transitioned to executing when the number of executing tasks is less than the number of executor theads https://github.com/dask/distributed/blob/cf26e1a559e2c89c4a4b14b6622111eaf0954f12/distributed/worker.py#L2361 IIUC this means, for single-threaded executors, tasks will only transition to executing when there are no currently executing tasks On the other hand, a drawback to this approach is that deserialization errors are not immediately raised. --- distributed/worker.py | 54 ++++++++++++++++++++++++++++--------------- 1 file changed, 36 insertions(+), 18 deletions(-) diff --git a/distributed/worker.py b/distributed/worker.py index ca4f4121af3..06102a98405 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1,6 +1,6 @@ import asyncio import bisect -from collections import defaultdict, deque +from collections import defaultdict, deque, namedtuple from collections.abc import MutableMapping from datetime import timedelta import heapq @@ -86,6 +86,8 @@ DEFAULT_STARTUP_INFORMATION = {} +SerializedTask = namedtuple("SerializedTask", ["function", "args", "kwargs", "task"]) + class Worker(ServerNode): """ Worker node in a Dask distributed cluster @@ -1330,23 +1332,9 @@ def add_task( return self.log.append((key, "new")) - try: - start = time() - self.tasks[key] = _deserialize(function, args, kwargs, task) - if actor: - self.actors[key] = None - stop = time() - - if stop - start > 0.010: - self.startstops[key].append(("deserialize", start, stop)) - except Exception as e: - logger.warning("Could not deserialize task", exc_info=True) - emsg = error_message(e) - emsg["key"] = key - emsg["op"] = "task-erred" - self.batched_stream.send(emsg) - self.log.append((key, "deserialize-error")) - return + self.tasks[key] = SerializedTask(function, args, kwargs, task) + if actor: + self.actors[key] = None self.priorities[key] = priority self.durations[key] = duration @@ -2344,6 +2332,26 @@ def meets_resource_constraints(self, key): return True + def _maybe_deserialize_task(self, key): + if not isinstance(self.tasks[key], SerializedTask): + return self.tasks[key] + try: + start = time() + function, args, kwargs = _deserialize(*self.tasks[key]) + stop = time() + + if stop - start > 0.010: + self.startstops[key].append(("deserialize", start, stop)) + return function, args, kwargs + except Exception as e: + logger.warning("Could not deserialize task", exc_info=True) + emsg = error_message(e) + emsg["key"] = key + emsg["op"] = "task-erred" + self.batched_stream.send(emsg) + self.log.append((key, "deserialize-error")) + raise + def ensure_computing(self): if self.paused: return @@ -2355,12 +2363,22 @@ def ensure_computing(self): continue if self.meets_resource_constraints(key): self.constrained.popleft() + try: + # Ensure task is deserialized prior to execution + self.tasks[key] = self._maybe_deserialize_task(key) + except Exception: + continue self.transition(key, "executing") else: break while self.ready and len(self.executing) < self.nthreads: _, key = heapq.heappop(self.ready) if self.task_state.get(key) in READY: + try: + # Ensure task is deserialized prior to execution + self.tasks[key] = self._maybe_deserialize_task(key) + except Exception: + continue self.transition(key, "executing") except Exception as e: logger.exception(e) From eecf25bd55b8bd6b58a3ee9c43e6e65f784ec4a4 Mon Sep 17 00:00:00 2001 From: Jim Crist Date: Tue, 10 Sep 2019 13:26:18 -0500 Subject: [PATCH 0457/1550] Use cgroups resource limits to determine default threads and memory (#3039) This adds support for detecting resource (CPU and memory) limits set using `cgroups`. This makes dask resource detection play nicer with container systems (e.g. `docker`), rather than detecting the host memory and cpus avaialable. This also centralizes all queries about the host platform to a single module (`distributed.platform`), with top-level constants defined for common usage. --- distributed/cli/dask_worker.py | 4 +- distributed/comm/tcp.py | 12 +--- distributed/core.py | 11 --- distributed/deploy/local.py | 12 ++-- distributed/deploy/tests/test_local.py | 25 +++---- distributed/nanny.py | 4 +- distributed/protocol/tests/test_numpy.py | 10 +-- distributed/protocol/tests/test_protocol.py | 13 ++-- distributed/system.py | 79 +++++++++++++++++++++ distributed/tests/test_steal.py | 6 +- distributed/tests/test_system.py | 32 +++++++++ distributed/tests/test_worker.py | 30 +++----- distributed/utils_test.py | 11 ++- distributed/worker.py | 28 ++------ 14 files changed, 168 insertions(+), 109 deletions(-) create mode 100644 distributed/system.py create mode 100644 distributed/tests/test_system.py diff --git a/distributed/cli/dask_worker.py b/distributed/cli/dask_worker.py index badecc4dad8..35ca11da34c 100755 --- a/distributed/cli/dask_worker.py +++ b/distributed/cli/dask_worker.py @@ -1,6 +1,5 @@ import atexit import logging -import multiprocessing import gc import os from sys import exit @@ -11,6 +10,7 @@ from dask.utils import ignoring from distributed import Nanny, Worker from distributed.security import Security +from distributed.system import CPU_COUNT from distributed.cli.utils import check_python_3, install_signal_handlers from distributed.comm import get_address_host_port from distributed.preloading import validate_preload_argv @@ -316,7 +316,7 @@ def main( port = worker_port if not nthreads: - nthreads = multiprocessing.cpu_count() // nprocs + nthreads = CPU_COUNT // nprocs if pid_file: with open(pid_file, "w") as f: diff --git a/distributed/comm/tcp.py b/distributed/comm/tcp.py index bd76d0e6946..d0322e151d7 100644 --- a/distributed/comm/tcp.py +++ b/distributed/comm/tcp.py @@ -18,6 +18,7 @@ from tornado.tcpclient import TCPClient from tornado.tcpserver import TCPServer +from ..system import MEMORY_LIMIT from ..threadpoolexecutor import ThreadPoolExecutor from ..utils import ( ensure_bytes, @@ -38,16 +39,7 @@ logger = logging.getLogger(__name__) -def get_total_physical_memory(): - try: - import psutil - - return psutil.virtual_memory().total / 2 - except ImportError: - return 2e9 - - -MAX_BUFFER_SIZE = get_total_physical_memory() +MAX_BUFFER_SIZE = MEMORY_LIMIT / 2 def set_tcp_timeout(stream): diff --git a/distributed/core.py b/distributed/core.py index 6bda9c9e0be..5da103802d2 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -47,15 +47,6 @@ class RPCClosed(IOError): logger = logging.getLogger(__name__) -def get_total_physical_memory(): - try: - import psutil - - return psutil.virtual_memory().total / 2 - except ImportError: - return 2e9 - - def raise_later(exc): def _raise(*args, **kwargs): raise exc @@ -63,8 +54,6 @@ def _raise(*args, **kwargs): return _raise -MAX_BUFFER_SIZE = get_total_physical_memory() - tick_maximum_delay = parse_timedelta( dask.config.get("distributed.admin.tick.limit"), default="ms" ) diff --git a/distributed/deploy/local.py b/distributed/deploy/local.py index efe5ed03098..29c344f6719 100644 --- a/distributed/deploy/local.py +++ b/distributed/deploy/local.py @@ -1,12 +1,12 @@ import atexit import logging import math -import multiprocessing import warnings import weakref from dask.utils import factors +from .. import system from .spec import SpecCluster from ..nanny import Nanny from ..scheduler import Scheduler @@ -146,14 +146,12 @@ def __init__( n_workers, threads_per_worker = nprocesses_nthreads() else: n_workers = 1 - threads_per_worker = multiprocessing.cpu_count() + threads_per_worker = system.CPU_COUNT if n_workers is None and threads_per_worker is not None: - n_workers = max(1, multiprocessing.cpu_count() // threads_per_worker) + n_workers = max(1, system.CPU_COUNT // threads_per_worker) if n_workers and threads_per_worker is None: # Overcommit threads per worker, rather than undercommit - threads_per_worker = max( - 1, int(math.ceil(multiprocessing.cpu_count() / n_workers)) - ) + threads_per_worker = max(1, int(math.ceil(system.CPU_COUNT / n_workers))) if n_workers and "memory_limit" not in worker_kwargs: worker_kwargs["memory_limit"] = parse_memory_limit("auto", 1, n_workers) @@ -208,7 +206,7 @@ def start_worker(self, *args, **kwargs): ) -def nprocesses_nthreads(n=multiprocessing.cpu_count()): +def nprocesses_nthreads(n=system.CPU_COUNT): """ The default breakdown of processes and threads for a given number of cores diff --git a/distributed/deploy/tests/test_local.py b/distributed/deploy/tests/test_local.py index 5459574cccf..6611bfccc38 100644 --- a/distributed/deploy/tests/test_local.py +++ b/distributed/deploy/tests/test_local.py @@ -1,6 +1,5 @@ from functools import partial import gc -import multiprocessing import subprocess import sys from time import sleep @@ -15,6 +14,7 @@ from distributed import Client, Worker, Nanny, get_client from distributed.deploy.local import LocalCluster, nprocesses_nthreads from distributed.metrics import time +from distributed.system import CPU_COUNT, MEMORY_LIMIT from distributed.utils_test import ( # noqa: F401 clean, cleanup, @@ -29,7 +29,6 @@ ) from distributed.utils_test import loop # noqa: F401 from distributed.utils import sync -from distributed.worker import TOTAL_MEMORY from distributed.deploy.utils_test import ClusterTest @@ -230,9 +229,7 @@ async def test_defaults(cleanup): async with LocalCluster( scheduler_port=0, silence_logs=False, dashboard_address=None, asynchronous=True ) as c: - assert ( - sum(w.nthreads for w in c.workers.values()) == multiprocessing.cpu_count() - ) + assert sum(w.nthreads for w in c.workers.values()) == CPU_COUNT assert all(isinstance(w, Nanny) for w in c.workers.values()) @@ -245,9 +242,7 @@ async def test_defaults_2(cleanup): dashboard_address=None, asynchronous=True, ) as c: - assert ( - sum(w.nthreads for w in c.workers.values()) == multiprocessing.cpu_count() - ) + assert sum(w.nthreads for w in c.workers.values()) == CPU_COUNT assert all(isinstance(w, Worker) for w in c.workers.values()) assert len(c.workers) == 1 @@ -261,18 +256,18 @@ async def test_defaults_3(cleanup): dashboard_address=None, asynchronous=True, ) as c: - if multiprocessing.cpu_count() % 2 == 0: - expected_total_threads = max(2, multiprocessing.cpu_count()) + if CPU_COUNT % 2 == 0: + expected_total_threads = max(2, CPU_COUNT) else: # n_workers not a divisor of _nthreads => threads are overcommitted - expected_total_threads = max(2, multiprocessing.cpu_count() + 1) + expected_total_threads = max(2, CPU_COUNT + 1) assert sum(w.nthreads for w in c.workers.values()) == expected_total_threads @pytest.mark.asyncio async def test_defaults_4(cleanup): async with LocalCluster( - threads_per_worker=multiprocessing.cpu_count() * 2, + threads_per_worker=CPU_COUNT * 2, scheduler_port=0, silence_logs=False, dashboard_address=None, @@ -284,7 +279,7 @@ async def test_defaults_4(cleanup): @pytest.mark.asyncio async def test_defaults_5(cleanup): async with LocalCluster( - n_workers=multiprocessing.cpu_count() * 2, + n_workers=CPU_COUNT * 2, scheduler_port=0, silence_logs=False, dashboard_address=None, @@ -473,7 +468,7 @@ def test_memory(loop, n_workers): dashboard_address=None, loop=loop, ) as cluster: - assert sum(w.memory_limit for w in cluster.workers.values()) <= TOTAL_MEMORY + assert sum(w.memory_limit for w in cluster.workers.values()) <= MEMORY_LIMIT @pytest.mark.parametrize("n_workers", [None, 3]) @@ -489,7 +484,7 @@ def test_memory_nanny(loop, n_workers): with Client(cluster.scheduler_address, loop=loop) as c: info = c.scheduler_info() assert ( - sum(w["memory_limit"] for w in info["workers"].values()) <= TOTAL_MEMORY + sum(w["memory_limit"] for w in info["workers"].values()) <= MEMORY_LIMIT ) diff --git a/distributed/nanny.py b/distributed/nanny.py index b6d8dadbf9a..c017eb54af7 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -1,7 +1,6 @@ from datetime import timedelta import logging from multiprocessing.queues import Empty -import multiprocessing import os import psutil import shutil @@ -23,6 +22,7 @@ from .process import AsyncProcess from .proctitle import enable_proctitle_on_children from .security import Security +from .system import CPU_COUNT from .utils import ( get_ip, mp_context, @@ -110,7 +110,7 @@ def __init__( nthreads = ncores self._given_worker_port = worker_port - self.nthreads = nthreads or multiprocessing.cpu_count() + self.nthreads = nthreads or CPU_COUNT self.reconnect = reconnect self.validate = validate self.resources = resources diff --git a/distributed/protocol/tests/test_numpy.py b/distributed/protocol/tests/test_numpy.py index b334683b661..70d57a2e74f 100644 --- a/distributed/protocol/tests/test_numpy.py +++ b/distributed/protocol/tests/test_numpy.py @@ -14,10 +14,11 @@ msgpack, ) from distributed.protocol.utils import BIG_BYTES_SHARD_SIZE -from distributed.utils import tmpfile, nbytes -from distributed.utils_test import gen_cluster from distributed.protocol.numpy import itemsize from distributed.protocol.compression import maybe_compress +from distributed.system import MEMORY_LIMIT +from distributed.utils import tmpfile, nbytes +from distributed.utils_test import gen_cluster def test_serialize(): @@ -151,9 +152,8 @@ def test_memmap(): @pytest.mark.slow def test_dumps_serialize_numpy_large(): - psutil = pytest.importorskip("psutil") - if psutil.virtual_memory().total < 2e9: - return + if MEMORY_LIMIT < 2e9: + pytest.skip("insufficient memory") x = np.random.random(size=int(BIG_BYTES_SHARD_SIZE * 2 // 8)).view("u1") assert x.nbytes == BIG_BYTES_SHARD_SIZE * 2 frames = dumps([to_serialize(x)]) diff --git a/distributed/protocol/tests/test_protocol.py b/distributed/protocol/tests/test_protocol.py index 3dd11ecc4d1..bf16aecf2f4 100644 --- a/distributed/protocol/tests/test_protocol.py +++ b/distributed/protocol/tests/test_protocol.py @@ -6,6 +6,7 @@ from distributed.protocol import loads, dumps, msgpack, maybe_compress, to_serialize from distributed.protocol.compression import compressions from distributed.protocol.serialize import Serialize, Serialized, serialize, deserialize +from distributed.system import MEMORY_LIMIT from distributed.utils import nbytes @@ -102,13 +103,9 @@ def test_large_bytes(): @pytest.mark.slow def test_large_messages(): np = pytest.importorskip("numpy") - psutil = pytest.importorskip("psutil") pytest.importorskip("lz4") - if psutil.virtual_memory().total < 8e9: - return - - if sys.version_info.major == 2: - return 2 + if MEMORY_LIMIT < 8e9: + pytest.skip("insufficient memory") x = np.random.randint(0, 255, size=200000000, dtype="u1") @@ -126,9 +123,7 @@ def test_large_messages(): def test_large_messages_map(): - import psutil - - if psutil.virtual_memory().total < 8e9: + if MEMORY_LIMIT < 8e9: pytest.skip("insufficient memory") x = {i: "mystring_%d" % i for i in range(100000)} diff --git a/distributed/system.py b/distributed/system.py new file mode 100644 index 00000000000..e0735e1e34a --- /dev/null +++ b/distributed/system.py @@ -0,0 +1,79 @@ +import os +import sys + +import psutil + +__all__ = ("memory_limit", "cpu_count", "MEMORY_LIMIT", "CPU_COUNT") + + +def memory_limit(): + """Get the memory limit (in bytes) for this system. + + Takes the minimum value from the following locations: + + - Total system host memory + - Cgroups limit (if set) + - RSS rlimit (if set) + """ + limit = psutil.virtual_memory().total + + # Check cgroups if available + if sys.platform == "linux": + try: + with open("/sys/fs/cgroup/memory/memory.limit_in_bytes") as f: + cgroups_limit = int(f.read()) + if cgroups_limit > 0: + limit = min(limit, cgroups_limit) + except Exception: + pass + + # Check rlimit if available + try: + import resource + + hard_limit = resource.getrlimit(resource.RLIMIT_RSS)[1] + if hard_limit > 0: + limit = min(limit, hard_limit) + except (ImportError, OSError): + pass + + return limit + + +def cpu_count(): + """Get the available CPU count for this system. + + Takes the minimum value from the following locations: + + - Total system cpus available on the host. + - CPU Affinity (if set) + - Cgroups limit (if set) + """ + count = os.cpu_count() + + # Check CPU affinity if available + try: + affinity_count = len(psutil.Process().cpu_affinity()) + if affinity_count > 0: + count = min(count, affinity_count) + except Exception: + pass + + # Check cgroups if available + if sys.platform == "linux": + try: + with open("/sys/fs/cgroup/cpuacct,cpu/cpu.cfs_quota_us") as f: + quota = int(f.read()) + with open("/sys/fs/cgroup/cpuacct,cpu/cpu.cfs_period_us") as f: + period = int(f.read()) + cgroups_count = int(quota / period) + if cgroups_count > 0: + count = min(count, cgroups_count) + except Exception: + pass + + return count + + +MEMORY_LIMIT = memory_limit() +CPU_COUNT = cpu_count() diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index d7c396bb63f..6d98e662034 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -13,6 +13,7 @@ from distributed.config import config from distributed.metrics import time from distributed.scheduler import key_split +from distributed.system import MEMORY_LIMIT from distributed.utils_test import ( slowinc, slowadd, @@ -22,7 +23,6 @@ captured_logger, ) from distributed.utils_test import nodebug_setup_module, nodebug_teardown_module -from distributed.worker import TOTAL_MEMORY import pytest @@ -170,7 +170,7 @@ def test_new_worker_steals(c, s, a): while len(a.task_state) < 10: yield gen.sleep(0.01) - b = yield Worker(s.address, loop=s.loop, nthreads=1, memory_limit=TOTAL_MEMORY) + b = yield Worker(s.address, loop=s.loop, nthreads=1, memory_limit=MEMORY_LIMIT) result = yield total assert result == sum(map(inc, range(100))) @@ -335,7 +335,7 @@ def test_dont_steal_few_saturated_tasks_many_workers(c, s, a, *rest): @gen_cluster( client=True, nthreads=[("127.0.0.1", 1)] * 10, - worker_kwargs={"memory_limit": TOTAL_MEMORY}, + worker_kwargs={"memory_limit": MEMORY_LIMIT}, ) def test_steal_when_more_tasks(c, s, a, *rest): s.extensions["stealing"]._pc.callback_time = 20 diff --git a/distributed/tests/test_system.py b/distributed/tests/test_system.py new file mode 100644 index 00000000000..d0f00f495e3 --- /dev/null +++ b/distributed/tests/test_system.py @@ -0,0 +1,32 @@ +import os + +import psutil +import pytest + +from distributed.system import cpu_count, memory_limit + + +def test_cpu_count(): + count = cpu_count() + assert isinstance(count, int) + assert count <= os.cpu_count() + assert count >= 1 + + +def test_memory_limit(): + limit = memory_limit() + assert isinstance(limit, int) + assert limit <= psutil.virtual_memory().total + assert limit >= 1 + + +def test_rlimit(): + resource = pytest.importorskip("resource") + + # decrease memory limit by one byte + new_limit = memory_limit() - 1 + try: + resource.setrlimit(resource.RLIMIT_RSS, (new_limit, new_limit)) + assert memory_limit() == new_limit + except OSError: + pytest.skip("resource could not set the RSS limit") diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 13dd92c00e0..aa0de9fd90f 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -2,7 +2,6 @@ from datetime import timedelta import importlib import logging -import multiprocessing from numbers import Number from operator import add import os @@ -30,6 +29,7 @@ get_worker, Reschedule, wait, + system, ) from distributed.compatibility import WINDOWS from distributed.core import rpc @@ -62,7 +62,7 @@ def test_worker_nthreads(): w = Worker("127.0.0.1", 8019) try: - assert w.executor._max_workers == multiprocessing.cpu_count() + assert w.executor._max_workers == system.CPU_COUNT finally: shutil.rmtree(w.local_directory) @@ -500,7 +500,7 @@ def test_memory_limit_auto(): assert isinstance(a.memory_limit, Number) assert isinstance(b.memory_limit, Number) - if multiprocessing.cpu_count() > 1: + if system.CPU_COUNT > 1: assert a.memory_limit < b.memory_limit assert c.memory_limit == d.memory_limit @@ -1436,26 +1436,14 @@ def test_host_address(c, s): yield n.close() -def test_resource_limit(): +def test_resource_limit(monkeypatch): assert parse_memory_limit("250MiB", 1, total_cores=1) == 1024 * 1024 * 250 - # get current limit - resource = pytest.importorskip("resource") - try: - hard_limit = resource.getrlimit(resource.RLIMIT_RSS)[1] - except OSError: - pytest.skip("resource could not get the RSS limit") - memory_limit = psutil.virtual_memory().total - if hard_limit > memory_limit or hard_limit < 0: - hard_limit = memory_limit - - # decrease memory limit by one byte - new_limit = hard_limit - 1 - try: - resource.setrlimit(resource.RLIMIT_RSS, (new_limit, new_limit)) - assert parse_memory_limit(hard_limit, 1, total_cores=1) == new_limit - except OSError: - pytest.skip("resource could not set the RSS limit") + new_limit = 1024 * 1024 * 200 + import distributed.worker + + monkeypatch.setattr(distributed.system, "MEMORY_LIMIT", new_limit) + assert parse_memory_limit("250MiB", 1, total_cores=1) == new_limit @pytest.mark.asyncio diff --git a/distributed/utils_test.py b/distributed/utils_test.py index e6b11ce2898..8725e3eb7e0 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -38,6 +38,7 @@ from tornado.gen import TimeoutError from tornado.ioloop import IOLoop +from . import system from .client import default_client, _global_clients, Client from .compatibility import WINDOWS from .comm import Comm @@ -61,7 +62,7 @@ thread_state, _offload_executor, ) -from .worker import Worker, TOTAL_MEMORY +from .worker import Worker from .nanny import Nanny try: @@ -636,7 +637,11 @@ def cluster( q = mp_context.Queue() fn = "_test_worker-%s" % uuid.uuid4() kwargs = merge( - {"nthreads": 1, "local_directory": fn, "memory_limit": TOTAL_MEMORY}, + { + "nthreads": 1, + "local_directory": fn, + "memory_limit": system.MEMORY_LIMIT, + }, worker_kwargs, ) proc = mp_context.Process( @@ -860,7 +865,7 @@ def test_foo(scheduler, worker1, worker2): nthreads = ncores worker_kwargs = merge( - {"memory_limit": TOTAL_MEMORY, "death_timeout": 5}, worker_kwargs + {"memory_limit": system.MEMORY_LIMIT, "death_timeout": 5}, worker_kwargs ) def _(func): diff --git a/distributed/worker.py b/distributed/worker.py index 06102a98405..bb00158ced8 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -5,7 +5,6 @@ from datetime import timedelta import heapq import logging -import multiprocessing import os from pickle import PicklingError import random @@ -14,7 +13,6 @@ import uuid import warnings import weakref -import psutil import dask from dask.core import istask @@ -28,7 +26,7 @@ from tornado import gen from tornado.ioloop import IOLoop -from . import profile, comm +from . import profile, comm, system from .batched import BatchedSend from .comm import get_address_host, connect from .comm.addressing import address_from_user_args @@ -72,8 +70,6 @@ no_value = "--no-value-sentinel--" -TOTAL_MEMORY = psutil.virtual_memory().total - IN_PLAY = ("waiting", "ready", "executing", "long-running") PENDING = ("waiting", "ready", "constrained") PROCESSING = ("waiting", "ready", "constrained", "executing", "long-running") @@ -242,7 +238,7 @@ class Worker(ServerNode): memory_limit: int, float, string Number of bytes of memory that this worker should use. Set to zero for no limit. Set to 'auto' to calculate - as TOTAL_MEMORY * min(1, nthreads / total_cores) + as system.MEMORY_LIMIT * min(1, nthreads / total_cores) Use strings or numbers like 5GB or 5e9 memory_target_fraction: float Fraction of memory to try to stay beneath @@ -458,7 +454,7 @@ def __init__( warnings.warn("the ncores= parameter has moved to nthreads=") nthreads = ncores - self.nthreads = nthreads or multiprocessing.cpu_count() + self.nthreads = nthreads or system.CPU_COUNT self.total_resources = resources or {} self.available_resources = (resources or {}).copy() self.death_timeout = parse_timedelta(death_timeout) @@ -3042,33 +3038,23 @@ class Reschedule(Exception): pass -def parse_memory_limit(memory_limit, nthreads, total_cores=multiprocessing.cpu_count()): +def parse_memory_limit(memory_limit, nthreads, total_cores=system.CPU_COUNT): if memory_limit is None: return None if memory_limit == "auto": - memory_limit = int(TOTAL_MEMORY * min(1, nthreads / total_cores)) + memory_limit = int(system.MEMORY_LIMIT * min(1, nthreads / total_cores)) with ignoring(ValueError, TypeError): memory_limit = float(memory_limit) if isinstance(memory_limit, float) and memory_limit <= 1: - memory_limit = int(memory_limit * TOTAL_MEMORY) + memory_limit = int(memory_limit * system.MEMORY_LIMIT) if isinstance(memory_limit, str): memory_limit = parse_bytes(memory_limit) else: memory_limit = int(memory_limit) - # should be less than hard RSS limit - try: - import resource - - hard_limit = resource.getrlimit(resource.RLIMIT_RSS)[1] - if hard_limit > 0: - memory_limit = min(memory_limit, hard_limit) - except (ImportError, OSError): - pass - - return memory_limit + return min(memory_limit, system.MEMORY_LIMIT) async def get_data_from_worker( From 019f7a63464a73dcb50a246ce8bb6f2f69750bc2 Mon Sep 17 00:00:00 2001 From: Elliott Sales de Andrade Date: Thu, 12 Sep 2019 11:04:03 -0400 Subject: [PATCH 0458/1550] Use mock from unittest standard library. (#3049) Since distributed depends on Python 3.5+, there's no need to use the external mock package any more. --- continuous_integration/setup_conda_environment.cmd | 1 - continuous_integration/travis/install.sh | 1 - dev-requirements.txt | 1 - distributed/tests/test_diskutils.py | 2 +- distributed/tests/test_ipython.py | 2 +- distributed/tests/test_submit_cli.py | 2 +- distributed/utils_test.py | 2 +- 7 files changed, 4 insertions(+), 7 deletions(-) diff --git a/continuous_integration/setup_conda_environment.cmd b/continuous_integration/setup_conda_environment.cmd index d09846faedd..87e37751548 100644 --- a/continuous_integration/setup_conda_environment.cmd +++ b/continuous_integration/setup_conda_environment.cmd @@ -30,7 +30,6 @@ call deactivate ipywidgets ^ joblib ^ jupyter_client ^ - mock ^ msgpack-python ^ prometheus_client ^ psutil ^ diff --git a/continuous_integration/travis/install.sh b/continuous_integration/travis/install.sh index ead4dbc3002..f0c4a07be67 100644 --- a/continuous_integration/travis/install.sh +++ b/continuous_integration/travis/install.sh @@ -39,7 +39,6 @@ conda install -q \ ipywidgets \ joblib \ jupyter_client \ - mock \ netcdf4 \ paramiko \ prometheus_client \ diff --git a/dev-requirements.txt b/dev-requirements.txt index 8cc8f7d256d..3c4cf7954a3 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,5 +1,4 @@ joblib >= 0.10.2 -mock >= 2.0.0 pandas >= 0.19.2 numpy >= 1.11.0 bokeh >= 0.12.3 diff --git a/distributed/tests/test_diskutils.py b/distributed/tests/test_diskutils.py index c5cca9d5824..e12fb324341 100644 --- a/distributed/tests/test_diskutils.py +++ b/distributed/tests/test_diskutils.py @@ -6,8 +6,8 @@ import subprocess import sys from time import sleep +from unittest import mock -import mock import pytest import dask diff --git a/distributed/tests/test_ipython.py b/distributed/tests/test_ipython.py index 8f2a40e45eb..aa4a3e4092e 100644 --- a/distributed/tests/test_ipython.py +++ b/distributed/tests/test_ipython.py @@ -1,4 +1,4 @@ -import mock +from unittest import mock import pytest from toolz import first diff --git a/distributed/tests/test_submit_cli.py b/distributed/tests/test_submit_cli.py index 9273261dc94..edc16e0a61e 100644 --- a/distributed/tests/test_submit_cli.py +++ b/distributed/tests/test_submit_cli.py @@ -1,4 +1,4 @@ -from mock import Mock +from unittest.mock import Mock from tornado import gen from tornado.ioloop import IOLoop diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 8725e3eb7e0..cd53fe6f86e 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -191,7 +191,7 @@ def pristine_loop(): @contextmanager def mock_ipython(): - import mock + from unittest import mock from distributed._ipython_utils import remote_magic ip = mock.Mock() From bb127bba8e3a554ce7ca02675ec425d1336a30b2 Mon Sep 17 00:00:00 2001 From: Elliott Sales de Andrade Date: Thu, 12 Sep 2019 11:04:32 -0400 Subject: [PATCH 0459/1550] Add missing test data to sdist tarball. (#3050) Fixes #2700. --- MANIFEST.in | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/MANIFEST.in b/MANIFEST.in index a6c03274f24..b7a3764c87a 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -12,7 +12,8 @@ include README.rst include LICENSE.txt include MANIFEST.in include requirements.txt -include distributed/tests/mytestegg-1.0.0-py3.4.egg +include distributed/tests/testegg-1.0.0-py3.4.egg +include distributed/tests/mytest.pyz include distributed/tests/*.pem prune docs/_build From 0ed877267febb5c2eb2b759b7f77d5de56a146c9 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 12 Sep 2019 15:32:38 -0700 Subject: [PATCH 0460/1550] Remove six (#3045) --- distributed/cfexecutor.py | 5 ++-- distributed/client.py | 37 ++++++++++++++---------- distributed/comm/addressing.py | 4 +-- distributed/comm/core.py | 9 +++--- distributed/comm/registry.py | 6 ++-- distributed/core.py | 10 +++---- distributed/protocol/tests/test_numpy.py | 2 +- distributed/scheduler.py | 3 +- distributed/utils.py | 4 +-- distributed/utils_test.py | 6 ++-- requirements.txt | 1 - 11 files changed, 41 insertions(+), 46 deletions(-) diff --git a/distributed/cfexecutor.py b/distributed/cfexecutor.py index 34350462f8b..373a3c4eb28 100644 --- a/distributed/cfexecutor.py +++ b/distributed/cfexecutor.py @@ -1,8 +1,6 @@ import concurrent.futures as cf import weakref -import six - from toolz import merge from tornado import gen @@ -27,7 +25,8 @@ def _cascade_future(future, cf_future): cf_future.set_running_or_notify_cancel() else: try: - six.reraise(*result) + typ, exc, tb = result + raise exc.with_traceback(tb) except BaseException as exc: cf_future.set_exception(exc) diff --git a/distributed/client.py b/distributed/client.py index c11257ee74a..63d8213c33f 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -17,7 +17,6 @@ import sys import uuid import threading -import six import socket from queue import Queue as pyQueue import warnings @@ -219,7 +218,8 @@ def result(self, timeout=None): # shorten error traceback result = self.client.sync(self._result, callback_timeout=timeout, raiseit=False) if self.status == "error": - six.reraise(*result) + typ, exc, tb = result + raise exc.with_traceback(tb) elif self.status == "cancelled": raise result else: @@ -230,7 +230,8 @@ async def _result(self, raiseit=True): if self.status == "error": exc = clean_exception(self._state.exception, self._state.traceback) if raiseit: - six.reraise(*exc) + typ, exc, tb = exc + raise exc.with_traceback(tb) else: return exc elif self.status == "cancelled": @@ -1145,7 +1146,8 @@ async def _handle_report(self): logger.debug("Client receives message %s", msg) if "status" in msg and "error" in msg["status"]: - six.reraise(*clean_exception(**msg)) + typ, exc, tb = clean_exception(**msg) + raise exc.with_traceback(tb) op = msg.pop("op") @@ -1431,7 +1433,7 @@ def submit( if allow_other_workers and workers is None: raise ValueError("Only use allow_other_workers= if using workers=") - if isinstance(workers, six.string_types + (Number,)): + if isinstance(workers, (str, Number)): workers = [workers] if workers is not None: restrictions = {skey: workers} @@ -1577,7 +1579,7 @@ def map( } ) - if isinstance(workers, six.string_types + (Number,)): + if isinstance(workers, (str, Number)): workers = [workers] if isinstance(workers, (list, set)): if workers and isinstance(first(workers), (list, set)): @@ -1671,7 +1673,7 @@ async def wait(k): except (KeyError, AttributeError): exc = CancelledError(key) else: - six.reraise(type(exception), exception, traceback) + raise exception.with_traceback(traceback) raise exc if errors == "skip": bad_keys.add(key) @@ -1830,7 +1832,7 @@ async def _scatter( ): if timeout == no_default: timeout = self._timeout - if isinstance(workers, six.string_types + (Number,)): + if isinstance(workers, (str, Number)): workers = [workers] if isinstance(data, dict) and not all( isinstance(k, (bytes, str)) for k in data @@ -2196,7 +2198,8 @@ async def _run_on_scheduler(self, function, *args, wait=True, **kwargs): function=dumps(function), args=dumps(args), kwargs=dumps(kwargs), wait=wait ) if response["status"] == "error": - six.reraise(*clean_exception(**response)) + typ, exc, tb = clean_exception(**response) + raise exc.with_traceback(tb) else: return response["result"] @@ -2251,7 +2254,8 @@ async def _run( if resp["status"] == "OK": results[key] = resp["result"] elif resp["status"] == "error": - six.reraise(*clean_exception(**resp)) + typ, exc, tb = clean_exception(**resp) + raise exc.with_traceback(tb) if wait: return results @@ -3200,7 +3204,7 @@ def profile( >>> client.profile() # call on collections >>> client.profile(filename='dask-profile.html') # save to html file """ - if isinstance(workers, six.string_types + (Number,)): + if isinstance(workers, (str, Number)): workers = [workers] return self.sync( @@ -3224,7 +3228,7 @@ async def _profile( plot=False, filename=None, ): - if isinstance(workers, six.string_types + (Number,)): + if isinstance(workers, (str, Number)): workers = [workers] state = await self.scheduler.profile( @@ -3561,12 +3565,12 @@ def start_ipython_workers( -------- Client.start_ipython_scheduler: start ipython on the scheduler """ - if isinstance(workers, six.string_types + (Number,)): + if isinstance(workers, (str, Number)): workers = [workers] (workers, info_dict) = sync(self.loop, self._start_ipython_workers, workers) - if magic_names and isinstance(magic_names, six.string_types): + if magic_names and isinstance(magic_names, str): if "*" in magic_names: magic_names = [ magic_names.replace("*", str(i)) for i in range(len(workers)) @@ -3872,7 +3876,7 @@ async def _register_worker_plugin(self, plugin=None, name=None): exc = response["exception"] typ = type(exc) tb = response["traceback"] - six.reraise(typ, exc, tb) + raise exc.with_traceback(tb) return responses def register_worker_plugin(self, plugin=None, name=None): @@ -4180,7 +4184,8 @@ def _get_and_raise(self): if self.with_results: future, result = res if self.raise_errors and future.status == "error": - six.reraise(*result) + typ, exc, tb = result + raise exc.with_traceback(tb) return res def __next__(self): diff --git a/distributed/comm/addressing.py b/distributed/comm/addressing.py index 21a23e1ef6e..35d5e1c3407 100644 --- a/distributed/comm/addressing.py +++ b/distributed/comm/addressing.py @@ -1,5 +1,3 @@ -import six - import dask from . import registry @@ -18,7 +16,7 @@ def parse_address(addr, strict=False): If strict is set to true the address must have a scheme. """ - if not isinstance(addr, six.string_types): + if not isinstance(addr, str): raise TypeError("expected str, got %r" % addr.__class__.__name__) scheme, sep, loc = addr.rpartition("://") if strict and not sep: diff --git a/distributed/comm/core.py b/distributed/comm/core.py index 602b3161657..1bbc043f52d 100644 --- a/distributed/comm/core.py +++ b/distributed/comm/core.py @@ -1,10 +1,9 @@ -from abc import ABCMeta, abstractmethod, abstractproperty +from abc import ABC, abstractmethod, abstractproperty from datetime import timedelta import logging import weakref import dask -from six import with_metaclass from tornado import gen from ..metrics import time @@ -24,7 +23,7 @@ class FatalCommClosedError(CommClosedError): pass -class Comm(with_metaclass(ABCMeta)): +class Comm(ABC): """ A message-oriented communication object, representing an established communication channel. There should be only one reader and one @@ -129,7 +128,7 @@ def __repr__(self): ) -class Listener(with_metaclass(ABCMeta)): +class Listener(ABC): @abstractmethod def start(self): """ @@ -165,7 +164,7 @@ def __exit__(self, *exc): self.stop() -class Connector(with_metaclass(ABCMeta)): +class Connector(ABC): @abstractmethod def connect(self, address, deserialize=True): """ diff --git a/distributed/comm/registry.py b/distributed/comm/registry.py index b7fcca912cd..369f2415c35 100644 --- a/distributed/comm/registry.py +++ b/distributed/comm/registry.py @@ -1,9 +1,7 @@ -from abc import ABCMeta, abstractmethod +from abc import ABC, abstractmethod -from six import with_metaclass - -class Backend(with_metaclass(ABCMeta)): +class Backend(ABC): """ A communication backend, selected by a given URI scheme (e.g. 'tcp'). """ diff --git a/distributed/core.py b/distributed/core.py index 5da103802d2..d0f3c13aa97 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -3,14 +3,12 @@ from concurrent.futures import CancelledError from functools import partial import logging -import six import threading import traceback import uuid import weakref import dask -from six import string_types from toolz import merge from tornado import gen from tornado.ioloop import IOLoop @@ -305,7 +303,7 @@ def listen(self, port_or_addr=None, listen_args=None): addr = unparse_host_port(*port_or_addr) else: addr = port_or_addr - assert isinstance(addr, string_types) + assert isinstance(addr, str) self.listener = listen( addr, self.handle_comm, @@ -545,7 +543,8 @@ async def send_recv(comm, reply=True, serializers=None, deserializers=None, **kw if isinstance(response, dict) and response.get("status") == "uncaught-error": if comm.deserialize: - six.reraise(*clean_exception(**response)) + typ, exc, tb = clean_exception(**response) + raise exc.with_traceback(tb) else: raise Exception(response["text"]) return response @@ -969,7 +968,6 @@ def error_message(e, status="error"): See Also -------- clean_exception: deserialize and unpack message into exception/traceback - six.reraise: raise exception/traceback """ tb = get_traceback() e2 = truncate_exception(e, 1000) @@ -1011,6 +1009,6 @@ def clean_exception(exception, traceback, **kwargs): traceback = protocol.pickle.loads(traceback) except (TypeError, AttributeError): traceback = None - elif isinstance(traceback, string_types): + elif isinstance(traceback, str): traceback = None # happens if the traceback failed serializing return type(exception), exception, traceback diff --git a/distributed/protocol/tests/test_numpy.py b/distributed/protocol/tests/test_numpy.py index 70d57a2e74f..4fb20d58631 100644 --- a/distributed/protocol/tests/test_numpy.py +++ b/distributed/protocol/tests/test_numpy.py @@ -118,7 +118,7 @@ def test_serialize_numpy_ma_masked(): def test_dumps_serialize_numpy_custom_dtype(): - from six.moves import builtins + import builtins test_rational = pytest.importorskip("numpy.core.test_rational") rational = test_rational.rational diff --git a/distributed/scheduler.py b/distributed/scheduler.py index ff9560767c0..6bf1ba0adb4 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -12,7 +12,6 @@ import os import pickle import random -import six import warnings import weakref @@ -4518,7 +4517,7 @@ def coerce_address(self, addr, resolve=True): addr = self.aliases[addr] if isinstance(addr, tuple): addr = unparse_host_port(*addr) - if not isinstance(addr, six.string_types): + if not isinstance(addr, str): raise TypeError("addresses should be strings or tuples, got %r" % (addr,)) if resolve: diff --git a/distributed/utils.py b/distributed/utils.py index b7f6631ce93..015e4dbfb61 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -25,7 +25,6 @@ import warnings import weakref import pkgutil -import six import tblib.pickling_support import xml.etree.ElementTree @@ -334,7 +333,8 @@ def f(): while not e.is_set(): e.wait(10) if error[0]: - six.reraise(*error[0]) + typ, exc, tb = error[0] + raise exc.with_traceback(tb) else: return result[0] diff --git a/distributed/utils_test.py b/distributed/utils_test.py index cd53fe6f86e..97cbe783318 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -5,6 +5,7 @@ from datetime import timedelta import functools from glob import glob +import io import itertools import logging import logging.config @@ -30,7 +31,6 @@ ssl = None import pytest -import six import dask from toolz import merge, memoize, assoc @@ -1226,7 +1226,7 @@ def captured_logger(logger, level=logging.INFO, propagate=None): if propagate is not None: orig_propagate = logger.propagate logger.propagate = propagate - sio = six.StringIO() + sio = io.StringIO() logger.handlers[:] = [logging.StreamHandler(sio)] logger.setLevel(level) try: @@ -1244,7 +1244,7 @@ def captured_handler(handler): """ assert isinstance(handler, logging.StreamHandler) orig_stream = handler.stream - handler.stream = six.StringIO() + handler.stream = io.StringIO() try: yield handler.stream finally: diff --git a/requirements.txt b/requirements.txt index e376b2a50cc..804bdfe9637 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,6 @@ cloudpickle >= 0.2.2 dask >= 2 msgpack psutil >= 5.0 -six sortedcontainers !=2.0.0, !=2.0.1 tblib toolz >= 0.7.4 From 7d017c467590c758fa4b8cb2b1193205fe5aa7ad Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Fri, 13 Sep 2019 15:34:02 -0500 Subject: [PATCH 0461/1550] bump version to 2.4.0 --- docs/source/changelog.rst | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index 7385567467e..6bed266f32b 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -1,11 +1,44 @@ Changelog ========= +2.4.0 - 2019-09-13 +------------------ + +- Remove six (:pr:`3045`) `Matthew Rocklin`_ +- Add missing test data to sdist tarball (:pr:`3050`) `Elliott Sales de Andrade`_ +- Use mock from unittest standard library (:pr:`3049`) `Elliott Sales de Andrade`_ +- Use cgroups resource limits to determine default threads and memory (:pr:`3039`) `Jim Crist`_ +- Move task deserialization to immediately before task execution (:pr:`3015`) `James Bourbeau`_ +- Drop joblib shim module in distributed (:pr:`3040`) `John Kirkham`_ +- Redirect configuration doc page (:pr:`3038`) `Matthew Rocklin`_ +- Support ``--name 0`` and ``--nprocs`` keywords in dask-worker cli (:pr:`3037`) `Matthew Rocklin`_ +- Remove lost workers from ``SpecCluster.workers`` (:pr:`2990`) `Guillaume Eynard-Bontemps`_ +- Clean up ``test_local.py::test_defaults`` (:pr:`3017`) `Matthew Rocklin`_ +- Replace print statement in ``Queue.__init__`` with debug message (:pr:`3035`) `Mikhail Akimov`_ +- Set the ``x_range`` limit of the Meory utilization plot to memory-limit (:pr:`3034`) `Matthew Rocklin`_ +- Rely on cudf codebase for cudf serialization (:pr:`2998`) `Benjamin Zaitlen`_ +- Add fallback html repr for Cluster (:pr:`3023`) `Jim Crist`_ +- Add support for zstandard compression to comms (:pr:`2970`) `Abael He`_ +- Avoid collision when using ``os.environ`` in ``dashboard_link`` (:pr:`3021`) `Matthew Rocklin`_ +- Fix ``ConnectionPool`` limit handling (:pr:`3005`) `byjott`_ +- Support Spec jobs that generate multiple workers (:pr:`3013`) `Matthew Rocklin`_ +- Tweak ``Logs`` styling (:pr:`3012`) `Jim Crist`_ +- Better name for cudf deserialization function name (:pr:`3008`) `Benjamin Zaitlen`_ +- Make ``spec.ProcessInterface`` a valid no-op worker (:pr:`3004`) `Matthew Rocklin`_ +- Return dictionaries from ``new_worker_spec`` rather than name/worker pairs (:pr:`3000`) `Matthew Rocklin`_ +- Fix minor typo in documentation (:pr:`3002`) `Mohammad Noor`_ +- Permit more keyword options when scaling with cores and memory (:pr:`2997`) `Matthew Rocklin`_ +- Add ``cuda_ipc`` to UCX environment for NVLink (:pr:`2996`) `Benjamin Zaitlen`_ +- Add ``threads=`` and ``memory=`` to Cluster and Client reprs (:pr:`2995`) `Matthew Rocklin`_ +- Fix PyNVML initialization (:pr:`2993`) `Richard J Zamora`_ + + 2.3.2 - 2019-08-23 ------------------ - Skip exceptions in startup information (:pr:`2991`) `Jacob Tomlinson`_ + 2.3.1 - 2019-08-22 ------------------ @@ -1230,3 +1263,8 @@ significantly without many new features. .. _`Shayan Amani`: https://github.com/SHi-ON .. _`Pav A`: https://github.com/rs2 .. _`Mads R. B. Kristensen`: https://github.com/madsbk +.. _`Mikhail Akimov`: https://github.com/roveo +.. _`Abael He`: https://github.com/abaelhe +.. _`byjott`: https://github.com/byjott +.. _`Mohammad Noor`: https://github.com/MdSalih +.. _`Richard J Zamora`: https://github.com/rjzamora From 86d7c0385967e530208eeb18e8544564cda05c2b Mon Sep 17 00:00:00 2001 From: Chris White Date: Sun, 15 Sep 2019 09:55:55 -0700 Subject: [PATCH 0462/1550] Add blurb about disabling work stealing (#3055) --- docs/source/work-stealing.rst | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/docs/source/work-stealing.rst b/docs/source/work-stealing.rst index cf5a4bc48c1..83afd795b7f 100644 --- a/docs/source/work-stealing.rst +++ b/docs/source/work-stealing.rst @@ -125,3 +125,12 @@ the task and sends a response to the scheduler: This avoids redundant work, and also the duplication of side effects for more exotic tasks. However, concurrent or repeated execution of the same task *is still possible* in the event of worker death or a disrupted network connection. + + +Disabling Work Stealing +--------------------------- + +Work stealing is a toggleable setting on the Dask Scheduler; to disable +work stealing, you can toggle the scheduler ``work-stealing`` configuration +option to ``"False"`` either by setting ``DASK_DISTRIBUTED__SCHEDULER__WORK_STEALING="False"`` +or through your `Dask configuration file `_ From 8a41770e8ba219b4b114027d5ed61806ff8e8612 Mon Sep 17 00:00:00 2001 From: Jim Crist Date: Mon, 16 Sep 2019 11:15:53 -0500 Subject: [PATCH 0463/1550] Check multiple cgroups dirs, ceil fractional cpus (#3056) A few fixes for resource detection using cgroups: - The directory for determining cpu availability isn't standardized across linux distros, could be either `cpuacct,cpu`, or `cpu,cpuacct`. We now check for both. - When allotted fractional cpus (e.g. 1.5), we now round up. Also adds tests for both CPU and memory limit detection under cgroups, by monkeypatching in fake files. Fixes #3053 --- distributed/system.py | 25 ++++++++------ distributed/tests/test_system.py | 57 ++++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 10 deletions(-) diff --git a/distributed/system.py b/distributed/system.py index e0735e1e34a..291248ddded 100644 --- a/distributed/system.py +++ b/distributed/system.py @@ -1,3 +1,4 @@ +import math import os import sys @@ -61,16 +62,20 @@ def cpu_count(): # Check cgroups if available if sys.platform == "linux": - try: - with open("/sys/fs/cgroup/cpuacct,cpu/cpu.cfs_quota_us") as f: - quota = int(f.read()) - with open("/sys/fs/cgroup/cpuacct,cpu/cpu.cfs_period_us") as f: - period = int(f.read()) - cgroups_count = int(quota / period) - if cgroups_count > 0: - count = min(count, cgroups_count) - except Exception: - pass + # The directory name isn't standardized across linux distros, check both + for dirname in ["cpuacct,cpu", "cpu,cpuacct"]: + try: + with open("/sys/fs/cgroup/%s/cpu.cfs_quota_us" % dirname) as f: + quota = int(f.read()) + with open("/sys/fs/cgroup/%s/cpu.cfs_period_us" % dirname) as f: + period = int(f.read()) + # We round up on fractional CPUs + cgroups_count = math.ceil(quota / period) + if cgroups_count > 0: + count = min(count, cgroups_count) + break + except Exception: + pass return count diff --git a/distributed/tests/test_system.py b/distributed/tests/test_system.py index d0f00f495e3..d276613b520 100644 --- a/distributed/tests/test_system.py +++ b/distributed/tests/test_system.py @@ -1,4 +1,7 @@ +import builtins +import io import os +import sys import psutil import pytest @@ -13,6 +16,44 @@ def test_cpu_count(): assert count >= 1 +@pytest.mark.parametrize("dirname", ["cpuacct,cpu", "cpu,cpuacct", None]) +def test_cpu_count_cgroups(dirname, monkeypatch): + def mycpu_count(): + # Absurdly high, unlikely to match real value + return 250 + + monkeypatch.setattr(os, "cpu_count", mycpu_count) + + class MyProcess(object): + def cpu_affinity(self): + # No affinity set + return [] + + monkeypatch.setattr(psutil, "Process", MyProcess) + + if dirname: + paths = { + "/sys/fs/cgroup/%s/cpu.cfs_quota_us" % dirname: io.StringIO("2005"), + "/sys/fs/cgroup/%s/cpu.cfs_period_us" % dirname: io.StringIO("10"), + } + builtin_open = builtins.open + + def myopen(path, *args, **kwargs): + if path in paths: + return paths.get(path) + return builtin_open(path, *args, **kwargs) + + monkeypatch.setattr(builtins, "open", myopen) + monkeypatch.setattr(sys, "platform", "linux") + + count = cpu_count() + if dirname: + # Rounds up + assert count == 201 + else: + assert count == 250 + + def test_memory_limit(): limit = memory_limit() assert isinstance(limit, int) @@ -20,6 +61,22 @@ def test_memory_limit(): assert limit >= 1 +def test_memory_limit_cgroups(monkeypatch): + builtin_open = builtins.open + + def myopen(path, *args, **kwargs): + if path == "/sys/fs/cgroup/memory/memory.limit_in_bytes": + # Absurdly low, unlikely to match real value + return io.StringIO("20") + return builtin_open(path, *args, **kwargs) + + monkeypatch.setattr(builtins, "open", myopen) + monkeypatch.setattr(sys, "platform", "linux") + + limit = memory_limit() + assert limit == 20 + + def test_rlimit(): resource = pytest.importorskip("resource") From 1f5cc12654d743c15947b2a997eaba8770922a2a Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 16 Sep 2019 12:39:14 -0700 Subject: [PATCH 0464/1550] Allow full script in preload inputs (#3052) Previously we allowed either a script or a module in Scheduler or Worker preload arguments. worker = Worker(..., preload='myfile.py') However for simple scripts we sometimes don't want to create a full file, but just want to include the script as a multi-line text value. script = """def dask_setup(worker):\n worker.foo=123""" worker = Worker(..., preload=script) --- distributed/preloading.py | 78 ++++++++++++++++--------------- distributed/tests/test_preload.py | 17 ++++++- 2 files changed, 56 insertions(+), 39 deletions(-) diff --git a/distributed/preloading.py b/distributed/preloading.py index a5e67c1611a..9b276b4337f 100644 --- a/distributed/preloading.py +++ b/distributed/preloading.py @@ -8,6 +8,8 @@ import click +from dask.utils import tmpfile + from .utils import import_file logger = logging.getLogger(__name__) @@ -29,7 +31,7 @@ def validate_preload_argv(ctx, param, value): % ("s" if len(value) > 1 else "", " ".join(value)) ) - preload_modules = _import_modules(ctx.params.get("preload")) + preload_modules = {name: _import_module(name) for name in ctx.params.get("preload")} preload_commands = [ m["dask_setup"] @@ -58,17 +60,16 @@ def validate_preload_argv(ctx, param, value): return value -def _import_modules(names, file_dir=None): - """ Imports modules and extracts preload interface functions. +def _import_module(name, file_dir=None): + """ Imports module and extract preload interface functions. - Imports modules specified by names and extracts 'dask_setup' + Import modules specified by name and extract 'dask_setup' and 'dask_teardown' if present. - Parameters ---------- - names: list of strings - Module names or file paths + name: str + Module name, file path, or text of module or script file_dir: string Path of a directory where files should be copied @@ -77,36 +78,37 @@ def _import_modules(names, file_dir=None): Nest dict of names to extracted module interface components if present in imported module. """ - result_modules = {} - - for name in names: - # import - if name.endswith(".py"): - # name is a file path - if file_dir is not None: - basename = os.path.basename(name) - copy_dst = os.path.join(file_dir, basename) - if os.path.exists(copy_dst): - if not filecmp.cmp(name, copy_dst): - logger.error("File name collision: %s", basename) - shutil.copy(name, copy_dst) - module = import_file(copy_dst)[0] - else: - module = import_file(name)[0] - + if name.endswith(".py"): + # name is a file path + if file_dir is not None: + basename = os.path.basename(name) + copy_dst = os.path.join(file_dir, basename) + if os.path.exists(copy_dst): + if not filecmp.cmp(name, copy_dst): + logger.error("File name collision: %s", basename) + shutil.copy(name, copy_dst) + module = import_file(copy_dst)[0] else: - # name is a module name - if name not in sys.modules: - import_module(name) - module = sys.modules[name] + module = import_file(name)[0] - logger.info("Import preload module: %s", name) - result_modules[name] = { - attrname: getattr(module, attrname, None) - for attrname in ("dask_setup", "dask_teardown") - } + elif " " not in name: + # name is a module name + if name not in sys.modules: + import_module(name) + module = sys.modules[name] - return result_modules + else: + # not a name, actually the text of the script + with tmpfile(extension=".py") as fn: + with open(fn, mode="w") as f: + f.write(name) + return _import_module(fn, file_dir=file_dir) + + logger.info("Import preload module: %s", name) + return { + attrname: getattr(module, attrname, None) + for attrname in ("dask_setup", "dask_teardown") + } def preload_modules(names, parameter=None, file_dir=None, argv=None): @@ -123,10 +125,12 @@ def preload_modules(names, parameter=None, file_dir=None, argv=None): file_dir: string Path of a directory where files should be copied """ + if isinstance(names, str): + names = [names] - imported_modules = _import_modules(names, file_dir=file_dir) + for name in names: + interface = _import_module(name, file_dir=file_dir) - for name, interface in imported_modules.items(): dask_setup = interface.get("dask_setup", None) dask_teardown = interface.get("dask_teardown", None) @@ -140,5 +144,5 @@ def preload_modules(names, parameter=None, file_dir=None, argv=None): dask_setup(parameter) logger.info("Run preload setup function: %s", name) - if interface["dask_teardown"]: + if dask_teardown: atexit.register(interface["dask_teardown"], parameter) diff --git a/distributed/tests/test_preload.py b/distributed/tests/test_preload.py index 07ee56d85a6..9ce804b752a 100644 --- a/distributed/tests/test_preload.py +++ b/distributed/tests/test_preload.py @@ -2,10 +2,11 @@ import shutil import sys import tempfile +import pytest -from distributed import Client +from distributed import Client, Scheduler, Worker from distributed.utils_test import cluster -from distributed.utils_test import loop # noqa F401 +from distributed.utils_test import cleanup, loop # noqa F401 PRELOAD_TEXT = """ @@ -42,6 +43,18 @@ def check_worker(): shutil.rmtree(tmpdir) +@pytest.mark.asyncio +async def test_worker_preload_text(cleanup): + text = """ +def dask_setup(worker): + worker.foo = 'setup' +""" + async with Scheduler(port=0, preload=text) as s: + assert s.foo == "setup" + async with Worker(s.address, preload=[text]) as w: + assert w.foo == "setup" + + def test_worker_preload_module(loop): def check_worker(): import worker_info From d16aabc012442cee0187e7af0c8ca30dff01e9a8 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 18 Sep 2019 11:01:03 -0500 Subject: [PATCH 0465/1550] Allow SpecCluster to scale by memory and cores (#3057) ```python >>> cluster.adapt(minimum_cores=10, maximum_memory="100 GiB") ``` --- distributed/deploy/spec.py | 115 ++++++++++++++---- distributed/deploy/tests/test_adaptive.py | 30 +++++ .../deploy/tests/test_slow_adaptive.py | 2 +- 3 files changed, 122 insertions(+), 25 deletions(-) diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index 487e5192e17..87336f96184 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -8,6 +8,7 @@ import dask from tornado import gen +from .adaptive import Adaptive from .cluster import Cluster from ..core import rpc, CommClosedError from ..utils import LoopRunner, silence_logging, ignoring, parse_bytes, parse_timedelta @@ -361,34 +362,39 @@ def __exit__(self, typ, value, traceback): self.close() self._loop_runner.stop() + def _threads_per_worker(self) -> int: + """ Return the number of threads per worker for new workers """ + if not self.new_spec: + raise ValueError("To scale by cores= you must specify cores per worker") + + for name in ["nthreads", "ncores", "threads", "cores"]: + with ignoring(KeyError): + return self.new_spec["options"][name] + + if not self.new_spec: + raise ValueError("To scale by cores= you must specify cores per worker") + + def _memory_per_worker(self) -> int: + """ Return the memory limit per worker for new workers """ + if not self.new_spec: + raise ValueError( + "to scale by memory= your worker definition must include a memory_limit definition" + ) + + for name in ["memory_limit", "memory"]: + with ignoring(KeyError): + return parse_bytes(self.new_spec["options"][name]) + + raise ValueError( + "to use scale(memory=...) your worker definition must include a memory_limit definition" + ) + def scale(self, n=0, memory=None, cores=None): if memory is not None: - for name in ["memory_limit", "memory"]: - try: - limit = self.new_spec["options"][name] - except KeyError: - pass - else: - n = max(n, int(math.ceil(parse_bytes(memory) / parse_bytes(limit)))) - break - else: - raise ValueError( - "to use scale(memory=...) your worker definition must include a memory_limit definition" - ) + n = max(n, int(math.ceil(parse_bytes(memory) / self._memory_per_worker()))) if cores is not None: - for name in ["nthreads", "ncores", "threads", "cores"]: - try: - threads_per_worker = self.new_spec["options"][name] - except KeyError: - pass - else: - n = max(n, int(math.ceil(cores / threads_per_worker))) - break - else: - raise ValueError( - "to use scale(cores=...) your worker definition must include an nthreads= definition" - ) + n = max(n, int(math.ceil(cores / self._threads_per_worker()))) if len(self.worker_spec) > n: not_yet_launched = set(self.worker_spec) - { @@ -473,6 +479,67 @@ def requested(self): out.add(name) return out + def adapt( + self, + *args, + minimum=0, + maximum=math.inf, + minimum_cores: int = None, + maximum_cores: int = None, + minimum_memory: str = None, + maximum_memory: str = None, + **kwargs + ) -> Adaptive: + """ Turn on adaptivity + + This scales Dask clusters automatically based on scheduler activity. + + Parameters + ---------- + minimum : int + Minimum number of workers + maximum : int + Maximum number of workers + minimum_cores : int + Minimum number of cores/threads to keep around in the cluster + maximum_cores : int + Maximum number of cores/threads to keep around in the cluster + minimum_memory : str + Minimum amount of memory to keep around in the cluster + Expressed as a string like "100 GiB" + maximum_cores : int + Maximum amount of memory to keep around in the cluster + Expressed as a string like "100 GiB" + + Examples + -------- + >>> cluster.adapt(minimum=0, maximum_memory="100 GiB", interval='500ms') + + See Also + -------- + dask.distributed.Adaptive : for more keyword arguments + """ + if minimum_cores is not None: + minimum = max( + minimum or 0, math.ceil(minimum_cores / self._threads_per_worker()) + ) + if minimum_memory is not None: + minimum = max( + minimum or 0, + math.ceil(parse_bytes(minimum_memory) / self._memory_per_worker()), + ) + if maximum_cores is not None: + maximum = min( + maximum, math.floor(maximum_cores / self._threads_per_worker()) + ) + if maximum_memory is not None: + maximum = min( + maximum, + math.floor(parse_bytes(maximum_memory) / self._memory_per_worker()), + ) + + return super().adapt(*args, minimum=minimum, maximum=maximum, **kwargs) + @atexit.register def close_clusters(): diff --git a/distributed/deploy/tests/test_adaptive.py b/distributed/deploy/tests/test_adaptive.py index 261b4355251..af198747822 100644 --- a/distributed/deploy/tests/test_adaptive.py +++ b/distributed/deploy/tests/test_adaptive.py @@ -397,3 +397,33 @@ def key(ws): names = {ws.name for ws in cluster.scheduler.workers.values()} assert names == {"a-1", "a-2"} or names == {"b-1", "b-2"} + + +@pytest.mark.asyncio +async def test_adapt_cores_memory(cleanup): + async with LocalCluster( + 0, + threads_per_worker=2, + memory_limit="3 GB", + scheduler_port=0, + silence_logs=False, + processes=False, + dashboard_address=None, + asynchronous=True, + ) as cluster: + adapt = cluster.adapt(minimum_cores=3, maximum_cores=9) + assert adapt.minimum == 2 + assert adapt.maximum == 4 + + adapt = cluster.adapt(minimum_memory="7GB", maximum_memory="20 GB") + assert adapt.minimum == 3 + assert adapt.maximum == 6 + + adapt = cluster.adapt( + minimum_cores=1, + minimum_memory="7GB", + maximum_cores=10, + maximum_memory="1 TB", + ) + assert adapt.minimum == 3 + assert adapt.maximum == 5 diff --git a/distributed/deploy/tests/test_slow_adaptive.py b/distributed/deploy/tests/test_slow_adaptive.py index 4f565a78289..09113fe3b23 100644 --- a/distributed/deploy/tests/test_slow_adaptive.py +++ b/distributed/deploy/tests/test_slow_adaptive.py @@ -47,7 +47,7 @@ async def test_startup(cleanup): ) as cluster: assert len(cluster.workers) == len(cluster.worker_spec) == 3 assert time() < start + 5 - assert 1 <= len(cluster.scheduler_info["workers"]) <= 2 + assert 0 <= len(cluster.scheduler_info["workers"]) <= 2 async with Client(cluster, asynchronous=True) as client: await client.wait_for_workers(n_workers=2) From 4bfe42d8b00b7553609a1797bdb88b93c7efd062 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 18 Sep 2019 11:01:20 -0500 Subject: [PATCH 0466/1550] Use Cluster.scheduler_info for workers= value in repr (#3058) Previously we would use `Cluster.workers`, which might differ from the number of Dask workers if a single job generated several Dask workers. Now, we use the reported number of workers from the scheduler. --- distributed/deploy/cluster.py | 2 +- distributed/deploy/tests/test_spec_cluster.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/distributed/deploy/cluster.py b/distributed/deploy/cluster.py index 32ceedc47bd..ad47881ba22 100644 --- a/distributed/deploy/cluster.py +++ b/distributed/deploy/cluster.py @@ -365,7 +365,7 @@ def __repr__(self): text = "%s(%r, workers=%d, threads=%d" % ( self._cluster_class_name, self.scheduler_address, - len(self.workers), + len(self.scheduler_info["workers"]), sum(w["nthreads"] for w in self.scheduler_info["workers"].values()), ) diff --git a/distributed/deploy/tests/test_spec_cluster.py b/distributed/deploy/tests/test_spec_cluster.py index 679adc0fd60..d8c155c7c69 100644 --- a/distributed/deploy/tests/test_spec_cluster.py +++ b/distributed/deploy/tests/test_spec_cluster.py @@ -374,6 +374,8 @@ async def test_MultiWorker(cleanup): assert len(cluster.worker_spec) == 2 await client.wait_for_workers(4) + assert "workers=4" in repr(cluster) + cluster.scale(1) await cluster assert len(s.workers) == 2 From 71c7e4a6305942ae038db05cd55a29e184f998ce Mon Sep 17 00:00:00 2001 From: Jacob Tomlinson Date: Wed, 18 Sep 2019 20:37:10 +0100 Subject: [PATCH 0467/1550] Allow specification of worker type in SSHCLuster (#3061) --- distributed/deploy/ssh2.py | 32 ++++++++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/distributed/deploy/ssh2.py b/distributed/deploy/ssh2.py index 064e580b111..cb6b967d544 100644 --- a/distributed/deploy/ssh2.py +++ b/distributed/deploy/ssh2.py @@ -57,6 +57,8 @@ class Worker(Process): The address of the scheduler address: str The hostname where we should run this worker + worker_module: str + The python module to run to start the worker. connect_kwargs: dict kwargs to be passed to asyncssh connections kwargs: dict @@ -70,11 +72,13 @@ def __init__( address: str, connect_kwargs: dict, kwargs: dict, + worker_module="distributed.cli.dask_worker", loop=None, name=None, ): self.address = address self.scheduler = scheduler + self.worker_module = worker_module self.connect_kwargs = connect_kwargs self.kwargs = kwargs self.name = name @@ -88,7 +92,7 @@ async def start(self): [ sys.executable, "-m", - "distributed.cli.dask_worker", + self.worker_module, self.scheduler, "--name", str(self.name), @@ -158,7 +162,12 @@ async def start(self): def SSHCluster( - hosts, connect_kwargs={}, worker_kwargs={}, scheduler_kwargs={}, **kwargs + hosts, + connect_kwargs={}, + worker_kwargs={}, + scheduler_kwargs={}, + worker_module="distributed.cli.dask_worker", + **kwargs ): """ Deploy a Dask cluster using SSH @@ -174,10 +183,12 @@ def SSHCluster( key presented during the SSH handshake. If this is not specified, the keys will be looked up in the file .ssh/known_hosts. If this is explicitly set to None, server host key validation will be disabled. - scheduler_kwargs: - Keywords to pass on to dask-scheduler worker_kwargs: Keywords to pass on to dask-worker + scheduler_kwargs: + Keywords to pass on to dask-scheduler + worker_module: + Python module to call to start the worker Examples -------- @@ -189,6 +200,18 @@ def SSHCluster( ... worker_kwargs={"nthreads": 2}, ... scheduler_kwargs={"port": 0, "dashboard_address": ":8797"}) >>> client = Client(cluster) + + Running GPU workers (requires ``dask_cuda`` to be installed on all hosts) + + >>> from dask.distributed import Client + >>> from distributed.deploy.ssh2 import SSHCluster # experimental for now + >>> cluster = SSHCluster( + ... ["localhost", "hostwithgpus", "anothergpuhost"], + ... connect_kwargs={"known_hosts": None}, + ... scheduler_kwargs={"port": 0, "dashboard_address": ":8797"}, + ... worker_module='dask_cuda.dask_cuda_worker') + >>> client = Client(cluster) + """ scheduler = { "cls": Scheduler, @@ -205,6 +228,7 @@ def SSHCluster( "address": host, "connect_kwargs": connect_kwargs, "kwargs": worker_kwargs, + "worker_module": worker_module, }, } for i, host in enumerate(hosts[1:]) From face9e8273ef755aa658d89d8f3dd80b1c02cd7c Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Fri, 20 Sep 2019 16:16:41 +0200 Subject: [PATCH 0468/1550] Protocol of cupy and numba handles serialization exclusively (#3047) --- distributed/protocol/cupy.py | 32 +++--------- distributed/protocol/numba.py | 62 ++++++------------------ distributed/protocol/tests/test_cupy.py | 6 ++- distributed/protocol/tests/test_numba.py | 19 ++++++++ 4 files changed, 44 insertions(+), 75 deletions(-) create mode 100644 distributed/protocol/tests/test_numba.py diff --git a/distributed/protocol/cupy.py b/distributed/protocol/cupy.py index f8d08ee3a1e..d85f37d8a1e 100644 --- a/distributed/protocol/cupy.py +++ b/distributed/protocol/cupy.py @@ -7,36 +7,18 @@ @cuda_serialize.register(cupy.ndarray) def serialize_cupy_ndarray(x): - # TODO: handle non-contiguous - # TODO: Handle order='K' ravel - # TODO: 0d + # Making sure `x` is behaving + if not x.flags.c_contiguous: + x = cupy.array(x, copy=True) - if x.flags.c_contiguous or x.flags.f_contiguous: - strides = x.strides - data = x.ravel() # order='K' - else: - x = cupy.ascontiguousarray(x) - strides = x.strides - data = x.ravel() - - dtype = (0, x.dtype.str) - - # used in the ucx comms for gpu/cpu message passing - # 'lengths' set by dask header = x.__cuda_array_interface__.copy() - header["is_cuda"] = 1 - header["dtype"] = dtype - return header, [data] + return header, [x] @cuda_deserialize.register(cupy.ndarray) def deserialize_cupy_array(header, frames): (frame,) = frames - # TODO: put this in ucx... as a kind of "fixup" - try: - frame.typestr = header["typestr"] - frame.shape = header["shape"] - except AttributeError: - pass - arr = cupy.asarray(frame) + arr = cupy.ndarray( + header["shape"], dtype=header["typestr"], memptr=cupy.asarray(frame).data + ) return arr diff --git a/distributed/protocol/numba.py b/distributed/protocol/numba.py index aa56a682b95..ddf43adc182 100644 --- a/distributed/protocol/numba.py +++ b/distributed/protocol/numba.py @@ -1,61 +1,27 @@ +import numpy as np import numba.cuda from .cuda import cuda_serialize, cuda_deserialize @cuda_serialize.register(numba.cuda.devicearray.DeviceNDArray) def serialize_numba_ndarray(x): - # TODO: handle non-contiguous - # TODO: handle 2d - # TODO: 0d - - if x.flags["C_CONTIGUOUS"] or x.flags["F_CONTIGUOUS"]: - strides = x.strides - if x.ndim > 1: - data = x.ravel() # order='K' - else: - data = x - else: - raise ValueError("Array must be contiguous") - x = numba.ascontiguousarray(x) - strides = x.strides - if x.ndim > 1: - data = x.ravel() - else: - data = x - - dtype = (0, x.dtype.str) - nbytes = data.dtype.itemsize * data.size - - # used in the ucx comms for gpu/cpu message passing - # 'lengths' set by dask + # Making sure `x` is behaving + if not x.is_c_contiguous(): + shape = x.shape + t = numba.cuda.device_array(shape, dtype=x.dtype) + t.copy_to_device(x) + x = t header = x.__cuda_array_interface__.copy() - header["is_cuda"] = 1 - header["dtype"] = dtype - return header, [data] + return header, [x] @cuda_deserialize.register(numba.cuda.devicearray.DeviceNDArray) def deserialize_numba_ndarray(header, frames): (frame,) = frames - # TODO: put this in ucx... as a kind of "fixup" - if isinstance(frame, bytes): - import numpy as np - - arr2 = np.frombuffer(frame, header["typestr"]) - return numba.cuda.to_device(arr2) - - frame.typestr = header["typestr"] - frame.shape = header["shape"] - - # numba & cupy don't properly roundtrip length-zero arrays. - if frame.shape[0] == 0: - arr = numba.cuda.device_array( - header["shape"], - header["typestr"] - # strides? - # order? - ) - return arr - - arr = numba.cuda.as_cuda_array(frame) + arr = numba.cuda.devicearray.DeviceNDArray( + header["shape"], + header["strides"], + np.dtype(header["typestr"]), + gpu_data=numba.cuda.as_cuda_array(frame).gpu_data, + ) return arr diff --git a/distributed/protocol/tests/test_cupy.py b/distributed/protocol/tests/test_cupy.py index 26940597f81..10335d14338 100644 --- a/distributed/protocol/tests/test_cupy.py +++ b/distributed/protocol/tests/test_cupy.py @@ -4,8 +4,10 @@ cupy = pytest.importorskip("cupy") -def test_serialize_cupy(): - x = cupy.arange(100) +@pytest.mark.parametrize("size", [0, 10]) +@pytest.mark.parametrize("dtype", ["u1", "u4", "u8", "f4"]) +def test_serialize_cupy(size, dtype): + x = cupy.arange(size, dtype=dtype) header, frames = serialize(x, serializers=("cuda", "dask", "pickle")) y = deserialize(header, frames, deserializers=("cuda", "dask", "pickle", "error")) diff --git a/distributed/protocol/tests/test_numba.py b/distributed/protocol/tests/test_numba.py new file mode 100644 index 00000000000..794db58b3c9 --- /dev/null +++ b/distributed/protocol/tests/test_numba.py @@ -0,0 +1,19 @@ +from distributed.protocol import serialize, deserialize +import pytest + +cuda = pytest.importorskip("numba.cuda") +np = pytest.importorskip("numpy") + + +@pytest.mark.parametrize("dtype", ["u1", "u4", "u8", "f4"]) +def test_serialize_cupy(dtype): + ary = np.arange(100, dtype=dtype) + x = cuda.to_device(ary) + header, frames = serialize(x, serializers=("cuda", "dask", "pickle")) + y = deserialize(header, frames, deserializers=("cuda", "dask", "pickle", "error")) + + hx = np.empty_like(ary) + hy = np.empty_like(ary) + x.copy_to_host(hx) + y.copy_to_host(hy) + assert (hx == hy).all() From 2759f34b3de7e17ee97d07ab09e73e2d741ca6f8 Mon Sep 17 00:00:00 2001 From: Arpit Solanki Date: Sat, 21 Sep 2019 00:50:17 +0530 Subject: [PATCH 0469/1550] Add monitoring with dask cluster docs (#3072) --- docs/source/index.rst | 1 + docs/source/prometheus.rst | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+) create mode 100644 docs/source/prometheus.rst diff --git a/docs/source/index.rst b/docs/source/index.rst index 732c234a53b..ee32738f826 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -112,6 +112,7 @@ Contents local-cluster ipython Joblib Integration + prometheus publish queues resources diff --git a/docs/source/prometheus.rst b/docs/source/prometheus.rst new file mode 100644 index 00000000000..5000858a045 --- /dev/null +++ b/docs/source/prometheus.rst @@ -0,0 +1,36 @@ +Prometheus Monitoring +----------------------- + +Prometheus_ is a widely popular tool for monitoring and alerting a wide variety of systems. Dask.distributed exposes +scheduler and worker metrics in a prometheus text based format. Metrics are available at ``http://scheduler-address:8787/metrics``. + +.. _Prometheus: https://prometheus.io + +Available metrics are as following + ++---------------------------------------------+----------------------------------------------+ +| Metric name | Description | ++=========================+===================+==============================================+ +| dask_scheduler_workers | Number of workers connected. | ++---------------------------------------------+----------------------------------------------+ +| dask_scheduler_clients | Number of clients connected. | ++---------------------------------------------+----------------------------------------------+ +| dask_scheduler_received_tasks | Number of tasks received at scheduler | ++---------------------------------------------+----------------------------------------------+ +| dask_scheduler_unrunnable_tasks | Number of unrunnable tasks at scheduler | ++---------------------------------------------+----------------------------------------------+ +| dask_worker_tasks | Number of tasks at worker. | ++---------------------------------------------+----------------------------------------------+ +| dask_worker_connections | Number of task connections to other workers. | ++---------------------------------------------+----------------------------------------------+ +| dask_worker_threads | Number of worker threads. | ++---------------------------------------------+----------------------------------------------+ +| dask_worker_latency_seconds | Latency of worker connection. | ++---------------------------------------------+----------------------------------------------+ +| dask_worker_tick_duration_median_seconds | Median tick duration at worker. | ++---------------------------------------------+----------------------------------------------+ +| dask_worker_task_duration_median_seconds | Median task runtime at worker. | ++---------------------------------------------+----------------------------------------------+ +| dask_worker_transfer_bandwidth_median_bytes | Bandwidth for transfer at worker in Bytes. | ++---------------------------------------------+----------------------------------------------+ + From 386ee6c181272d398881e0a749176671e5c3ed61 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 23 Sep 2019 08:21:23 -0500 Subject: [PATCH 0470/1550] Respect Cluster.dashboard_link in Client._repr_html_ if it exists (#3077) --- distributed/client.py | 20 +++++++++++++------- distributed/tests/test_client.py | 13 +++++++++++++ 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 63d8213c33f..08fda7539cf 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -811,13 +811,19 @@ def _repr_html_(self): text += "
        • Scheduler: not connected
        • \n" if info and "dashboard" in info["services"]: - protocol, rest = scheduler.address.split("://") - port = info["services"]["dashboard"] - if protocol == "inproc": - host = "localhost" - else: - host = rest.split(":")[0] - address = format_dashboard_link(host, port) + try: + address = self.cluster.dashboard_link + except AttributeError: + protocol, rest = scheduler.address.split("://") + + port = info["services"]["dashboard"] + if protocol == "inproc": + host = "localhost" + else: + host = rest.split(":")[0] + + address = format_dashboard_link(host, port) + text += ( "
        • Dashboard: %(web)s\n" % {"web": address} diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index a4e472c882a..dfb0eb59386 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -91,6 +91,7 @@ from distributed.utils_test import ( # noqa: F401 client as c, client_secondary as c2, + cleanup, cluster_fixture, loop, loop_in_thread, @@ -5618,5 +5619,17 @@ async def test_file_descriptors_dont_leak(Worker): assert time() < begin + 5, (start, proc.num_fds()) +@pytest.mark.asyncio +async def test_dashboard_link_cluster(cleanup): + class MyCluster(LocalCluster): + @property + def dashboard_link(self): + return "http://foo.com" + + async with MyCluster(processes=False, asynchronous=True) as cluster: + async with Client(cluster, asynchronous=True) as client: + assert "http://foo.com" in client._repr_html_() + + if sys.version_info >= (3, 5): from distributed.tests.py3_test_client import * # noqa F401 From 549660e07c0c70fdb17e07c6a18ca438933bd8ba Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 23 Sep 2019 08:21:39 -0500 Subject: [PATCH 0471/1550] Have Client get Security from passed Cluster (#3079) --- distributed/client.py | 17 ++++++++++------- distributed/deploy/tests/test_local.py | 16 ++++++++++++++++ 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 08fda7539cf..5ca99b7f156 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -637,14 +637,7 @@ def __init__( self._gather_future = None # Communication - self.security = security or Security() self.scheduler_comm = None - assert isinstance(self.security, Security) - - if name == "worker": - self.connection_args = self.security.get_connection_args("worker") - else: - self.connection_args = self.security.get_connection_args("client") if address is None: address = dask.config.get("scheduler-address", None) @@ -658,6 +651,16 @@ def __init__( self.cluster = address with ignoring(AttributeError): loop = address.loop + if security is None: + security = self.cluster.security + + self.security = security or Security() + assert isinstance(self.security, Security) + + if name == "worker": + self.connection_args = self.security.get_connection_args("worker") + else: + self.connection_args = self.security.get_connection_args("client") self._connecting_to_scheduler = False self._asynchronous = asynchronous diff --git a/distributed/deploy/tests/test_local.py b/distributed/deploy/tests/test_local.py index 6611bfccc38..520c99eb268 100644 --- a/distributed/deploy/tests/test_local.py +++ b/distributed/deploy/tests/test_local.py @@ -26,6 +26,7 @@ assert_can_connect_from_everywhere_4, assert_can_connect_from_everywhere_4_6, captured_logger, + tls_only_security, ) from distributed.utils_test import loop # noqa: F401 from distributed.utils import sync @@ -952,3 +953,18 @@ async def test_repr(cleanup): n_workers=2, processes=False, memory_limit=None, asynchronous=True ) as cluster: assert "memory" not in repr(cluster) + + +@pytest.mark.asyncio +async def test_capture_security(cleanup): + security = tls_only_security() + async with LocalCluster( + n_workers=0, + silence_logs=False, + security=security, + asynchronous=True, + dashboard_address=False, + host="tls://0.0.0.0", + ) as cluster: + async with Client(cluster, asynchronous=True) as client: + assert client.security == cluster.security From 4953f81d7b6856f6a22f00e27476d4a554a120c2 Mon Sep 17 00:00:00 2001 From: Daniel Farrell Date: Wed, 25 Sep 2019 09:23:52 -0700 Subject: [PATCH 0472/1550] Add configuation option for longer error tracebacks (#3086) --- distributed/core.py | 5 +++-- distributed/distributed.yaml | 1 + distributed/tests/test_worker.py | 16 ++++++++++++++++ 3 files changed, 20 insertions(+), 2 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index d0f3c13aa97..32b509dc170 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -969,8 +969,9 @@ def error_message(e, status="error"): -------- clean_exception: deserialize and unpack message into exception/traceback """ + MAX_ERROR_LEN = dask.config.get("distributed.admin.max-error-length") tb = get_traceback() - e2 = truncate_exception(e, 1000) + e2 = truncate_exception(e, MAX_ERROR_LEN) try: e3 = protocol.pickle.dumps(e2) protocol.pickle.loads(e3) @@ -982,7 +983,7 @@ def error_message(e, status="error"): except Exception: tb = tb2 = "".join(traceback.format_tb(tb)) - if len(tb2) > 10000: + if len(tb2) > MAX_ERROR_LEN: tb_result = None else: tb_result = protocol.to_serialize(tb) diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index f277eb2f90b..7d012a2f68b 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -117,6 +117,7 @@ distributed: interval: 20ms # time between event loop health checks limit: 3s # time allowed before triggering a warning + max-error-length: 10000 # Maximum size traceback after error to return log-length: 10000 # default length of logs to keep in memory log-format: '%(name)s - %(levelname)s - %(message)s' pdb-on-err: False # enter debug mode on scheduling error diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index aa0de9fd90f..eb4c0f86c7b 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -356,6 +356,22 @@ def __str__(self): msg = error_message(MyException("Hello", "World!")) assert "Hello" in str(msg["exception"]) + max_error_len = 100 + with dask.config.set({"distributed.admin.max-error-length": max_error_len}): + msg = error_message(RuntimeError("-" * max_error_len)) + assert len(msg["text"]) <= max_error_len + assert len(msg["text"]) < max_error_len * 2 + msg = error_message(RuntimeError("-" * max_error_len * 20)) + cut_text = msg["text"].replace("('Long error message', '", "")[:-2] + assert len(cut_text) == max_error_len + + max_error_len = 1000000 + with dask.config.set({"distributed.admin.max-error-length": max_error_len}): + msg = error_message(RuntimeError("-" * max_error_len * 2)) + cut_text = msg["text"].replace("('Long error message', '", "")[:-2] + assert len(cut_text) == max_error_len + assert len(msg["text"]) > 10100 # default + 100 + @gen_cluster() def test_gather(s, a, b): From dec8abead0c460fece9b9d15f216bee5ca2e0d11 Mon Sep 17 00:00:00 2001 From: Jacob Tomlinson Date: Thu, 26 Sep 2019 14:42:24 +0100 Subject: [PATCH 0473/1550] Make Client.get_versions async friendly (#3064) --- distributed/client.py | 13 +++++++------ distributed/tests/test_client.py | 5 +++++ 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 5ca99b7f156..381516da5c6 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -3463,18 +3463,19 @@ def get_versions(self, check=False, packages=[]): >>> c.get_versions(packages=['sklearn', 'geopandas']) # doctest: +SKIP """ + return self.sync(self._get_versions, check=check, packages=packages) + + async def _get_versions(self, check=False, packages=[]): client = get_versions(packages=packages) try: - scheduler = sync(self.loop, self.scheduler.versions, packages=packages) + scheduler = await self.scheduler.versions(packages=packages) except KeyError: scheduler = None except TypeError: # packages keyword not supported - scheduler = sync(self.loop, self.scheduler.versions) # this raises + scheduler = await self.scheduler.versions() # this raises - workers = sync( - self.loop, - self.scheduler.broadcast, - msg={"op": "versions", "packages": packages}, + workers = await self.scheduler.broadcast( + msg={"op": "versions", "packages": packages} ) result = {"scheduler": scheduler, "workers": workers, "client": client} diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index dfb0eb59386..a95548689d1 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -3725,6 +3725,11 @@ def test_get_versions(c): assert dict(v["client"]["packages"]["optional"])["requests"] == requests.__version__ +@gen_cluster(client=True) +async def test_async_get_versions(c, s, a, b): + await c.get_versions(check=True) + + def test_threaded_get_within_distributed(c): import dask.multiprocessing From f7f6bd77ab8d28d8811b94aa84528d3a470f1ff1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Est=C3=A8ve?= Date: Thu, 26 Sep 2019 16:26:04 +0200 Subject: [PATCH 0474/1550] Fix widget with spec that generates multiple workers (#3067) --- distributed/deploy/cluster.py | 5 ++++- distributed/deploy/tests/test_spec_cluster.py | 3 +++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/distributed/deploy/cluster.py b/distributed/deploy/cluster.py index ad47881ba22..033f6877684 100644 --- a/distributed/deploy/cluster.py +++ b/distributed/deploy/cluster.py @@ -207,7 +207,10 @@ def dashboard_link(self): def _widget_status(self): workers = len(self.scheduler_info["workers"]) if hasattr(self, "worker_spec"): - requested = len(self.worker_spec) + requested = sum( + 1 if "group" not in each else len(each["group"]) + for each in self.worker_spec.values() + ) elif hasattr(self, "workers"): requested = len(self.workers) else: diff --git a/distributed/deploy/tests/test_spec_cluster.py b/distributed/deploy/tests/test_spec_cluster.py index d8c155c7c69..485ae1989ea 100644 --- a/distributed/deploy/tests/test_spec_cluster.py +++ b/distributed/deploy/tests/test_spec_cluster.py @@ -1,4 +1,5 @@ import asyncio +import re import dask from dask.distributed import SpecCluster, Worker, Client, Scheduler, Nanny @@ -375,6 +376,8 @@ async def test_MultiWorker(cleanup): await client.wait_for_workers(4) assert "workers=4" in repr(cluster) + workers_line = re.search("(Workers.+)", cluster._widget_status()).group(1) + assert re.match("Workers.*4", workers_line) cluster.scale(1) await cluster From 031e3a29edd9b4979b910c4773586881b2b26401 Mon Sep 17 00:00:00 2001 From: byjott Date: Thu, 26 Sep 2019 20:18:42 +0200 Subject: [PATCH 0475/1550] Fix worker preload config (#3027) * Fix preload option handling Fixes #3026 * Get preload from config in nanny --- distributed/nanny.py | 8 ++++++-- distributed/tests/test_preload.py | 20 ++++++++++++++++++-- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/distributed/nanny.py b/distributed/nanny.py index c017eb54af7..06b5a27dc79 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -75,8 +75,8 @@ def __init__( resources=None, silence_logs=None, death_timeout=None, - preload=(), - preload_argv=[], + preload=None, + preload_argv=None, security=None, contact_address=None, listen_address=None, @@ -116,7 +116,11 @@ def __init__( self.resources = resources self.death_timeout = parse_timedelta(death_timeout) self.preload = preload + if self.preload is None: + self.preload = dask.config.get("distributed.worker.preload") self.preload_argv = preload_argv + if self.preload_argv is None: + self.preload_argv = dask.config.get("distributed.worker.preload-argv") self.Worker = Worker if worker_class is None else worker_class self.env = env or {} worker_kwargs.update( diff --git a/distributed/tests/test_preload.py b/distributed/tests/test_preload.py index 9ce804b752a..d3171ed6842 100644 --- a/distributed/tests/test_preload.py +++ b/distributed/tests/test_preload.py @@ -1,12 +1,14 @@ import os +import pytest import shutil import sys import tempfile import pytest -from distributed import Client, Scheduler, Worker +import dask +from distributed import Client, Scheduler, Worker, Nanny from distributed.utils_test import cluster -from distributed.utils_test import cleanup, loop # noqa F401 +from distributed.utils_test import loop, cleanup # noqa F401 PRELOAD_TEXT = """ @@ -55,6 +57,20 @@ def dask_setup(worker): assert w.foo == "setup" +@pytest.mark.asyncio +async def test_worker_preload_config(cleanup): + text = """ +def dask_setup(worker): + worker.foo = 'setup' +""" + with dask.config.set({"distributed.worker.preload": text}): + async with Scheduler(port=0) as s: + async with Nanny(s.address) as w: + async with Client(s.address, asynchronous=True) as c: + d = await c.run(lambda dask_worker: dask_worker.foo) + assert d == {w.worker_address: "setup"} + + def test_worker_preload_module(loop): def check_worker(): import worker_info From 316aedbae2ee9d8b14b6dbfb220e191bd15aa4c2 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Fri, 27 Sep 2019 15:41:46 +0200 Subject: [PATCH 0476/1550] Use the new UCX Python bindings (#3059) See also https://github.com/rapidsai/ucx-py/pull/180 --- distributed/comm/tests/test_ucx.py | 52 +--------- distributed/comm/ucx.py | 157 ++++++++++++++--------------- 2 files changed, 82 insertions(+), 127 deletions(-) diff --git a/distributed/comm/tests/test_ucx.py b/distributed/comm/tests/test_ucx.py index 1355daf95b8..28348369899 100644 --- a/distributed/comm/tests/test_ucx.py +++ b/distributed/comm/tests/test_ucx.py @@ -32,11 +32,6 @@ async def get_comm_pair( async def handle_comm(comm): await q.put(comm) - # Workaround for hanging test in - # pytest distributed/comm/tests/test_ucx.py::test_comm_objs -vs --count=2 - # on the second time through. - ucp._libs.ucp_py.fin() - listener = listen(listen_addr, handle_comm, connection_args=listen_args, **kwargs) with listener: comm = await connect( @@ -93,6 +88,7 @@ async def handle_comm(comm): msg = await comm.read() msg["op"] = "pong" await comm.write(msg) + await comm.read() assert comm.closed() is False await comm.close() assert comm.closed @@ -118,11 +114,9 @@ async def client_communicate(key, delay=0): await asyncio.sleep(delay) msg = await comm.read() assert msg == {"op": "pong", "data": key} + await comm.write({"op": "client closed"}) l.append(key) return comm - assert comm.closed() is False - await comm.close() - assert comm.closed comm = await client_communicate(key=1234, delay=0.5) @@ -177,6 +171,7 @@ def test_ucx_deserialize(): lambda cudf: cudf.DataFrame({"a": [1.0]}).head(0), lambda cudf: cudf.DataFrame({"a": [1]}).head(0), lambda cudf: cudf.DataFrame({"a": [1, 2, None], "b": [1.0, 2.0, None]}), + lambda cudf: cudf.DataFrame({"a": ["Check", "str"], "b": ["Sup", "port"]}), ], ) async def test_ping_pong_cudf(g): @@ -264,24 +259,18 @@ async def test_ping_pong_numba(): assert result["op"] == "ping" -@pytest.mark.skip(reason="hangs") @pytest.mark.parametrize("processes", [True, False]) def test_ucx_localcluster(loop, processes): if processes: - kwargs = {"env": {"UCX_MEMTYPE_CACHE": "n"}} - else: - kwargs = {} + pytest.skip("Known bug, processes=True doesn't work currently") - ucx_addr = ucp.get_address() with LocalCluster( protocol="ucx", - interface="ib0", dashboard_address=None, n_workers=2, threads_per_worker=1, processes=processes, loop=loop, - **kwargs ) as cluster: with Client(cluster) as client: x = client.submit(inc, 1) @@ -292,47 +281,16 @@ def test_ucx_localcluster(loop, processes): assert len(cluster.scheduler.workers) == 2 -def test_tcp_localcluster(loop): - ucx_addr = "127.0.0.1" - port = 13337 - env = {"UCX_MEMTYPE_CACHE": "n"} - with LocalCluster( - 2, - scheduler_port=port, - ip=ucx_addr, - processes=True, - threads_per_worker=1, - dashboard_address=None, - silence_logs=False, - env=env, - ) as cluster: - pass - # with Client(cluster) as e: - # x = e.submit(inc, 1) - # x.result() - # assert x.key in c.scheduler.tasks - # assert any(w.data == {x.key: 2} for w in c.workers) - # assert e.loop is c.loop - # print(c.scheduler.workers) - - @pytest.mark.slow @pytest.mark.asyncio async def test_stress(): - from distributed.utils import get_ip_interface - - try: # this check should be removed once UCX + TCP works - get_ip_interface("ib0") - except Exception: - pytest.skip("ib0 interface not found") - import dask.array as da from distributed import wait chunksize = "10 MB" async with LocalCluster( - protocol="ucx", interface="ib0", asynchronous=True + protocol="ucx", dashboard_address=None, asynchronous=True, processes=False ) as cluster: async with Client(cluster, asynchronous=True) as client: rs = da.random.RandomState() diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index 8631bb18229..bbf2451e323 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -5,9 +5,7 @@ .. _UCX: https://github.com/openucx/ucx """ -import asyncio import logging -import struct from .addressing import parse_host_port, unparse_host_port from .core import Comm, Connector, Listener, CommClosedError @@ -15,7 +13,10 @@ from .utils import ensure_concrete_host, to_frames, from_frames from ..utils import ensure_ip, get_ip, get_ipv6, nbytes, log_errors +from tornado.ioloop import IOLoop import ucp +import numpy as np +import numba.cuda import os @@ -95,80 +96,100 @@ async def write( on_error: str = "message", ): with log_errors(): + if self.closed(): + raise CommClosedError("Endpoint is closed -- unable to send message") + if serializers is None: serializers = ("cuda", "dask", "pickle", "error") # msg can also be a list of dicts when sending batched messages frames = await to_frames(msg, serializers=serializers, on_error=on_error) - is_gpus = b"".join( - [ - struct.pack("?", hasattr(frame, "__cuda_array_interface__")) - for frame in frames - ] - ) - sizes = b"".join([struct.pack("Q", nbytes(frame)) for frame in frames]) - - nframes = struct.pack("Q", len(frames)) - - meta = b"".join([nframes, is_gpus, sizes]) - - await self.ep.send_obj(meta) + # Send meta data + await self.ep.send(np.array([len(frames)], dtype=np.uint64)) + await self.ep.send( + np.array( + [hasattr(f, "__cuda_array_interface__") for f in frames], + dtype=np.bool, + ) + ) + await self.ep.send(np.array([nbytes(f) for f in frames], dtype=np.uint64)) + # Send frames for frame in frames: - await self.ep.send_obj(frame) + if nbytes(frame) > 0: + if hasattr(frame, "__array_interface__") or hasattr( + frame, "__cuda_array_interface__" + ): + await self.ep.send(frame) + else: + await self.ep.send(frame) return sum(map(nbytes, frames)) async def read(self, deserializers=("cuda", "dask", "pickle", "error")): with log_errors(): + if self.closed(): + raise CommClosedError("Endpoint is closed -- unable to read message") + if deserializers is None: deserializers = ("cuda", "dask", "pickle", "error") - resp = await self.ep.recv_future() - obj = ucp.get_obj_from_msg(resp) - (nframes,) = struct.unpack( - "Q", obj[:8] - ) # first eight bytes for number of frames - - gpu_frame_msg = obj[ - 8 : 8 + nframes - ] # next nframes bytes for if they're GPU frames - is_gpus = struct.unpack("{}?".format(nframes), gpu_frame_msg) - - sized_frame_msg = obj[8 + nframes :] # then the rest for frame sizes - sizes = struct.unpack("{}Q".format(nframes), sized_frame_msg) - - frames = [] - - for i, (is_gpu, size) in enumerate(zip(is_gpus, sizes)): - if size > 0: - resp = await self.ep.recv_obj(size, cuda=is_gpu) - else: - resp = await self.ep.recv_future() - frame = ucp.get_obj_from_msg(resp) - frames.append(frame) - - msg = await from_frames( - frames, deserialize=self.deserialize, deserializers=deserializers - ) - return msg + try: + # Recv meta data + nframes = np.empty(1, dtype=np.uint64) + await self.ep.recv(nframes) + is_cudas = np.empty(nframes[0], dtype=np.bool) + await self.ep.recv(is_cudas) + sizes = np.empty(nframes[0], dtype=np.uint64) + await self.ep.recv(sizes) + except (ucp.exceptions.UCXCanceled, ucp.exceptions.UCXCloseError): + if self._ep is not None and not self._ep.closed(): + await self._ep.shutdown() + self._ep.close() + self._ep = None + raise CommClosedError("While reading, the connection was canceled") + else: + # Recv frames + frames = [] + for is_cuda, size in zip(is_cudas.tolist(), sizes.tolist()): + if size > 0: + if is_cuda: + frame = numba.cuda.device_array((size,), dtype=np.uint8) + else: + frame = np.empty(size, dtype=np.uint8) + await self.ep.recv(frame) + if is_cuda: + frames.append(frame) + else: + frames.append(frame.data) + else: + if is_cuda: + frames.append(numba.cuda.device_array((0,), dtype=np.uint8)) + else: + frames.append(b"") + msg = await from_frames( + frames, deserialize=self.deserialize, deserializers=deserializers + ) + return msg + + async def close(self): + if self._ep is not None: + if not self._ep.closed(): + await self._ep.signal_shutdown() + self._ep.close() + self._ep = None def abort(self): - if self._ep: - ucp.destroy_ep(self._ep) + if self._ep is not None: logger.debug("Destroyed UCX endpoint") + IOLoop.current().add_callback(self._ep.signal_shutdown) self._ep = None @property def ep(self): - if self._ep: + if self._ep is not None: return self._ep else: raise CommClosedError("UCX Endpoint is closed") - async def close(self): - # TODO: Handle in-flight messages? - # sleep is currently used to help flush buffer - self.abort() - def closed(self): return self._ep is None @@ -180,9 +201,8 @@ class UCXConnector(Connector): async def connect(self, address: str, deserialize=True, **connection_args) -> UCX: logger.debug("UCXConnector.connect: %s", address) - ucp.init() ip, port = parse_host_port(address) - ep = await ucp.get_endpoint(ip.encode(), port) + ep = await ucp.create_endpoint(ip, port) return self.comm_class( ep, local_addr=None, @@ -206,12 +226,8 @@ def __init__( self.comm_handler = comm_handler self.deserialize = deserialize self._ep = None # type: ucp.Endpoint - self.listener_instance = None # type: ucp.ListenerFuture self.ucp_server = None - self._task = None - self.connection_args = connection_args - self._task = None @property def port(self): @@ -222,39 +238,20 @@ def address(self): return "ucx://" + self.ip + ":" + str(self.port) def start(self): - async def serve_forever(client_ep, listener_instance): + async def serve_forever(client_ep): ucx = UCX( client_ep, local_addr=self.address, peer_addr=self.address, # TODO: https://github.com/Akshay-Venkatesh/ucx-py/issues/111 deserialize=self.deserialize, ) - self.listener_instance = listener_instance if self.comm_handler: await self.comm_handler(ucx) - ucp.init() - self.ucp_server = ucp.start_listener( - serve_forever, listener_port=self._input_port, is_coroutine=True - ) - - try: - loop = asyncio.get_running_loop() - except (RuntimeError, AttributeError): - loop = asyncio.get_event_loop() - - t = loop.create_task(self.ucp_server.coroutine) - self._task = t + self.ucp_server = ucp.create_listener(serve_forever, port=self._input_port) def stop(self): - # What all should this do? - if self._task: - self._task.cancel() - - if self._ep: - ucp.destroy_ep(self._ep) - # if self.listener_instance: - # ucp.stop_listener(self.listener_instance) + self.ucp_server = None def get_host_port(self): # TODO: TCP raises if this hasn't started yet. From 50f11f3cbca6b3b8d42c95dd9431f191a9680c89 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 27 Sep 2019 10:47:37 -0500 Subject: [PATCH 0477/1550] bump version to 2.5.0 --- docs/source/changelog.rst | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index 6bed266f32b..a55a6923a29 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -1,6 +1,25 @@ Changelog ========= +2.5.0 - 2019-09-27 +------------------ + +- Use the new UCX Python bindings (:pr:`3059`) `Mads R. B. Kristensen`_ +- Fix worker preload config (:pr:`3027`) `byjott`_ +- Fix widget with spec that generates multiple workers (:pr:`3067`) `Loïc Estève`_ +- Make Client.get_versions async friendly (:pr:`3064`) `Jacob Tomlinson`_ +- Add configuation option for longer error tracebacks (:pr:`3086`) `Daniel Farrell`_ +- Have Client get Security from passed Cluster (:pr:`3079`) `Matthew Rocklin`_ +- Respect Cluster.dashboard_link in Client._repr_html_ if it exists (:pr:`3077`) `Matthew Rocklin`_ +- Add monitoring with dask cluster docs (:pr:`3072`) `Arpit Solanki`_ +- Protocol of cupy and numba handles serialization exclusively (:pr:`3047`) `Mads R. B. Kristensen`_ +- Allow specification of worker type in SSHCLuster (:pr:`3061`) `Jacob Tomlinson`_ +- Use Cluster.scheduler_info for workers= value in repr (:pr:`3058`) `Matthew Rocklin`_ +- Allow SpecCluster to scale by memory and cores (:pr:`3057`) `Matthew Rocklin`_ +- Allow full script in preload inputs (:pr:`3052`) `Matthew Rocklin`_ +- Check multiple cgroups dirs, ceil fractional cpus (:pr:`3056`) `Jim Crist`_ +- Add blurb about disabling work stealing (:pr:`3055`) `Chris White`_ + 2.4.0 - 2019-09-13 ------------------ @@ -1268,3 +1287,4 @@ significantly without many new features. .. _`byjott`: https://github.com/byjott .. _`Mohammad Noor`: https://github.com/MdSalih .. _`Richard J Zamora`: https://github.com/rjzamora +.. _`Arpit Solanki`: https://github.com/arpit1997 From 1739811c48a73058e0db5eed38e8d46d0864d6d6 Mon Sep 17 00:00:00 2001 From: Daniel Farrell Date: Fri, 27 Sep 2019 10:11:00 -0700 Subject: [PATCH 0478/1550] Set known task durations with configuration (#3085) --- distributed/distributed.yaml | 1 + distributed/scheduler.py | 4 ++++ distributed/tests/test_scheduler.py | 25 ++++++++++++++++++++++++- 3 files changed, 29 insertions(+), 1 deletion(-) diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index 7d012a2f68b..48484be12a6 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -23,6 +23,7 @@ distributed: pickle: True # Is the scheduler allowed to deserialize arbitrary bytestrings preload: [] preload-argv: [] + default-task-durations: {} # How long we expect function names to run ("1h", "1s") (helps for long tasks) dashboard: status: task-stream-length: 1000 diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 6bf1ba0adb4..bdee7c7bdac 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -964,6 +964,10 @@ def __init__( # Prefix-keyed containers self.task_duration = {prefix: 0.00001 for prefix in fast_tasks} + for k, v in dask.config.get( + "distributed.scheduler.default-task-durations", {} + ).items(): + self.task_duration[k] = parse_timedelta(v) self.unknown_durations = defaultdict(set) # Client state diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 943daffde58..71a19eba54f 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -17,7 +17,7 @@ from distributed import Nanny, Worker, Client, wait, fire_and_forget from distributed.core import connect, rpc -from distributed.scheduler import Scheduler +from distributed.scheduler import Scheduler, TaskState from distributed.client import wait from distributed.metrics import time from distributed.protocol.pickle import dumps @@ -1641,3 +1641,26 @@ async def test_retire_names_str(cleanup): await s.retire_workers(names=[0]) assert all(f.done() for f in futures) assert len(b.data) == 10 + + +def test_get_task_duration(): + with dask.config.set( + {"distributed.scheduler.default-task-durations": {"prefix_1": 100}} + ): + s = Scheduler(port=0) + assert "prefix_1" in s.task_duration + assert s.task_duration["prefix_1"] == 100 + + ts_pref1 = TaskState("prefix_1-abcdefab", None) + assert s.get_task_duration(ts_pref1) == 100 + + # make sure get_task_duration adds TaskStates to unknown dict + assert len(s.unknown_durations) == 0 + ts_pref2 = TaskState("prefix_2-abcdefab", None) + assert s.get_task_duration(ts_pref2) == 0.5 # default + assert len(s.unknown_durations) == 1 + assert len(s.unknown_durations["prefix_2"]) == 1 + ts_pref2_2 = TaskState("prefix_2-accdefab", None) + assert s.get_task_duration(ts_pref2_2) == 0.5 # default + assert len(s.unknown_durations) == 1 + assert len(s.unknown_durations["prefix_2"]) == 2 From 95a2f4cbec9b26b37a85120841479232afc48432 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Fri, 27 Sep 2019 16:47:12 -0500 Subject: [PATCH 0479/1550] Fix tornado typo in asynchronous docs (#3101) --- docs/source/asynchronous.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/asynchronous.rst b/docs/source/asynchronous.rst index 9d38a8b04fe..1ffc8d4f5c3 100644 --- a/docs/source/asynchronous.rst +++ b/docs/source/asynchronous.rst @@ -108,7 +108,7 @@ Python 2/3 with Tornado future = client.submit(lambda x: x + 1, 10) result = yield future yield client.close() - raise gen.Result(result) + raise gen.Return(result) from tornado.ioloop import IOLoop IOLoop().run_sync(f) From 9bea7d78aaf4f830f9ef905fe193908d8da806cf Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 27 Sep 2019 17:21:06 -0500 Subject: [PATCH 0480/1550] Support clusters that don't have .security or ._close methods (#3100) Previously clients checked for attributes that may not be universally present. Now we only check on scheduler, scheduler_comm, close, workers, and other elements defined in the `Cluster` interface. Fixes https://github.com/dask/dask-jobqueue/issues/341 --- distributed/client.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 381516da5c6..75f1fd03d9f 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -652,7 +652,7 @@ def __init__( with ignoring(AttributeError): loop = address.loop if security is None: - security = self.cluster.security + security = getattr(self.cluster, "security", None) self.security = security or Security() assert isinstance(self.security, Security) @@ -917,9 +917,7 @@ async def _start(self, timeout=no_default, **kwargs): if self.cluster is not None: # Ensure the cluster is started (no-op if already running) try: - await self.cluster._start() - except AttributeError: # Some clusters don't have this method - pass + await self.cluster except Exception: logger.info( "Tried to start cluster and received an error. Proceeding.", @@ -1266,7 +1264,7 @@ async def _close(self, fast=False): self._release_key(key=key) if self._start_arg is None: with ignoring(AttributeError): - await self.cluster._close() + await self.cluster.close() self.rpc.close() self.status = "closed" if _get_global_client() is self: From 8bc04e559d3c4d744b0453cf4aada4eece28c35a Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 27 Sep 2019 17:22:21 -0500 Subject: [PATCH 0481/1550] bump version to 2.5.1 --- docs/source/changelog.rst | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index a55a6923a29..aff09d2c3c7 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -1,6 +1,13 @@ Changelog ========= + +2.5.1 - 2019-09-27 +------------------ + +- Support clusters that don't have .security or ._close methods (:pr:`3100`) `Matthew Rocklin`_ + + 2.5.0 - 2019-09-27 ------------------ @@ -20,6 +27,7 @@ Changelog - Check multiple cgroups dirs, ceil fractional cpus (:pr:`3056`) `Jim Crist`_ - Add blurb about disabling work stealing (:pr:`3055`) `Chris White`_ + 2.4.0 - 2019-09-13 ------------------ From a76bd8a2db4358aeb02e5f97e053b63ebc280eca Mon Sep 17 00:00:00 2001 From: Jacob Tomlinson Date: Mon, 30 Sep 2019 15:48:46 +0100 Subject: [PATCH 0482/1550] Check if self.cluster.scheduler is a local scheduler (#3099) --- distributed/client.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/distributed/client.py b/distributed/client.py index 75f1fd03d9f..94165b4c380 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -786,10 +786,12 @@ def __repr__(self): return "<%s: not connected>" % (self.__class__.__name__,) def _repr_html_(self): + from .scheduler import Scheduler + if ( self.cluster and hasattr(self.cluster, "scheduler") - and self.cluster.scheduler + and isinstance(self.cluster.scheduler, Scheduler) ): info = self.cluster.scheduler.identity() scheduler = self.cluster.scheduler From 430760b0e45ad8bbaeac40de34a497e554df1389 Mon Sep 17 00:00:00 2001 From: Philipp Rudiger Date: Mon, 30 Sep 2019 17:12:02 +0200 Subject: [PATCH 0483/1550] Lower default bokeh log level (#3087) --- distributed/config.py | 2 +- distributed/deploy/spec.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/distributed/config.py b/distributed/config.py index 7e6075125fd..a313f18416b 100644 --- a/distributed/config.py +++ b/distributed/config.py @@ -76,7 +76,7 @@ def _initialize_logging_old_style(config): loggers = { # default values "distributed": "info", "distributed.client": "warning", - "bokeh": "critical", + "bokeh": "error", "tornado": "critical", "tornado.application": "error", } diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index 87336f96184..5cc0722ef72 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -211,6 +211,9 @@ def __init__( if silence_logs: self._old_logging_level = silence_logging(level=silence_logs) + self._old_bokeh_logging_level = silence_logging( + level=silence_logs, root="bokeh" + ) self._loop_runner = LoopRunner(loop=loop, asynchronous=asynchronous) self.loop = self._loop_runner.loop @@ -350,6 +353,8 @@ async def _close(self): if hasattr(self, "_old_logging_level"): silence_logging(self._old_logging_level) + if hasattr(self, "_old_bokeh_logging_level"): + silence_logging(self._old_bokeh_logging_level, root="bokeh") await super()._close() From 8b4cbe2ba6b4015bf9f7fd46730e3c2bc2b3464e Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 30 Sep 2019 10:12:27 -0500 Subject: [PATCH 0484/1550] Remove utils.py functions for their dask/utils.py equivalents (#3042) --- distributed/client.py | 3 +- distributed/utils.py | 168 ++---------------------------------------- distributed/worker.py | 3 +- requirements.txt | 2 +- 4 files changed, 11 insertions(+), 165 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 94165b4c380..d00f545b161 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -27,7 +27,7 @@ from dask.core import flatten, get_dependencies from dask.optimization import SubgraphCallable from dask.compatibility import apply -from dask.utils import ensure_dict, format_bytes +from dask.utils import ensure_dict, format_bytes, funcname try: from cytoolz import first, groupby, merge, valmap, keymap @@ -69,7 +69,6 @@ from .utils import ( All, sync, - funcname, ignoring, tokey, log_errors, diff --git a/distributed/utils.py b/distributed/utils.py index 015e4dbfb61..fbac950df43 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -10,7 +10,6 @@ import json import logging import multiprocessing -from numbers import Number import os import re import shutil @@ -37,7 +36,14 @@ from dask import istask # provide format_bytes here for backwards compatibility -from dask.utils import format_bytes # noqa +from dask.utils import ( # noqa + format_bytes, + funcname, + format_time, + parse_bytes, + parse_timedelta, +) + import toolz import tornado from tornado import gen @@ -80,16 +86,6 @@ def _initialize_mp_context(): mp_context = _initialize_mp_context() -def funcname(func): - """Get the name of a function.""" - while hasattr(func, "func"): - func = func.func - try: - return func.__name__ - except AttributeError: - return str(func) - - def has_arg(func, argname): """ Whether the function takes an argument with the given name. @@ -1086,135 +1082,6 @@ def __reduce__(self): return (itemgetter, (self.index,)) -byte_sizes = { - "kB": 10 ** 3, - "MB": 10 ** 6, - "GB": 10 ** 9, - "TB": 10 ** 12, - "PB": 10 ** 15, - "KiB": 2 ** 10, - "MiB": 2 ** 20, - "GiB": 2 ** 30, - "TiB": 2 ** 40, - "PiB": 2 ** 50, - "B": 1, - "": 1, -} -byte_sizes = {k.lower(): v for k, v in byte_sizes.items()} -byte_sizes.update({k[0]: v for k, v in byte_sizes.items() if k and "i" not in k}) -byte_sizes.update({k[:-1]: v for k, v in byte_sizes.items() if k and "i" in k}) - - -def parse_bytes(s): - """ Parse byte string to numbers - - >>> parse_bytes('100') - 100 - >>> parse_bytes('100 MB') - 100000000 - >>> parse_bytes('100M') - 100000000 - >>> parse_bytes('5kB') - 5000 - >>> parse_bytes('5.4 kB') - 5400 - >>> parse_bytes('1kiB') - 1024 - >>> parse_bytes('1e6') - 1000000 - >>> parse_bytes('1e6 kB') - 1000000000 - >>> parse_bytes('MB') - 1000000 - """ - if isinstance(s, (int, float)): - return int(s) - s = s.replace(" ", "") - if not s[0].isdigit(): - s = "1" + s - - for i in range(len(s) - 1, -1, -1): - if not s[i].isalpha(): - break - index = i + 1 - - prefix = s[:index] - suffix = s[index:] - - n = float(prefix) - - multiplier = byte_sizes[suffix.lower()] - - result = n * multiplier - return int(result) - - -timedelta_sizes = { - "s": 1, - "ms": 1e-3, - "us": 1e-6, - "ns": 1e-9, - "m": 60, - "h": 3600, - "d": 3600 * 24, -} - -tds2 = { - "second": 1, - "minute": 60, - "hour": 60 * 60, - "day": 60 * 60 * 24, - "millisecond": 1e-3, - "microsecond": 1e-6, - "nanosecond": 1e-9, -} -tds2.update({k + "s": v for k, v in tds2.items()}) -timedelta_sizes.update(tds2) -timedelta_sizes.update({k.upper(): v for k, v in timedelta_sizes.items()}) - - -def parse_timedelta(s, default="seconds"): - """ Parse timedelta string to number of seconds - - Examples - -------- - >>> parse_timedelta('3s') - 3 - >>> parse_timedelta('3.5 seconds') - 3.5 - >>> parse_timedelta('300ms') - 0.3 - >>> parse_timedelta(timedelta(seconds=3)) # also supports timedeltas - 3.0 - """ - if s is None: - return None - if isinstance(s, timedelta): - return s.total_seconds() - if isinstance(s, Number): - s = str(s) - s = s.replace(" ", "") - if not s[0].isdigit(): - s = "1" + s - - for i in range(len(s) - 1, -1, -1): - if not s[i].isalpha(): - break - index = i + 1 - - prefix = s[:index] - suffix = s[index:] or default - - n = float(prefix) - - multiplier = timedelta_sizes[suffix.lower()] - - result = n * multiplier - if int(result) == result: - result = int(result) - return result - - def asciitable(columns, rows): """Formats an ascii table for given columns and rows. @@ -1282,25 +1149,6 @@ def json_load_robust(fn, load=json.load): sleep(0.1) -def format_time(n): - """ format integers as time - - >>> format_time(1) - '1.00 s' - >>> format_time(0.001234) - '1.23 ms' - >>> format_time(0.00012345) - '123.45 us' - >>> format_time(123.456) - '123.46 s' - """ - if n >= 1: - return "%.2f s" % n - if n >= 1e-3: - return "%.2f ms" % (n * 1e3) - return "%.2f us" % (n * 1e6) - - class DequeHandler(logging.Handler): """ A logging.Handler that records records into a deque """ diff --git a/distributed/worker.py b/distributed/worker.py index bb00158ced8..2bd345ea0c5 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -17,7 +17,7 @@ import dask from dask.core import istask from dask.compatibility import apply -from dask.utils import format_bytes +from dask.utils import format_bytes, funcname try: from cytoolz import pluck, partial, merge, first @@ -43,7 +43,6 @@ from .threadpoolexecutor import ThreadPoolExecutor, secede as tpe_secede from .utils import ( get_ip, - funcname, typename, has_arg, _maybe_complex, diff --git a/requirements.txt b/requirements.txt index 804bdfe9637..e599ab0de93 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ click >= 6.6 cloudpickle >= 0.2.2 -dask >= 2 +dask >= 2.3 msgpack psutil >= 5.0 sortedcontainers !=2.0.0, !=2.0.1 From 790a4c032128f651c67b671b18a732034346f291 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Mon, 30 Sep 2019 11:26:05 -0500 Subject: [PATCH 0485/1550] Add favicon of logo to the dashboard (#3095) * Add favicon to dashboard * Include in package data * Use relative links --- MANIFEST.in | 1 + distributed/dashboard/scheduler_html.py | 26 +++++++++++++----- .../dashboard/static/images/favicon.ico | Bin 0 -> 15406 bytes distributed/dashboard/templates/base.html | 1 + distributed/dashboard/templates/main.html | 1 + 5 files changed, 22 insertions(+), 7 deletions(-) create mode 100755 distributed/dashboard/static/images/favicon.ico diff --git a/MANIFEST.in b/MANIFEST.in index b7a3764c87a..2a8f054e213 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -4,6 +4,7 @@ recursive-include distributed *.coffee recursive-include distributed *.html recursive-include distributed *.css recursive-include distributed *.svg +recursive-include distributed *.ico recursive-include distributed *.yaml recursive-include docs *.rst diff --git a/distributed/dashboard/scheduler_html.py b/distributed/dashboard/scheduler_html.py index 1377b037173..6ac13915523 100644 --- a/distributed/dashboard/scheduler_html.py +++ b/distributed/dashboard/scheduler_html.py @@ -12,6 +12,8 @@ func.__name__: func for func in [format_bytes, format_time, datetime.fromtimestamp] } +rel_path_statics = {"rel_path_statics": "../../"} + class Workers(RequestHandler): def get(self): @@ -20,7 +22,7 @@ def get(self): "workers.html", title="Workers", scheduler=self.server, - **toolz.merge(self.server.__dict__, ns, self.extra) + **toolz.merge(self.server.__dict__, ns, self.extra, rel_path_statics), ) @@ -36,7 +38,7 @@ def get(self, worker): title="Worker: " + worker, scheduler=self.server, Worker=worker, - **toolz.merge(self.server.__dict__, ns, self.extra) + **toolz.merge(self.server.__dict__, ns, self.extra, rel_path_statics), ) @@ -52,7 +54,7 @@ def get(self, task): title="Task: " + task, Task=task, scheduler=self.server, - **toolz.merge(self.server.__dict__, ns, self.extra) + **toolz.merge(self.server.__dict__, ns, self.extra, rel_path_statics), ) @@ -60,7 +62,12 @@ class Logs(RequestHandler): def get(self): with log_errors(): logs = self.server.get_logs() - self.render("logs.html", title="Logs", logs=logs, **self.extra) + self.render( + "logs.html", + title="Logs", + logs=logs, + **toolz.merge(self.extra, rel_path_statics), + ) class WorkerLogs(RequestHandler): @@ -69,7 +76,12 @@ async def get(self, worker): worker = escape.url_unescape(worker) logs = await self.server.get_worker_logs(workers=[worker]) logs = logs[worker] - self.render("logs.html", title="Logs: " + worker, logs=logs, **self.extra) + self.render( + "logs.html", + title="Logs: " + worker, + logs=logs, + **toolz.merge(self.extra, rel_path_statics), + ) class WorkerCallStacks(RequestHandler): @@ -82,7 +94,7 @@ async def get(self, worker): "call-stack.html", title="Call Stacks: " + worker, call_stack=call_stack, - **self.extra + **toolz.merge(self.extra, rel_path_statics), ) @@ -101,7 +113,7 @@ async def get(self, key): "call-stack.html", title="Call Stack: " + key, call_stack=call_stack, - **self.extra + **toolz.merge(self.extra, rel_path_statics), ) diff --git a/distributed/dashboard/static/images/favicon.ico b/distributed/dashboard/static/images/favicon.ico new file mode 100755 index 0000000000000000000000000000000000000000..eac169d0c7ef3caad9b48ca6c142a5351e9f7dcc GIT binary patch literal 15406 zcmeHO30zZWwh!7)7h7v>-Dd1y9k)_JAc2IiMD~3NBm_u8AV4520xGy06&D1=x}mrr zim1e;&WxRo)5VS(TQ))<5J({VhJe_1-g`4|^3KJgs7OU)-<$WF`}^JB&As>g&N<(A zzwa#nbAm$IpeCcHO+&$T9xCHC6v`EaLd~1^;C{*bC{z_Zb9DUa{UQ|V&I}Z43ABMO z(Br{#xJ>ByB(#2^6Q%wW6vr}YLLGl%G z4jae+GnOKp7mzQRPTDQ6qwZI(MR7P|j%NY2VJeQ!ug50yumP*2=x|PFk}HMx2Mkr< zf&#t~yQS|Y9#;N2^@KibY#C4}ER7%Mo-924;5~P;3#s{AmjuBo52mmT&EmaIt#iKB5YH~PUKsppV9Ay{`9f> z!!sLqrs#r4su*qWecQ`9ys6EZ%tr(7x1+XoU!@&X1*O&)aM>437f=?9%~&+tlLX}+ zEHN2{=DqF|*653;3%((+#p}YhcR42=(R|A|IWR5zf|18KiF~K04;(|7x123*4K9>k zjoK<-8L&>g#5ai$Dc;-9hjDKLR*qrrQH^@tNRXppT>z@ zs~52Jf2NyfosE5 z=jf)*m^j`#m$=4Hkv;V6*01D^4Ti`o!o~B3K&PlY(nZHSb6Nu)=*L&W?jHY0mP$HU%5Y zy8t>+|0$*EBSL25?7%f`xe=RX^5`AiT1u&+0I>@bl>rU9)s@CCbW0TI5SxW(2u+@; zBB^&)>zDp3TZ@C%c94VD$`?8Ki>6>myka-1V8E5yfJ68Y2-jZ=-qh|7{O?fqDEbrj zt2X+V>ZeRp20L=&Y+R!7iW@`l50FC%{BL<@iB*9s+Xh0`b__&vWTGgJjOjvdUW23Z zhp==($Ri#Wn=nq7Go`HOue1Zb=ETG5Tmt9%g!d{sj)!$i(TcHUm|{=GCb+1 zTA5a@yODC*xL_3Z=lbbD7JZDP3+nO7B8eMQwDhq~hz*Wx-U9pwG4cEuES+Bmb5(*( z5ziWFhxirAm8#9ae<-6?w=-zpME8Wn#+Axj4?50^4_RUQ_slCHKYfxmI*FQdn5kbefVq%+?0$sIdn&=By z=Tk?TBL0w5L{q&hnI~}sU4SO@e(RPbI^&bwIvBF9qtbWhPd)_LVrG?w$v&&UmsMvd znW$~HBVT_9_Tg2pv=*&bQrp52{-1i^Tn0Xr9uwVk$^mo%@_ZRd_#5AxR&Lk^=^*UE zjRejEKPdTxVSYxfu_@<*sVVbfzwHQ(=eoi=gf+O%BUN0|GGC*LWw@kghq z#=SELl7~K{SGr(f*!qqeq0Uhw*kVA3wRihfA-b>d;wZ;(E$T0>Me z*EpDW$+(@&mHg^?c_ZyylX#vUDMCHs>yGsM0DtS;mpghd3e$eB)L^hPclsw{;f|Vv4XGJxdY~sXAGwbdB+~+7EMT z`le>p85?pho7-4$KMuR6Ub)4IoL~=jd!@Gw_-40$9kg1qFQl+7C3w9Q zTV$V;b}5tz<-LRSgQ{G|+Yo&6Bk6r`TeC1!C=HOVJ*GuX9;gE<$B z{Tb))PFN3(ArvYor(tUNjt+N9dCxxJuVNlk8`Dl|jz#b8@($bFaU;5@OHC}5ld`$G zboK>tt@@6u~@-#~n&OgU?KXd_;f=UDyXE43f8PU<$nJWDca^>I-}o#nA5@`1Rr z9wx0qKNDgj1?b>g>{=xt6R+kV?NPl2F1Yb3!L1bgGNgwJ{h7>Cv0rJdB?OFykmc~*S=0<50dCFoC6{JMs0N}re zc0f6V=m#s1oFEB{O>9{&HQ}?KuSol3fWH0m2f3ujF}6$*_@a2-%Sue(SYAW3EU&*T zo&T@>5u5}dT2vC8*2PJH;4=>-MjspRd}f#bGcPQY76;bPbBL(N0+zgZNmL`?g|8lS zu-@=Fuh|2B>JroZl}l{%Rfw_junBzh(zyCrsGvs}gZ21XpVHE>+e=(xcwb=Scmsfe z7(9}Ne2+BoUVMrm1Ou_UogH^VKECxO9G~LSaPcgc82(l)rP++9@&6nA`@eyIZSu}; z75d~#4*BJ_272bT&lw%_j!*HIeLE+N_qKBsFAU;uIi4!`FL$P}&of=520VSxZ>8j3 z&>E>Ecx}ff|8?C4xOB6wc z?caxP=+K36WLLvE?O~y75ub7Fyu{D*`k59;+)9S^ikzt7737op?p zJCI{JnB<Wzd3_&{T;yJS%3ldxX}b! z7dkHt!KEIIkvrLcL&wKqoQ{=IMKTe&L@^Xs-qV|~Pq{6+ymS7ua($!c=lOnHCOMJ$ z3jlvR2C@TQEBGgz|8QprhCqHZf(Za7WcJT%QwJBcw}r0l_)FM^&f;*6EX+HTZ|h8K zB;w-vosbVT0M^Gqa*Yr9!aRqpmduXaCQBgimj4yxA4=HY+f6;FDxe&KwPY>(4Sc`` zc>>P#jT>F?H+-V-2N+8;q;`vmb*u^k~(@=i>#A~vp6UQgZE`vbi~DNU->X44@ycsSOV{9z0E zVGNy5hcR>jZharH(0bQIo;Q{uoadR<;taS6#eY@%j*vAS*CID{c13ON5=Cw5WML?R zV8|m#aIlB548bY{>mN&$0FGojxl|Dg^Dh8Qdnoy+wuM=#ilV&$zCIRxLt|X zIkRX89{Hqi2e>t0G!P5rFe+5-%ww9t)RVg3CspZO0G|ik&I$*tBQ&+e8DziV#^8UC zPZrjr6PqF4d|Va;pCe=%-$B;8V|cSaV7Tp&%Q9hUycOtJp0jfTFU$_`T+sJV<&kKm z-7{&2m1Tgd8vs}COFgCA$vpnzInO6YJ)iv5d#*Gd2iA`PpCmH5B@0#|wFghO8@{)_ z{-#4nQ#9oAWgx!+d^lK7dU>G#~y%d^0S)IRtE&%rix3^hy)$ z1)DnaVS}H3%y+vz6EOUxX>~avNM@4ePS;cO&m#?rO1F2DLF@ldpkaeAg>g^qt3p zH8iR&gfrHAHY^`-3^{ zi$-G*jDyD6>DBsbnE%0SuBkiok}-%?X&s!vKJ@1Mu5>{$;N}L9qjx#0#dq{Np5<_D zioNfxNaxT-1tzjd19_YRTf(S)4s-36){6ICA^v^n8tFZ-cM4K*cVSrBBkUf@@1&g4 zF*0inJ%H=>uej8=KBZ1I{pIHVL2ty5!O)uGAg=26N)sDB(pxe|+g4jb<7|7MJG-2M z8~U6h8hadx4T&G%AI~?VVI6vBOMHS0+iwuJcHN6DQHUw!ipbG9;go8F1K_ncvg-`@ zAvbU_n>+BaHRK0fFv%sM(Gh$Ip+{QFfMw7~Jq9I-;>3 z)=({)$P0CfxCQkq4`X?oWenl8fI>+!%zrDTv}cgEU-f%>g~n-Q9i<*s&&@caI{>+r zyDKi5MLAp@!CLmr8Eh#&S@2iLJ>K(Vi7w+(_#-)7n?;1%OYD7bA72(&-v#r34`TUy zSPH)l@?e`0y^qufz%YmRi${)ZM%32sGMIl8{ea5IIHD;a?U4+x6Ds6|GR_$3AioOc zUzvBsoW;EI&k5zHp7p&DU&bv-aM3eGG)PDn-|$F=9Noi%_xb|Y>vQe=Z;;U8O=mDM z&7HV7{yp%$wP3fu#H9+UPE0<03$Wc_vqCshJYf0X(e^9vGb=UUFu{f)-vF6w<|&gC z)Lpg1{2TKxo6B?RhOB~fBI_OUPRBw1ey~HW9w}mnhvGl@VS~f~_JK`{ouZm(kcYg8 z#4(_=1|V*b5Hedf;xk$nBmNT-GsNtc=hOE0s-f0Q1-X#TDaQtfzd=%!<|CwTEa#%> zZr|cWuFiDd`SDKo$^Jr&w8(c`~w*^h6d0*pMaD$8K?C#L5H4OalzaN z^7B?)y!S9q{W7CKXj$p51^TYSqeM{?XhU$QP>b!$w}N|6(xP&OA|0QuD6 znm3D$15q?{IA6>Xe+5`XpI=^^KA=Eay*#%KOIX?QVf<&k+03Jwo{Sp3DxGT}BKd%n zW2))twfZv9p~k$+W(AvTMvv_8myv&Qg!EmIe;blJ^<=dSKpfML*wet(?HaI2UEv#~ zjZjBY63Gi07K?Mwqdf?O$o4|LGRY_35E`CV1+MVKxt*fGRjAu%mdf8voQIs>(nU-{&; z3IkTPcSDX<4>{HWhy#XVws)IJ#fmn-_z#fFby&4AIO=H_LpAxGaz=% zKH=v3j>)nah&~}?@)tmk+ApY}JwIf9M|BjZQxH?st${qwT-iTb1wJ>+_*ADJJh81F%GJXNtLRGHH;Zvhy&QDT$(+lkSR{UdI)PPi53-xoNWPxTvCjNIE&r2G zIFleoc^higeoZXxB>_Hi4(b*Z%%ke>Q%`90tZLmAsOMcCyiu9}b;d%-DgO|&Q~qsq zNv9vXMje`c*+fpRh5DA$I!@L(gAV4tH;a4M&0042NomKsNBptqa^-x;^|6wUsJ=}- zt}(Gr=}l1kbewuXy&U9U8&%Y$jM*jE#*`?^s1=IEuzoP9r}ZDFoz@(K`e8HZASwIY z5TaNAcMi~_h;KkU)Qe9(q1g}ecR)V78+^1}$m6+1ZtXY+waVt$QbkKlsXCanPd9@K zK2X{z?U!H!`XRr51KLigzT=e)6vmNoTxZLw(ye5lF$iF8zfL=&3l87b86C5;TSP(X z0?HLt(Yw_PlMbsFLXDb;eMUEsS!XzxUf1`~Pkbe^Juguh15(qF3iaI}|5b=l4pJ+% z&LDpU_;Pn4uig{4M@~zu=m}1()^=so=nUC)`t7XRjxjN$=Vfc%b`~2F_&Mx40}blG zs-OnS52@DyJhUCb6Tu&Qp~rsMeL_WvC%O0#Z+rf5n&# z7@Rw)M1Fo){sX-xkYA8er5l7=qlT1Yh9`Z*mtkwgz7`wMqtmj%wjz3mTr5w4{|u-> zSN$F2|Bsa8n(tW<3sb-swIWmNl7e~;q;?L$f5_W=94TdTKG^HOGY%>L3AIlqsI$z3 z+LuT8wRO2$lbQ{&R6mnarlDask Diagnostic UI + {% block resources %} {% block css_resources %} {{ bokeh_css | indent(8) if bokeh_css }} diff --git a/distributed/dashboard/templates/main.html b/distributed/dashboard/templates/main.html index 8d0d8264d52..88f7453a42a 100644 --- a/distributed/dashboard/templates/main.html +++ b/distributed/dashboard/templates/main.html @@ -3,6 +3,7 @@ {{title}} + From 6fe62774aa7ad585cf2231ca6475f70fdc1cec24 Mon Sep 17 00:00:00 2001 From: Jacob Tomlinson Date: Tue, 1 Oct 2019 00:16:57 +0100 Subject: [PATCH 0486/1550] Retry scheduler connect multiple times (#3104) --- distributed/comm/core.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/distributed/comm/core.py b/distributed/comm/core.py index 1bbc043f52d..256a17de3a5 100644 --- a/distributed/comm/core.py +++ b/distributed/comm/core.py @@ -7,7 +7,7 @@ from tornado import gen from ..metrics import time -from ..utils import parse_timedelta +from ..utils import parse_timedelta, ignoring from . import registry from .addressing import parse_address @@ -188,6 +188,7 @@ async def connect(addr, timeout=None, deserialize=True, connection_args=None): scheme, loc = parse_address(addr) backend = registry.get_backend(scheme) connector = backend.get_connector() + comm = None start = time() deadline = start + timeout @@ -205,14 +206,19 @@ def _raise(error): # This starts a thread while True: try: - future = connector.connect( - loc, deserialize=deserialize, **(connection_args or {}) - ) - comm = await gen.with_timeout( - timedelta(seconds=deadline - time()), - future, - quiet_exceptions=EnvironmentError, - ) + while deadline - time() > 0: + future = connector.connect( + loc, deserialize=deserialize, **(connection_args or {}) + ) + with ignoring(gen.TimeoutError): + comm = await gen.with_timeout( + timedelta(seconds=min(deadline - time(), 1)), + future, + quiet_exceptions=EnvironmentError, + ) + break + if not comm: + _raise(error) except FatalCommClosedError: raise except EnvironmentError as e: @@ -222,8 +228,6 @@ def _raise(error): logger.debug("sleeping on connect") else: _raise(error) - except gen.TimeoutError: - _raise(error) else: break From ad37905521c8c133707b44a4c7676ee4735e20f0 Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Tue, 1 Oct 2019 14:35:16 -0400 Subject: [PATCH 0487/1550] Send noise over the wire to keep connection alive (#3105) This change keeps the `dask-ssh` connection alive, preventing ssh timeouts from closing a pipe and killing workers/schedulers without user action. Note that for the `asyncssh` version, this won't be required as `asyncssh` has a specific `keepalive` parameter that can be passed through, but before that is the default, this should ease some pain. --- distributed/deploy/ssh.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/distributed/deploy/ssh.py b/distributed/deploy/ssh.py index 9390d00a2ab..30f6f819224 100644 --- a/distributed/deploy/ssh.py +++ b/distributed/deploy/ssh.py @@ -182,11 +182,16 @@ def communicate(): ) return True + # Get transport to current SSH client + transport = ssh.get_transport() + # Wait for a message on the input_queue. Any message received signals this # thread to shut itself down. while cmd_dict["input_queue"].empty(): # Kill some time so that this thread does not hog the CPU. time.sleep(1.0) + # Send noise down the pipe to keep connection active + transport.send_ignore() if communicate(): break From 52d3e057721802ee725d71448b6f7ef17a8e515a Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 1 Oct 2019 16:32:45 -0500 Subject: [PATCH 0488/1550] Collect worker-worker and type bandwidth information (#3094) This collects the bandwidth that we observe both by type, and by worker-worker pair. We use this for visual diagnostics, and maybe scheduling decisions in the future. --- distributed/dashboard/nvml.py | 11 +- distributed/dashboard/scheduler.py | 183 +++++++++++++++++- .../dashboard/tests/test_scheduler_bokeh.py | 13 +- distributed/scheduler.py | 30 ++- distributed/tests/test_scheduler.py | 34 +++- distributed/worker.py | 27 ++- 6 files changed, 277 insertions(+), 21 deletions(-) diff --git a/distributed/dashboard/nvml.py b/distributed/dashboard/nvml.py index 7fd628dd469..131a02a8397 100644 --- a/distributed/dashboard/nvml.py +++ b/distributed/dashboard/nvml.py @@ -183,5 +183,12 @@ def gpu_utilization_doc(scheduler, extra, doc): doc.theme = BOKEH_THEME -applications["/individual-gpu-memory"] = gpu_memory_doc -applications["/individual-gpu-utilization"] = gpu_utilization_doc +try: + import pynvml + + pynvml.nvmlInit() +except Exception: + pass +else: + applications["/individual-gpu-memory"] = gpu_memory_doc + applications["/individual-gpu-utilization"] = gpu_utilization_doc diff --git a/distributed/dashboard/scheduler.py b/distributed/dashboard/scheduler.py index 7f172c879d1..27a49b4fd68 100644 --- a/distributed/dashboard/scheduler.py +++ b/distributed/dashboard/scheduler.py @@ -8,6 +8,7 @@ from bokeh.layouts import column, row from bokeh.models import ( ColumnDataSource, + ColorBar, DataRange1d, HoverTool, ResetTool, @@ -32,7 +33,7 @@ from bokeh.plotting import figure from bokeh.palettes import Viridis11 from bokeh.themes import Theme -from bokeh.transform import factor_cmap +from bokeh.transform import factor_cmap, linear_cmap from bokeh.io import curdoc import dask from dask.utils import format_bytes @@ -295,6 +296,166 @@ def update(self): self.root.title.text = "Bytes stored (Histogram): " + format_bytes(nbytes.sum()) +class BandwidthTypes(DashboardComponent): + """ Bar chart showing bandwidth per type """ + + def __init__(self, scheduler, **kwargs): + with log_errors(): + self.last = 0 + self.scheduler = scheduler + self.source = ColumnDataSource( + { + "bandwidth": [1, 2], + "bandwidth-half": [0.5, 1], + "type": ["a", "b"], + "bandwidth_text": ["1", "2"], + } + ) + + fig = figure( + title="Bandwidth by Type", + tools="", + id="bk-bandwidth-type-plot", + name="bandwidth_type_histogram", + y_range=["a", "b"], + **kwargs + ) + rect = fig.rect( + source=self.source, + x="bandwidth-half", + y="type", + width="bandwidth", + height=1, + color="blue", + ) + fig.x_range.start = 0 + fig.xaxis[0].formatter = NumeralTickFormatter(format="0.0 b") + rect.nonselection_glyph = None + + fig.xaxis.minor_tick_line_alpha = 0 + fig.ygrid.visible = False + + fig.toolbar.logo = None + fig.toolbar_location = None + + hover = HoverTool() + hover.tooltips = "@type: @bandwidth_text / s" + hover.point_policy = "follow_mouse" + fig.add_tools(hover) + + self.fig = fig + + @without_property_validation + def update(self): + with log_errors(): + bw = self.scheduler.bandwidth_types + self.fig.y_range.factors = list(sorted(bw)) + result = { + "bandwidth": list(bw.values()), + "bandwidth-half": [b / 2 for b in bw.values()], + "type": list(bw.keys()), + "bandwidth_text": list(map(format_bytes, bw.values())), + } + self.fig.title.text = "Bandwidth: " + format_bytes(self.scheduler.bandwidth) + + update(self.source, result) + + +class BandwidthWorkers(DashboardComponent): + """ How many tasks are on each worker """ + + def __init__(self, scheduler, **kwargs): + with log_errors(): + self.last = 0 + self.scheduler = scheduler + self.source = ColumnDataSource( + { + "bandwidth": [1, 2], + "source": ["a", "b"], + "destination": ["a", "b"], + "bandwidth_text": ["1", "2"], + } + ) + + values = [hex(x)[2:] for x in range(64, 256)][::-1] + mapper = linear_cmap( + field_name="bandwidth", + palette=["#" + x + x + "FF" for x in values], + low=0, + high=1, + ) + + fig = figure( + title="Bandwidth by Worker", + tools="", + id="bk-bandwidth-worker-plot", + name="bandwidth_worker_heatmap", + x_range=["a", "b"], + y_range=["a", "b"], + **kwargs + ) + fig.xaxis.major_label_orientation = -math.pi / 12 + rect = fig.rect( + source=self.source, + x="source", + y="destination", + color=mapper, + height=1, + width=1, + ) + + self.color_map = mapper["transform"] + color_bar = ColorBar( + color_mapper=self.color_map, + label_standoff=12, + border_line_color=None, + location=(0, 0), + ) + color_bar.formatter = NumeralTickFormatter(format="0 b") + fig.add_layout(color_bar, "right") + + fig.toolbar.logo = None + fig.toolbar_location = None + + hover = HoverTool() + hover.tooltips = """ +
          +

          Source: @source

          +

          Destination: @destination

          +

          Bandwidth: @bandwidth_text / s

          +
          + """ + hover.point_policy = "follow_mouse" + fig.add_tools(hover) + + self.fig = fig + + @without_property_validation + def update(self): + with log_errors(): + bw = self.scheduler.bandwidth_workers + if not bw: + return + x, y, value = zip(*[(a, b, c) for (a, b), c in bw.items()]) + + if self.color_map.high < max(value): + self.color_map.high = max(value) + + factors = list(sorted(set(x + y))) + self.fig.x_range.factors = factors + self.fig.y_range.factors = factors + + result = { + "source": x, + "destination": y, + "bandwidth": value, + "bandwidth_text": list(map(format_bytes, value)), + } + self.fig.title.text = "Bandwidth: " + format_bytes(self.scheduler.bandwidth) + + update(self.source, result) + + class CurrentLoad(DashboardComponent): """ How many tasks are on each worker """ @@ -1596,6 +1757,24 @@ def individual_workers_doc(scheduler, extra, doc): doc.theme = BOKEH_THEME +def individual_bandwidth_types(scheduler, extra, doc): + with log_errors(): + bw = BandwidthTypes(scheduler, sizing_mode="stretch_both") + bw.update() + add_periodic_callback(doc, bw, 500) + doc.add_root(bw.fig) + doc.theme = BOKEH_THEME + + +def individual_bandwidth_workers(scheduler, extra, doc): + with log_errors(): + bw = BandwidthWorkers(scheduler, sizing_mode="stretch_both") + bw.update() + add_periodic_callback(doc, bw, 500) + doc.add_root(bw.fig) + doc.theme = BOKEH_THEME + + def profile_doc(scheduler, extra, doc): with log_errors(): doc.title = "Dask: Profile" @@ -1703,6 +1882,8 @@ def listen(self, *args, **kwargs): "/individual-cpu": individual_cpu_doc, "/individual-nprocessing": individual_nprocessing_doc, "/individual-workers": individual_workers_doc, + "/individual-bandwidth-types": individual_bandwidth_types, + "/individual-bandwidth-workers": individual_bandwidth_workers, } try: diff --git a/distributed/dashboard/tests/test_scheduler_bokeh.py b/distributed/dashboard/tests/test_scheduler_bokeh.py index 2dc29572ea8..e68d7935583 100644 --- a/distributed/dashboard/tests/test_scheduler_bokeh.py +++ b/distributed/dashboard/tests/test_scheduler_bokeh.py @@ -664,6 +664,11 @@ def test_https_support(c, s, a, b): ctx.load_verify_locations(get_cert("tls-ca-cert.pem")) http_client = AsyncHTTPClient() + response = yield http_client.fetch( + "https://localhost:%d/individual-plots.json" % port, ssl_options=ctx + ) + response = json.loads(response.body.decode()) + for suffix in [ "system", "counters", @@ -672,13 +677,7 @@ def test_https_support(c, s, a, b): "tasks", "stealing", "graph", - "individual-task-stream", - "individual-progress", - "individual-graph", - "individual-nbytes", - "individual-nprocessing", - "individual-profile", - ]: + ] + [url.strip("/") for url in response.values()]: req = HTTPRequest( url="https://localhost:%d/%s" % (port, suffix), ssl_options=ctx ) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index bdee7c7bdac..f3d3fc92ea8 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -210,6 +210,7 @@ class WorkerState(object): __slots__ = ( "actors", "address", + "bandwidth", "extra", "has_what", "last_seen", @@ -257,6 +258,7 @@ def __init__( self.metrics = {} self.last_seen = 0 self.time_delay = 0 + self.bandwidth = parse_bytes(dask.config.get("distributed.scheduler.bandwidth")) self.actors = set() self.has_what = set() @@ -881,6 +883,8 @@ def __init__( self.idle_timeout = None self.time_started = time() self.bandwidth = parse_bytes(dask.config.get("distributed.scheduler.bandwidth")) + self.bandwidth_workers = defaultdict(float) + self.bandwidth_types = defaultdict(float) if not preload: preload = dask.config.get("distributed.scheduler.preload") @@ -1346,9 +1350,27 @@ def heartbeat_worker( host_info = host_info or {} self.host_info[host]["last-seen"] = local_now - frac = 1 / 20 / len(self.workers) + frac = 1 / len(self.workers) try: - self.bandwidth = self.bandwidth * (1 - frac) + metrics["bandwidth"] * frac + self.bandwidth = ( + self.bandwidth * (1 - frac) + metrics["bandwidth"]["total"] * frac + ) + for other, (bw, count) in metrics["bandwidth"]["workers"].items(): + if (address, other) not in self.bandwidth_workers: + self.bandwidth_workers[address, other] = bw / count + else: + alpha = (1 - frac) ** count + self.bandwidth_workers[address, other] = self.bandwidth_workers[ + address, other + ] * alpha + bw * (1 - alpha) + for typ, (bw, count) in metrics["bandwidth"]["types"].items(): + if typ not in self.bandwidth_types: + self.bandwidth_types[typ] = bw / count + else: + alpha = (1 - frac) ** count + self.bandwidth_types[typ] = self.bandwidth_types[ + typ + ] * alpha + bw * (1 - alpha) except KeyError: pass @@ -1948,6 +1970,10 @@ def remove_worker(self, comm=None, address=None, safe=False, close=True): if not self.workers: logger.info("Lost all workers") + for w in self.workers: + self.bandwidth_workers.pop((address, w), None) + self.bandwidth_workers.pop((w, address), None) + def remove_worker_from_events(): # If the worker isn't registered anymore after the delay, remove from events if address not in self.workers and address in self.events: diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 71a19eba54f..4b56a4d084f 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -22,7 +22,7 @@ from distributed.metrics import time from distributed.protocol.pickle import dumps from distributed.worker import dumps_function, dumps_task -from distributed.utils import tmpfile +from distributed.utils import tmpfile, typename from distributed.utils_test import ( # noqa: F401 captured_logger, cleanup, @@ -1515,14 +1515,38 @@ def test_idle_timeout(c, s, a, b): @gen_cluster(client=True, config={"distributed.scheduler.bandwidth": "100 GB"}) -def test_bandwidth(c, s, a, b): +async def test_bandwidth(c, s, a, b): start = s.bandwidth - x = c.submit(operator.mul, b"0", 20000, workers=a.address) + x = c.submit(operator.mul, b"0", 1000000, workers=a.address) y = c.submit(lambda x: x, x, workers=b.address) - yield y - yield b.heartbeat() + await y + await b.heartbeat() assert s.bandwidth < start # we've learned that we're slower assert b.latency + assert typename(bytes) in s.bandwidth_types + assert (b.address, a.address) in s.bandwidth_workers + + await a.close() + assert not s.bandwidth_workers + + +@gen_cluster(client=True, Worker=Nanny) +async def test_bandwidth_clear(c, s, a, b): + np = pytest.importorskip("numpy") + x = c.submit(np.arange, 1000000, workers=[a.worker_address], pure=False) + y = c.submit(np.arange, 1000000, workers=[b.worker_address], pure=False) + z = c.submit(operator.add, x, y) # force communication + await z + + async def f(dask_worker): + await dask_worker.heartbeat() + + await c.run(f) + + assert s.bandwidth_workers + + await s.restart() + assert not s.bandwidth_workers @gen_cluster() diff --git a/distributed/worker.py b/distributed/worker.py index 2bd345ea0c5..fba4eed57b3 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -20,9 +20,9 @@ from dask.utils import format_bytes, funcname try: - from cytoolz import pluck, partial, merge, first + from cytoolz import pluck, partial, merge, first, keymap except ImportError: - from toolz import pluck, partial, merge, first + from toolz import pluck, partial, merge, first, keymap from tornado import gen from tornado.ioloop import IOLoop @@ -416,6 +416,10 @@ def __init__( self.outgoing_current_count = 0 self.repetitively_busy = 0 self.bandwidth = parse_bytes(dask.config.get("distributed.scheduler.bandwidth")) + self.bandwidth_workers = defaultdict( + lambda: (0, 0) + ) # bw/count recent transfers + self.bandwidth_types = defaultdict(lambda: (0, 0)) # bw/count recent transfers self.latency = 0.001 self._client = None @@ -728,7 +732,11 @@ async def get_metrics(self): in_memory=len(self.data), ready=len(self.ready), in_flight=len(self.in_flight_tasks), - bandwidth=self.bandwidth, + bandwidth={ + "total": self.bandwidth, + "workers": dict(self.bandwidth_workers), + "types": keymap(typename, self.bandwidth_types), + }, ) custom = {} for k, metric in self.metrics.items(): @@ -881,6 +889,8 @@ async def heartbeat(self): self.periodic_callbacks["heartbeat"].callback_time = ( response["heartbeat-interval"] * 1000 ) + self.bandwidth_workers.clear() + self.bandwidth_types.clear() except CommClosedError: logger.warning("Heartbeat to scheduler failed") finally: @@ -1920,8 +1930,17 @@ async def gather_dep(self, worker, dep, deps, total_nbytes, cause=None): "who": worker, } ) - if total_bytes > 10000: + if total_bytes > 1000000: self.bandwidth = self.bandwidth * 0.95 + bandwidth * 0.05 + bw, cnt = self.bandwidth_workers[worker] + self.bandwidth_workers[worker] = (bw + bandwidth, cnt + 1) + + types = set(map(type, response["data"].values())) + if len(types) == 1: + [typ] = types + bw, cnt = self.bandwidth_types[typ] + self.bandwidth_types[typ] = (bw + bandwidth, cnt + 1) + if self.digests is not None: self.digests["transfer-bandwidth"].add(total_bytes / duration) self.digests["transfer-duration"].add(duration) From 00ac6f0dbae52aa085834b2e4f3a513e02463008 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 1 Oct 2019 17:08:47 -0500 Subject: [PATCH 0489/1550] Add Client.shutdown method (#3106) This lets the client shut down the scheduler and workers Asked for in https://stackoverflow.com/questions/50919227/is-it-possible-to-shutdown-a-dask-distributed-cluster-given-a-client-instance --- distributed/client.py | 23 +++++++++++++++++------ distributed/tests/test_client.py | 20 ++++++++++++++++++++ 2 files changed, 37 insertions(+), 6 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index d00f545b161..f918fffbf78 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -1335,14 +1335,25 @@ async def _(): if self._should_close_loop and not shutting_down(): self._loop_runner.stop() - def shutdown(self, *args, **kwargs): - """ Deprecated, see close instead + async def _shutdown(self): + logger.info("Shutting down scheduler from Client") + if self.cluster: + await self.cluster.close() + else: + with ignoring(CommClosedError): + await self.scheduler.terminate(close_workers=True) + + def shutdown(self): + """ Shut down the connected scheduler and workers + + Note, this may disrupt other clients that may be using the same + scheudler and workers. - This was deprecated because "shutdown" was sometimes confusingly - thought to refer to the cluster rather than the client + See also + -------- + Client.close: close only this client """ - warnings.warn("Shutdown is deprecated. Please use close instead") - return self.close(*args, **kwargs) + return self.sync(self._shutdown) def get_executor(self, **kwargs): """ diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index a95548689d1..d01088502f0 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -5636,5 +5636,25 @@ def dashboard_link(self): assert "http://foo.com" in client._repr_html_() +@pytest.mark.asyncio +async def test_shutdown(cleanup): + async with Scheduler(port=0) as s: + async with Worker(s.address) as w: + async with Client(s.address, asynchronous=True) as c: + await c.shutdown() + + assert s.status == "closed" + assert w.status == "closed" + + +@pytest.mark.asyncio +async def test_shutdown_localcluster(cleanup): + async with LocalCluster(n_workers=1, asynchronous=True, processes=False) as lc: + async with Client(lc, asynchronous=True) as c: + await c.shutdown() + + assert lc.scheduler.status == "closed" + + if sys.version_info >= (3, 5): from distributed.tests.py3_test_client import * # noqa F401 From 78d8cad9e8934a3b863999832b7c46dfe8162683 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Thu, 3 Oct 2019 09:36:37 -0500 Subject: [PATCH 0490/1550] Identify lost workers in SpecCluster based on address not name (#3088) Fixes #3062 --- distributed/deploy/spec.py | 9 ++++++++- distributed/deploy/tests/test_spec_cluster.py | 20 +++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index 5cc0722ef72..2dc910de2cc 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -309,7 +309,14 @@ def _update_worker_status(self, op, msg): name = self.scheduler_info["workers"][msg]["name"] def f(): - if name in self.workers and msg not in self.scheduler_info: + if ( + name in self.workers + and msg not in self.scheduler_info["workers"] + and not any( + d["name"] == name + for d in self.scheduler_info["workers"].values() + ) + ): self._futures.add(asyncio.ensure_future(self.workers[name].close())) del self.workers[name] diff --git a/distributed/deploy/tests/test_spec_cluster.py b/distributed/deploy/tests/test_spec_cluster.py index 485ae1989ea..e0ea735f8ce 100644 --- a/distributed/deploy/tests/test_spec_cluster.py +++ b/distributed/deploy/tests/test_spec_cluster.py @@ -157,6 +157,26 @@ async def test_unexpected_closed_worker(cleanup): assert len(cluster.workers) == 2 +@pytest.mark.slow +@pytest.mark.asyncio +async def test_restart(cleanup): + # Regression test for https://github.com/dask/distributed/issues/3062 + worker = {"cls": Nanny, "options": {"nthreads": 1}} + with dask.config.set({"distributed.deploy.lost-worker-timeout": "2s"}): + async with SpecCluster( + asynchronous=True, scheduler=scheduler, worker=worker + ) as cluster: + async with Client(cluster, asynchronous=True) as client: + cluster.scale(2) + await cluster + assert len(cluster.workers) == 2 + + await client.restart() + await asyncio.sleep(3) + + assert len(cluster.workers) == 2 + + @pytest.mark.asyncio async def test_broken_worker(): with pytest.raises(Exception) as info: From 4f11e7c6ecc3422b26d31124b8300ac4fbb5d1d2 Mon Sep 17 00:00:00 2001 From: Jim Crist Date: Fri, 4 Oct 2019 05:16:48 -0500 Subject: [PATCH 0491/1550] Support calling `cluster.scale` as async method (#3110) This adds optional support for calling `cluster.scale` as an async method (i.e. `await cluster.scale(...)`). This is currently optional and backwards compatible - perhaps in the future we may want to deprecate calling in the non-async context. --- distributed/deploy/spec.py | 25 ++++++++++++++----- distributed/deploy/tests/test_spec_cluster.py | 5 ++++ 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index 2dc910de2cc..11ff3b44322 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -77,6 +77,19 @@ async def __aexit__(self, *args, **kwargs): await self.close() +class NoOpAwaitable(object): + """An awaitable object that always returns None. + + Useful to return from a method that can be called in both asynchronous and + synchronous contexts""" + + def __await__(self): + async def f(): + return None + + return f().__await__() + + class SpecCluster(Cluster): """ Cluster that requires a full specification of workers @@ -418,15 +431,15 @@ def scale(self, n=0, memory=None, cores=None): while len(self.worker_spec) > n: self.worker_spec.popitem() - if self.status in ("closing", "closed"): - self.loop.add_callback(self._correct_state) - return - - while len(self.worker_spec) < n: - self.worker_spec.update(self.new_worker_spec()) + if self.status not in ("closing", "closed"): + while len(self.worker_spec) < n: + self.worker_spec.update(self.new_worker_spec()) self.loop.add_callback(self._correct_state) + if self.asynchronous: + return NoOpAwaitable() + def new_worker_spec(self): """ Return name and spec for the next worker diff --git a/distributed/deploy/tests/test_spec_cluster.py b/distributed/deploy/tests/test_spec_cluster.py index e0ea735f8ce..1c8a01e98ce 100644 --- a/distributed/deploy/tests/test_spec_cluster.py +++ b/distributed/deploy/tests/test_spec_cluster.py @@ -125,6 +125,11 @@ async def test_scale(cleanup): await cluster assert len(cluster.workers) == 1 + # Can use with await + await cluster.scale(2) + await cluster + assert len(cluster.workers) == 2 + @pytest.mark.asyncio async def test_unexpected_closed_worker(cleanup): From f1c7bfdb7f6041d566b56ca1ffb439e8e09b39fb Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Fri, 4 Oct 2019 12:53:28 +0200 Subject: [PATCH 0492/1550] UCX: trying to allocate CUDA arrays using RMM and Numba (#3109) * Removed duplicate code * Added tcp to the default UCX_TLS * Trying to allocate CUDA arrays using RMM and Numba --- distributed/comm/ucx.py | 36 ++++++++++++++++++++++-------------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index bbf2451e323..77f65c661e8 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -16,13 +16,12 @@ from tornado.ioloop import IOLoop import ucp import numpy as np -import numba.cuda import os os.environ.setdefault("UCX_RNDV_SCHEME", "put_zcopy") os.environ.setdefault("UCX_MEMTYPE_CACHE", "n") -os.environ.setdefault("UCX_TLS", "rc,cuda_copy,cuda_ipc") +os.environ.setdefault("UCX_TLS", "tcp,rc,cuda_copy,cuda_ipc") logger = logging.getLogger(__name__) MAX_MSG_LOG = 23 @@ -32,6 +31,23 @@ # Comm Interface # ---------------------------------------------------------------------------- +# Let's find the function, `cuda_array`, to use when allocating new CUDA arrays +try: + import rmm + + cuda_array = lambda n: rmm.device_array(n, dtype=np.uint8) +except ImportError: + try: + import numba.cuda + + cuda_array = lambda n: numba.cuda.device_array((n,), dtype=np.uint8) + except ImportError: + + def cuda_array(n): + raise RuntimeError( + "In order to send/recv CUDA arrays, Numba or RMM is required" + ) + class UCX(Comm): """Comm object using UCP. @@ -116,12 +132,7 @@ async def write( # Send frames for frame in frames: if nbytes(frame) > 0: - if hasattr(frame, "__array_interface__") or hasattr( - frame, "__cuda_array_interface__" - ): - await self.ep.send(frame) - else: - await self.ep.send(frame) + await self.ep.send(frame) return sum(map(nbytes, frames)) async def read(self, deserializers=("cuda", "dask", "pickle", "error")): @@ -152,17 +163,14 @@ async def read(self, deserializers=("cuda", "dask", "pickle", "error")): for is_cuda, size in zip(is_cudas.tolist(), sizes.tolist()): if size > 0: if is_cuda: - frame = numba.cuda.device_array((size,), dtype=np.uint8) + frame = cuda_array(size) else: frame = np.empty(size, dtype=np.uint8) await self.ep.recv(frame) - if is_cuda: - frames.append(frame) - else: - frames.append(frame.data) + frames.append(frame) else: if is_cuda: - frames.append(numba.cuda.device_array((0,), dtype=np.uint8)) + frames.append(cuda_array(size)) else: frames.append(b"") msg = await from_frames( From 670f2e84a7607b685cbc2f8fbdbfb29aa8267c84 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 4 Oct 2019 09:33:03 -0500 Subject: [PATCH 0493/1550] Replace use of tornado.gen with asyncio in dask-worker (#3114) --- distributed/cli/dask_worker.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/distributed/cli/dask_worker.py b/distributed/cli/dask_worker.py index 35ca11da34c..31c17fed0d0 100755 --- a/distributed/cli/dask_worker.py +++ b/distributed/cli/dask_worker.py @@ -1,3 +1,4 @@ +import asyncio import atexit import logging import gc @@ -21,7 +22,6 @@ from toolz import valmap from tornado.ioloop import IOLoop, TimeoutError -from tornado import gen logger = logging.getLogger("distributed.dask_worker") @@ -379,20 +379,18 @@ def del_pid_file(): for i in range(nprocs) ] - @gen.coroutine - def close_all(): + async def close_all(): # Unregister all workers from scheduler if nanny: - yield [n.close(timeout=2) for n in nannies] + await asyncio.gather(*[n.close(timeout=2) for n in nannies]) def on_signal(signum): logger.info("Exiting on signal %d", signum) - close_all() + asyncio.ensure_future(close_all()) - @gen.coroutine - def run(): - yield nannies - yield [n.finished() for n in nannies] + async def run(): + await asyncio.gather(*nannies) + await asyncio.gather(*[n.finished() for n in nannies]) install_signal_handlers(loop, cleanup=on_signal) From 1630f4de2579ed0d99e56df5ae778cb35cceab77 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 4 Oct 2019 14:23:24 -0500 Subject: [PATCH 0494/1550] Make dask-worker close quietly when given sigint signal (#3116) --- distributed/cli/dask_worker.py | 26 +++++++++++++++-------- distributed/cli/tests/test_dask_worker.py | 1 - distributed/nanny.py | 2 +- 3 files changed, 18 insertions(+), 11 deletions(-) diff --git a/distributed/cli/dask_worker.py b/distributed/cli/dask_worker.py index 31c17fed0d0..fb32fc2e882 100755 --- a/distributed/cli/dask_worker.py +++ b/distributed/cli/dask_worker.py @@ -3,7 +3,8 @@ import logging import gc import os -from sys import exit +import signal +import sys import warnings import click @@ -267,34 +268,34 @@ def main( logger.error( "Failed to launch worker. You cannot use the --port argument when nprocs > 1." ) - exit(1) + sys.exit(1) if nprocs > 1 and not nanny: logger.error( "Failed to launch worker. You cannot use the --no-nanny argument when nprocs > 1." ) - exit(1) + sys.exit(1) if contact_address and not listen_address: logger.error( "Failed to launch worker. " "Must specify --listen-address when --contact-address is given" ) - exit(1) + sys.exit(1) if nprocs > 1 and listen_address: logger.error( "Failed to launch worker. " "You cannot specify --listen-address when nprocs > 1." ) - exit(1) + sys.exit(1) if (worker_port or host) and listen_address: logger.error( "Failed to launch worker. " "You cannot specify --listen-address when --worker-port or --host is given." ) - exit(1) + sys.exit(1) try: if listen_address: @@ -308,7 +309,7 @@ def main( contact_address = listen_address except ValueError as e: logger.error("Failed to launch worker. " + str(e)) - exit(1) + sys.exit(1) if nanny: port = nanny_port @@ -384,8 +385,13 @@ async def close_all(): if nanny: await asyncio.gather(*[n.close(timeout=2) for n in nannies]) + signal_fired = False + def on_signal(signum): - logger.info("Exiting on signal %d", signum) + nonlocal signal_fired + signal_fired = True + if signum != signal.SIGINT: + logger.info("Exiting on signal %d", signum) asyncio.ensure_future(close_all()) async def run(): @@ -398,7 +404,9 @@ async def run(): loop.run_sync(run) except TimeoutError: # We already log the exception in nanny / worker. Don't do it again. - raise TimeoutError("Timed out starting worker.") from None + if not signal_fired: + logger.info("Timed out starting worker") + sys.exit(1) except KeyboardInterrupt: pass finally: diff --git a/distributed/cli/tests/test_dask_worker.py b/distributed/cli/tests/test_dask_worker.py index 01327d64291..c509772d113 100644 --- a/distributed/cli/tests/test_dask_worker.py +++ b/distributed/cli/tests/test_dask_worker.py @@ -311,7 +311,6 @@ def test_worker_timeout(no_nanny): args.append("--no-nanny") result = runner.invoke(distributed.cli.dask_worker.main, args) assert result.exit_code != 0 - assert str(result.exception).startswith("Timed out") def test_bokeh_deprecation(): diff --git a/distributed/nanny.py b/distributed/nanny.py index 06b5a27dc79..8fbbf761368 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -314,7 +314,7 @@ async def instantiate(self, comm=None): ) except gen.TimeoutError: await self.close(timeout=self.death_timeout) - logger.exception( + logger.error( "Timed out connecting Nanny '%s' to scheduler '%s'", self, self.scheduler_addr, From 8d300008e1c93525309198cc2e09405f6e2a8c04 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 4 Oct 2019 15:25:30 -0500 Subject: [PATCH 0495/1550] bump version to 2.5.2 --- docs/source/changelog.rst | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index aff09d2c3c7..5b6288885fc 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -1,6 +1,23 @@ Changelog ========= +2.5.2 - 2019-10-04 +------------------ + +- Make dask-worker close quietly when given sigint signal (:pr:`3116`) `Matthew Rocklin`_ +- Replace use of tornado.gen with asyncio in dask-worker (:pr:`3114`) `Matthew Rocklin`_ +- UCX: allocate CUDA arrays using RMM and Numba (:pr:`3109`) `Mads R. B. Kristensen`_ +- Support calling `cluster.scale` as async method (:pr:`3110`) `Jim Crist`_ +- Identify lost workers in SpecCluster based on address not name (:pr:`3088`) `James Bourbeau`_ +- Add Client.shutdown method (:pr:`3106`) `Matthew Rocklin`_ +- Collect worker-worker and type bandwidth information (:pr:`3094`) `Matthew Rocklin`_ +- Send noise over the wire to keep dask-ssh connection alive (:pr:`3105`) `Gil Forsyth`_ +- Retry scheduler connect multiple times (:pr:`3104`) `Jacob Tomlinson`_ +- Add favicon of logo to the dashboard (:pr:`3095`) `James Bourbeau`_ +- Remove utils.py functions for their dask/utils.py equivalents (:pr:`3042`) `Matthew Rocklin`_ +- Lower default bokeh log level (:pr:`3087`) `Philipp Rudiger`_ +- Check if self.cluster.scheduler is a local scheduler (:pr:`3099`) `Jacob Tomlinson`_ + 2.5.1 - 2019-09-27 ------------------ @@ -1296,3 +1313,5 @@ significantly without many new features. .. _`Mohammad Noor`: https://github.com/MdSalih .. _`Richard J Zamora`: https://github.com/rjzamora .. _`Arpit Solanki`: https://github.com/arpit1997 +.. _`Gil Forsyth`: https://github.com/gforsyth +.. _`Philipp Rudiger`: https://github.com/philippjfr From 8dd912c1e71105f714292e2ecabb14804837fbe9 Mon Sep 17 00:00:00 2001 From: Jim Crist Date: Tue, 8 Oct 2019 11:53:48 -0500 Subject: [PATCH 0496/1550] Bump dask dependency (#3124) Due to moving some things around, distributed 2.5 now relies on dask 2.5. We bump the required version to account for this. --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index e599ab0de93..d1335d0b3b1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ click >= 6.6 cloudpickle >= 0.2.2 -dask >= 2.3 +dask >= 2.5.2 msgpack psutil >= 5.0 sortedcontainers !=2.0.0, !=2.0.1 From dd0d60b9505952951736109f0b6e2aadea36744e Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 9 Oct 2019 08:24:16 -0500 Subject: [PATCH 0497/1550] Move new SSHCluster to top level (#3128) We have two SSHCluster implementations, one that has been around for a long while, and is still used in the dask-ssh CLI, and a newer implementation that is based off of SpecCluster. The old solution is battle hardened, and probably nicer for the CLI, but not very modifyable. The new solution is much simpler, and defers most of the configuration to other systems, notably asyncssh and the worker/scheduler classes. This PR keeps both, but moves the new system to the main import, while gracefully deferring to the old solution if the user seems to be using those options (it is very obvious which the user intended to use). We also move SSHCluster to a top-level import so users can do the following: from dask.distributed import SSHCluster, Client --- continuous_integration/travis/install.sh | 2 +- distributed/__init__.py | 2 +- distributed/cli/dask_ssh.py | 2 +- distributed/deploy/__init__.py | 1 + distributed/deploy/old_ssh.py | 472 +++++++++++++++ distributed/deploy/ssh.py | 724 +++++++++-------------- distributed/deploy/ssh2.py | 236 -------- distributed/deploy/tests/test_old_ssh.py | 31 + distributed/deploy/tests/test_ssh.py | 72 ++- distributed/deploy/tests/test_ssh2.py | 43 -- 10 files changed, 825 insertions(+), 760 deletions(-) create mode 100644 distributed/deploy/old_ssh.py delete mode 100644 distributed/deploy/ssh2.py create mode 100644 distributed/deploy/tests/test_old_ssh.py delete mode 100644 distributed/deploy/tests/test_ssh2.py diff --git a/continuous_integration/travis/install.sh b/continuous_integration/travis/install.sh index f0c4a07be67..8c34f38d276 100644 --- a/continuous_integration/travis/install.sh +++ b/continuous_integration/travis/install.sh @@ -55,7 +55,7 @@ conda install -q \ # For low-level profiler, install libunwind and stacktrace from conda-forge # For stacktrace we use --no-deps to avoid upgrade of python -conda install -c defaults -c conda-forge libunwind zstandard +conda install -c defaults -c conda-forge libunwind zstandard asyncssh conda install --no-deps -c defaults -c numba -c conda-forge stacktrace pip install -q "pytest>=4" pytest-repeat pytest-faulthandler pytest-asyncio diff --git a/distributed/__init__.py b/distributed/__init__.py index d79993dfef7..07015ff44af 100644 --- a/distributed/__init__.py +++ b/distributed/__init__.py @@ -2,7 +2,7 @@ from dask.config import config from .actor import Actor, ActorFuture from .core import connect, rpc -from .deploy import LocalCluster, Adaptive, SpecCluster +from .deploy import LocalCluster, Adaptive, SpecCluster, SSHCluster from .diagnostics.progressbar import progress from .client import ( Client, diff --git a/distributed/cli/dask_ssh.py b/distributed/cli/dask_ssh.py index 97cf91f3519..07cbb57bf01 100755 --- a/distributed/cli/dask_ssh.py +++ b/distributed/cli/dask_ssh.py @@ -1,4 +1,4 @@ -from distributed.deploy.ssh import SSHCluster +from distributed.deploy.old_ssh import SSHCluster import click from distributed.cli.utils import check_python_3 diff --git a/distributed/deploy/__init__.py b/distributed/deploy/__init__.py index 5a5a9106005..35fc86fe393 100644 --- a/distributed/deploy/__init__.py +++ b/distributed/deploy/__init__.py @@ -2,6 +2,7 @@ from .cluster import Cluster from .local import LocalCluster +from .ssh import SSHCluster from .spec import SpecCluster, ProcessInterface from .adaptive import Adaptive diff --git a/distributed/deploy/old_ssh.py b/distributed/deploy/old_ssh.py new file mode 100644 index 00000000000..30f6f819224 --- /dev/null +++ b/distributed/deploy/old_ssh.py @@ -0,0 +1,472 @@ +import logging +import socket +import os +import sys +import time +import traceback + +try: + from queue import Queue +except ImportError: # Python 2.7 fix + from Queue import Queue + +from threading import Thread + +from toolz import merge + +from tornado import gen + + +logger = logging.getLogger(__name__) + + +# These are handy for creating colorful terminal output to enhance readability +# of the output generated by dask-ssh. +class bcolors: + HEADER = "\033[95m" + OKBLUE = "\033[94m" + OKGREEN = "\033[92m" + WARNING = "\033[93m" + FAIL = "\033[91m" + ENDC = "\033[0m" + BOLD = "\033[1m" + UNDERLINE = "\033[4m" + + +def async_ssh(cmd_dict): + import paramiko + from paramiko.buffered_pipe import PipeTimeout + from paramiko.ssh_exception import SSHException, PasswordRequiredException + + ssh = paramiko.SSHClient() + ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + + retries = 0 + while True: # Be robust to transient SSH failures. + try: + # Set paramiko logging to WARN or higher to squelch INFO messages. + logging.getLogger("paramiko").setLevel(logging.WARN) + + ssh.connect( + hostname=cmd_dict["address"], + username=cmd_dict["ssh_username"], + port=cmd_dict["ssh_port"], + key_filename=cmd_dict["ssh_private_key"], + compress=True, + timeout=20, + banner_timeout=20, + ) # Helps prevent timeouts when many concurrent ssh connections are opened. + # Connection successful, break out of while loop + break + + except (SSHException, PasswordRequiredException) as e: + + print( + "[ dask-ssh ] : " + + bcolors.FAIL + + "SSH connection error when connecting to {addr}:{port}" + "to run '{cmd}'".format( + addr=cmd_dict["address"], + port=cmd_dict["ssh_port"], + cmd=cmd_dict["cmd"], + ) + + bcolors.ENDC + ) + + print( + bcolors.FAIL + + " SSH reported this exception: " + + str(e) + + bcolors.ENDC + ) + + # Print an exception traceback + traceback.print_exc() + + # Transient SSH errors can occur when many SSH connections are + # simultaneously opened to the same server. This makes a few + # attempts to retry. + retries += 1 + if retries >= 3: + print( + "[ dask-ssh ] : " + + bcolors.FAIL + + "SSH connection failed after 3 retries. Exiting." + + bcolors.ENDC + ) + + # Connection failed after multiple attempts. Terminate this thread. + os._exit(1) + + # Wait a moment before retrying + print( + " " + + bcolors.FAIL + + "Retrying... (attempt {n}/{total})".format(n=retries, total=3) + + bcolors.ENDC + ) + + time.sleep(1) + + # Execute the command, and grab file handles for stdout and stderr. Note + # that we run the command using the user's default shell, but force it to + # run in an interactive login shell, which hopefully ensures that all of the + # user's normal environment variables (via the dot files) have been loaded + # before the command is run. This should help to ensure that important + # aspects of the environment like PATH and PYTHONPATH are configured. + + print("[ {label} ] : {cmd}".format(label=cmd_dict["label"], cmd=cmd_dict["cmd"])) + stdin, stdout, stderr = ssh.exec_command( + "$SHELL -i -c '" + cmd_dict["cmd"] + "'", get_pty=True + ) + + # Set up channel timeout (which we rely on below to make readline() non-blocking) + channel = stdout.channel + channel.settimeout(0.1) + + def read_from_stdout(): + """ + Read stdout stream, time out if necessary. + """ + try: + line = stdout.readline() + while len(line) > 0: # Loops until a timeout exception occurs + line = line.rstrip() + logger.debug("stdout from ssh channel: %s", line) + cmd_dict["output_queue"].put( + "[ {label} ] : {output}".format( + label=cmd_dict["label"], output=line + ) + ) + line = stdout.readline() + except (PipeTimeout, socket.timeout): + pass + + def read_from_stderr(): + """ + Read stderr stream, time out if necessary. + """ + try: + line = stderr.readline() + while len(line) > 0: + line = line.rstrip() + logger.debug("stderr from ssh channel: %s", line) + cmd_dict["output_queue"].put( + "[ {label} ] : ".format(label=cmd_dict["label"]) + + bcolors.FAIL + + "{output}".format(output=line) + + bcolors.ENDC + ) + line = stderr.readline() + except (PipeTimeout, socket.timeout): + pass + + def communicate(): + """ + Communicate a little bit, without blocking too long. + Return True if the command ended. + """ + read_from_stdout() + read_from_stderr() + + # Check to see if the process has exited. If it has, we let this thread + # terminate. + if channel.exit_status_ready(): + exit_status = channel.recv_exit_status() + cmd_dict["output_queue"].put( + "[ {label} ] : ".format(label=cmd_dict["label"]) + + bcolors.FAIL + + "remote process exited with exit status " + + str(exit_status) + + bcolors.ENDC + ) + return True + + # Get transport to current SSH client + transport = ssh.get_transport() + + # Wait for a message on the input_queue. Any message received signals this + # thread to shut itself down. + while cmd_dict["input_queue"].empty(): + # Kill some time so that this thread does not hog the CPU. + time.sleep(1.0) + # Send noise down the pipe to keep connection active + transport.send_ignore() + if communicate(): + break + + # Ctrl-C the executing command and wait a bit for command to end cleanly + start = time.time() + while time.time() < start + 5.0: + channel.send(b"\x03") # Ctrl-C + if communicate(): + break + time.sleep(1.0) + + # Shutdown the channel, and close the SSH connection + channel.close() + ssh.close() + + +def start_scheduler( + logdir, addr, port, ssh_username, ssh_port, ssh_private_key, remote_python=None +): + cmd = "{python} -m distributed.cli.dask_scheduler --port {port}".format( + python=remote_python or sys.executable, port=port, logdir=logdir + ) + + # Optionally re-direct stdout and stderr to a logfile + if logdir is not None: + cmd = "mkdir -p {logdir} && ".format(logdir=logdir) + cmd + cmd += "&> {logdir}/dask_scheduler_{addr}:{port}.log".format( + addr=addr, port=port, logdir=logdir + ) + + # Format output labels we can prepend to each line of output, and create + # a 'status' key to keep track of jobs that terminate prematurely. + label = ( + bcolors.BOLD + + "scheduler {addr}:{port}".format(addr=addr, port=port) + + bcolors.ENDC + ) + + # Create a command dictionary, which contains everything we need to run and + # interact with this command. + input_queue = Queue() + output_queue = Queue() + cmd_dict = { + "cmd": cmd, + "label": label, + "address": addr, + "port": port, + "input_queue": input_queue, + "output_queue": output_queue, + "ssh_username": ssh_username, + "ssh_port": ssh_port, + "ssh_private_key": ssh_private_key, + } + + # Start the thread + thread = Thread(target=async_ssh, args=[cmd_dict]) + thread.daemon = True + thread.start() + + return merge(cmd_dict, {"thread": thread}) + + +def start_worker( + logdir, + scheduler_addr, + scheduler_port, + worker_addr, + nthreads, + nprocs, + ssh_username, + ssh_port, + ssh_private_key, + nohost, + memory_limit, + worker_port, + nanny_port, + remote_python=None, + remote_dask_worker="distributed.cli.dask_worker", +): + + cmd = ( + "{python} -m {remote_dask_worker} " + "{scheduler_addr}:{scheduler_port} " + "--nthreads {nthreads}" + (" --nprocs {nprocs}" if nprocs != 1 else "") + ) + + if not nohost: + cmd += " --host {worker_addr}" + + if memory_limit: + cmd += " --memory-limit {memory_limit}" + + if worker_port: + cmd += " --worker-port {worker_port}" + + if nanny_port: + cmd += " --nanny-port {nanny_port}" + + cmd = cmd.format( + python=remote_python or sys.executable, + remote_dask_worker=remote_dask_worker, + scheduler_addr=scheduler_addr, + scheduler_port=scheduler_port, + worker_addr=worker_addr, + nthreads=nthreads, + nprocs=nprocs, + memory_limit=memory_limit, + worker_port=worker_port, + nanny_port=nanny_port, + ) + + # Optionally redirect stdout and stderr to a logfile + if logdir is not None: + cmd = "mkdir -p {logdir} && ".format(logdir=logdir) + cmd + cmd += "&> {logdir}/dask_scheduler_{addr}.log".format( + addr=worker_addr, logdir=logdir + ) + + label = "worker {addr}".format(addr=worker_addr) + + # Create a command dictionary, which contains everything we need to run and + # interact with this command. + input_queue = Queue() + output_queue = Queue() + cmd_dict = { + "cmd": cmd, + "label": label, + "address": worker_addr, + "input_queue": input_queue, + "output_queue": output_queue, + "ssh_username": ssh_username, + "ssh_port": ssh_port, + "ssh_private_key": ssh_private_key, + } + + # Start the thread + thread = Thread(target=async_ssh, args=[cmd_dict]) + thread.daemon = True + thread.start() + + return merge(cmd_dict, {"thread": thread}) + + +class SSHCluster(object): + def __init__( + self, + scheduler_addr, + scheduler_port, + worker_addrs, + nthreads=0, + nprocs=1, + ssh_username=None, + ssh_port=22, + ssh_private_key=None, + nohost=False, + logdir=None, + remote_python=None, + memory_limit=None, + worker_port=None, + nanny_port=None, + remote_dask_worker="distributed.cli.dask_worker", + ): + + self.scheduler_addr = scheduler_addr + self.scheduler_port = scheduler_port + self.nthreads = nthreads + self.nprocs = nprocs + + self.ssh_username = ssh_username + self.ssh_port = ssh_port + self.ssh_private_key = ssh_private_key + + self.nohost = nohost + + self.remote_python = remote_python + + self.memory_limit = memory_limit + self.worker_port = worker_port + self.nanny_port = nanny_port + self.remote_dask_worker = remote_dask_worker + + # Generate a universal timestamp to use for log files + import datetime + + if logdir is not None: + logdir = os.path.join( + logdir, + "dask-ssh_" + datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S"), + ) + print( + bcolors.WARNING + "Output will be redirected to logfiles " + 'stored locally on individual worker nodes under "{logdir}".'.format( + logdir=logdir + ) + + bcolors.ENDC + ) + self.logdir = logdir + + # Keep track of all running threads + self.threads = [] + + # Start the scheduler node + self.scheduler = start_scheduler( + logdir, + scheduler_addr, + scheduler_port, + ssh_username, + ssh_port, + ssh_private_key, + remote_python, + ) + + # Start worker nodes + self.workers = [] + for i, addr in enumerate(worker_addrs): + self.add_worker(addr) + + @gen.coroutine + def _start(self): + pass + + @property + def scheduler_address(self): + return "%s:%d" % (self.scheduler_addr, self.scheduler_port) + + def monitor_remote_processes(self): + + # Form a list containing all processes, since we treat them equally from here on out. + all_processes = [self.scheduler] + self.workers + + try: + while True: + for process in all_processes: + while not process["output_queue"].empty(): + print(process["output_queue"].get()) + + # Kill some time and free up CPU before starting the next sweep + # through the processes. + time.sleep(0.1) + + # end while true + + except KeyboardInterrupt: + pass # Return execution to the calling process + + def add_worker(self, address): + self.workers.append( + start_worker( + self.logdir, + self.scheduler_addr, + self.scheduler_port, + address, + self.nthreads, + self.nprocs, + self.ssh_username, + self.ssh_port, + self.ssh_private_key, + self.nohost, + self.memory_limit, + self.worker_port, + self.nanny_port, + self.remote_python, + self.remote_dask_worker, + ) + ) + + def shutdown(self): + all_processes = [self.scheduler] + self.workers + + for process in all_processes: + process["input_queue"].put("shutdown") + process["thread"].join() + + def __enter__(self): + return self + + def __exit__(self, *args): + self.shutdown() diff --git a/distributed/deploy/ssh.py b/distributed/deploy/ssh.py index 30f6f819224..8aa3cc17d97 100644 --- a/distributed/deploy/ssh.py +++ b/distributed/deploy/ssh.py @@ -1,472 +1,286 @@ import logging -import socket -import os import sys -import time -import traceback - -try: - from queue import Queue -except ImportError: # Python 2.7 fix - from Queue import Queue - -from threading import Thread - -from toolz import merge - -from tornado import gen +from typing import List +import warnings +import weakref +from .spec import SpecCluster, ProcessInterface +from ..utils import cli_keywords +from ..scheduler import Scheduler as _Scheduler +from ..worker import Worker as _Worker logger = logging.getLogger(__name__) -# These are handy for creating colorful terminal output to enhance readability -# of the output generated by dask-ssh. -class bcolors: - HEADER = "\033[95m" - OKBLUE = "\033[94m" - OKGREEN = "\033[92m" - WARNING = "\033[93m" - FAIL = "\033[91m" - ENDC = "\033[0m" - BOLD = "\033[1m" - UNDERLINE = "\033[4m" - - -def async_ssh(cmd_dict): - import paramiko - from paramiko.buffered_pipe import PipeTimeout - from paramiko.ssh_exception import SSHException, PasswordRequiredException - - ssh = paramiko.SSHClient() - ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - - retries = 0 - while True: # Be robust to transient SSH failures. - try: - # Set paramiko logging to WARN or higher to squelch INFO messages. - logging.getLogger("paramiko").setLevel(logging.WARN) - - ssh.connect( - hostname=cmd_dict["address"], - username=cmd_dict["ssh_username"], - port=cmd_dict["ssh_port"], - key_filename=cmd_dict["ssh_private_key"], - compress=True, - timeout=20, - banner_timeout=20, - ) # Helps prevent timeouts when many concurrent ssh connections are opened. - # Connection successful, break out of while loop - break - - except (SSHException, PasswordRequiredException) as e: - - print( - "[ dask-ssh ] : " - + bcolors.FAIL - + "SSH connection error when connecting to {addr}:{port}" - "to run '{cmd}'".format( - addr=cmd_dict["address"], - port=cmd_dict["ssh_port"], - cmd=cmd_dict["cmd"], - ) - + bcolors.ENDC - ) - - print( - bcolors.FAIL - + " SSH reported this exception: " - + str(e) - + bcolors.ENDC - ) - - # Print an exception traceback - traceback.print_exc() - - # Transient SSH errors can occur when many SSH connections are - # simultaneously opened to the same server. This makes a few - # attempts to retry. - retries += 1 - if retries >= 3: - print( - "[ dask-ssh ] : " - + bcolors.FAIL - + "SSH connection failed after 3 retries. Exiting." - + bcolors.ENDC - ) - - # Connection failed after multiple attempts. Terminate this thread. - os._exit(1) - - # Wait a moment before retrying - print( - " " - + bcolors.FAIL - + "Retrying... (attempt {n}/{total})".format(n=retries, total=3) - + bcolors.ENDC - ) - - time.sleep(1) - - # Execute the command, and grab file handles for stdout and stderr. Note - # that we run the command using the user's default shell, but force it to - # run in an interactive login shell, which hopefully ensures that all of the - # user's normal environment variables (via the dot files) have been loaded - # before the command is run. This should help to ensure that important - # aspects of the environment like PATH and PYTHONPATH are configured. - - print("[ {label} ] : {cmd}".format(label=cmd_dict["label"], cmd=cmd_dict["cmd"])) - stdin, stdout, stderr = ssh.exec_command( - "$SHELL -i -c '" + cmd_dict["cmd"] + "'", get_pty=True - ) - - # Set up channel timeout (which we rely on below to make readline() non-blocking) - channel = stdout.channel - channel.settimeout(0.1) - - def read_from_stdout(): - """ - Read stdout stream, time out if necessary. - """ - try: - line = stdout.readline() - while len(line) > 0: # Loops until a timeout exception occurs - line = line.rstrip() - logger.debug("stdout from ssh channel: %s", line) - cmd_dict["output_queue"].put( - "[ {label} ] : {output}".format( - label=cmd_dict["label"], output=line - ) - ) - line = stdout.readline() - except (PipeTimeout, socket.timeout): - pass - - def read_from_stderr(): - """ - Read stderr stream, time out if necessary. - """ - try: - line = stderr.readline() - while len(line) > 0: - line = line.rstrip() - logger.debug("stderr from ssh channel: %s", line) - cmd_dict["output_queue"].put( - "[ {label} ] : ".format(label=cmd_dict["label"]) - + bcolors.FAIL - + "{output}".format(output=line) - + bcolors.ENDC - ) - line = stderr.readline() - except (PipeTimeout, socket.timeout): - pass - - def communicate(): - """ - Communicate a little bit, without blocking too long. - Return True if the command ended. - """ - read_from_stdout() - read_from_stderr() - - # Check to see if the process has exited. If it has, we let this thread - # terminate. - if channel.exit_status_ready(): - exit_status = channel.recv_exit_status() - cmd_dict["output_queue"].put( - "[ {label} ] : ".format(label=cmd_dict["label"]) - + bcolors.FAIL - + "remote process exited with exit status " - + str(exit_status) - + bcolors.ENDC - ) - return True - - # Get transport to current SSH client - transport = ssh.get_transport() - - # Wait for a message on the input_queue. Any message received signals this - # thread to shut itself down. - while cmd_dict["input_queue"].empty(): - # Kill some time so that this thread does not hog the CPU. - time.sleep(1.0) - # Send noise down the pipe to keep connection active - transport.send_ignore() - if communicate(): - break - - # Ctrl-C the executing command and wait a bit for command to end cleanly - start = time.time() - while time.time() < start + 5.0: - channel.send(b"\x03") # Ctrl-C - if communicate(): - break - time.sleep(1.0) - - # Shutdown the channel, and close the SSH connection - channel.close() - ssh.close() - - -def start_scheduler( - logdir, addr, port, ssh_username, ssh_port, ssh_private_key, remote_python=None -): - cmd = "{python} -m distributed.cli.dask_scheduler --port {port}".format( - python=remote_python or sys.executable, port=port, logdir=logdir - ) - - # Optionally re-direct stdout and stderr to a logfile - if logdir is not None: - cmd = "mkdir -p {logdir} && ".format(logdir=logdir) + cmd - cmd += "&> {logdir}/dask_scheduler_{addr}:{port}.log".format( - addr=addr, port=port, logdir=logdir - ) - - # Format output labels we can prepend to each line of output, and create - # a 'status' key to keep track of jobs that terminate prematurely. - label = ( - bcolors.BOLD - + "scheduler {addr}:{port}".format(addr=addr, port=port) - + bcolors.ENDC - ) - - # Create a command dictionary, which contains everything we need to run and - # interact with this command. - input_queue = Queue() - output_queue = Queue() - cmd_dict = { - "cmd": cmd, - "label": label, - "address": addr, - "port": port, - "input_queue": input_queue, - "output_queue": output_queue, - "ssh_username": ssh_username, - "ssh_port": ssh_port, - "ssh_private_key": ssh_private_key, - } - - # Start the thread - thread = Thread(target=async_ssh, args=[cmd_dict]) - thread.daemon = True - thread.start() - - return merge(cmd_dict, {"thread": thread}) - - -def start_worker( - logdir, - scheduler_addr, - scheduler_port, - worker_addr, - nthreads, - nprocs, - ssh_username, - ssh_port, - ssh_private_key, - nohost, - memory_limit, - worker_port, - nanny_port, - remote_python=None, - remote_dask_worker="distributed.cli.dask_worker", -): +class Process(ProcessInterface): + """ A superclass for SSH Workers and Nannies + + See Also + -------- + Worker + Scheduler + """ + + def __init__(self, **kwargs): + self.connection = None + self.proc = None + super().__init__(**kwargs) + + async def start(self): + assert self.connection + weakref.finalize( + self, self.proc.kill + ) # https://github.com/ronf/asyncssh/issues/112 + await super().start() + + async def close(self): + self.proc.kill() # https://github.com/ronf/asyncssh/issues/112 + self.connection.close() + await super().close() + + def __repr__(self): + return "" % (type(self).__name__, self.status) + + +class Worker(Process): + """ A Remote Dask Worker controled by SSH + + Parameters + ---------- + scheduler: str + The address of the scheduler + address: str + The hostname where we should run this worker + worker_module: str + The python module to run to start the worker. + connect_options: dict + kwargs to be passed to asyncssh connections + kwargs: dict + These will be passed through the dask-worker CLI to the + dask.distributed.Worker class + """ - cmd = ( - "{python} -m {remote_dask_worker} " - "{scheduler_addr}:{scheduler_port} " - "--nthreads {nthreads}" + (" --nprocs {nprocs}" if nprocs != 1 else "") - ) - - if not nohost: - cmd += " --host {worker_addr}" - - if memory_limit: - cmd += " --memory-limit {memory_limit}" - - if worker_port: - cmd += " --worker-port {worker_port}" - - if nanny_port: - cmd += " --nanny-port {nanny_port}" - - cmd = cmd.format( - python=remote_python or sys.executable, - remote_dask_worker=remote_dask_worker, - scheduler_addr=scheduler_addr, - scheduler_port=scheduler_port, - worker_addr=worker_addr, - nthreads=nthreads, - nprocs=nprocs, - memory_limit=memory_limit, - worker_port=worker_port, - nanny_port=nanny_port, - ) - - # Optionally redirect stdout and stderr to a logfile - if logdir is not None: - cmd = "mkdir -p {logdir} && ".format(logdir=logdir) + cmd - cmd += "&> {logdir}/dask_scheduler_{addr}.log".format( - addr=worker_addr, logdir=logdir - ) - - label = "worker {addr}".format(addr=worker_addr) - - # Create a command dictionary, which contains everything we need to run and - # interact with this command. - input_queue = Queue() - output_queue = Queue() - cmd_dict = { - "cmd": cmd, - "label": label, - "address": worker_addr, - "input_queue": input_queue, - "output_queue": output_queue, - "ssh_username": ssh_username, - "ssh_port": ssh_port, - "ssh_private_key": ssh_private_key, - } - - # Start the thread - thread = Thread(target=async_ssh, args=[cmd_dict]) - thread.daemon = True - thread.start() - - return merge(cmd_dict, {"thread": thread}) - - -class SSHCluster(object): def __init__( self, - scheduler_addr, - scheduler_port, - worker_addrs, - nthreads=0, - nprocs=1, - ssh_username=None, - ssh_port=22, - ssh_private_key=None, - nohost=False, - logdir=None, - remote_python=None, - memory_limit=None, - worker_port=None, - nanny_port=None, - remote_dask_worker="distributed.cli.dask_worker", + scheduler: str, + address: str, + connect_options: dict, + kwargs: dict, + worker_module="distributed.cli.dask_worker", + loop=None, + name=None, ): - - self.scheduler_addr = scheduler_addr - self.scheduler_port = scheduler_port - self.nthreads = nthreads - self.nprocs = nprocs - - self.ssh_username = ssh_username - self.ssh_port = ssh_port - self.ssh_private_key = ssh_private_key - - self.nohost = nohost - - self.remote_python = remote_python - - self.memory_limit = memory_limit - self.worker_port = worker_port - self.nanny_port = nanny_port - self.remote_dask_worker = remote_dask_worker - - # Generate a universal timestamp to use for log files - import datetime - - if logdir is not None: - logdir = os.path.join( - logdir, - "dask-ssh_" + datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S"), - ) - print( - bcolors.WARNING + "Output will be redirected to logfiles " - 'stored locally on individual worker nodes under "{logdir}".'.format( - logdir=logdir - ) - + bcolors.ENDC + self.address = address + self.scheduler = scheduler + self.worker_module = worker_module + self.connect_options = connect_options + self.kwargs = kwargs + self.name = name + + super().__init__() + + async def start(self): + import asyncssh # import now to avoid adding to module startup time + + self.connection = await asyncssh.connect(self.address, **self.connect_options) + self.proc = await self.connection.create_process( + " ".join( + [ + sys.executable, + "-m", + self.worker_module, + self.scheduler, + "--name", + str(self.name), + ] + + cli_keywords(self.kwargs, cls=_Worker) ) - self.logdir = logdir - - # Keep track of all running threads - self.threads = [] - - # Start the scheduler node - self.scheduler = start_scheduler( - logdir, - scheduler_addr, - scheduler_port, - ssh_username, - ssh_port, - ssh_private_key, - remote_python, ) - # Start worker nodes - self.workers = [] - for i, addr in enumerate(worker_addrs): - self.add_worker(addr) - - @gen.coroutine - def _start(self): - pass - - @property - def scheduler_address(self): - return "%s:%d" % (self.scheduler_addr, self.scheduler_port) - - def monitor_remote_processes(self): - - # Form a list containing all processes, since we treat them equally from here on out. - all_processes = [self.scheduler] + self.workers - - try: - while True: - for process in all_processes: - while not process["output_queue"].empty(): - print(process["output_queue"].get()) - - # Kill some time and free up CPU before starting the next sweep - # through the processes. - time.sleep(0.1) - - # end while true - - except KeyboardInterrupt: - pass # Return execution to the calling process - - def add_worker(self, address): - self.workers.append( - start_worker( - self.logdir, - self.scheduler_addr, - self.scheduler_port, - address, - self.nthreads, - self.nprocs, - self.ssh_username, - self.ssh_port, - self.ssh_private_key, - self.nohost, - self.memory_limit, - self.worker_port, - self.nanny_port, - self.remote_python, - self.remote_dask_worker, + # We watch stderr in order to get the address, then we return + while True: + line = await self.proc.stderr.readline() + if not line.strip(): + raise Exception("Worker failed to start") + logger.info(line.strip()) + if "worker at" in line: + self.address = line.split("worker at:")[1].strip() + self.status = "running" + break + logger.debug("%s", line) + await super().start() + + +class Scheduler(Process): + """ A Remote Dask Scheduler controled by SSH + + Parameters + ---------- + address: str + The hostname where we should run this worker + connect_options: dict + kwargs to be passed to asyncssh connections + kwargs: dict + These will be passed through the dask-scheduler CLI to the + dask.distributed.Scheduler class + """ + + def __init__(self, address: str, connect_options: dict, kwargs: dict): + self.address = address + self.kwargs = kwargs + self.connect_options = connect_options + + super().__init__() + + async def start(self): + import asyncssh # import now to avoid adding to module startup time + + logger.debug("Created Scheduler Connection") + + self.connection = await asyncssh.connect(self.address, **self.connect_options) + + self.proc = await self.connection.create_process( + " ".join( + [sys.executable, "-m", "distributed.cli.dask_scheduler"] + + cli_keywords(self.kwargs, cls=_Scheduler) ) ) - def shutdown(self): - all_processes = [self.scheduler] + self.workers - - for process in all_processes: - process["input_queue"].put("shutdown") - process["thread"].join() - - def __enter__(self): - return self - - def __exit__(self, *args): - self.shutdown() + # We watch stderr in order to get the address, then we return + while True: + line = await self.proc.stderr.readline() + if not line.strip(): + raise Exception("Worker failed to start") + logger.info(line.strip()) + if "Scheduler at" in line: + self.address = line.split("Scheduler at:")[1].strip() + break + logger.debug("%s", line) + await super().start() + + +old_cluster_kwargs = { + "scheduler_addr", + "scheduler_port", + "worker_addrs", + "nthreads", + "nprocs", + "ssh_username", + "ssh_port", + "ssh_private_key", + "nohost", + "logdir", + "remote_python", + "memory_limit", + "worker_port", + "nanny_port", + "remote_dask_worker", +} + + +def SSHCluster( + hosts: List[str] = None, + connect_options: dict = {}, + worker_options: dict = {}, + scheduler_options: dict = {}, + worker_module: str = "distributed.cli.dask_worker", + **kwargs +): + """ Deploy a Dask cluster using SSH + + The SSHCluster function deploys a Dask Scheduler and Workers for you on a + set of machine addresses that you provide. The first address will be used + for the scheduler while the rest will be used for the workers (feel free to + repeat the first hostname if you want to have the scheudler and worker + co-habitate one machine.) + + You may configure the scheduler and workers by passing + ``scheduler_options`` and ``worker_options`` dictionary keywords. See the + ``dask.distributed.Scheduler`` and ``dask.distributed.Worker`` classes for + details on the available options, but the defaults should work in most + situations. + + You may configure your use of SSH itself using the ``connect_options`` + keyword, which passes values to the ``asyncssh.connect`` function. For + more information on these see the documentation for the ``asyncssh`` + library https://asyncssh.readthedocs.io . + + Parameters + ---------- + hosts: List[str] + List of hostnames or addresses on which to launch our cluster + The first will be used for the scheduler and the rest for workers + connect_options: + Keywords to pass through to asyncssh.connect + known_hosts: List[str] or None + The list of keys which will be used to validate the server host + key presented during the SSH handshake. If this is not specified, + the keys will be looked up in the file .ssh/known_hosts. If this + is explicitly set to None, server host key validation will be disabled. + worker_options: + Keywords to pass on to dask-worker + scheduler_options: + Keywords to pass on to dask-scheduler + worker_module: + Python module to call to start the worker + + Examples + -------- + >>> from dask.distributed import Client, SSHCluster + >>> cluster = SSHCluster( + ... ["localhost", "localhost", "localhost", "localhost"], + ... connect_options={"known_hosts": None}, + ... worker_options={"nthreads": 2}, + ... scheduler_options={"port": 0, "dashboard_address": ":8797"} + ... ) + >>> client = Client(cluster) + + An example using a different worker module, in particular the + ``dask-cuda-worker`` command from the ``dask-cuda`` project. + + >>> from dask.distributed import Client, SSHCluster + >>> cluster = SSHCluster( + ... ["localhost", "hostwithgpus", "anothergpuhost"], + ... connect_options={"known_hosts": None}, + ... scheduler_options={"port": 0, "dashboard_address": ":8797"}, + ... worker_module='dask_cuda.dask_cuda_worker') + >>> client = Client(cluster) + + See Also + -------- + dask.distributed.Scheduler + dask.distributed.Worker + asyncssh.connect + """ + if set(kwargs) & old_cluster_kwargs: + from .old_ssh import SSHCluster as OldSSHCluster + + warnings.warn( + "Note that the SSHCluster API has been replaced. " + "We're routing you to the older implementation. " + "This will be removed in the future" + ) + kwargs.setdefault("worker_addrs", hosts) + return OldSSHCluster(**kwargs) + + scheduler = { + "cls": Scheduler, + "options": { + "address": hosts[0], + "connect_options": connect_options, + "kwargs": scheduler_options, + }, + } + workers = { + i: { + "cls": Worker, + "options": { + "address": host, + "connect_options": connect_options, + "kwargs": worker_options, + "worker_module": worker_module, + }, + } + for i, host in enumerate(hosts[1:]) + } + return SpecCluster(workers, scheduler, name="SSHCluster", **kwargs) diff --git a/distributed/deploy/ssh2.py b/distributed/deploy/ssh2.py deleted file mode 100644 index cb6b967d544..00000000000 --- a/distributed/deploy/ssh2.py +++ /dev/null @@ -1,236 +0,0 @@ -import logging -import sys -import warnings -import weakref - -import asyncssh - -from .spec import SpecCluster, ProcessInterface -from ..utils import cli_keywords -from ..scheduler import Scheduler as _Scheduler -from ..worker import Worker as _Worker - -logger = logging.getLogger(__name__) - -warnings.warn( - "the distributed.deploy.ssh2 module is experimental " - "and will move/change in the future without notice" -) - - -class Process(ProcessInterface): - """ A superclass for SSH Workers and Nannies - - See Also - -------- - Worker - Scheduler - """ - - def __init__(self, **kwargs): - self.connection = None - self.proc = None - super().__init__(**kwargs) - - async def start(self): - assert self.connection - weakref.finalize( - self, self.proc.kill - ) # https://github.com/ronf/asyncssh/issues/112 - await super().start() - - async def close(self): - self.proc.kill() # https://github.com/ronf/asyncssh/issues/112 - self.connection.close() - await super().close() - - def __repr__(self): - return "" % (type(self).__name__, self.status) - - -class Worker(Process): - """ A Remote Dask Worker controled by SSH - - Parameters - ---------- - scheduler: str - The address of the scheduler - address: str - The hostname where we should run this worker - worker_module: str - The python module to run to start the worker. - connect_kwargs: dict - kwargs to be passed to asyncssh connections - kwargs: dict - These will be passed through the dask-worker CLI to the - dask.distributed.Worker class - """ - - def __init__( - self, - scheduler: str, - address: str, - connect_kwargs: dict, - kwargs: dict, - worker_module="distributed.cli.dask_worker", - loop=None, - name=None, - ): - self.address = address - self.scheduler = scheduler - self.worker_module = worker_module - self.connect_kwargs = connect_kwargs - self.kwargs = kwargs - self.name = name - - super().__init__() - - async def start(self): - self.connection = await asyncssh.connect(self.address, **self.connect_kwargs) - self.proc = await self.connection.create_process( - " ".join( - [ - sys.executable, - "-m", - self.worker_module, - self.scheduler, - "--name", - str(self.name), - ] - + cli_keywords(self.kwargs, cls=_Worker) - ) - ) - - # We watch stderr in order to get the address, then we return - while True: - line = await self.proc.stderr.readline() - if not line.strip(): - raise Exception("Worker failed to start") - logger.info(line.strip()) - if "worker at" in line: - self.address = line.split("worker at:")[1].strip() - self.status = "running" - break - logger.debug("%s", line) - await super().start() - - -class Scheduler(Process): - """ A Remote Dask Scheduler controled by SSH - - Parameters - ---------- - address: str - The hostname where we should run this worker - connect_kwargs: dict - kwargs to be passed to asyncssh connections - kwargs: dict - These will be passed through the dask-scheduler CLI to the - dask.distributed.Scheduler class - """ - - def __init__(self, address: str, connect_kwargs: dict, kwargs: dict): - self.address = address - self.kwargs = kwargs - self.connect_kwargs = connect_kwargs - - super().__init__() - - async def start(self): - logger.debug("Created Scheduler Connection") - - self.connection = await asyncssh.connect(self.address, **self.connect_kwargs) - - self.proc = await self.connection.create_process( - " ".join( - [sys.executable, "-m", "distributed.cli.dask_scheduler"] - + cli_keywords(self.kwargs, cls=_Scheduler) - ) - ) - - # We watch stderr in order to get the address, then we return - while True: - line = await self.proc.stderr.readline() - if not line.strip(): - raise Exception("Worker failed to start") - logger.info(line.strip()) - if "Scheduler at" in line: - self.address = line.split("Scheduler at:")[1].strip() - break - logger.debug("%s", line) - await super().start() - - -def SSHCluster( - hosts, - connect_kwargs={}, - worker_kwargs={}, - scheduler_kwargs={}, - worker_module="distributed.cli.dask_worker", - **kwargs -): - """ Deploy a Dask cluster using SSH - - Parameters - ---------- - hosts: List[str] - List of hostnames or addresses on which to launch our cluster - The first will be used for the scheduler and the rest for workers - connect_kwargs: - Keywords to pass through to asyncssh.connect - known_hosts: List[str] or None - The list of keys which will be used to validate the server host - key presented during the SSH handshake. If this is not specified, - the keys will be looked up in the file .ssh/known_hosts. If this - is explicitly set to None, server host key validation will be disabled. - worker_kwargs: - Keywords to pass on to dask-worker - scheduler_kwargs: - Keywords to pass on to dask-scheduler - worker_module: - Python module to call to start the worker - - Examples - -------- - >>> from dask.distributed import Client - >>> from distributed.deploy.ssh2 import SSHCluster # experimental for now - >>> cluster = SSHCluster( - ... ["localhost"] * 4, - ... connect_kwargs={"known_hosts": None}, - ... worker_kwargs={"nthreads": 2}, - ... scheduler_kwargs={"port": 0, "dashboard_address": ":8797"}) - >>> client = Client(cluster) - - Running GPU workers (requires ``dask_cuda`` to be installed on all hosts) - - >>> from dask.distributed import Client - >>> from distributed.deploy.ssh2 import SSHCluster # experimental for now - >>> cluster = SSHCluster( - ... ["localhost", "hostwithgpus", "anothergpuhost"], - ... connect_kwargs={"known_hosts": None}, - ... scheduler_kwargs={"port": 0, "dashboard_address": ":8797"}, - ... worker_module='dask_cuda.dask_cuda_worker') - >>> client = Client(cluster) - - """ - scheduler = { - "cls": Scheduler, - "options": { - "address": hosts[0], - "connect_kwargs": connect_kwargs, - "kwargs": scheduler_kwargs, - }, - } - workers = { - i: { - "cls": Worker, - "options": { - "address": host, - "connect_kwargs": connect_kwargs, - "kwargs": worker_kwargs, - "worker_module": worker_module, - }, - } - for i, host in enumerate(hosts[1:]) - } - return SpecCluster(workers, scheduler, name="SSHCluster", **kwargs) diff --git a/distributed/deploy/tests/test_old_ssh.py b/distributed/deploy/tests/test_old_ssh.py new file mode 100644 index 00000000000..e6960b3392d --- /dev/null +++ b/distributed/deploy/tests/test_old_ssh.py @@ -0,0 +1,31 @@ +from time import sleep + +import pytest + +pytest.importorskip("paramiko") + +from distributed import Client +from distributed.deploy.old_ssh import SSHCluster +from distributed.metrics import time +from distributed.utils_test import loop # noqa: F401 + + +@pytest.mark.avoid_travis +def test_cluster(loop): + with SSHCluster( + scheduler_addr="127.0.0.1", + scheduler_port=7437, + worker_addrs=["127.0.0.1", "127.0.0.1"], + ) as c: + with Client(c, loop=loop) as e: + start = time() + while len(e.ncores()) != 2: + sleep(0.01) + assert time() < start + 5 + + c.add_worker("127.0.0.1") + + start = time() + while len(e.ncores()) != 3: + sleep(0.01) + assert time() < start + 5 diff --git a/distributed/deploy/tests/test_ssh.py b/distributed/deploy/tests/test_ssh.py index 492ee2c792d..3124af4f177 100644 --- a/distributed/deploy/tests/test_ssh.py +++ b/distributed/deploy/tests/test_ssh.py @@ -1,31 +1,57 @@ -from time import sleep - import pytest -pytest.importorskip("paramiko") +pytest.importorskip("asyncssh") -from distributed import Client +from dask.distributed import Client from distributed.deploy.ssh import SSHCluster -from distributed.metrics import time from distributed.utils_test import loop # noqa: F401 +@pytest.mark.asyncio +async def test_basic(): + async with SSHCluster( + ["127.0.0.1"] * 3, + connect_options=dict(known_hosts=None), + asynchronous=True, + scheduler_options={"port": 0, "idle_timeout": "5s"}, + worker_options={"death_timeout": "5s"}, + ) as cluster: + assert len(cluster.workers) == 2 + async with Client(cluster, asynchronous=True) as client: + result = await client.submit(lambda x: x + 1, 10) + assert result == 11 + assert not cluster._supports_scaling + + assert "SSH" in repr(cluster) + + +@pytest.mark.asyncio +async def test_keywords(): + async with SSHCluster( + ["127.0.0.1"] * 3, + connect_options=dict(known_hosts=None), + asynchronous=True, + worker_options={"nthreads": 2, "memory_limit": "2 GiB", "death_timeout": "5s"}, + scheduler_options={"idle_timeout": "5s", "port": 0}, + ) as cluster: + async with Client(cluster, asynchronous=True) as client: + assert ( + await client.run_on_scheduler( + lambda dask_scheduler: dask_scheduler.idle_timeout + ) + ) == 5 + d = client.scheduler_info()["workers"] + assert all(v["nthreads"] == 2 for v in d.values()) + + @pytest.mark.avoid_travis -def test_cluster(loop): - with SSHCluster( - scheduler_addr="127.0.0.1", - scheduler_port=7437, - worker_addrs=["127.0.0.1", "127.0.0.1"], - ) as c: - with Client(c, loop=loop) as e: - start = time() - while len(e.ncores()) != 2: - sleep(0.01) - assert time() < start + 5 - - c.add_worker("127.0.0.1") - - start = time() - while len(e.ncores()) != 3: - sleep(0.01) - assert time() < start + 5 +def test_defer_to_old(loop): + with pytest.warns(Warning): + with SSHCluster( + scheduler_addr="127.0.0.1", + scheduler_port=7437, + worker_addrs=["127.0.0.1", "127.0.0.1"], + ) as c: + from distributed.deploy.old_ssh import SSHCluster as OldSSHCluster + + assert isinstance(c, OldSSHCluster) diff --git a/distributed/deploy/tests/test_ssh2.py b/distributed/deploy/tests/test_ssh2.py deleted file mode 100644 index 076711bb841..00000000000 --- a/distributed/deploy/tests/test_ssh2.py +++ /dev/null @@ -1,43 +0,0 @@ -import pytest - -pytest.importorskip("asyncssh") - -from dask.distributed import Client -from distributed.deploy.ssh2 import SSHCluster - - -@pytest.mark.asyncio -async def test_basic(): - async with SSHCluster( - ["127.0.0.1"] * 3, - connect_kwargs=dict(known_hosts=None), - asynchronous=True, - scheduler_kwargs={"port": 0, "idle_timeout": "5s"}, - worker_kwargs={"death_timeout": "5s"}, - ) as cluster: - assert len(cluster.workers) == 2 - async with Client(cluster, asynchronous=True) as client: - result = await client.submit(lambda x: x + 1, 10) - assert result == 11 - assert not cluster._supports_scaling - - assert "SSH" in repr(cluster) - - -@pytest.mark.asyncio -async def test_keywords(): - async with SSHCluster( - ["127.0.0.1"] * 3, - connect_kwargs=dict(known_hosts=None), - asynchronous=True, - worker_kwargs={"nthreads": 2, "memory_limit": "2 GiB", "death_timeout": "5s"}, - scheduler_kwargs={"idle_timeout": "5s", "port": 0}, - ) as cluster: - async with Client(cluster, asynchronous=True) as client: - assert ( - await client.run_on_scheduler( - lambda dask_scheduler: dask_scheduler.idle_timeout - ) - ) == 5 - d = client.scheduler_info()["workers"] - assert all(v["nthreads"] == 2 for v in d.values()) From 856bf29c0c45cc4411e47575fb8f637dc0cf9b77 Mon Sep 17 00:00:00 2001 From: Jonathan De Troye Date: Wed, 9 Oct 2019 14:57:02 -0400 Subject: [PATCH 0498/1550] Raise exception if the user passes in unused keywords to Client (#3117) Fixes #3014 --- distributed/client.py | 5 +++++ distributed/deploy/tests/test_local.py | 21 +++++++++++++++++++++ distributed/tests/test_client.py | 6 ++---- 3 files changed, 28 insertions(+), 4 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index f918fffbf78..a14929ba326 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -643,6 +643,11 @@ def __init__( if address: logger.info("Config value `scheduler-address` found: %s", address) + if address is not None and kwargs: + raise ValueError( + "Unexpected keyword arguments: {}".format(str(sorted(kwargs))) + ) + if isinstance(address, (rpc, PooledRPCCall)): self.scheduler = address elif hasattr(address, "scheduler_address"): diff --git a/distributed/deploy/tests/test_local.py b/distributed/deploy/tests/test_local.py index 520c99eb268..7a340a9c6f8 100644 --- a/distributed/deploy/tests/test_local.py +++ b/distributed/deploy/tests/test_local.py @@ -219,6 +219,27 @@ def test_Client_kwargs(loop): assert c.cluster.status == "closed" +def test_Client_unused_kwargs_with_cluster(loop): + with LocalCluster() as cluster: + with pytest.raises(Exception) as argexcept: + c = Client(cluster, n_workers=2, dashboard_port=8000, silence_logs=None) + assert ( + str(argexcept.value) + == "Unexpected keyword arguments: ['dashboard_port', 'n_workers', 'silence_logs']" + ) + + +def test_Client_unused_kwargs_with_address(loop): + with pytest.raises(Exception) as argexcept: + c = Client( + "127.0.0.1:8786", n_workers=2, dashboard_port=8000, silence_logs=None + ) + assert ( + str(argexcept.value) + == "Unexpected keyword arguments: ['dashboard_port', 'n_workers', 'silence_logs']" + ) + + def test_Client_twice(loop): with Client(loop=loop, silence_logs=False, dashboard_address=None) as c: with Client(loop=loop, silence_logs=False, dashboard_address=None) as f: diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index d01088502f0..c4ebe92b4f1 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -5010,9 +5010,7 @@ def test_profile_keys(c, s, a, b): @gen_cluster() def test_client_with_name(s, a, b): with captured_logger("distributed.scheduler") as sio: - client = yield Client( - s.address, asynchronous=True, name="foo", silence_logs=False - ) + client = yield Client(s.address, asynchronous=True, name="foo") assert "foo" in client.id yield client.close() @@ -5356,7 +5354,7 @@ def test_de_serialization_none(s, a, b): @gen_cluster() def test_client_repr_closed(s, a, b): - c = yield Client(s.address, asynchronous=True, dashboard_address=None) + c = yield Client(s.address, asynchronous=True) yield c.close() c._repr_html_() From 935ec35eb7e5c84551ae50bd0bf267fb45a68c04 Mon Sep 17 00:00:00 2001 From: matthieubulte Date: Wed, 9 Oct 2019 22:10:59 +0200 Subject: [PATCH 0499/1550] Extend Worker plugin API with transition method (#2994) --- distributed/__init__.py | 1 + distributed/client.py | 27 ++++-- distributed/diagnostics/plugin.py | 67 ++++++++++++- ...est_plugin.py => test_scheduler_plugin.py} | 3 +- .../diagnostics/tests/test_worker_plugin.py | 93 +++++++++++++++++++ distributed/tests/test_worker_plugins.py | 67 ------------- distributed/worker.py | 30 ++++-- docs/source/plugins.rst | 6 ++ 8 files changed, 205 insertions(+), 89 deletions(-) rename distributed/diagnostics/tests/{test_plugin.py => test_scheduler_plugin.py} (94%) create mode 100644 distributed/diagnostics/tests/test_worker_plugin.py delete mode 100644 distributed/tests/test_worker_plugins.py diff --git a/distributed/__init__.py b/distributed/__init__.py index 07015ff44af..1eadee32307 100644 --- a/distributed/__init__.py +++ b/distributed/__init__.py @@ -4,6 +4,7 @@ from .core import connect, rpc from .deploy import LocalCluster, Adaptive, SpecCluster, SSHCluster from .diagnostics.progressbar import progress +from .diagnostics.plugin import WorkerPlugin, SchedulerPlugin from .client import ( Client, Executor, diff --git a/distributed/client.py b/distributed/client.py index a14929ba326..08e808acd15 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -66,6 +66,7 @@ from .sizeof import sizeof from .threadpoolexecutor import rejoin from .worker import dumps_task, get_client, get_worker, secede +from .diagnostics.plugin import WorkerPlugin from .utils import ( All, sync, @@ -3908,13 +3909,15 @@ def register_worker_plugin(self, plugin=None, name=None): """ Registers a lifecycle worker plugin for all current and future workers. - This registers a new object to handle setup and teardown for workers in - this cluster. The plugin will instantiate itself on all currently - connected workers. It will also be run on any worker that connects in - the future. + This registers a new object to handle setup, task state transitions and + teardown for workers in this cluster. The plugin will instantiate itself + on all currently connected workers. It will also be run on any worker + that connects in the future. - The plugin should be an object with ``setup`` and ``teardown`` methods. - It must be serializable with the pickle or cloudpickle modules. + The plugin may include methods ``setup``, ``teardown``, and + ``transition``. See the ``dask.distributed.WorkerPlugin`` class or the + examples below for the interface and docstrings. It must be + serializable with the pickle or cloudpickle modules. If the plugin has a ``name`` attribute, or if the ``name=`` keyword is used then that will control idempotency. A a plugin with that name has @@ -3925,7 +3928,7 @@ def register_worker_plugin(self, plugin=None, name=None): Parameters ---------- - plugin: object + plugin: WorkerPlugin The plugin object to pass to the workers name: str, optional A name for the plugin. @@ -3933,13 +3936,15 @@ def register_worker_plugin(self, plugin=None, name=None): Examples -------- - >>> class MyPlugin: + >>> class MyPlugin(WorkerPlugin): ... def __init__(self, *args, **kwargs): ... pass # the constructor is up to you ... def setup(self, worker: dask.distributed.Worker): ... pass ... def teardown(self, worker: dask.distributed.Worker): ... pass + ... def transition(self, key: str, start: str, finish: str, **kwargs): + ... pass >>> plugin = MyPlugin(1, 2, 3) >>> client.register_worker_plugin(plugin) @@ -3953,11 +3958,15 @@ def register_worker_plugin(self, plugin=None, name=None): ... return plugin.my_state >>> future = client.run(f) + + See Also + -------- + distributed.WorkerPlugin """ return self.sync(self._register_worker_plugin, plugin=plugin, name=name) -class _WorkerSetupPlugin(object): +class _WorkerSetupPlugin(WorkerPlugin): """ This is used to support older setup functions as callbacks """ def __init__(self, setup): diff --git a/distributed/diagnostics/plugin.py b/distributed/diagnostics/plugin.py index cfe5fa42b49..8d56679e9a9 100644 --- a/distributed/diagnostics/plugin.py +++ b/distributed/diagnostics/plugin.py @@ -33,8 +33,8 @@ class SchedulerPlugin(object): ... def restart(self, scheduler): ... self.counter = 0 - >>> c = Counter() - >>> scheduler.add_plugin(c) # doctest: +SKIP + >>> plugin = Counter() + >>> scheduler.add_plugin(plugin) # doctest: +SKIP """ def update_graph(self, scheduler, dsk=None, keys=None, restrictions=None, **kwargs): @@ -63,3 +63,66 @@ def add_worker(self, scheduler=None, worker=None, **kwargs): def remove_worker(self, scheduler=None, worker=None, **kwargs): """ Run when a worker leaves the cluster""" + + +class WorkerPlugin(object): + """ Interface to extend the Worker + + A worker plugin enables custom code to run at different stages of the Workers' + lifecycle: at setup, during task state transitions and at teardown. + + A plugin enables custom code to run at each of step of a Workers's life. Whenever such + an event happens, the corresponding method on this class will be called. Note that the + user code always runs within the Worker's main thread. + + To implement a plugin implement some of the methods of this class and register + the plugin to your client in order to have it attached to every existing and + future workers with ``Client.register_worker_plugin``. + + Examples + -------- + >>> class ErrorLogger(WorkerPlugin): + ... def __init__(self, logger): + ... self.logger = logger + ... + ... def setup(self, worker): + ... self.worker = worker + ... + ... def transition(self, key, start, finish, *args, **kwargs): + ... if finish == 'error': + ... exc = self.worker.exceptions[key] + ... self.logger.error("Task '%s' has failed with exception: %s" % (key, str(exc))) + + >>> plugin = ErrorLogger() + >>> client.register_worker_plugin(plugin) # doctest: +SKIP + """ + + def setup(self, worker): + """ + Run when the plugin is attached to a worker. This happens when the plugin is registered + and attached to existing workers, or when a worker is created after the plugin has been + registered. + """ + + def teardown(self, worker): + """ Run when the worker to which the plugin is attached to is closed """ + + def transition(self, key, start, finish, **kwargs): + """ + Throughout the lifecycle of a task (see :doc:`Worker `), Workers are + instructed by the scheduler to compute certain tasks, resulting in transitions + in the state of each task. The Worker owning the task is then notified of this + state transition. + + Whenever a task changes its state, this method will be called. + + Parameters + ---------- + key: string + start: string + Start state of the transition. + One of waiting, ready, executing, long-running, memory, error. + finish: string + Final state of the transition. + kwargs: More options passed when transitioning + """ diff --git a/distributed/diagnostics/tests/test_plugin.py b/distributed/diagnostics/tests/test_scheduler_plugin.py similarity index 94% rename from distributed/diagnostics/tests/test_plugin.py rename to distributed/diagnostics/tests/test_scheduler_plugin.py index af29e81674d..2903214ba32 100644 --- a/distributed/diagnostics/tests/test_plugin.py +++ b/distributed/diagnostics/tests/test_scheduler_plugin.py @@ -1,6 +1,5 @@ -from distributed import Worker +from distributed import Worker, SchedulerPlugin from distributed.utils_test import inc, gen_cluster -from distributed.diagnostics.plugin import SchedulerPlugin @gen_cluster(client=True) diff --git a/distributed/diagnostics/tests/test_worker_plugin.py b/distributed/diagnostics/tests/test_worker_plugin.py new file mode 100644 index 00000000000..b3b919d7fe2 --- /dev/null +++ b/distributed/diagnostics/tests/test_worker_plugin.py @@ -0,0 +1,93 @@ +import pytest + +from distributed import Worker, WorkerPlugin +from distributed.utils_test import gen_cluster + + +class MyPlugin(WorkerPlugin): + name = "MyPlugin" + + def __init__(self, data, expected_transitions=None): + self.data = data + self.expected_transitions = expected_transitions + + def setup(self, worker): + assert isinstance(worker, Worker) + self.worker = worker + self.worker._my_plugin_status = "setup" + self.worker._my_plugin_data = self.data + + self.observed_transitions = [] + + def teardown(self, worker): + self.worker._my_plugin_status = "teardown" + + if self.expected_transitions is not None: + assert len(self.observed_transitions) == len(self.expected_transitions) + for expected, real in zip( + self.expected_transitions, self.observed_transitions + ): + assert expected == real + + def transition(self, key, start, finish, **kwargs): + self.observed_transitions.append((key, start, finish)) + + +@gen_cluster(client=True, nthreads=[]) +def test_create_with_client(c, s): + yield c.register_worker_plugin(MyPlugin(123)) + + worker = yield Worker(s.address, loop=s.loop) + assert worker._my_plugin_status == "setup" + assert worker._my_plugin_data == 123 + + yield worker.close() + assert worker._my_plugin_status == "teardown" + + +@gen_cluster(client=True, worker_kwargs={"plugins": [MyPlugin(5)]}) +def test_create_on_construction(c, s, a, b): + assert len(a.plugins) == len(b.plugins) == 1 + assert a._my_plugin_status == "setup" + assert a._my_plugin_data == 5 + + +@gen_cluster(nthreads=[("127.0.0.1", 1)], client=True) +def test_normal_task_transitions_called(c, s, w): + expected_transitions = [ + ("task", "waiting", "ready"), + ("task", "ready", "executing"), + ("task", "executing", "memory"), + ] + + plugin = MyPlugin(1, expected_transitions=expected_transitions) + + yield c.register_worker_plugin(plugin) + yield c.submit(lambda x: x, 1, key="task") + + +@gen_cluster(nthreads=[("127.0.0.1", 1)], client=True) +def test_failing_task_transitions_called(c, s, w): + def failing(x): + raise Exception() + + expected_transitions = [ + ("task", "waiting", "ready"), + ("task", "ready", "executing"), + ("task", "executing", "error"), + ] + + plugin = MyPlugin(1, expected_transitions=expected_transitions) + + yield c.register_worker_plugin(plugin) + + with pytest.raises(Exception): + yield c.submit(failing, 1, key="task") + + +@gen_cluster(nthreads=[("127.0.0.1", 1)], client=True) +async def test_empty_plugin(c, s, w): + class EmptyPlugin: + pass + + await c.register_worker_plugin(EmptyPlugin()) diff --git a/distributed/tests/test_worker_plugins.py b/distributed/tests/test_worker_plugins.py deleted file mode 100644 index 02db9419d4e..00000000000 --- a/distributed/tests/test_worker_plugins.py +++ /dev/null @@ -1,67 +0,0 @@ -from distributed.utils_test import gen_cluster -from distributed import Worker - - -class MyPlugin: - name = "MyPlugin" - - def __init__(self, data): - self.data = data - - def setup(self, worker): - assert isinstance(worker, Worker) - self.worker = worker - self.worker._my_plugin_status = "setup" - self.worker._my_plugin_data = self.data - - def teardown(self, worker): - assert isinstance(worker, Worker) - self.worker._my_plugin_status = "teardown" - - -@gen_cluster(client=True, nthreads=[]) -def test_create_with_client(c, s): - yield c.register_worker_plugin(MyPlugin(123)) - - worker = yield Worker(s.address, loop=s.loop) - assert worker._my_plugin_status == "setup" - assert worker._my_plugin_data == 123 - - yield worker.close() - assert worker._my_plugin_status == "teardown" - - -@gen_cluster(client=True, worker_kwargs={"plugins": [MyPlugin(5)]}) -def test_create_on_construction(c, s, a, b): - assert len(a.plugins) == len(b.plugins) == 1 - assert a._my_plugin_status == "setup" - assert a._my_plugin_data == 5 - - -@gen_cluster(client=True, worker_kwargs={"plugins": [MyPlugin(5)]}) -def test_idempotence_with_name(c, s, a, b): - a._my_plugin_data = 100 - - yield c.register_worker_plugin(MyPlugin(5)) - - assert a._my_plugin_data == 100 # call above has no effect - - -@gen_cluster(client=True, worker_kwargs={"plugins": [MyPlugin(5)]}) -def test_duplicate_with_no_name(c, s, a, b): - assert len(a.plugins) == len(b.plugins) == 1 - - plugin = MyPlugin(10) - plugin.name = "other-name" - - yield c.register_worker_plugin(plugin) - - assert len(a.plugins) == len(b.plugins) == 2 - - assert a._my_plugin_data == 10 - - yield c.register_worker_plugin(plugin) - assert len(a.plugins) == len(b.plugins) == 2 - - yield c.register_worker_plugin(plugin, name="foo") - assert len(a.plugins) == len(b.plugins) == 3 diff --git a/distributed/worker.py b/distributed/worker.py index fba4eed57b3..e06f224da68 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1515,6 +1515,7 @@ def transition(self, key, finish, **kwargs): self.task_state[key] = state or finish if self.validate: self.validate_key(key) + self._notify_transition(key, start, finish, **kwargs) def transition_waiting_ready(self, key): try: @@ -2293,15 +2294,16 @@ async def plugin_add(self, comm=None, plugin=None, name=None): self.plugins[name] = plugin logger.info("Starting Worker plugin %s" % name) - try: - result = plugin.setup(worker=self) - if hasattr(result, "__await__"): - result = await result - except Exception as e: - msg = error_message(e) - return msg - else: - return {"status": "OK"} + if hasattr(plugin, "setup"): + try: + result = plugin.setup(worker=self) + if hasattr(result, "__await__"): + result = await result + except Exception as e: + msg = error_message(e) + return msg + + return {"status": "OK"} async def actor_execute( self, comm=None, actor=None, function=None, args=(), kwargs={} @@ -2712,6 +2714,16 @@ def get_call_stack(self, comm=None, keys=None): result = {k: profile.call_stack(frame) for k, frame in frames.items()} return result + def _notify_transition(self, key, start, finish, **kwargs): + for name, plugin in self.plugins.items(): + if hasattr(plugin, "transition"): + try: + plugin.transition(key, start, finish, **kwargs) + except Exception: + logger.info( + "Plugin '%s' failed with exception" % name, exc_info=True + ) + ############## # Validation # ############## diff --git a/docs/source/plugins.rst b/docs/source/plugins.rst index b5f52f8843e..5c831fc167e 100644 --- a/docs/source/plugins.rst +++ b/docs/source/plugins.rst @@ -73,3 +73,9 @@ the scheduler as so: def dask_setup(scheduler): plugin = MyPlugin(scheduler) scheduler.add_plugin(plugin) + +Worker Plugins +================= + +.. autoclass:: distributed.diagnostics.plugin.WorkerPlugin + :members: From 7ca0c6b8bc5a6d54d93033e4f483cceafca75b9b Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 9 Oct 2019 15:38:43 -0500 Subject: [PATCH 0500/1550] Xfail test_worksapce_concurrency on Python 3.6 (#3132) --- distributed/tests/test_diskutils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/distributed/tests/test_diskutils.py b/distributed/tests/test_diskutils.py index e12fb324341..0057f96fb36 100644 --- a/distributed/tests/test_diskutils.py +++ b/distributed/tests/test_diskutils.py @@ -275,8 +275,8 @@ def _test_workspace_concurrency(tmpdir, timeout, max_procs): def test_workspace_concurrency(tmpdir): if WINDOWS: raise pytest.xfail.Exception("TODO: unknown failure on windows") - if sys.version_info < (3, 6): - raise pytest.xfail.Exception("TODO: unknown failure on Python 3.5") + if sys.version_info <= (3, 6): + raise pytest.xfail.Exception("TODO: unknown failure on Python 3.6") _test_workspace_concurrency(tmpdir, 2.0, 6) From c2cc1a98cbaadc3e7952ea66e5096f0126492539 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 10 Oct 2019 09:51:12 -0700 Subject: [PATCH 0501/1550] Add Nanny(config={...}) keyword (#3134) --- distributed/nanny.py | 16 +++++++++++++++- distributed/tests/test_nanny.py | 9 +++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/distributed/nanny.py b/distributed/nanny.py index 8fbbf761368..83ca2ebbf80 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -86,6 +86,7 @@ def __init__( host=None, port=None, protocol=None, + config=None, **worker_kwargs ): self._setup_logging(logger) @@ -123,6 +124,7 @@ def __init__( self.preload_argv = dask.config.get("distributed.worker.preload-argv") self.Worker = Worker if worker_class is None else worker_class self.env = env or {} + self.config = config or {} worker_kwargs.update( { "port": worker_port, @@ -304,6 +306,7 @@ async def instantiate(self, comm=None): on_exit=self._on_exit_sync, worker=self.Worker, env=self.env, + config=self.config, ) self.auto_restart = True @@ -437,7 +440,14 @@ async def close(self, comm=None, timeout=5, report=None): class WorkerProcess(object): def __init__( - self, worker_kwargs, worker_start_args, silence_logs, on_exit, worker, env + self, + worker_kwargs, + worker_start_args, + silence_logs, + on_exit, + worker, + env, + config, ): self.status = "init" self.silence_logs = silence_logs @@ -447,6 +457,7 @@ def __init__( self.process = None self.Worker = worker self.env = env + self.config = config # Initialized when worker is ready self.worker_dir = None @@ -479,6 +490,7 @@ async def start(self): uid=uid, Worker=self.Worker, env=self.env, + config=self.config, ), ) self.process.daemon = dask.config.get("distributed.worker.daemon", default=True) @@ -621,9 +633,11 @@ def _run( child_stop_q, uid, env, + config, Worker, ): # pragma: no cover os.environ.update(env) + dask.config.set(config) try: from dask.multiprocessing import initialize_worker_process except ImportError: # old Dask version diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index dec6bd91b20..d54cf4e3b14 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -436,3 +436,12 @@ async def test_nanny_closes_cleanly(cleanup): assert time() < start + 5 assert n.status == "closed" + + +@pytest.mark.asyncio +async def test_config(cleanup): + async with Scheduler() as s: + async with Nanny(s.address, config={"foo": "bar"}) as n: + async with Client(s.address, asynchronous=True) as client: + config = await client.run(dask.config.get, "foo") + assert config[n.worker_address] == "bar" From e39959e9b0e2e9746da0d1a53edeeea1931f9e47 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 11 Oct 2019 13:48:24 -0700 Subject: [PATCH 0502/1550] Only include metric in WorkerTable if it is a scalar (#3140) --- distributed/dashboard/scheduler.py | 7 ++++++- distributed/dashboard/tests/test_scheduler_bokeh.py | 5 +++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/distributed/dashboard/scheduler.py b/distributed/dashboard/scheduler.py index 27a49b4fd68..79484cd4196 100644 --- a/distributed/dashboard/scheduler.py +++ b/distributed/dashboard/scheduler.py @@ -1380,7 +1380,12 @@ def __init__(self, scheduler, width=800, **kwargs): ] workers = self.scheduler.workers.values() self.extra_names = sorted( - {m for ws in workers for m in ws.metrics if m not in self.names} + { + m + for ws in workers + for m, v in ws.metrics.items() + if m not in self.names and isinstance(v, (str, int, float)) + } - self.excluded_names ) diff --git a/distributed/dashboard/tests/test_scheduler_bokeh.py b/distributed/dashboard/tests/test_scheduler_bokeh.py index e68d7935583..1e48a3addec 100644 --- a/distributed/dashboard/tests/test_scheduler_bokeh.py +++ b/distributed/dashboard/tests/test_scheduler_bokeh.py @@ -305,6 +305,11 @@ def test_WorkerTable(c, s, a, b): wt = WorkerTable(s) wt.update() assert all(wt.source.data.values()) + assert all( + not v or isinstance(v, (str, int, float)) + for L in wt.source.data.values() + for v in L + ), {type(v).__name__ for L in wt.source.data.values() for v in L} assert all(len(v) == 2 for v in wt.source.data.values()) nthreads = wt.source.data["nthreads"] From 00d7f7dfff47cf2aa67dd220c15d3ccaebbabcc8 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 14 Oct 2019 07:55:24 -0700 Subject: [PATCH 0503/1550] Move death timeout logic up to Node.start (#3115) Previously there were some cases where the death-timeout logic wouldn't reliably be triggered. Now we handle it higher up in the call chain, where hopefully it will be more consistent. --- distributed/node.py | 19 ++++++++++++++++++- distributed/tests/test_worker.py | 5 ++++- distributed/worker.py | 16 ---------------- 3 files changed, 22 insertions(+), 18 deletions(-) diff --git a/distributed/node.py b/distributed/node.py index 8ef610a8481..2d7447b1a06 100644 --- a/distributed/node.py +++ b/distributed/node.py @@ -1,3 +1,4 @@ +import asyncio import logging import warnings import weakref @@ -163,7 +164,23 @@ def __await__(self): if self.status == "running": return gen.sleep(0).__await__() else: - return self.start().__await__() + future = self.start() + timeout = getattr(self, "death_timeout", 0) + if timeout: + + async def wait_for(future, timeout=None): + try: + await asyncio.wait_for(future, timeout=timeout) + except Exception: + await self.close(timeout=1) + raise gen.TimeoutError( + "{} failed to start in {} seconds".format( + type(self).__name__, timeout + ) + ) + + future = wait_for(future, timeout=timeout) + return future.__await__() async def start(self): # subclasses should implement this return self diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index eb4c0f86c7b..53aac46216a 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -762,9 +762,12 @@ def test_worker_death_timeout(s): yield s.close() w = Worker(s.address, death_timeout=1) - with pytest.raises(gen.TimeoutError): + with pytest.raises(gen.TimeoutError) as info: yield w + assert "Worker" in str(info.value) + assert "timed out" in str(info.value) or "failed to start" in str(info.value) + assert w.status == "closed" diff --git a/distributed/worker.py b/distributed/worker.py index e06f224da68..12dfe3fe178 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -784,17 +784,6 @@ async def _register_with_scheduler(self): self.contact_address = self.address logger.info("-" * 49) while True: - if self.death_timeout and time() > start + self.death_timeout: - logger.exception( - "Timed out when connecting to scheduler '%s'", - self.scheduler.address, - ) - await self.close(timeout=1) - raise gen.TimeoutError( - "Timed out connecting to scheduler '%s'" % self.scheduler.address - ) - if self.status in ("closed", "closing"): - return try: _start = time() types = {k: typename(v) for k, v in self.data.items()} @@ -826,11 +815,6 @@ async def _register_with_scheduler(self): serializers=["msgpack"], ) future = comm.read(deserializers=["msgpack"]) - if self.death_timeout: - diff = self.death_timeout - (time() - start) - if diff < 0: - continue - future = gen.with_timeout(timedelta(seconds=diff), future) response = await future _end = time() middle = (_start + _end) / 2 From eb7bccc7e3a202f11c17455740f3af4d913332c3 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 15 Oct 2019 12:06:59 -0700 Subject: [PATCH 0504/1550] Use setuptools.find_packages in setup.py (#3150) --- setup.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/setup.py b/setup.py index 84054d199e0..5d900199256 100755 --- a/setup.py +++ b/setup.py @@ -1,7 +1,7 @@ #!/usr/bin/env python import os -from setuptools import setup +from setuptools import setup, find_packages import versioneer requires = open("requirements.txt").read().strip().split("\n") @@ -35,15 +35,7 @@ include_package_data=True, install_requires=install_requires, extras_require=extras_require, - packages=[ - "distributed", - "distributed.dashboard", - "distributed.cli", - "distributed.comm", - "distributed.deploy", - "distributed.diagnostics", - "distributed.protocol", - ], + packages=find_packages(exclude=["*tests*"]), long_description=( open("README.rst").read() if os.path.exists("README.rst") else "" ), From c970ec0835648ea7ecced35b1909eddbddd925b4 Mon Sep 17 00:00:00 2001 From: Jacob Tomlinson Date: Tue, 15 Oct 2019 20:08:20 +0100 Subject: [PATCH 0505/1550] Refactor dashboard module (#3138) Closes #3048. In an attempt to tidy up the dashboard code I've done some major shuffling of things within the dashboard submodule. I think overall this puts things in more obvious places and should make it more accessible to new contributors. Notable changes: - Components have been broken out in a new `distributed.dashboard.components` submodule and placed into logical groupings "scheduler specific", "worker specific", "shared" and "nvml/gpu". This could be broken down further but there is a balance to be struck between indirection and giant scary files. - The additional server routes from `scheduler_html.py` and `worker_html.py` have been moved into the `scheduler.py` and `worker.py` files to keep all server things together. - Shared functions and utilities have been moved around into more appropriate places. Other things I'd like to do: - [x] Make use of the [update source function](https://github.com/dask/distributed/pull/3138/files#diff-0db0a1f6d00335e7cc6e5e94eae3f8a1R89-R116) outside of the scheduler components - [x] ~Re-order components to group `_doc` functions with their `DashboardComponent` classes to reduce indirection (or at least distance).~ _On second thoughts as some docs use multiple components having the docs together at the bottom actually makes more sense._ - [x] Rename functions for consistency (there are a few which should probably follow the `_doc` names) - [x] Identify and remove unused components. (Could use some pointers on this) --- distributed/client.py | 6 +- distributed/dashboard/components/__init__.py | 93 + .../dashboard/{ => components}/nvml.py | 27 +- distributed/dashboard/components/scheduler.py | 1873 ++++++++++++++++ .../{components.py => components/shared.py} | 295 +-- distributed/dashboard/components/worker.py | 661 ++++++ distributed/dashboard/scheduler.py | 1954 ++--------------- distributed/dashboard/scheduler_html.py | 269 --- .../dashboard/tests/test_components.py | 6 +- .../dashboard/tests/test_scheduler_bokeh.py | 28 +- .../dashboard/tests/test_worker_bokeh.py | 4 +- distributed/dashboard/utils.py | 46 + distributed/dashboard/worker.py | 818 +------ distributed/dashboard/worker_html.py | 108 - distributed/diagnostics/graph_layout.py | 4 +- 15 files changed, 3132 insertions(+), 3060 deletions(-) create mode 100644 distributed/dashboard/components/__init__.py rename distributed/dashboard/{ => components}/nvml.py (94%) create mode 100644 distributed/dashboard/components/scheduler.py rename distributed/dashboard/{components.py => components/shared.py} (75%) create mode 100644 distributed/dashboard/components/worker.py delete mode 100644 distributed/dashboard/scheduler_html.py delete mode 100644 distributed/dashboard/worker_html.py diff --git a/distributed/client.py b/distributed/client.py index 08e808acd15..ca7ca431c90 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -3861,7 +3861,7 @@ async def _get_task_stream( from .diagnostics.task_stream import rectangles rects = rectangles(msgs) - from .dashboard.components import task_stream_figure + from .dashboard.components.scheduler import task_stream_figure source, figure = task_stream_figure(sizing_mode="stretch_both") source.data.update(rects) @@ -4424,7 +4424,7 @@ class get_task_stream(object): To share this file with others you may wish to upload and serve it online. A common way to do this is to upload the file as a gist, and then serve it - on https://rawgit.com :: + on https://raw.githack.com :: $ pip install gist $ gist task-stream.html @@ -4432,7 +4432,7 @@ class get_task_stream(object): You can then navigate to that site, click the "Raw" button to the right of the ``task-stream.html`` file, and then provide that URL to - https://rawgit.com . This process should provide a sharable link that + https://raw.githack.com . This process should provide a sharable link that others can use to see your task stream plot. See Also diff --git a/distributed/dashboard/components/__init__.py b/distributed/dashboard/components/__init__.py new file mode 100644 index 00000000000..12f57b352b1 --- /dev/null +++ b/distributed/dashboard/components/__init__.py @@ -0,0 +1,93 @@ +import asyncio +from bisect import bisect +from operator import add +from time import time +import weakref + +from bokeh.layouts import row, column +from bokeh.models import ( + ColumnDataSource, + Plot, + DataRange1d, + LinearAxis, + HoverTool, + BoxZoomTool, + ResetTool, + PanTool, + WheelZoomTool, + Range1d, + Quad, + TapTool, + OpenURL, + Button, + Select, +) +from bokeh.palettes import Spectral9 +from bokeh.plotting import figure +import dask +from tornado import gen +import toolz + +from distributed.dashboard.utils import without_property_validation, BOKEH_VERSION +from distributed.diagnostics.progress_stream import nbytes_bar +from distributed import profile +from distributed.utils import log_errors, parse_timedelta + +if dask.config.get("distributed.dashboard.export-tool"): + from distributed.dashboard.export_tool import ExportTool +else: + ExportTool = None + + +profile_interval = dask.config.get("distributed.worker.profile.interval") +profile_interval = parse_timedelta(profile_interval, default="ms") + + +class DashboardComponent(object): + """ Base class for Dask.distributed UI dashboard components. + + This class must have two attributes, ``root`` and ``source``, and one + method ``update``: + + * source: a Bokeh ColumnDataSource + * root: a Bokeh Model + * update: a method that consumes the messages dictionary found in + distributed.bokeh.messages + """ + + def __init__(self): + self.source = None + self.root = None + + def update(self, messages): + """ Reads from bokeh.distributed.messages and updates self.source """ + + +def add_periodic_callback(doc, component, interval): + """ Add periodic callback to doc in a way that avoids reference cycles + + If we instead use ``doc.add_periodic_callback(component.update, 100)`` then + the component stays in memory as a reference cycle because its method is + still around. This way we avoid that and let things clean up a bit more + nicely. + + TODO: we still have reference cycles. Docs seem to be referred to by their + add_periodic_callback methods. + """ + ref = weakref.ref(component) + + doc.add_periodic_callback(lambda: update(ref), interval) + _attach(doc, component) + + +def update(ref): + comp = ref() + if comp is not None: + comp.update() + + +def _attach(doc, component): + if not hasattr(doc, "components"): + doc.components = set() + + doc.components.add(component) diff --git a/distributed/dashboard/nvml.py b/distributed/dashboard/components/nvml.py similarity index 94% rename from distributed/dashboard/nvml.py rename to distributed/dashboard/components/nvml.py index 131a02a8397..b0c56c4ef47 100644 --- a/distributed/dashboard/nvml.py +++ b/distributed/dashboard/components/nvml.py @@ -1,6 +1,6 @@ import math -from .components import DashboardComponent, add_periodic_callback +from distributed.dashboard.components import DashboardComponent, add_periodic_callback from bokeh.plotting import figure from bokeh.models import ( @@ -13,9 +13,17 @@ ) from tornado import escape from dask.utils import format_bytes -from ..utils import log_errors -from .scheduler import update, applications, BOKEH_THEME -from .utils import without_property_validation +from distributed.utils import log_errors +from distributed.dashboard.components.scheduler import BOKEH_THEME +from distributed.dashboard.utils import without_property_validation, update + + +try: + import pynvml + + pynvml.nvmlInit() +except Exception: + pass class GPUCurrentLoad(DashboardComponent): @@ -181,14 +189,3 @@ def gpu_utilization_doc(scheduler, extra, doc): add_periodic_callback(doc, gpu_load, 100) doc.add_root(gpu_load.utilization_figure) doc.theme = BOKEH_THEME - - -try: - import pynvml - - pynvml.nvmlInit() -except Exception: - pass -else: - applications["/individual-gpu-memory"] = gpu_memory_doc - applications["/individual-gpu-utilization"] = gpu_utilization_doc diff --git a/distributed/dashboard/components/scheduler.py b/distributed/dashboard/components/scheduler.py new file mode 100644 index 00000000000..5e94a034cb0 --- /dev/null +++ b/distributed/dashboard/components/scheduler.py @@ -0,0 +1,1873 @@ +import logging +import math +from numbers import Number +from operator import add +import os + +from bokeh.layouts import column, row +from bokeh.models import ( + ColumnDataSource, + ColorBar, + DataRange1d, + HoverTool, + ResetTool, + PanTool, + WheelZoomTool, + TapTool, + OpenURL, + Range1d, + Plot, + Quad, + Span, + value, + LinearAxis, + NumeralTickFormatter, + BoxZoomTool, + BasicTicker, + NumberFormatter, + BoxSelectTool, + GroupFilter, + CDSView, +) +from bokeh.models.widgets import DataTable, TableColumn +from bokeh.plotting import figure +from bokeh.palettes import Viridis11 +from bokeh.themes import Theme +from bokeh.transform import factor_cmap, linear_cmap +from bokeh.io import curdoc +import dask +from dask.utils import format_bytes +from toolz import pipe +from tornado import escape + +try: + import numpy as np +except ImportError: + np = False + +from distributed.dashboard.components import add_periodic_callback +from distributed.dashboard.components.shared import ( + DashboardComponent, + ProfileTimePlot, + ProfileServer, + SystemMonitor, +) +from distributed.dashboard.utils import ( + transpose, + BOKEH_VERSION, + PROFILING, + without_property_validation, + update, +) +from distributed.metrics import time +from distributed.utils import log_errors, format_time, parse_timedelta +from distributed.diagnostics.progress_stream import color_of, progress_quads, nbytes_bar +from distributed.diagnostics.progress import AllProgress +from distributed.diagnostics.graph_layout import GraphLayout +from distributed.diagnostics.task_stream import TaskStreamPlugin + +try: + from cytoolz.curried import map, concat, groupby, valmap +except ImportError: + from toolz.curried import map, concat, groupby, valmap + +if dask.config.get("distributed.dashboard.export-tool"): + from distributed.dashboard.export_tool import ExportTool +else: + ExportTool = None + +logger = logging.getLogger(__name__) + +from jinja2 import Environment, FileSystemLoader + +env = Environment( + loader=FileSystemLoader(os.path.join(os.path.dirname(__file__), "..", "templates")) +) + +BOKEH_THEME = Theme(os.path.join(os.path.dirname(__file__), "..", "theme.yaml")) + +nan = float("nan") +inf = float("inf") + + +class Occupancy(DashboardComponent): + """ Occupancy (in time) per worker """ + + def __init__(self, scheduler, **kwargs): + with log_errors(): + self.scheduler = scheduler + self.source = ColumnDataSource( + { + "occupancy": [0, 0], + "worker": ["a", "b"], + "x": [0.0, 0.1], + "y": [1, 2], + "ms": [1, 2], + "color": ["red", "blue"], + "escaped_worker": ["a", "b"], + } + ) + + fig = figure( + title="Occupancy", + tools="", + id="bk-occupancy-plot", + x_axis_type="datetime", + **kwargs + ) + rect = fig.rect( + source=self.source, x="x", width="ms", y="y", height=1, color="color" + ) + rect.nonselection_glyph = None + + fig.xaxis.minor_tick_line_alpha = 0 + fig.yaxis.visible = False + fig.ygrid.visible = False + # fig.xaxis[0].formatter = NumeralTickFormatter(format='0.0s') + fig.x_range.start = 0 + + tap = TapTool(callback=OpenURL(url="./info/worker/@escaped_worker.html")) + + hover = HoverTool() + hover.tooltips = "@worker : @occupancy s." + hover.point_policy = "follow_mouse" + fig.add_tools(hover, tap) + + self.root = fig + + @without_property_validation + def update(self): + with log_errors(): + workers = list(self.scheduler.workers.values()) + + y = list(range(len(workers))) + occupancy = [ws.occupancy for ws in workers] + ms = [occ * 1000 for occ in occupancy] + x = [occ / 500 for occ in occupancy] + total = sum(occupancy) + color = [] + for ws in workers: + if ws in self.scheduler.idle: + color.append("red") + elif ws in self.scheduler.saturated: + color.append("green") + else: + color.append("blue") + + if total: + self.root.title.text = "Occupancy -- total time: %s wall time: %s" % ( + format_time(total), + format_time(total / self.scheduler.total_nthreads), + ) + else: + self.root.title.text = "Occupancy" + + if occupancy: + result = { + "occupancy": occupancy, + "worker": [ws.address for ws in workers], + "ms": ms, + "color": color, + "escaped_worker": [escape.url_escape(ws.address) for ws in workers], + "x": x, + "y": y, + } + + update(self.source, result) + + +class ProcessingHistogram(DashboardComponent): + """ How many tasks are on each worker """ + + def __init__(self, scheduler, **kwargs): + with log_errors(): + self.last = 0 + self.scheduler = scheduler + self.source = ColumnDataSource( + {"left": [1, 2], "right": [10, 10], "top": [0, 0]} + ) + + self.root = figure( + title="Tasks Processing (Histogram)", + id="bk-nprocessing-histogram-plot", + name="processing_hist", + y_axis_label="frequency", + tools="", + **kwargs + ) + + self.root.xaxis.minor_tick_line_alpha = 0 + self.root.ygrid.visible = False + + self.root.toolbar.logo = None + self.root.toolbar_location = None + + self.root.quad( + source=self.source, + left="left", + right="right", + bottom=0, + top="top", + color="deepskyblue", + fill_alpha=0.5, + ) + + @without_property_validation + def update(self): + L = [len(ws.processing) for ws in self.scheduler.workers.values()] + counts, x = np.histogram(L, bins=40) + self.source.data.update({"left": x[:-1], "right": x[1:], "top": counts}) + + +class NBytesHistogram(DashboardComponent): + """ How many tasks are on each worker """ + + def __init__(self, scheduler, **kwargs): + with log_errors(): + self.last = 0 + self.scheduler = scheduler + self.source = ColumnDataSource( + {"left": [1, 2], "right": [10, 10], "top": [0, 0]} + ) + + self.root = figure( + title="Bytes Stored (Histogram)", + name="nbytes_hist", + id="bk-nbytes-histogram-plot", + y_axis_label="frequency", + tools="", + **kwargs + ) + + self.root.xaxis[0].formatter = NumeralTickFormatter(format="0.0 b") + self.root.xaxis.major_label_orientation = -math.pi / 12 + + self.root.xaxis.minor_tick_line_alpha = 0 + self.root.ygrid.visible = False + + self.root.toolbar.logo = None + self.root.toolbar_location = None + + self.root.quad( + source=self.source, + left="left", + right="right", + bottom=0, + top="top", + color="deepskyblue", + fill_alpha=0.5, + ) + + @without_property_validation + def update(self): + nbytes = np.asarray([ws.nbytes for ws in self.scheduler.workers.values()]) + counts, x = np.histogram(nbytes, bins=40) + d = {"left": x[:-1], "right": x[1:], "top": counts} + self.source.data.update(d) + + self.root.title.text = "Bytes stored (Histogram): " + format_bytes(nbytes.sum()) + + +class BandwidthTypes(DashboardComponent): + """ Bar chart showing bandwidth per type """ + + def __init__(self, scheduler, **kwargs): + with log_errors(): + self.last = 0 + self.scheduler = scheduler + self.source = ColumnDataSource( + { + "bandwidth": [1, 2], + "bandwidth-half": [0.5, 1], + "type": ["a", "b"], + "bandwidth_text": ["1", "2"], + } + ) + + fig = figure( + title="Bandwidth by Type", + tools="", + id="bk-bandwidth-type-plot", + name="bandwidth_type_histogram", + y_range=["a", "b"], + **kwargs + ) + rect = fig.rect( + source=self.source, + x="bandwidth-half", + y="type", + width="bandwidth", + height=1, + color="blue", + ) + fig.x_range.start = 0 + fig.xaxis[0].formatter = NumeralTickFormatter(format="0.0 b") + rect.nonselection_glyph = None + + fig.xaxis.minor_tick_line_alpha = 0 + fig.ygrid.visible = False + + fig.toolbar.logo = None + fig.toolbar_location = None + + hover = HoverTool() + hover.tooltips = "@type: @bandwidth_text / s" + hover.point_policy = "follow_mouse" + fig.add_tools(hover) + + self.fig = fig + + @without_property_validation + def update(self): + with log_errors(): + bw = self.scheduler.bandwidth_types + self.fig.y_range.factors = list(sorted(bw)) + result = { + "bandwidth": list(bw.values()), + "bandwidth-half": [b / 2 for b in bw.values()], + "type": list(bw.keys()), + "bandwidth_text": list(map(format_bytes, bw.values())), + } + self.fig.title.text = "Bandwidth: " + format_bytes(self.scheduler.bandwidth) + + update(self.source, result) + + +class BandwidthWorkers(DashboardComponent): + """ How many tasks are on each worker """ + + def __init__(self, scheduler, **kwargs): + with log_errors(): + self.last = 0 + self.scheduler = scheduler + self.source = ColumnDataSource( + { + "bandwidth": [1, 2], + "source": ["a", "b"], + "destination": ["a", "b"], + "bandwidth_text": ["1", "2"], + } + ) + + values = [hex(x)[2:] for x in range(64, 256)][::-1] + mapper = linear_cmap( + field_name="bandwidth", + palette=["#" + x + x + "FF" for x in values], + low=0, + high=1, + ) + + fig = figure( + title="Bandwidth by Worker", + tools="", + id="bk-bandwidth-worker-plot", + name="bandwidth_worker_heatmap", + x_range=["a", "b"], + y_range=["a", "b"], + **kwargs + ) + fig.xaxis.major_label_orientation = -math.pi / 12 + rect = fig.rect( + source=self.source, + x="source", + y="destination", + color=mapper, + height=1, + width=1, + ) + + self.color_map = mapper["transform"] + color_bar = ColorBar( + color_mapper=self.color_map, + label_standoff=12, + border_line_color=None, + location=(0, 0), + ) + color_bar.formatter = NumeralTickFormatter(format="0 b") + fig.add_layout(color_bar, "right") + + fig.toolbar.logo = None + fig.toolbar_location = None + + hover = HoverTool() + hover.tooltips = """ +
          +

          Source: @source

          +

          Destination: @destination

          +

          Bandwidth: @bandwidth_text / s

          +
          + """ + hover.point_policy = "follow_mouse" + fig.add_tools(hover) + + self.fig = fig + + @without_property_validation + def update(self): + with log_errors(): + bw = self.scheduler.bandwidth_workers + if not bw: + return + x, y, value = zip(*[(a, b, c) for (a, b), c in bw.items()]) + + if self.color_map.high < max(value): + self.color_map.high = max(value) + + factors = list(sorted(set(x + y))) + self.fig.x_range.factors = factors + self.fig.y_range.factors = factors + + result = { + "source": x, + "destination": y, + "bandwidth": value, + "bandwidth_text": list(map(format_bytes, value)), + } + self.fig.title.text = "Bandwidth: " + format_bytes(self.scheduler.bandwidth) + + update(self.source, result) + + +class CurrentLoad(DashboardComponent): + """ How many tasks are on each worker """ + + def __init__(self, scheduler, width=600, **kwargs): + with log_errors(): + self.last = 0 + self.scheduler = scheduler + self.source = ColumnDataSource( + { + "nprocessing": [1, 2], + "nprocessing-half": [0.5, 1], + "nprocessing-color": ["red", "blue"], + "nbytes": [1, 2], + "nbytes-half": [0.5, 1], + "nbytes_text": ["1B", "2B"], + "cpu": [1, 2], + "cpu-half": [0.5, 1], + "worker": ["a", "b"], + "y": [1, 2], + "nbytes-color": ["blue", "blue"], + "escaped_worker": ["a", "b"], + } + ) + + processing = figure( + title="Tasks Processing", + tools="", + id="bk-nprocessing-plot", + name="processing_hist", + width=int(width / 2), + **kwargs + ) + rect = processing.rect( + source=self.source, + x="nprocessing-half", + y="y", + width="nprocessing", + height=1, + color="nprocessing-color", + ) + processing.x_range.start = 0 + rect.nonselection_glyph = None + + nbytes = figure( + title="Bytes stored", + tools="", + id="bk-nbytes-worker-plot", + width=int(width / 2), + name="nbytes_hist", + **kwargs + ) + rect = nbytes.rect( + source=self.source, + x="nbytes-half", + y="y", + width="nbytes", + height=1, + color="nbytes-color", + ) + rect.nonselection_glyph = None + + cpu = figure( + title="CPU Utilization", + tools="", + id="bk-cpu-worker-plot", + width=int(width / 2), + name="cpu_hist", + **kwargs + ) + rect = cpu.rect( + source=self.source, + x="cpu-half", + y="y", + width="cpu", + height=1, + color="blue", + ) + rect.nonselection_glyph = None + hundred_span = Span( + location=100, + dimension="height", + line_color="gray", + line_dash="dashed", + line_width=3, + ) + cpu.add_layout(hundred_span) + + nbytes.axis[0].ticker = BasicTicker(mantissas=[1, 256, 512], base=1024) + nbytes.xaxis[0].formatter = NumeralTickFormatter(format="0.0 b") + nbytes.xaxis.major_label_orientation = -math.pi / 12 + nbytes.x_range.start = 0 + + for fig in [processing, nbytes]: + fig.xaxis.minor_tick_line_alpha = 0 + fig.yaxis.visible = False + fig.ygrid.visible = False + + tap = TapTool( + callback=OpenURL(url="./info/worker/@escaped_worker.html") + ) + fig.add_tools(tap) + + fig.toolbar.logo = None + fig.toolbar_location = None + fig.yaxis.visible = False + + hover = HoverTool() + hover.tooltips = "@worker : @nprocessing tasks" + hover.point_policy = "follow_mouse" + processing.add_tools(hover) + + hover = HoverTool() + hover.tooltips = "@worker : @nbytes_text" + hover.point_policy = "follow_mouse" + nbytes.add_tools(hover) + + hover = HoverTool() + hover.tooltips = "@worker : @cpu %" + hover.point_policy = "follow_mouse" + cpu.add_tools(hover) + + self.processing_figure = processing + self.nbytes_figure = nbytes + self.cpu_figure = cpu + + processing.y_range = nbytes.y_range + cpu.y_range = nbytes.y_range + + @without_property_validation + def update(self): + with log_errors(): + workers = list(self.scheduler.workers.values()) + + y = list(range(len(workers))) + + cpu = [int(ws.metrics["cpu"]) for ws in workers] + + nprocessing = [len(ws.processing) for ws in workers] + processing_color = [] + for ws in workers: + if ws in self.scheduler.idle: + processing_color.append("red") + elif ws in self.scheduler.saturated: + processing_color.append("green") + else: + processing_color.append("blue") + + nbytes = [ws.metrics["memory"] for ws in workers] + nbytes_text = [format_bytes(nb) for nb in nbytes] + nbytes_color = [] + max_limit = 0 + for ws, nb in zip(workers, nbytes): + limit = ( + getattr(self.scheduler.workers[ws.address], "memory_limit", inf) + or inf + ) + + if limit > max_limit: + max_limit = limit + + if nb > limit: + nbytes_color.append("red") + elif nb > limit / 2: + nbytes_color.append("orange") + else: + nbytes_color.append("blue") + + now = time() + if any(nprocessing) or self.last + 1 < now: + self.last = now + result = { + "cpu": cpu, + "cpu-half": [c / 2 for c in cpu], + "nprocessing": nprocessing, + "nprocessing-half": [np / 2 for np in nprocessing], + "nprocessing-color": processing_color, + "nbytes": nbytes, + "nbytes-half": [nb / 2 for nb in nbytes], + "nbytes-color": nbytes_color, + "nbytes_text": nbytes_text, + "worker": [ws.address for ws in workers], + "escaped_worker": [escape.url_escape(ws.address) for ws in workers], + "y": y, + } + + self.nbytes_figure.title.text = "Bytes stored: " + format_bytes( + sum(nbytes) + ) + self.nbytes_figure.x_range.end = max_limit + + update(self.source, result) + + +class StealingTimeSeries(DashboardComponent): + def __init__(self, scheduler, **kwargs): + self.scheduler = scheduler + self.source = ColumnDataSource( + {"time": [time(), time() + 1], "idle": [0, 0.1], "saturated": [0, 0.1]} + ) + + x_range = DataRange1d(follow="end", follow_interval=20000, range_padding=0) + + fig = figure( + title="Idle and Saturated Workers Over Time", + x_axis_type="datetime", + y_range=[-0.1, len(scheduler.workers) + 0.1], + height=150, + tools="", + x_range=x_range, + **kwargs + ) + fig.line(source=self.source, x="time", y="idle", color="red") + fig.line(source=self.source, x="time", y="saturated", color="green") + fig.yaxis.minor_tick_line_color = None + + fig.add_tools( + ResetTool(), PanTool(dimensions="width"), WheelZoomTool(dimensions="width") + ) + + self.root = fig + + @without_property_validation + def update(self): + with log_errors(): + result = { + "time": [time() * 1000], + "idle": [len(self.scheduler.idle)], + "saturated": [len(self.scheduler.saturated)], + } + if PROFILING: + curdoc().add_next_tick_callback( + lambda: self.source.stream(result, 10000) + ) + else: + self.source.stream(result, 10000) + + +class StealingEvents(DashboardComponent): + def __init__(self, scheduler, **kwargs): + self.scheduler = scheduler + self.steal = scheduler.extensions["stealing"] + self.last = 0 + self.source = ColumnDataSource( + { + "time": [time() - 20, time()], + "level": [0, 15], + "color": ["white", "white"], + "duration": [0, 0], + "radius": [1, 1], + "cost_factor": [0, 10], + "count": [1, 1], + } + ) + + x_range = DataRange1d(follow="end", follow_interval=20000, range_padding=0) + + fig = figure( + title="Stealing Events", + x_axis_type="datetime", + y_axis_type="log", + height=250, + tools="", + x_range=x_range, + **kwargs + ) + + fig.circle( + source=self.source, + x="time", + y="cost_factor", + color="color", + size="radius", + alpha=0.5, + ) + fig.yaxis.axis_label = "Cost Multiplier" + + hover = HoverTool() + hover.tooltips = "Level: @level, Duration: @duration, Count: @count, Cost factor: @cost_factor" + hover.point_policy = "follow_mouse" + + fig.add_tools( + hover, + ResetTool(), + PanTool(dimensions="width"), + WheelZoomTool(dimensions="width"), + ) + + self.root = fig + + def convert(self, msgs): + """ Convert a log message to a glyph """ + total_duration = 0 + for msg in msgs: + time, level, key, duration, sat, occ_sat, idl, occ_idl = msg + total_duration += duration + + try: + color = Viridis11[level] + except (KeyError, IndexError): + color = "black" + + radius = math.sqrt(min(total_duration, 10)) * 30 + 2 + + d = { + "time": time * 1000, + "level": level, + "count": len(msgs), + "color": color, + "duration": total_duration, + "radius": radius, + "cost_factor": min(10, self.steal.cost_multipliers[level]), + } + + return d + + @without_property_validation + def update(self): + with log_errors(): + log = self.steal.log + n = self.steal.count - self.last + log = [log[-i] for i in range(1, n + 1) if isinstance(log[-i], list)] + self.last = self.steal.count + + if log: + new = pipe( + log, + map(groupby(1)), + map(dict.values), + concat, + map(self.convert), + list, + transpose, + ) + if PROFILING: + curdoc().add_next_tick_callback( + lambda: self.source.stream(new, 10000) + ) + else: + self.source.stream(new, 10000) + + +class Events(DashboardComponent): + def __init__(self, scheduler, name, height=150, **kwargs): + self.scheduler = scheduler + self.action_ys = dict() + self.last = 0 + self.name = name + self.source = ColumnDataSource( + {"time": [], "action": [], "hover": [], "y": [], "color": []} + ) + + x_range = DataRange1d(follow="end", follow_interval=200000) + + fig = figure( + title=name, + x_axis_type="datetime", + height=height, + tools="", + x_range=x_range, + **kwargs + ) + + fig.circle( + source=self.source, + x="time", + y="y", + color="color", + size=50, + alpha=0.5, + legend="action", + ) + fig.yaxis.axis_label = "Action" + fig.legend.location = "top_left" + + hover = HoverTool() + hover.tooltips = "@action
          @hover" + hover.point_policy = "follow_mouse" + + fig.add_tools( + hover, + ResetTool(), + PanTool(dimensions="width"), + WheelZoomTool(dimensions="width"), + ) + + self.root = fig + + @without_property_validation + def update(self): + with log_errors(): + log = self.scheduler.events[self.name] + n = self.scheduler.event_counts[self.name] - self.last + if log: + log = [log[-i] for i in range(1, n + 1)] + self.last = self.scheduler.event_counts[self.name] + + if log: + actions = [] + times = [] + hovers = [] + ys = [] + colors = [] + for msg in log: + times.append(msg["time"] * 1000) + action = msg["action"] + actions.append(action) + try: + ys.append(self.action_ys[action]) + except KeyError: + self.action_ys[action] = len(self.action_ys) + ys.append(self.action_ys[action]) + colors.append(color_of(action)) + hovers.append("TODO") + + new = { + "time": times, + "action": actions, + "hover": hovers, + "y": ys, + "color": colors, + } + + if PROFILING: + curdoc().add_next_tick_callback( + lambda: self.source.stream(new, 10000) + ) + else: + self.source.stream(new, 10000) + + +class TaskStream(DashboardComponent): + def __init__(self, scheduler, n_rectangles=1000, clear_interval="20s", **kwargs): + self.scheduler = scheduler + self.offset = 0 + es = [p for p in self.scheduler.plugins if isinstance(p, TaskStreamPlugin)] + if not es: + self.plugin = TaskStreamPlugin(self.scheduler) + else: + self.plugin = es[0] + self.index = max(0, self.plugin.index - n_rectangles) + self.workers = dict() + self.n_rectangles = n_rectangles + clear_interval = parse_timedelta(clear_interval, default="ms") + self.clear_interval = clear_interval + self.last = 0 + + self.source, self.root = task_stream_figure(clear_interval, **kwargs) + + # Required for update callback + self.task_stream_index = [0] + + @without_property_validation + def update(self): + if self.index == self.plugin.index: + return + with log_errors(): + if self.index and len(self.source.data["start"]): + start = min(self.source.data["start"]) + duration = max(self.source.data["duration"]) + boundary = (self.offset + start - duration) / 1000 + else: + boundary = self.offset + rectangles = self.plugin.rectangles( + istart=self.index, workers=self.workers, start_boundary=boundary + ) + n = len(rectangles["name"]) + self.index = self.plugin.index + + if not rectangles["start"]: + return + + # If there has been a significant delay then clear old rectangles + first_end = min(map(add, rectangles["start"], rectangles["duration"])) + if first_end > self.last: + last = self.last + self.last = first_end + if first_end > last + self.clear_interval * 1000: + self.offset = min(rectangles["start"]) + self.source.data.update({k: [] for k in rectangles}) + + rectangles["start"] = [x - self.offset for x in rectangles["start"]] + + # Convert to numpy for serialization speed + if n >= 10 and np: + for k, v in rectangles.items(): + if isinstance(v[0], Number): + rectangles[k] = np.array(v) + + if PROFILING: + curdoc().add_next_tick_callback( + lambda: self.source.stream(rectangles, self.n_rectangles) + ) + else: + self.source.stream(rectangles, self.n_rectangles) + + +def task_stream_figure(clear_interval="20s", **kwargs): + """ + kwargs are applied to the bokeh.models.plots.Plot constructor + """ + clear_interval = parse_timedelta(clear_interval, default="ms") + + source = ColumnDataSource( + data=dict( + start=[time() - clear_interval], + duration=[0.1], + key=["start"], + name=["start"], + color=["white"], + duration_text=["100 ms"], + worker=["foo"], + y=[0], + worker_thread=[1], + alpha=[0.0], + ) + ) + + x_range = DataRange1d(range_padding=0) + y_range = DataRange1d(range_padding=0) + + root = figure( + name="task_stream", + title="Task Stream", + id="bk-task-stream-plot", + x_range=x_range, + y_range=y_range, + toolbar_location="above", + x_axis_type="datetime", + min_border_right=35, + tools="", + **kwargs + ) + + rect = root.rect( + source=source, + x="start", + y="y", + width="duration", + height=0.4, + fill_color="color", + line_color="color", + line_alpha=0.6, + fill_alpha="alpha", + line_width=3, + ) + rect.nonselection_glyph = None + + root.yaxis.major_label_text_alpha = 0 + root.yaxis.minor_tick_line_alpha = 0 + root.yaxis.major_tick_line_alpha = 0 + root.xgrid.visible = False + + hover = HoverTool( + point_policy="follow_mouse", + tooltips=""" +
          + @name:  + @duration_text +
          + """, + ) + + tap = TapTool(callback=OpenURL(url="/profile?key=@name")) + + root.add_tools( + hover, + tap, + BoxZoomTool(), + ResetTool(), + PanTool(dimensions="width"), + WheelZoomTool(dimensions="width"), + ) + if ExportTool: + export = ExportTool() + export.register_plot(root) + root.add_tools(export) + + return source, root + + +class TaskGraph(DashboardComponent): + """ + A dynamic node-link diagram for the task graph on the scheduler + + See also the GraphLayout diagnostic at + distributed/diagnostics/graph_layout.py + """ + + def __init__(self, scheduler, **kwargs): + self.scheduler = scheduler + self.layout = GraphLayout(scheduler) + self.invisible_count = 0 # number of invisible nodes + + self.node_source = ColumnDataSource( + {"x": [], "y": [], "name": [], "state": [], "visible": [], "key": []} + ) + self.edge_source = ColumnDataSource({"x": [], "y": [], "visible": []}) + + node_view = CDSView( + source=self.node_source, + filters=[GroupFilter(column_name="visible", group="True")], + ) + edge_view = CDSView( + source=self.edge_source, + filters=[GroupFilter(column_name="visible", group="True")], + ) + + node_colors = factor_cmap( + "state", + factors=["waiting", "processing", "memory", "released", "erred"], + palette=["gray", "green", "red", "blue", "black"], + ) + + self.root = figure(title="Task Graph", **kwargs) + self.root.multi_line( + xs="x", + ys="y", + source=self.edge_source, + line_width=1, + view=edge_view, + color="black", + alpha=0.3, + ) + rect = self.root.square( + x="x", + y="y", + size=10, + color=node_colors, + source=self.node_source, + view=node_view, + legend="state", + ) + self.root.xgrid.grid_line_color = None + self.root.ygrid.grid_line_color = None + + hover = HoverTool( + point_policy="follow_mouse", + tooltips="@name: @state", + renderers=[rect], + ) + tap = TapTool(callback=OpenURL(url="info/task/@key.html"), renderers=[rect]) + rect.nonselection_glyph = None + self.root.add_tools(hover, tap) + + @without_property_validation + def update(self): + with log_errors(): + # occasionally reset the column data source to remove old nodes + if self.invisible_count > len(self.node_source.data["x"]) / 2: + self.layout.reset_index() + self.invisible_count = 0 + update = True + else: + update = False + + new, self.layout.new = self.layout.new, [] + new_edges = self.layout.new_edges + self.layout.new_edges = [] + + self.add_new_nodes_edges(new, new_edges, update=update) + + self.patch_updates() + + @without_property_validation + def add_new_nodes_edges(self, new, new_edges, update=False): + if new or update: + node_key = [] + node_x = [] + node_y = [] + node_state = [] + node_name = [] + edge_x = [] + edge_y = [] + + x = self.layout.x + y = self.layout.y + + tasks = self.scheduler.tasks + for key in new: + try: + task = tasks[key] + except KeyError: + continue + xx = x[key] + yy = y[key] + node_key.append(escape.url_escape(key)) + node_x.append(xx) + node_y.append(yy) + node_state.append(task.state) + node_name.append(task.prefix) + + for a, b in new_edges: + try: + edge_x.append([x[a], x[b]]) + edge_y.append([y[a], y[b]]) + except KeyError: + pass + + node = { + "x": node_x, + "y": node_y, + "state": node_state, + "name": node_name, + "key": node_key, + "visible": ["True"] * len(node_x), + } + edge = {"x": edge_x, "y": edge_y, "visible": ["True"] * len(edge_x)} + + if update or not len(self.node_source.data["x"]): + # see https://github.com/bokeh/bokeh/issues/7523 + self.node_source.data.update(node) + self.edge_source.data.update(edge) + else: + self.node_source.stream(node) + self.edge_source.stream(edge) + + @without_property_validation + def patch_updates(self): + """ + Small updates like color changes or lost nodes from task transitions + """ + n = len(self.node_source.data["x"]) + m = len(self.edge_source.data["x"]) + + if self.layout.state_updates: + state_updates = self.layout.state_updates + self.layout.state_updates = [] + updates = [(i, c) for i, c in state_updates if i < n] + self.node_source.patch({"state": updates}) + + if self.layout.visible_updates: + updates = self.layout.visible_updates + updates = [(i, c) for i, c in updates if i < n] + self.visible_updates = [] + self.node_source.patch({"visible": updates}) + self.invisible_count += len(updates) + + if self.layout.visible_edge_updates: + updates = self.layout.visible_edge_updates + updates = [(i, c) for i, c in updates if i < m] + self.visible_updates = [] + self.edge_source.patch({"visible": updates}) + + def __del__(self): + self.scheduler.remove_plugin(self.layout) + + +class TaskProgress(DashboardComponent): + """ Progress bars per task type """ + + def __init__(self, scheduler, **kwargs): + self.scheduler = scheduler + ps = [p for p in scheduler.plugins if isinstance(p, AllProgress)] + if ps: + self.plugin = ps[0] + else: + self.plugin = AllProgress(scheduler) + + data = progress_quads( + dict(all={}, memory={}, erred={}, released={}, processing={}) + ) + self.source = ColumnDataSource(data=data) + + x_range = DataRange1d(range_padding=0) + y_range = Range1d(-8, 0) + + self.root = figure( + id="bk-task-progress-plot", + title="Progress", + name="task_progress", + x_range=x_range, + y_range=y_range, + toolbar_location=None, + tools="", + **kwargs + ) + self.root.line( # just to define early ranges + x=[0, 0.9], y=[-1, 0], line_color="#FFFFFF", alpha=0.0 + ) + self.root.quad( + source=self.source, + top="top", + bottom="bottom", + left="left", + right="right", + fill_color="#aaaaaa", + line_color="#aaaaaa", + fill_alpha=0.1, + line_alpha=0.3, + ) + self.root.quad( + source=self.source, + top="top", + bottom="bottom", + left="left", + right="released-loc", + fill_color="color", + line_color="color", + fill_alpha=0.6, + ) + self.root.quad( + source=self.source, + top="top", + bottom="bottom", + left="released-loc", + right="memory-loc", + fill_color="color", + line_color="color", + fill_alpha=1.0, + ) + self.root.quad( + source=self.source, + top="top", + bottom="bottom", + left="memory-loc", + right="erred-loc", + fill_color="black", + fill_alpha=0.5, + line_alpha=0, + ) + self.root.quad( + source=self.source, + top="top", + bottom="bottom", + left="erred-loc", + right="processing-loc", + fill_color="gray", + fill_alpha=0.35, + line_alpha=0, + ) + self.root.text( + source=self.source, + text="show-name", + y="bottom", + x="left", + x_offset=5, + text_font_size=value("10pt"), + ) + self.root.text( + source=self.source, + text="done", + y="bottom", + x="right", + x_offset=-5, + text_align="right", + text_font_size=value("10pt"), + ) + self.root.ygrid.visible = False + self.root.yaxis.minor_tick_line_alpha = 0 + self.root.yaxis.visible = False + self.root.xgrid.visible = False + self.root.xaxis.minor_tick_line_alpha = 0 + self.root.xaxis.visible = False + + hover = HoverTool( + point_policy="follow_mouse", + tooltips=""" +
          + Name:  + @name +
          +
          + All:  + @all +
          +
          + Memory:  + @memory +
          +
          + Erred:  + @erred +
          +
          + Ready:  + @processing +
          + """, + ) + self.root.add_tools(hover) + + @without_property_validation + def update(self): + with log_errors(): + state = {"all": valmap(len, self.plugin.all), "nbytes": self.plugin.nbytes} + for k in ["memory", "erred", "released", "processing", "waiting"]: + state[k] = valmap(len, self.plugin.state[k]) + if not state["all"] and not len(self.source.data["all"]): + return + + d = progress_quads(state) + + update(self.source, d) + + totals = { + k: sum(state[k].values()) + for k in ["all", "memory", "erred", "released", "waiting"] + } + totals["processing"] = totals["all"] - sum( + v for k, v in totals.items() if k != "all" + ) + + self.root.title.text = ( + "Progress -- total: %(all)s, " + "in-memory: %(memory)s, processing: %(processing)s, " + "waiting: %(waiting)s, " + "erred: %(erred)s" % totals + ) + + +class MemoryUse(DashboardComponent): + """ The memory usage across the cluster, grouped by task type """ + + def __init__(self, scheduler, **kwargs): + self.scheduler = scheduler + ps = [p for p in scheduler.plugins if isinstance(p, AllProgress)] + if ps: + self.plugin = ps[0] + else: + self.plugin = AllProgress(scheduler) + + self.source = ColumnDataSource( + data=dict( + name=[], + left=[], + right=[], + center=[], + color=[], + percent=[], + MB=[], + text=[], + ) + ) + + self.root = Plot( + id="bk-nbytes-plot", + x_range=DataRange1d(), + y_range=DataRange1d(), + toolbar_location=None, + outline_line_color=None, + **kwargs + ) + + self.root.add_glyph( + self.source, + Quad( + top=1, + bottom=0, + left="left", + right="right", + fill_color="color", + fill_alpha=1, + ), + ) + + self.root.add_layout(LinearAxis(), "left") + self.root.add_layout(LinearAxis(), "below") + + hover = HoverTool( + point_policy="follow_mouse", + tooltips=""" +
          + Name:  + @name +
          +
          + Percent:  + @percent +
          +
          + MB:  + @MB +
          + """, + ) + self.root.add_tools(hover) + + @without_property_validation + def update(self): + with log_errors(): + nb = nbytes_bar(self.plugin.nbytes) + update(self.source, nb) + self.root.title.text = "Memory Use: %0.2f MB" % ( + sum(self.plugin.nbytes.values()) / 1e6 + ) + + +class WorkerTable(DashboardComponent): + """ Status of the current workers + + This is two plots, a text-based table for each host and a thin horizontal + plot laying out hosts by their current memory use. + """ + + excluded_names = {"executing", "in_flight", "in_memory", "ready", "time"} + + def __init__(self, scheduler, width=800, **kwargs): + self.scheduler = scheduler + self.names = [ + "name", + "address", + "nthreads", + "cpu", + "memory", + "memory_limit", + "memory_percent", + "num_fds", + "read_bytes", + "write_bytes", + "cpu_fraction", + ] + workers = self.scheduler.workers.values() + self.extra_names = sorted( + { + m + for ws in workers + for m, v in ws.metrics.items() + if m not in self.names and isinstance(v, (str, int, float)) + } + - self.excluded_names + ) + + table_names = [ + "name", + "address", + "nthreads", + "cpu", + "memory", + "memory_limit", + "memory_percent", + "num_fds", + "read_bytes", + "write_bytes", + ] + + self.source = ColumnDataSource({k: [] for k in self.names}) + + columns = { + name: TableColumn(field=name, title=name.replace("_percent", " %")) + for name in table_names + } + + formatters = { + "cpu": NumberFormatter(format="0.0 %"), + "memory_percent": NumberFormatter(format="0.0 %"), + "memory": NumberFormatter(format="0 b"), + "memory_limit": NumberFormatter(format="0 b"), + "read_bytes": NumberFormatter(format="0 b"), + "write_bytes": NumberFormatter(format="0 b"), + "num_fds": NumberFormatter(format="0"), + "nthreads": NumberFormatter(format="0"), + } + + if BOKEH_VERSION < "0.12.15": + dt_kwargs = {"row_headers": False} + else: + dt_kwargs = {"index_position": None} + + table = DataTable( + source=self.source, + columns=[columns[n] for n in table_names], + reorderable=True, + sortable=True, + width=width, + **dt_kwargs + ) + + for name in table_names: + if name in formatters: + table.columns[table_names.index(name)].formatter = formatters[name] + + extra_names = ["name", "address"] + self.extra_names + extra_columns = { + name: TableColumn(field=name, title=name.replace("_percent", "%")) + for name in extra_names + } + + extra_table = DataTable( + source=self.source, + columns=[extra_columns[n] for n in extra_names], + reorderable=True, + sortable=True, + width=width, + **dt_kwargs + ) + + hover = HoverTool( + point_policy="follow_mouse", + tooltips=""" +
          + @worker: + @memory_percent +
          + """, + ) + + mem_plot = figure( + title="Memory Use (%)", + toolbar_location=None, + x_range=(0, 1), + y_range=(-0.1, 0.1), + height=60, + width=width, + tools="", + **kwargs + ) + mem_plot.circle( + source=self.source, x="memory_percent", y=0, size=10, fill_alpha=0.5 + ) + mem_plot.ygrid.visible = False + mem_plot.yaxis.minor_tick_line_alpha = 0 + mem_plot.xaxis.visible = False + mem_plot.yaxis.visible = False + mem_plot.add_tools(hover, BoxSelectTool()) + + hover = HoverTool( + point_policy="follow_mouse", + tooltips=""" +
          + @worker: + @cpu +
          + """, + ) + + cpu_plot = figure( + title="CPU Use (%)", + toolbar_location=None, + x_range=(0, 1), + y_range=(-0.1, 0.1), + height=60, + width=width, + tools="", + **kwargs + ) + cpu_plot.circle( + source=self.source, x="cpu_fraction", y=0, size=10, fill_alpha=0.5 + ) + cpu_plot.ygrid.visible = False + cpu_plot.yaxis.minor_tick_line_alpha = 0 + cpu_plot.xaxis.visible = False + cpu_plot.yaxis.visible = False + cpu_plot.add_tools(hover, BoxSelectTool()) + self.cpu_plot = cpu_plot + + if "sizing_mode" in kwargs: + sizing_mode = {"sizing_mode": kwargs["sizing_mode"]} + else: + sizing_mode = {} + + components = [cpu_plot, mem_plot, table] + if self.extra_names: + components.append(extra_table) + + self.root = column(*components, id="bk-worker-table", **sizing_mode) + + @without_property_validation + def update(self): + data = {name: [] for name in self.names + self.extra_names} + for i, (addr, ws) in enumerate( + sorted(self.scheduler.workers.items(), key=lambda kv: kv[1].name) + ): + for name in self.names + self.extra_names: + data[name].append(ws.metrics.get(name, None)) + data["name"][-1] = ws.name if ws.name is not None else i + data["address"][-1] = ws.address + if ws.memory_limit: + data["memory_percent"][-1] = ws.metrics["memory"] / ws.memory_limit + else: + data["memory_percent"][-1] = "" + data["memory_limit"][-1] = ws.memory_limit + data["cpu"][-1] = ws.metrics["cpu"] / 100.0 + data["cpu_fraction"][-1] = ws.metrics["cpu"] / 100.0 / ws.nthreads + data["nthreads"][-1] = ws.nthreads + + self.source.data.update(data) + + +def systemmonitor_doc(scheduler, extra, doc): + with log_errors(): + sysmon = SystemMonitor(scheduler, sizing_mode="stretch_both") + doc.title = "Dask: Scheduler System Monitor" + add_periodic_callback(doc, sysmon, 500) + + for subdoc in sysmon.root.children: + doc.add_root(subdoc) + doc.template = env.get_template("system.html") + doc.template_variables.update(extra) + doc.theme = BOKEH_THEME + + +def stealing_doc(scheduler, extra, doc): + with log_errors(): + occupancy = Occupancy(scheduler, height=200, sizing_mode="scale_width") + stealing_ts = StealingTimeSeries(scheduler, sizing_mode="scale_width") + stealing_events = StealingEvents(scheduler, sizing_mode="scale_width") + stealing_events.root.x_range = stealing_ts.root.x_range + doc.title = "Dask: Work Stealing" + add_periodic_callback(doc, occupancy, 500) + add_periodic_callback(doc, stealing_ts, 500) + add_periodic_callback(doc, stealing_events, 500) + + doc.add_root( + column( + occupancy.root, + stealing_ts.root, + stealing_events.root, + sizing_mode="scale_width", + ) + ) + + doc.template = env.get_template("simple.html") + doc.template_variables.update(extra) + doc.theme = BOKEH_THEME + + +def events_doc(scheduler, extra, doc): + with log_errors(): + events = Events(scheduler, "all", height=250) + events.update() + add_periodic_callback(doc, events, 500) + doc.title = "Dask: Scheduler Events" + doc.add_root(column(events.root, sizing_mode="scale_width")) + doc.template = env.get_template("simple.html") + doc.template_variables.update(extra) + doc.theme = BOKEH_THEME + + +def workers_doc(scheduler, extra, doc): + with log_errors(): + table = WorkerTable(scheduler) + table.update() + add_periodic_callback(doc, table, 500) + doc.title = "Dask: Workers" + doc.add_root(table.root) + doc.template = env.get_template("simple.html") + doc.template_variables.update(extra) + doc.theme = BOKEH_THEME + + +def tasks_doc(scheduler, extra, doc): + with log_errors(): + ts = TaskStream( + scheduler, + n_rectangles=dask.config.get( + "distributed.scheduler.dashboard.tasks.task-stream-length" + ), + clear_interval="60s", + sizing_mode="stretch_both", + ) + ts.update() + add_periodic_callback(doc, ts, 5000) + doc.title = "Dask: Task Stream" + doc.add_root(ts.root) + doc.template = env.get_template("simple.html") + doc.template_variables.update(extra) + doc.theme = BOKEH_THEME + + +def graph_doc(scheduler, extra, doc): + with log_errors(): + graph = TaskGraph(scheduler, sizing_mode="stretch_both") + doc.title = "Dask: Task Graph" + graph.update() + add_periodic_callback(doc, graph, 200) + doc.add_root(graph.root) + + doc.template = env.get_template("simple.html") + doc.template_variables.update(extra) + doc.theme = BOKEH_THEME + + +def status_doc(scheduler, extra, doc): + with log_errors(): + task_stream = TaskStream( + scheduler, + n_rectangles=dask.config.get( + "distributed.scheduler.dashboard.status.task-stream-length" + ), + clear_interval="10s", + sizing_mode="stretch_both", + ) + task_stream.update() + add_periodic_callback(doc, task_stream, 100) + + task_progress = TaskProgress(scheduler, sizing_mode="stretch_both") + task_progress.update() + add_periodic_callback(doc, task_progress, 100) + + if len(scheduler.workers) < 50: + current_load = CurrentLoad(scheduler, sizing_mode="stretch_both") + current_load.update() + add_periodic_callback(doc, current_load, 100) + doc.add_root(current_load.nbytes_figure) + doc.add_root(current_load.processing_figure) + else: + nbytes_hist = NBytesHistogram(scheduler, sizing_mode="stretch_both") + nbytes_hist.update() + processing_hist = ProcessingHistogram(scheduler, sizing_mode="stretch_both") + processing_hist.update() + add_periodic_callback(doc, nbytes_hist, 100) + add_periodic_callback(doc, processing_hist, 100) + current_load_fig = row( + nbytes_hist.root, processing_hist.root, sizing_mode="stretch_both" + ) + + doc.add_root(nbytes_hist.root) + doc.add_root(processing_hist.root) + + doc.title = "Dask: Status" + doc.add_root(task_progress.root) + doc.add_root(task_stream.root) + doc.theme = BOKEH_THEME + doc.template = env.get_template("status.html") + doc.template_variables.update(extra) + doc.theme = BOKEH_THEME + + +def individual_task_stream_doc(scheduler, extra, doc): + task_stream = TaskStream( + scheduler, n_rectangles=1000, clear_interval="10s", sizing_mode="stretch_both" + ) + task_stream.update() + add_periodic_callback(doc, task_stream, 100) + doc.add_root(task_stream.root) + doc.theme = BOKEH_THEME + + +def individual_nbytes_doc(scheduler, extra, doc): + current_load = CurrentLoad(scheduler, sizing_mode="stretch_both") + current_load.update() + add_periodic_callback(doc, current_load, 100) + doc.add_root(current_load.nbytes_figure) + doc.theme = BOKEH_THEME + + +def individual_memory_use_doc(scheduler, extra, doc): + memory_use = MemoryUse(scheduler, sizing_mode="stretch_both") + memory_use.update() + add_periodic_callback(doc, memory_use, 100) + doc.add_root(memory_use.root) + doc.theme = BOKEH_THEME + + +def individual_cpu_doc(scheduler, extra, doc): + current_load = CurrentLoad(scheduler, sizing_mode="stretch_both") + current_load.update() + add_periodic_callback(doc, current_load, 100) + doc.add_root(current_load.cpu_figure) + doc.theme = BOKEH_THEME + + +def individual_nprocessing_doc(scheduler, extra, doc): + current_load = CurrentLoad(scheduler, sizing_mode="stretch_both") + current_load.update() + add_periodic_callback(doc, current_load, 100) + doc.add_root(current_load.processing_figure) + doc.theme = BOKEH_THEME + + +def individual_progress_doc(scheduler, extra, doc): + task_progress = TaskProgress(scheduler, height=160, sizing_mode="stretch_both") + task_progress.update() + add_periodic_callback(doc, task_progress, 100) + doc.add_root(task_progress.root) + doc.theme = BOKEH_THEME + + +def individual_graph_doc(scheduler, extra, doc): + with log_errors(): + graph = TaskGraph(scheduler, sizing_mode="stretch_both") + graph.update() + + add_periodic_callback(doc, graph, 200) + doc.add_root(graph.root) + doc.theme = BOKEH_THEME + + +def individual_profile_doc(scheduler, extra, doc): + with log_errors(): + prof = ProfileTimePlot(scheduler, sizing_mode="scale_width", doc=doc) + doc.add_root(prof.root) + prof.trigger_update() + doc.theme = BOKEH_THEME + + +def individual_profile_server_doc(scheduler, extra, doc): + with log_errors(): + prof = ProfileServer(scheduler, sizing_mode="scale_width", doc=doc) + doc.add_root(prof.root) + prof.trigger_update() + doc.theme = BOKEH_THEME + + +def individual_workers_doc(scheduler, extra, doc): + with log_errors(): + table = WorkerTable(scheduler) + table.update() + add_periodic_callback(doc, table, 500) + doc.add_root(table.root) + doc.theme = BOKEH_THEME + + +def individual_bandwidth_types_doc(scheduler, extra, doc): + with log_errors(): + bw = BandwidthTypes(scheduler, sizing_mode="stretch_both") + bw.update() + add_periodic_callback(doc, bw, 500) + doc.add_root(bw.fig) + doc.theme = BOKEH_THEME + + +def individual_bandwidth_workers_doc(scheduler, extra, doc): + with log_errors(): + bw = BandwidthWorkers(scheduler, sizing_mode="stretch_both") + bw.update() + add_periodic_callback(doc, bw, 500) + doc.add_root(bw.fig) + doc.theme = BOKEH_THEME + + +def profile_doc(scheduler, extra, doc): + with log_errors(): + doc.title = "Dask: Profile" + prof = ProfileTimePlot(scheduler, sizing_mode="scale_width", doc=doc) + doc.add_root(prof.root) + doc.template = env.get_template("simple.html") + doc.template_variables.update(extra) + doc.theme = BOKEH_THEME + + prof.trigger_update() + + +def profile_server_doc(scheduler, extra, doc): + with log_errors(): + doc.title = "Dask: Profile of Event Loop" + prof = ProfileServer(scheduler, sizing_mode="scale_width", doc=doc) + doc.add_root(prof.root) + doc.template = env.get_template("simple.html") + doc.template_variables.update(extra) + doc.theme = BOKEH_THEME + + prof.trigger_update() diff --git a/distributed/dashboard/components.py b/distributed/dashboard/components/shared.py similarity index 75% rename from distributed/dashboard/components.py rename to distributed/dashboard/components/shared.py index 7fb8a6cb022..882db411434 100644 --- a/distributed/dashboard/components.py +++ b/distributed/dashboard/components/shared.py @@ -1,7 +1,4 @@ import asyncio -from bisect import bisect -from operator import add -from time import time import weakref from bokeh.layouts import row, column @@ -11,16 +8,11 @@ DataRange1d, LinearAxis, HoverTool, - BoxZoomTool, - ResetTool, - PanTool, - WheelZoomTool, Range1d, Quad, - TapTool, - OpenURL, Button, Select, + NumeralTickFormatter, ) from bokeh.palettes import Spectral9 from bokeh.plotting import figure @@ -28,13 +20,19 @@ from tornado import gen import toolz -from .utils import without_property_validation, BOKEH_VERSION -from ..diagnostics.progress_stream import nbytes_bar -from .. import profile -from ..utils import log_errors, parse_timedelta +from distributed.dashboard.components import DashboardComponent +from distributed.dashboard.utils import ( + without_property_validation, + BOKEH_VERSION, + update, +) +from distributed.diagnostics.progress_stream import nbytes_bar +from distributed import profile +from distributed.utils import log_errors, parse_timedelta +from distributed.compatibility import WINDOWS if dask.config.get("distributed.dashboard.export-tool"): - from .export_tool import ExportTool + from distributed.dashboard.export_tool import ExportTool else: ExportTool = None @@ -43,157 +41,6 @@ profile_interval = parse_timedelta(profile_interval, default="ms") -class DashboardComponent(object): - """ Base class for Dask.distributed UI dashboard components. - - This class must have two attributes, ``root`` and ``source``, and one - method ``update``: - - * source: a Bokeh ColumnDataSource - * root: a Bokeh Model - * update: a method that consumes the messages dictionary found in - distributed.bokeh.messages - """ - - def __init__(self): - self.source = None - self.root = None - - def update(self, messages): - """ Reads from bokeh.distributed.messages and updates self.source """ - - -class TaskStream(DashboardComponent): - """ Task Stream - - The start and stop time of tasks as they occur on each core of the cluster. - """ - - def __init__(self, n_rectangles=1000, clear_interval="20s", **kwargs): - """ - kwargs are applied to the bokeh.models.plots.Plot constructor - """ - self.n_rectangles = n_rectangles - clear_interval = parse_timedelta(clear_interval, default="ms") - self.clear_interval = clear_interval - self.last = 0 - - self.source, self.root = task_stream_figure(clear_interval, **kwargs) - - # Required for update callback - self.task_stream_index = [0] - - @without_property_validation - def update(self, messages): - with log_errors(): - index = messages["task-events"]["index"] - rectangles = messages["task-events"]["rectangles"] - - if not index or index[-1] == self.task_stream_index[0]: - return - - ind = bisect(index, self.task_stream_index[0]) - rectangles = { - k: [v[i] for i in range(ind, len(index))] for k, v in rectangles.items() - } - self.task_stream_index[0] = index[-1] - - # If there has been a significant delay then clear old rectangles - if rectangles["start"]: - m = min(map(add, rectangles["start"], rectangles["duration"])) - if m > self.last: - self.last, last = m, self.last - if m > last + self.clear_interval: - self.source.data.update(rectangles) - return - - self.source.stream(rectangles, self.n_rectangles) - - -def task_stream_figure(clear_interval="20s", **kwargs): - """ - kwargs are applied to the bokeh.models.plots.Plot constructor - """ - clear_interval = parse_timedelta(clear_interval, default="ms") - - source = ColumnDataSource( - data=dict( - start=[time() - clear_interval], - duration=[0.1], - key=["start"], - name=["start"], - color=["white"], - duration_text=["100 ms"], - worker=["foo"], - y=[0], - worker_thread=[1], - alpha=[0.0], - ) - ) - - x_range = DataRange1d(range_padding=0) - y_range = DataRange1d(range_padding=0) - - root = figure( - name="task_stream", - title="Task Stream", - id="bk-task-stream-plot", - x_range=x_range, - y_range=y_range, - toolbar_location="above", - x_axis_type="datetime", - min_border_right=35, - tools="", - **kwargs - ) - - rect = root.rect( - source=source, - x="start", - y="y", - width="duration", - height=0.4, - fill_color="color", - line_color="color", - line_alpha=0.6, - fill_alpha="alpha", - line_width=3, - ) - rect.nonselection_glyph = None - - root.yaxis.major_label_text_alpha = 0 - root.yaxis.minor_tick_line_alpha = 0 - root.yaxis.major_tick_line_alpha = 0 - root.xgrid.visible = False - - hover = HoverTool( - point_policy="follow_mouse", - tooltips=""" -
          - @name:  - @duration_text -
          - """, - ) - - tap = TapTool(callback=OpenURL(url="/profile?key=@name")) - - root.add_tools( - hover, - tap, - BoxZoomTool(), - ResetTool(), - PanTool(dimensions="width"), - WheelZoomTool(dimensions="width"), - ) - if ExportTool: - export = ExportTool() - export.register_plot(root) - root.add_tools(export) - - return source, root - - class MemoryUsage(DashboardComponent): """ The memory usage across the cluster, grouped by task type """ @@ -261,7 +108,7 @@ def update(self, messages): if not msg: return nb = nbytes_bar(msg["nbytes"]) - self.source.data.update(nb) + update(self.source, nb) self.root.title.text = "Memory Use: %0.2f MB" % ( sum(msg["nbytes"].values()) / 1e6 ) @@ -331,7 +178,7 @@ def update(self, messages): elif x_range.end > 2 * max_right + cores: # way out there, walk back x_range.end = x_range.end * 0.95 + max_right * 0.05 - self.source.data.update(data) + update(self.source, data) @staticmethod def processing_update(msg): @@ -383,7 +230,7 @@ def cb(attr, old, new): data = profile.plot_data(self.states[ind], profile_interval) del self.states[:] self.states.extend(data.pop("states")) - self.source.data.update(data) + update(self.source, data) self.source.selected = old if BOKEH_VERSION >= "1.0.0": @@ -397,7 +244,7 @@ def update(self, state): self.state = state data = profile.plot_data(self.state, profile_interval) self.states = data.pop("states") - self.source.data.update(data) + update(self.source, data) class ProfileTimePlot(DashboardComponent): @@ -450,7 +297,7 @@ def cb(attr, old, new): del self.states[:] self.states.extend(data.pop("states")) changing[0] = True # don't recursively trigger callback - self.source.data.update(data) + update(self.source, data) if isinstance(new, list): # bokeh >= 1.0 self.source.selected.indices = old else: @@ -532,7 +379,7 @@ def update(self, state, metadata=None): self.state = state data = profile.plot_data(self.state, profile_interval) self.states = data.pop("states") - self.source.data.update(data) + update(self.source, data) if metadata is not None and metadata["counts"]: self.task_names = ["All"] + sorted(metadata["keys"]) @@ -602,7 +449,7 @@ def cb(attr, old, new): del self.states[:] self.states.extend(data.pop("states")) changing[0] = True # don't recursively trigger callback - self.source.data.update(data) + update(self.source, data) if isinstance(new, list): # bokeh >= 1.0 self.source.selected.indices = old else: @@ -669,44 +516,104 @@ def update(self, state): self.state = state data = profile.plot_data(self.state, profile_interval) self.states = data.pop("states") - self.source.data.update(data) + update(self.source, data) @without_property_validation def trigger_update(self): self.state = profile.get_profile(self.log, start=self.start, stop=self.stop) data = profile.plot_data(self.state, profile_interval) self.states = data.pop("states") - self.source.data.update(data) + update(self.source, data) times = [t * 1000 for t, _ in self.log] counts = list(toolz.pluck("count", toolz.pluck(1, self.log))) self.ts_source.data.update({"time": times, "count": counts}) -def add_periodic_callback(doc, component, interval): - """ Add periodic callback to doc in a way that avoids reference cycles +class SystemMonitor(DashboardComponent): + def __init__(self, worker, height=150, **kwargs): + self.worker = worker - If we instead use ``doc.add_periodic_callback(component.update, 100)`` then - the component stays in memory as a reference cycle because its method is - still around. This way we avoid that and let things clean up a bit more - nicely. + names = worker.monitor.quantities + self.last = 0 + self.source = ColumnDataSource({name: [] for name in names}) + update(self.source, self.get_data()) - TODO: we still have reference cycles. Docs seem to be referred to by their - add_periodic_callback methods. - """ - ref = weakref.ref(component) + x_range = DataRange1d(follow="end", follow_interval=20000, range_padding=0) - doc.add_periodic_callback(lambda: update(ref), interval) - _attach(doc, component) + tools = "reset,xpan,xwheel_zoom" + self.cpu = figure( + title="CPU", + x_axis_type="datetime", + height=height, + tools=tools, + x_range=x_range, + **kwargs + ) + self.cpu.line(source=self.source, x="time", y="cpu") + self.cpu.yaxis.axis_label = "Percentage" + self.mem = figure( + title="Memory", + x_axis_type="datetime", + height=height, + tools=tools, + x_range=x_range, + **kwargs + ) + self.mem.line(source=self.source, x="time", y="memory") + self.mem.yaxis.axis_label = "Bytes" + self.bandwidth = figure( + title="Bandwidth", + x_axis_type="datetime", + height=height, + x_range=x_range, + tools=tools, + **kwargs + ) + self.bandwidth.line(source=self.source, x="time", y="read_bytes", color="red") + self.bandwidth.line(source=self.source, x="time", y="write_bytes", color="blue") + self.bandwidth.yaxis.axis_label = "Bytes / second" + + # self.cpu.yaxis[0].formatter = NumeralTickFormatter(format='0%') + self.bandwidth.yaxis[0].formatter = NumeralTickFormatter(format="0.0b") + self.mem.yaxis[0].formatter = NumeralTickFormatter(format="0.0b") + + plots = [self.cpu, self.mem, self.bandwidth] + + if not WINDOWS: + self.num_fds = figure( + title="Number of File Descriptors", + x_axis_type="datetime", + height=height, + x_range=x_range, + tools=tools, + **kwargs + ) + + self.num_fds.line(source=self.source, x="time", y="num_fds") + plots.append(self.num_fds) + + if "sizing_mode" in kwargs: + kw = {"sizing_mode": kwargs["sizing_mode"]} + else: + kw = {} -def update(ref): - comp = ref() - if comp is not None: - comp.update() + if not WINDOWS: + self.num_fds.y_range.start = 0 + self.mem.y_range.start = 0 + self.cpu.y_range.start = 0 + self.bandwidth.y_range.start = 0 + self.root = column(*plots, **kw) + self.worker.monitor.update() -def _attach(doc, component): - if not hasattr(doc, "components"): - doc.components = set() + def get_data(self): + d = self.worker.monitor.range_query(start=self.last) + d["time"] = [x * 1000 for x in d["time"]] + self.last = self.worker.monitor.count + return d - doc.components.add(component) + @without_property_validation + def update(self): + with log_errors(): + self.source.stream(self.get_data(), 1000) diff --git a/distributed/dashboard/components/worker.py b/distributed/dashboard/components/worker.py new file mode 100644 index 00000000000..9dc2b2ec82f --- /dev/null +++ b/distributed/dashboard/components/worker.py @@ -0,0 +1,661 @@ +import logging +import math +import os + +from bokeh.layouts import row, column, widgetbox +from bokeh.models import ( + ColumnDataSource, + DataRange1d, + HoverTool, + BoxZoomTool, + ResetTool, + PanTool, + WheelZoomTool, + NumeralTickFormatter, + Select, +) + +from bokeh.models.widgets import DataTable, TableColumn +from bokeh.plotting import figure +from bokeh.palettes import RdBu +from bokeh.themes import Theme +from dask.utils import format_bytes +from toolz import merge, partition_all + +from distributed.dashboard.components import add_periodic_callback +from distributed.dashboard.components.shared import ( + DashboardComponent, + ProfileTimePlot, + ProfileServer, + SystemMonitor, +) +from distributed.dashboard.utils import transpose, without_property_validation, update +from distributed.diagnostics.progress_stream import color_of +from distributed.metrics import time +from distributed.utils import log_errors, key_split, format_time + + +logger = logging.getLogger(__name__) + +with open(os.path.join(os.path.dirname(__file__), "..", "templates", "base.html")) as f: + template_source = f.read() + +from jinja2 import Environment, FileSystemLoader + +env = Environment( + loader=FileSystemLoader(os.path.join(os.path.dirname(__file__), "..", "templates")) +) + +BOKEH_THEME = Theme(os.path.join(os.path.dirname(__file__), "..", "theme.yaml")) + +template_variables = {"pages": ["status", "system", "profile", "crossfilter"]} + + +class StateTable(DashboardComponent): + """ Currently running tasks """ + + def __init__(self, worker): + self.worker = worker + + names = ["Stored", "Executing", "Ready", "Waiting", "Connections", "Serving"] + self.source = ColumnDataSource({name: [] for name in names}) + + columns = {name: TableColumn(field=name, title=name) for name in names} + + table = DataTable( + source=self.source, columns=[columns[n] for n in names], height=70 + ) + self.root = table + + @without_property_validation + def update(self): + with log_errors(): + w = self.worker + d = { + "Stored": [len(w.data)], + "Executing": ["%d / %d" % (len(w.executing), w.nthreads)], + "Ready": [len(w.ready)], + "Waiting": [len(w.waiting_for_data)], + "Connections": [len(w.in_flight_workers)], + "Serving": [len(w._comms)], + } + update(self.source, d) + + +class CommunicatingStream(DashboardComponent): + def __init__(self, worker, height=300, **kwargs): + with log_errors(): + self.worker = worker + names = [ + "start", + "stop", + "middle", + "duration", + "who", + "y", + "hover", + "alpha", + "bandwidth", + "total", + ] + + self.incoming = ColumnDataSource({name: [] for name in names}) + self.outgoing = ColumnDataSource({name: [] for name in names}) + + x_range = DataRange1d(range_padding=0) + y_range = DataRange1d(range_padding=0) + + fig = figure( + title="Peer Communications", + x_axis_type="datetime", + x_range=x_range, + y_range=y_range, + height=height, + tools="", + **kwargs + ) + + fig.rect( + source=self.incoming, + x="middle", + y="y", + width="duration", + height=0.9, + color="red", + alpha="alpha", + ) + fig.rect( + source=self.outgoing, + x="middle", + y="y", + width="duration", + height=0.9, + color="blue", + alpha="alpha", + ) + + hover = HoverTool(point_policy="follow_mouse", tooltips="""@hover""") + fig.add_tools( + hover, + ResetTool(), + PanTool(dimensions="width"), + WheelZoomTool(dimensions="width"), + ) + + self.root = fig + + self.last_incoming = 0 + self.last_outgoing = 0 + self.who = dict() + + @without_property_validation + def update(self): + with log_errors(): + outgoing = self.worker.outgoing_transfer_log + n = self.worker.outgoing_count - self.last_outgoing + outgoing = [outgoing[-i].copy() for i in range(1, n + 1)] + self.last_outgoing = self.worker.outgoing_count + + incoming = self.worker.incoming_transfer_log + n = self.worker.incoming_count - self.last_incoming + incoming = [incoming[-i].copy() for i in range(1, n + 1)] + self.last_incoming = self.worker.incoming_count + + for [msgs, source] in [ + [incoming, self.incoming], + [outgoing, self.outgoing], + ]: + + for msg in msgs: + if "compressed" in msg: + del msg["compressed"] + del msg["keys"] + + bandwidth = msg["total"] / (msg["duration"] or 0.5) + bw = max(min(bandwidth / 500e6, 1), 0.3) + msg["alpha"] = bw + try: + msg["y"] = self.who[msg["who"]] + except KeyError: + self.who[msg["who"]] = len(self.who) + msg["y"] = self.who[msg["who"]] + + msg["hover"] = "%s / %s = %s/s" % ( + format_bytes(msg["total"]), + format_time(msg["duration"]), + format_bytes(msg["total"] / msg["duration"]), + ) + + for k in ["middle", "duration", "start", "stop"]: + msg[k] = msg[k] * 1000 + + if msgs: + msgs = transpose(msgs) + if ( + len(source.data["stop"]) + and min(msgs["start"]) > source.data["stop"][-1] + 10000 + ): + source.data.update(msgs) + else: + source.stream(msgs, rollover=10000) + + +class CommunicatingTimeSeries(DashboardComponent): + def __init__(self, worker, **kwargs): + self.worker = worker + self.source = ColumnDataSource({"x": [], "in": [], "out": []}) + + x_range = DataRange1d(follow="end", follow_interval=20000, range_padding=0) + + fig = figure( + title="Communication History", + x_axis_type="datetime", + y_range=[-0.1, worker.total_out_connections + 0.5], + height=150, + tools="", + x_range=x_range, + **kwargs + ) + fig.line(source=self.source, x="x", y="in", color="red") + fig.line(source=self.source, x="x", y="out", color="blue") + + fig.add_tools( + ResetTool(), PanTool(dimensions="width"), WheelZoomTool(dimensions="width") + ) + + self.root = fig + + @without_property_validation + def update(self): + with log_errors(): + self.source.stream( + { + "x": [time() * 1000], + "out": [len(self.worker._comms)], + "in": [len(self.worker.in_flight_workers)], + }, + 10000, + ) + + +class ExecutingTimeSeries(DashboardComponent): + def __init__(self, worker, **kwargs): + self.worker = worker + self.source = ColumnDataSource({"x": [], "y": []}) + + x_range = DataRange1d(follow="end", follow_interval=20000, range_padding=0) + + fig = figure( + title="Executing History", + x_axis_type="datetime", + y_range=[-0.1, worker.nthreads + 0.1], + height=150, + tools="", + x_range=x_range, + **kwargs + ) + fig.line(source=self.source, x="x", y="y") + + fig.add_tools( + ResetTool(), PanTool(dimensions="width"), WheelZoomTool(dimensions="width") + ) + + self.root = fig + + @without_property_validation + def update(self): + with log_errors(): + self.source.stream( + {"x": [time() * 1000], "y": [len(self.worker.executing)]}, 1000 + ) + + +class CrossFilter(DashboardComponent): + def __init__(self, worker, **kwargs): + with log_errors(): + self.worker = worker + + quantities = ["nbytes", "duration", "bandwidth", "count", "start", "stop"] + colors = ["inout-color", "type-color", "key-color"] + + # self.source = ColumnDataSource({name: [] for name in names}) + self.source = ColumnDataSource( + { + "nbytes": [1, 2], + "duration": [0.01, 0.02], + "bandwidth": [0.01, 0.02], + "count": [1, 2], + "type": ["int", "str"], + "inout-color": ["blue", "red"], + "type-color": ["blue", "red"], + "key": ["add", "inc"], + "start": [1, 2], + "stop": [1, 2], + } + ) + + self.x = Select(title="X-Axis", value="nbytes", options=quantities) + self.x.on_change("value", self.update_figure) + + self.y = Select(title="Y-Axis", value="bandwidth", options=quantities) + self.y.on_change("value", self.update_figure) + + self.size = Select( + title="Size", value="None", options=["None"] + quantities + ) + self.size.on_change("value", self.update_figure) + + self.color = Select( + title="Color", value="inout-color", options=["black"] + colors + ) + self.color.on_change("value", self.update_figure) + + if "sizing_mode" in kwargs: + kw = {"sizing_mode": kwargs["sizing_mode"]} + else: + kw = {} + + self.control = widgetbox( + [self.x, self.y, self.size, self.color], width=200, **kw + ) + + self.last_outgoing = 0 + self.last_incoming = 0 + self.kwargs = kwargs + + self.layout = row(self.control, self.create_figure(**self.kwargs), **kw) + + self.root = self.layout + + @without_property_validation + def update(self): + with log_errors(): + outgoing = self.worker.outgoing_transfer_log + n = self.worker.outgoing_count - self.last_outgoing + n = min(n, 1000) + outgoing = [outgoing[-i].copy() for i in range(1, n)] + self.last_outgoing = self.worker.outgoing_count + + incoming = self.worker.incoming_transfer_log + n = self.worker.incoming_count - self.last_incoming + n = min(n, 1000) + incoming = [incoming[-i].copy() for i in range(1, n)] + self.last_incoming = self.worker.incoming_count + + out = [] + + for msg in incoming: + if msg["keys"]: + d = self.process_msg(msg) + d["inout-color"] = "red" + out.append(d) + + for msg in outgoing: + if msg["keys"]: + d = self.process_msg(msg) + d["inout-color"] = "blue" + out.append(d) + + if out: + out = transpose(out) + if ( + len(self.source.data["stop"]) + and min(out["start"]) > self.source.data["stop"][-1] + 10 + ): + update(self.source, out) + else: + self.source.stream(out, rollover=1000) + + def create_figure(self, **kwargs): + with log_errors(): + fig = figure(title="", tools="", **kwargs) + + size = self.size.value + if size == "None": + size = 1 + + fig.circle( + source=self.source, + x=self.x.value, + y=self.y.value, + color=self.color.value, + size=10, + alpha=0.5, + hover_alpha=1, + ) + fig.xaxis.axis_label = self.x.value + fig.yaxis.axis_label = self.y.value + + fig.add_tools( + # self.hover, + ResetTool(), + PanTool(), + WheelZoomTool(), + BoxZoomTool(), + ) + return fig + + @without_property_validation + def update_figure(self, attr, old, new): + with log_errors(): + fig = self.create_figure(**self.kwargs) + self.layout.children[1] = fig + + def process_msg(self, msg): + try: + + def func(k): + return msg["keys"].get(k, 0) + + status_key = max(msg["keys"], key=func) + typ = self.worker.types.get(status_key, object).__name__ + keyname = key_split(status_key) + d = { + "nbytes": msg["total"], + "duration": msg["duration"], + "bandwidth": msg["bandwidth"], + "count": len(msg["keys"]), + "type": typ, + "type-color": color_of(typ), + "key": keyname, + "key-color": color_of(keyname), + "start": msg["start"], + "stop": msg["stop"], + } + return d + except Exception as e: + logger.exception(e) + raise + + +class Counters(DashboardComponent): + def __init__(self, server, sizing_mode="stretch_both", **kwargs): + self.server = server + self.counter_figures = {} + self.counter_sources = {} + self.digest_figures = {} + self.digest_sources = {} + self.sizing_mode = sizing_mode + + if self.server.digests: + for name in self.server.digests: + self.add_digest_figure(name) + for name in self.server.counters: + self.add_counter_figure(name) + + figures = merge(self.digest_figures, self.counter_figures) + figures = [figures[k] for k in sorted(figures)] + + if len(figures) <= 5: + self.root = column(figures, sizing_mode=sizing_mode) + else: + self.root = column( + *[ + row(*pair, sizing_mode=sizing_mode) + for pair in partition_all(2, figures) + ], + sizing_mode=sizing_mode + ) + + def add_digest_figure(self, name): + with log_errors(): + n = len(self.server.digests[name].intervals) + sources = {i: ColumnDataSource({"x": [], "y": []}) for i in range(n)} + + kwargs = {} + if name.endswith("duration"): + kwargs["x_axis_type"] = "datetime" + + fig = figure( + title=name, tools="", height=150, sizing_mode=self.sizing_mode, **kwargs + ) + fig.yaxis.visible = False + fig.ygrid.visible = False + if name.endswith("bandwidth") or name.endswith("bytes"): + fig.xaxis[0].formatter = NumeralTickFormatter(format="0.0b") + + for i in range(n): + alpha = 0.3 + 0.3 * (n - i) / n + fig.line( + source=sources[i], + x="x", + y="y", + alpha=alpha, + color=RdBu[max(n, 3)][-i], + ) + + fig.xaxis.major_label_orientation = math.pi / 12 + fig.toolbar.logo = None + self.digest_sources[name] = sources + self.digest_figures[name] = fig + return fig + + def add_counter_figure(self, name): + with log_errors(): + n = len(self.server.counters[name].intervals) + sources = { + i: ColumnDataSource({"x": [], "y": [], "y-center": [], "counts": []}) + for i in range(n) + } + + fig = figure( + title=name, + tools="", + height=150, + sizing_mode=self.sizing_mode, + x_range=sorted(map(str, self.server.counters[name].components[0])), + ) + fig.ygrid.visible = False + + for i in range(n): + width = 0.5 + 0.4 * i / n + fig.rect( + source=sources[i], + x="x", + y="y-center", + width=width, + height="y", + alpha=0.3, + color=RdBu[max(n, 3)][-i], + ) + hover = HoverTool( + point_policy="follow_mouse", tooltips="""@x : @counts""" + ) + fig.add_tools(hover) + fig.xaxis.major_label_orientation = math.pi / 12 + + fig.toolbar.logo = None + + self.counter_sources[name] = sources + self.counter_figures[name] = fig + return fig + + @without_property_validation + def update(self): + with log_errors(): + for name, fig in self.digest_figures.items(): + digest = self.server.digests[name] + d = {} + for i, d in enumerate(digest.components): + if d.size(): + ys, xs = d.histogram(100) + xs = xs[1:] + if name.endswith("duration"): + xs *= 1000 + self.digest_sources[name][i].data.update({"x": xs, "y": ys}) + fig.title.text = "%s: %d" % (name, digest.size()) + + for name, fig in self.counter_figures.items(): + counter = self.server.counters[name] + d = {} + for i, d in enumerate(counter.components): + if d: + xs = sorted(d) + factor = counter.intervals[0] / counter.intervals[i] + counts = [d[x] for x in xs] + ys = [factor * c for c in counts] + y_centers = [y / 2 for y in ys] + xs = list(map(str, xs)) + d = {"x": xs, "y": ys, "y-center": y_centers, "counts": counts} + self.counter_sources[name][i].data.update(d) + fig.title.text = "%s: %d" % (name, counter.size()) + fig.x_range.factors = list(map(str, xs)) + + +def status_doc(worker, extra, doc): + with log_errors(): + statetable = StateTable(worker) + executing_ts = ExecutingTimeSeries(worker, sizing_mode="scale_width") + communicating_ts = CommunicatingTimeSeries(worker, sizing_mode="scale_width") + communicating_stream = CommunicatingStream(worker, sizing_mode="scale_width") + + xr = executing_ts.root.x_range + communicating_ts.root.x_range = xr + communicating_stream.root.x_range = xr + + doc.title = "Dask Worker Internal Monitor" + add_periodic_callback(doc, statetable, 200) + add_periodic_callback(doc, executing_ts, 200) + add_periodic_callback(doc, communicating_ts, 200) + add_periodic_callback(doc, communicating_stream, 200) + doc.add_root( + column( + statetable.root, + executing_ts.root, + communicating_ts.root, + communicating_stream.root, + sizing_mode="scale_width", + ) + ) + doc.template = env.get_template("simple.html") + doc.template_variables["active_page"] = "status" + doc.template_variables.update(extra) + doc.theme = BOKEH_THEME + + +def crossfilter_doc(worker, extra, doc): + with log_errors(): + statetable = StateTable(worker) + crossfilter = CrossFilter(worker) + + doc.title = "Dask Worker Cross-filter" + add_periodic_callback(doc, statetable, 500) + add_periodic_callback(doc, crossfilter, 500) + + doc.add_root(column(statetable.root, crossfilter.root)) + doc.template = env.get_template("simple.html") + doc.template_variables["active_page"] = "crossfilter" + doc.template_variables.update(extra) + doc.theme = BOKEH_THEME + + +def systemmonitor_doc(worker, extra, doc): + with log_errors(): + sysmon = SystemMonitor(worker, sizing_mode="scale_width") + doc.title = "Dask Worker Monitor" + add_periodic_callback(doc, sysmon, 500) + + doc.add_root(sysmon.root) + doc.template = env.get_template("simple.html") + doc.template_variables["active_page"] = "system" + doc.template_variables.update(extra) + doc.theme = BOKEH_THEME + + +def counters_doc(server, extra, doc): + with log_errors(): + doc.title = "Dask Worker Counters" + counter = Counters(server, sizing_mode="stretch_both") + add_periodic_callback(doc, counter, 500) + + doc.add_root(counter.root) + doc.template = env.get_template("simple.html") + doc.template_variables["active_page"] = "counters" + doc.template_variables.update(extra) + doc.theme = BOKEH_THEME + + +def profile_doc(server, extra, doc): + with log_errors(): + doc.title = "Dask Worker Profile" + profile = ProfileTimePlot(server, sizing_mode="scale_width", doc=doc) + profile.trigger_update() + + doc.add_root(profile.root) + doc.template = env.get_template("simple.html") + doc.template_variables["active_page"] = "profile" + doc.template_variables.update(extra) + doc.theme = BOKEH_THEME + + +def profile_server_doc(server, extra, doc): + with log_errors(): + doc.title = "Dask: Profile of Event Loop" + prof = ProfileServer(server, sizing_mode="scale_width", doc=doc) + doc.add_root(prof.root) + doc.template = env.get_template("simple.html") + # doc.template_variables['active_page'] = '' + doc.template_variables.update(extra) + doc.theme = BOKEH_THEME + + prof.trigger_update() diff --git a/distributed/dashboard/scheduler.py b/distributed/dashboard/scheduler.py index 79484cd4196..1117fe7bd72 100644 --- a/distributed/dashboard/scheduler.py +++ b/distributed/dashboard/scheduler.py @@ -1,43 +1,11 @@ +from datetime import datetime from functools import partial import logging -import math -from numbers import Number -from operator import add -import os - -from bokeh.layouts import column, row -from bokeh.models import ( - ColumnDataSource, - ColorBar, - DataRange1d, - HoverTool, - ResetTool, - PanTool, - WheelZoomTool, - TapTool, - OpenURL, - Range1d, - Plot, - Quad, - Span, - value, - LinearAxis, - NumeralTickFormatter, - BasicTicker, - NumberFormatter, - BoxSelectTool, - GroupFilter, - CDSView, -) -from bokeh.models.widgets import DataTable, TableColumn -from bokeh.plotting import figure -from bokeh.palettes import Viridis11 -from bokeh.themes import Theme -from bokeh.transform import factor_cmap, linear_cmap -from bokeh.io import curdoc + import dask from dask.utils import format_bytes -from toolz import pipe, merge +import toolz +from toolz import merge from tornado import escape try: @@ -45,1763 +13,303 @@ except ImportError: np = False -from . import components -from .components import ( - DashboardComponent, - ProfileTimePlot, - ProfileServer, - add_periodic_callback, +from .components.worker import counters_doc +from .components.scheduler import ( + systemmonitor_doc, + stealing_doc, + workers_doc, + events_doc, + tasks_doc, + status_doc, + profile_doc, + profile_server_doc, + graph_doc, + individual_task_stream_doc, + individual_progress_doc, + individual_graph_doc, + individual_profile_doc, + individual_profile_server_doc, + individual_nbytes_doc, + individual_memory_use_doc, + individual_cpu_doc, + individual_nprocessing_doc, + individual_workers_doc, + individual_bandwidth_types_doc, + individual_bandwidth_workers_doc, ) from .core import BokehServer -from .worker import SystemMonitor, counters_doc -from .utils import transpose, BOKEH_VERSION, without_property_validation -from ..metrics import time +from .worker import counters_doc +from .proxy import GlobalProxyHandler +from .utils import RequestHandler, redirect from ..utils import log_errors, format_time -from ..diagnostics.progress_stream import color_of, progress_quads, nbytes_bar -from ..diagnostics.progress import AllProgress -from ..diagnostics.graph_layout import GraphLayout -from ..diagnostics.task_stream import TaskStreamPlugin - -try: - from cytoolz.curried import map, concat, groupby, valmap, first -except ImportError: - from toolz.curried import map, concat, groupby, valmap, first -logger = logging.getLogger(__name__) +ns = { + func.__name__: func for func in [format_bytes, format_time, datetime.fromtimestamp] +} -PROFILING = False +rel_path_statics = {"rel_path_statics": "../../"} -from jinja2 import Environment, FileSystemLoader -env = Environment( - loader=FileSystemLoader(os.path.join(os.path.dirname(__file__), "templates")) -) +logger = logging.getLogger(__name__) template_variables = { "pages": ["status", "workers", "tasks", "system", "profile", "graph", "info"] } -BOKEH_THEME = Theme(os.path.join(os.path.dirname(__file__), "theme.yaml")) - -nan = float("nan") -inf = float("inf") - - -@without_property_validation -def update(source, data): - """ Update source with data - - This checks a few things first - - 1. If the data is the same, then don't update - 2. If numpy is available and the data is numeric, then convert to numpy - arrays - 3. If profiling then perform the update in another callback - """ - if not np or not any(isinstance(v, np.ndarray) for v in source.data.values()): - if source.data == data: - return - if np and len(data[first(data)]) > 10: - d = {} - for k, v in data.items(): - if type(v) is not np.ndarray and isinstance(v[0], Number): - d[k] = np.array(v) - else: - d[k] = v - else: - d = data - - if PROFILING: - curdoc().add_next_tick_callback(lambda: source.data.update(d)) - else: - source.data.update(d) - - -class Occupancy(DashboardComponent): - """ Occupancy (in time) per worker """ - def __init__(self, scheduler, **kwargs): +class Workers(RequestHandler): + def get(self): with log_errors(): - self.scheduler = scheduler - self.source = ColumnDataSource( - { - "occupancy": [0, 0], - "worker": ["a", "b"], - "x": [0.0, 0.1], - "y": [1, 2], - "ms": [1, 2], - "color": ["red", "blue"], - "escaped_worker": ["a", "b"], - } + self.render( + "workers.html", + title="Workers", + scheduler=self.server, + **toolz.merge(self.server.__dict__, ns, self.extra, rel_path_statics), ) - fig = figure( - title="Occupancy", - tools="", - id="bk-occupancy-plot", - x_axis_type="datetime", - **kwargs - ) - rect = fig.rect( - source=self.source, x="x", width="ms", y="y", height=1, color="color" - ) - rect.nonselection_glyph = None - - fig.xaxis.minor_tick_line_alpha = 0 - fig.yaxis.visible = False - fig.ygrid.visible = False - # fig.xaxis[0].formatter = NumeralTickFormatter(format='0.0s') - fig.x_range.start = 0 - - tap = TapTool(callback=OpenURL(url="./info/worker/@escaped_worker.html")) - - hover = HoverTool() - hover.tooltips = "@worker : @occupancy s." - hover.point_policy = "follow_mouse" - fig.add_tools(hover, tap) - - self.root = fig - - @without_property_validation - def update(self): - with log_errors(): - workers = list(self.scheduler.workers.values()) - - y = list(range(len(workers))) - occupancy = [ws.occupancy for ws in workers] - ms = [occ * 1000 for occ in occupancy] - x = [occ / 500 for occ in occupancy] - total = sum(occupancy) - color = [] - for ws in workers: - if ws in self.scheduler.idle: - color.append("red") - elif ws in self.scheduler.saturated: - color.append("green") - else: - color.append("blue") - - if total: - self.root.title.text = "Occupancy -- total time: %s wall time: %s" % ( - format_time(total), - format_time(total / self.scheduler.total_nthreads), - ) - else: - self.root.title.text = "Occupancy" - - if occupancy: - result = { - "occupancy": occupancy, - "worker": [ws.address for ws in workers], - "ms": ms, - "color": color, - "escaped_worker": [escape.url_escape(ws.address) for ws in workers], - "x": x, - "y": y, - } - update(self.source, result) - - -class ProcessingHistogram(DashboardComponent): - """ How many tasks are on each worker """ - - def __init__(self, scheduler, **kwargs): +class Worker(RequestHandler): + def get(self, worker): + worker = escape.url_unescape(worker) + if worker not in self.server.workers: + self.send_error(404) + return with log_errors(): - self.last = 0 - self.scheduler = scheduler - self.source = ColumnDataSource( - {"left": [1, 2], "right": [10, 10], "top": [0, 0]} + self.render( + "worker.html", + title="Worker: " + worker, + scheduler=self.server, + Worker=worker, + **toolz.merge(self.server.__dict__, ns, self.extra, rel_path_statics), ) - self.root = figure( - title="Tasks Processing (Histogram)", - id="bk-nprocessing-histogram-plot", - name="processing_hist", - y_axis_label="frequency", - tools="", - **kwargs - ) - - self.root.xaxis.minor_tick_line_alpha = 0 - self.root.ygrid.visible = False - self.root.toolbar.logo = None - self.root.toolbar_location = None - - self.root.quad( - source=self.source, - left="left", - right="right", - bottom=0, - top="top", - color="deepskyblue", - fill_alpha=0.5, - ) - - @without_property_validation - def update(self): - L = [len(ws.processing) for ws in self.scheduler.workers.values()] - counts, x = np.histogram(L, bins=40) - self.source.data.update({"left": x[:-1], "right": x[1:], "top": counts}) - - -class NBytesHistogram(DashboardComponent): - """ How many tasks are on each worker """ - - def __init__(self, scheduler, **kwargs): +class Task(RequestHandler): + def get(self, task): + task = escape.url_unescape(task) + if task not in self.server.tasks: + self.send_error(404) + return with log_errors(): - self.last = 0 - self.scheduler = scheduler - self.source = ColumnDataSource( - {"left": [1, 2], "right": [10, 10], "top": [0, 0]} - ) - - self.root = figure( - title="Bytes Stored (Histogram)", - name="nbytes_hist", - id="bk-nbytes-histogram-plot", - y_axis_label="frequency", - tools="", - **kwargs + self.render( + "task.html", + title="Task: " + task, + Task=task, + scheduler=self.server, + **toolz.merge(self.server.__dict__, ns, self.extra, rel_path_statics), ) - self.root.xaxis[0].formatter = NumeralTickFormatter(format="0.0 b") - self.root.xaxis.major_label_orientation = -math.pi / 12 - self.root.xaxis.minor_tick_line_alpha = 0 - self.root.ygrid.visible = False - - self.root.toolbar.logo = None - self.root.toolbar_location = None - - self.root.quad( - source=self.source, - left="left", - right="right", - bottom=0, - top="top", - color="deepskyblue", - fill_alpha=0.5, - ) - - @without_property_validation - def update(self): - nbytes = np.asarray([ws.nbytes for ws in self.scheduler.workers.values()]) - counts, x = np.histogram(nbytes, bins=40) - d = {"left": x[:-1], "right": x[1:], "top": counts} - self.source.data.update(d) - - self.root.title.text = "Bytes stored (Histogram): " + format_bytes(nbytes.sum()) - - -class BandwidthTypes(DashboardComponent): - """ Bar chart showing bandwidth per type """ - - def __init__(self, scheduler, **kwargs): +class Logs(RequestHandler): + def get(self): with log_errors(): - self.last = 0 - self.scheduler = scheduler - self.source = ColumnDataSource( - { - "bandwidth": [1, 2], - "bandwidth-half": [0.5, 1], - "type": ["a", "b"], - "bandwidth_text": ["1", "2"], - } + logs = self.server.get_logs() + self.render( + "logs.html", + title="Logs", + logs=logs, + **toolz.merge(self.extra, rel_path_statics), ) - fig = figure( - title="Bandwidth by Type", - tools="", - id="bk-bandwidth-type-plot", - name="bandwidth_type_histogram", - y_range=["a", "b"], - **kwargs - ) - rect = fig.rect( - source=self.source, - x="bandwidth-half", - y="type", - width="bandwidth", - height=1, - color="blue", - ) - fig.x_range.start = 0 - fig.xaxis[0].formatter = NumeralTickFormatter(format="0.0 b") - rect.nonselection_glyph = None - - fig.xaxis.minor_tick_line_alpha = 0 - fig.ygrid.visible = False - - fig.toolbar.logo = None - fig.toolbar_location = None - hover = HoverTool() - hover.tooltips = "@type: @bandwidth_text / s" - hover.point_policy = "follow_mouse" - fig.add_tools(hover) - - self.fig = fig - - @without_property_validation - def update(self): +class WorkerLogs(RequestHandler): + async def get(self, worker): with log_errors(): - bw = self.scheduler.bandwidth_types - self.fig.y_range.factors = list(sorted(bw)) - result = { - "bandwidth": list(bw.values()), - "bandwidth-half": [b / 2 for b in bw.values()], - "type": list(bw.keys()), - "bandwidth_text": list(map(format_bytes, bw.values())), - } - self.fig.title.text = "Bandwidth: " + format_bytes(self.scheduler.bandwidth) - - update(self.source, result) - - -class BandwidthWorkers(DashboardComponent): - """ How many tasks are on each worker """ - - def __init__(self, scheduler, **kwargs): - with log_errors(): - self.last = 0 - self.scheduler = scheduler - self.source = ColumnDataSource( - { - "bandwidth": [1, 2], - "source": ["a", "b"], - "destination": ["a", "b"], - "bandwidth_text": ["1", "2"], - } - ) - - values = [hex(x)[2:] for x in range(64, 256)][::-1] - mapper = linear_cmap( - field_name="bandwidth", - palette=["#" + x + x + "FF" for x in values], - low=0, - high=1, - ) - - fig = figure( - title="Bandwidth by Worker", - tools="", - id="bk-bandwidth-worker-plot", - name="bandwidth_worker_heatmap", - x_range=["a", "b"], - y_range=["a", "b"], - **kwargs + worker = escape.url_unescape(worker) + logs = await self.server.get_worker_logs(workers=[worker]) + logs = logs[worker] + self.render( + "logs.html", + title="Logs: " + worker, + logs=logs, + **toolz.merge(self.extra, rel_path_statics), ) - fig.xaxis.major_label_orientation = -math.pi / 12 - rect = fig.rect( - source=self.source, - x="source", - y="destination", - color=mapper, - height=1, - width=1, - ) - - self.color_map = mapper["transform"] - color_bar = ColorBar( - color_mapper=self.color_map, - label_standoff=12, - border_line_color=None, - location=(0, 0), - ) - color_bar.formatter = NumeralTickFormatter(format="0 b") - fig.add_layout(color_bar, "right") - - fig.toolbar.logo = None - fig.toolbar_location = None - - hover = HoverTool() - hover.tooltips = """ -
          -

          Source: @source

          -

          Destination: @destination

          -

          Bandwidth: @bandwidth_text / s

          -
          - """ - hover.point_policy = "follow_mouse" - fig.add_tools(hover) - - self.fig = fig - - @without_property_validation - def update(self): - with log_errors(): - bw = self.scheduler.bandwidth_workers - if not bw: - return - x, y, value = zip(*[(a, b, c) for (a, b), c in bw.items()]) - - if self.color_map.high < max(value): - self.color_map.high = max(value) - - factors = list(sorted(set(x + y))) - self.fig.x_range.factors = factors - self.fig.y_range.factors = factors - - result = { - "source": x, - "destination": y, - "bandwidth": value, - "bandwidth_text": list(map(format_bytes, value)), - } - self.fig.title.text = "Bandwidth: " + format_bytes(self.scheduler.bandwidth) - update(self.source, result) - -class CurrentLoad(DashboardComponent): - """ How many tasks are on each worker """ - - def __init__(self, scheduler, width=600, **kwargs): +class WorkerCallStacks(RequestHandler): + async def get(self, worker): with log_errors(): - self.last = 0 - self.scheduler = scheduler - self.source = ColumnDataSource( - { - "nprocessing": [1, 2], - "nprocessing-half": [0.5, 1], - "nprocessing-color": ["red", "blue"], - "nbytes": [1, 2], - "nbytes-half": [0.5, 1], - "nbytes_text": ["1B", "2B"], - "cpu": [1, 2], - "cpu-half": [0.5, 1], - "worker": ["a", "b"], - "y": [1, 2], - "nbytes-color": ["blue", "blue"], - "escaped_worker": ["a", "b"], - } - ) - - processing = figure( - title="Tasks Processing", - tools="", - id="bk-nprocessing-plot", - name="processing_hist", - width=int(width / 2), - **kwargs - ) - rect = processing.rect( - source=self.source, - x="nprocessing-half", - y="y", - width="nprocessing", - height=1, - color="nprocessing-color", - ) - processing.x_range.start = 0 - rect.nonselection_glyph = None - - nbytes = figure( - title="Bytes stored", - tools="", - id="bk-nbytes-worker-plot", - width=int(width / 2), - name="nbytes_hist", - **kwargs - ) - rect = nbytes.rect( - source=self.source, - x="nbytes-half", - y="y", - width="nbytes", - height=1, - color="nbytes-color", - ) - rect.nonselection_glyph = None - - cpu = figure( - title="CPU Utilization", - tools="", - id="bk-cpu-worker-plot", - width=int(width / 2), - name="cpu_hist", - **kwargs + worker = escape.url_unescape(worker) + keys = self.server.processing[worker] + call_stack = await self.server.get_call_stack(keys=keys) + self.render( + "call-stack.html", + title="Call Stacks: " + worker, + call_stack=call_stack, + **toolz.merge(self.extra, rel_path_statics), ) - rect = cpu.rect( - source=self.source, - x="cpu-half", - y="y", - width="cpu", - height=1, - color="blue", - ) - rect.nonselection_glyph = None - hundred_span = Span( - location=100, - dimension="height", - line_color="gray", - line_dash="dashed", - line_width=3, - ) - cpu.add_layout(hundred_span) - - nbytes.axis[0].ticker = BasicTicker(mantissas=[1, 256, 512], base=1024) - nbytes.xaxis[0].formatter = NumeralTickFormatter(format="0.0 b") - nbytes.xaxis.major_label_orientation = -math.pi / 12 - nbytes.x_range.start = 0 - - for fig in [processing, nbytes]: - fig.xaxis.minor_tick_line_alpha = 0 - fig.yaxis.visible = False - fig.ygrid.visible = False - - tap = TapTool( - callback=OpenURL(url="./info/worker/@escaped_worker.html") - ) - fig.add_tools(tap) - - fig.toolbar.logo = None - fig.toolbar_location = None - fig.yaxis.visible = False - hover = HoverTool() - hover.tooltips = "@worker : @nprocessing tasks" - hover.point_policy = "follow_mouse" - processing.add_tools(hover) - hover = HoverTool() - hover.tooltips = "@worker : @nbytes_text" - hover.point_policy = "follow_mouse" - nbytes.add_tools(hover) - - hover = HoverTool() - hover.tooltips = "@worker : @cpu %" - hover.point_policy = "follow_mouse" - cpu.add_tools(hover) - - self.processing_figure = processing - self.nbytes_figure = nbytes - self.cpu_figure = cpu - - processing.y_range = nbytes.y_range - cpu.y_range = nbytes.y_range - - @without_property_validation - def update(self): +class TaskCallStack(RequestHandler): + async def get(self, key): with log_errors(): - workers = list(self.scheduler.workers.values()) - - y = list(range(len(workers))) - - cpu = [int(ws.metrics["cpu"]) for ws in workers] - - nprocessing = [len(ws.processing) for ws in workers] - processing_color = [] - for ws in workers: - if ws in self.scheduler.idle: - processing_color.append("red") - elif ws in self.scheduler.saturated: - processing_color.append("green") - else: - processing_color.append("blue") - - nbytes = [ws.metrics["memory"] for ws in workers] - nbytes_text = [format_bytes(nb) for nb in nbytes] - nbytes_color = [] - max_limit = 0 - for ws, nb in zip(workers, nbytes): - limit = ( - getattr(self.scheduler.workers[ws.address], "memory_limit", inf) - or inf - ) - - if limit > max_limit: - max_limit = limit - - if nb > limit: - nbytes_color.append("red") - elif nb > limit / 2: - nbytes_color.append("orange") - else: - nbytes_color.append("blue") - - now = time() - if any(nprocessing) or self.last + 1 < now: - self.last = now - result = { - "cpu": cpu, - "cpu-half": [c / 2 for c in cpu], - "nprocessing": nprocessing, - "nprocessing-half": [np / 2 for np in nprocessing], - "nprocessing-color": processing_color, - "nbytes": nbytes, - "nbytes-half": [nb / 2 for nb in nbytes], - "nbytes-color": nbytes_color, - "nbytes_text": nbytes_text, - "worker": [ws.address for ws in workers], - "escaped_worker": [escape.url_escape(ws.address) for ws in workers], - "y": y, - } - - self.nbytes_figure.title.text = "Bytes stored: " + format_bytes( - sum(nbytes) - ) - self.nbytes_figure.x_range.end = max_limit - - update(self.source, result) - - -class StealingTimeSeries(DashboardComponent): - def __init__(self, scheduler, **kwargs): - self.scheduler = scheduler - self.source = ColumnDataSource( - {"time": [time(), time() + 1], "idle": [0, 0.1], "saturated": [0, 0.1]} - ) - - x_range = DataRange1d(follow="end", follow_interval=20000, range_padding=0) - - fig = figure( - title="Idle and Saturated Workers Over Time", - x_axis_type="datetime", - y_range=[-0.1, len(scheduler.workers) + 0.1], - height=150, - tools="", - x_range=x_range, - **kwargs - ) - fig.line(source=self.source, x="time", y="idle", color="red") - fig.line(source=self.source, x="time", y="saturated", color="green") - fig.yaxis.minor_tick_line_color = None - - fig.add_tools( - ResetTool(), PanTool(dimensions="width"), WheelZoomTool(dimensions="width") - ) - - self.root = fig - - @without_property_validation - def update(self): - with log_errors(): - result = { - "time": [time() * 1000], - "idle": [len(self.scheduler.idle)], - "saturated": [len(self.scheduler.saturated)], - } - if PROFILING: - curdoc().add_next_tick_callback( - lambda: self.source.stream(result, 10000) + key = escape.url_unescape(key) + call_stack = await self.server.get_call_stack(keys=[key]) + if not call_stack: + self.write( + "

          Task not actively running. " + "It may be finished or not yet started

          " ) else: - self.source.stream(result, 10000) - - -class StealingEvents(DashboardComponent): - def __init__(self, scheduler, **kwargs): - self.scheduler = scheduler - self.steal = scheduler.extensions["stealing"] - self.last = 0 - self.source = ColumnDataSource( - { - "time": [time() - 20, time()], - "level": [0, 15], - "color": ["white", "white"], - "duration": [0, 0], - "radius": [1, 1], - "cost_factor": [0, 10], - "count": [1, 1], - } - ) - - x_range = DataRange1d(follow="end", follow_interval=20000, range_padding=0) - - fig = figure( - title="Stealing Events", - x_axis_type="datetime", - y_axis_type="log", - height=250, - tools="", - x_range=x_range, - **kwargs - ) - - fig.circle( - source=self.source, - x="time", - y="cost_factor", - color="color", - size="radius", - alpha=0.5, - ) - fig.yaxis.axis_label = "Cost Multiplier" - - hover = HoverTool() - hover.tooltips = "Level: @level, Duration: @duration, Count: @count, Cost factor: @cost_factor" - hover.point_policy = "follow_mouse" - - fig.add_tools( - hover, - ResetTool(), - PanTool(dimensions="width"), - WheelZoomTool(dimensions="width"), - ) - - self.root = fig - - def convert(self, msgs): - """ Convert a log message to a glyph """ - total_duration = 0 - for msg in msgs: - time, level, key, duration, sat, occ_sat, idl, occ_idl = msg - total_duration += duration - - try: - color = Viridis11[level] - except (KeyError, IndexError): - color = "black" - - radius = math.sqrt(min(total_duration, 10)) * 30 + 2 - - d = { - "time": time * 1000, - "level": level, - "count": len(msgs), - "color": color, - "duration": total_duration, - "radius": radius, - "cost_factor": min(10, self.steal.cost_multipliers[level]), - } - - return d - - @without_property_validation - def update(self): - with log_errors(): - log = self.steal.log - n = self.steal.count - self.last - log = [log[-i] for i in range(1, n + 1) if isinstance(log[-i], list)] - self.last = self.steal.count - - if log: - new = pipe( - log, - map(groupby(1)), - map(dict.values), - concat, - map(self.convert), - list, - transpose, - ) - if PROFILING: - curdoc().add_next_tick_callback( - lambda: self.source.stream(new, 10000) - ) - else: - self.source.stream(new, 10000) - - -class Events(DashboardComponent): - def __init__(self, scheduler, name, height=150, **kwargs): - self.scheduler = scheduler - self.action_ys = dict() - self.last = 0 - self.name = name - self.source = ColumnDataSource( - {"time": [], "action": [], "hover": [], "y": [], "color": []} - ) - - x_range = DataRange1d(follow="end", follow_interval=200000) - - fig = figure( - title=name, - x_axis_type="datetime", - height=height, - tools="", - x_range=x_range, - **kwargs - ) - - fig.circle( - source=self.source, - x="time", - y="y", - color="color", - size=50, - alpha=0.5, - legend="action", - ) - fig.yaxis.axis_label = "Action" - fig.legend.location = "top_left" - - hover = HoverTool() - hover.tooltips = "@action
          @hover" - hover.point_policy = "follow_mouse" - - fig.add_tools( - hover, - ResetTool(), - PanTool(dimensions="width"), - WheelZoomTool(dimensions="width"), - ) - - self.root = fig - - @without_property_validation - def update(self): - with log_errors(): - log = self.scheduler.events[self.name] - n = self.scheduler.event_counts[self.name] - self.last - if log: - log = [log[-i] for i in range(1, n + 1)] - self.last = self.scheduler.event_counts[self.name] - - if log: - actions = [] - times = [] - hovers = [] - ys = [] - colors = [] - for msg in log: - times.append(msg["time"] * 1000) - action = msg["action"] - actions.append(action) - try: - ys.append(self.action_ys[action]) - except KeyError: - self.action_ys[action] = len(self.action_ys) - ys.append(self.action_ys[action]) - colors.append(color_of(action)) - hovers.append("TODO") - - new = { - "time": times, - "action": actions, - "hover": hovers, - "y": ys, - "color": colors, - } - - if PROFILING: - curdoc().add_next_tick_callback( - lambda: self.source.stream(new, 10000) - ) - else: - self.source.stream(new, 10000) - - -class TaskStream(components.TaskStream): - def __init__(self, scheduler, n_rectangles=1000, clear_interval="20s", **kwargs): - self.scheduler = scheduler - self.offset = 0 - es = [p for p in self.scheduler.plugins if isinstance(p, TaskStreamPlugin)] - if not es: - self.plugin = TaskStreamPlugin(self.scheduler) - else: - self.plugin = es[0] - self.index = max(0, self.plugin.index - n_rectangles) - self.workers = dict() - - components.TaskStream.__init__( - self, n_rectangles=n_rectangles, clear_interval=clear_interval, **kwargs - ) - - @without_property_validation - def update(self): - if self.index == self.plugin.index: - return - with log_errors(): - if self.index and len(self.source.data["start"]): - start = min(self.source.data["start"]) - duration = max(self.source.data["duration"]) - boundary = (self.offset + start - duration) / 1000 - else: - boundary = self.offset - rectangles = self.plugin.rectangles( - istart=self.index, workers=self.workers, start_boundary=boundary - ) - n = len(rectangles["name"]) - self.index = self.plugin.index - - if not rectangles["start"]: - return - - # If there has been a significant delay then clear old rectangles - first_end = min(map(add, rectangles["start"], rectangles["duration"])) - if first_end > self.last: - last = self.last - self.last = first_end - if first_end > last + self.clear_interval * 1000: - self.offset = min(rectangles["start"]) - self.source.data.update({k: [] for k in rectangles}) - - rectangles["start"] = [x - self.offset for x in rectangles["start"]] - - # Convert to numpy for serialization speed - if n >= 10 and np: - for k, v in rectangles.items(): - if isinstance(v[0], Number): - rectangles[k] = np.array(v) - - if PROFILING: - curdoc().add_next_tick_callback( - lambda: self.source.stream(rectangles, self.n_rectangles) + self.render( + "call-stack.html", + title="Call Stack: " + key, + call_stack=call_stack, + **toolz.merge(self.extra, rel_path_statics), ) - else: - self.source.stream(rectangles, self.n_rectangles) - - -class GraphPlot(DashboardComponent): - """ - A dynamic node-link diagram for the task graph on the scheduler - - See also the GraphLayout diagnostic at - distributed/diagnostics/graph_layout.py - """ - - def __init__(self, scheduler, **kwargs): - self.scheduler = scheduler - self.layout = GraphLayout(scheduler) - self.invisible_count = 0 # number of invisible nodes - - self.node_source = ColumnDataSource( - {"x": [], "y": [], "name": [], "state": [], "visible": [], "key": []} - ) - self.edge_source = ColumnDataSource({"x": [], "y": [], "visible": []}) - - node_view = CDSView( - source=self.node_source, - filters=[GroupFilter(column_name="visible", group="True")], - ) - edge_view = CDSView( - source=self.edge_source, - filters=[GroupFilter(column_name="visible", group="True")], - ) - - node_colors = factor_cmap( - "state", - factors=["waiting", "processing", "memory", "released", "erred"], - palette=["gray", "green", "red", "blue", "black"], - ) - - self.root = figure(title="Task Graph", **kwargs) - self.root.multi_line( - xs="x", - ys="y", - source=self.edge_source, - line_width=1, - view=edge_view, - color="black", - alpha=0.3, - ) - rect = self.root.square( - x="x", - y="y", - size=10, - color=node_colors, - source=self.node_source, - view=node_view, - legend="state", - ) - self.root.xgrid.grid_line_color = None - self.root.ygrid.grid_line_color = None - - hover = HoverTool( - point_policy="follow_mouse", - tooltips="@name: @state", - renderers=[rect], - ) - tap = TapTool(callback=OpenURL(url="info/task/@key.html"), renderers=[rect]) - rect.nonselection_glyph = None - self.root.add_tools(hover, tap) - - @without_property_validation - def update(self): - with log_errors(): - # occasionally reset the column data source to remove old nodes - if self.invisible_count > len(self.node_source.data["x"]) / 2: - self.layout.reset_index() - self.invisible_count = 0 - update = True - else: - update = False - - new, self.layout.new = self.layout.new, [] - new_edges = self.layout.new_edges - self.layout.new_edges = [] - - self.add_new_nodes_edges(new, new_edges, update=update) - - self.patch_updates() - - @without_property_validation - def add_new_nodes_edges(self, new, new_edges, update=False): - if new or update: - node_key = [] - node_x = [] - node_y = [] - node_state = [] - node_name = [] - edge_x = [] - edge_y = [] - - x = self.layout.x - y = self.layout.y - - tasks = self.scheduler.tasks - for key in new: - try: - task = tasks[key] - except KeyError: - continue - xx = x[key] - yy = y[key] - node_key.append(escape.url_escape(key)) - node_x.append(xx) - node_y.append(yy) - node_state.append(task.state) - node_name.append(task.prefix) - - for a, b in new_edges: - try: - edge_x.append([x[a], x[b]]) - edge_y.append([y[a], y[b]]) - except KeyError: - pass - - node = { - "x": node_x, - "y": node_y, - "state": node_state, - "name": node_name, - "key": node_key, - "visible": ["True"] * len(node_x), - } - edge = {"x": edge_x, "y": edge_y, "visible": ["True"] * len(edge_x)} - - if update or not len(self.node_source.data["x"]): - # see https://github.com/bokeh/bokeh/issues/7523 - self.node_source.data.update(node) - self.edge_source.data.update(edge) - else: - self.node_source.stream(node) - self.edge_source.stream(edge) - - @without_property_validation - def patch_updates(self): - """ - Small updates like color changes or lost nodes from task transitions - """ - n = len(self.node_source.data["x"]) - m = len(self.edge_source.data["x"]) - - if self.layout.state_updates: - state_updates = self.layout.state_updates - self.layout.state_updates = [] - updates = [(i, c) for i, c in state_updates if i < n] - self.node_source.patch({"state": updates}) - - if self.layout.visible_updates: - updates = self.layout.visible_updates - updates = [(i, c) for i, c in updates if i < n] - self.visible_updates = [] - self.node_source.patch({"visible": updates}) - self.invisible_count += len(updates) - - if self.layout.visible_edge_updates: - updates = self.layout.visible_edge_updates - updates = [(i, c) for i, c in updates if i < m] - self.visible_updates = [] - self.edge_source.patch({"visible": updates}) - - def __del__(self): - self.scheduler.remove_plugin(self.layout) - - -class TaskProgress(DashboardComponent): - """ Progress bars per task type """ - - def __init__(self, scheduler, **kwargs): - self.scheduler = scheduler - ps = [p for p in scheduler.plugins if isinstance(p, AllProgress)] - if ps: - self.plugin = ps[0] - else: - self.plugin = AllProgress(scheduler) - - data = progress_quads( - dict(all={}, memory={}, erred={}, released={}, processing={}) - ) - self.source = ColumnDataSource(data=data) - - x_range = DataRange1d(range_padding=0) - y_range = Range1d(-8, 0) - - self.root = figure( - id="bk-task-progress-plot", - title="Progress", - name="task_progress", - x_range=x_range, - y_range=y_range, - toolbar_location=None, - tools="", - **kwargs - ) - self.root.line( # just to define early ranges - x=[0, 0.9], y=[-1, 0], line_color="#FFFFFF", alpha=0.0 - ) - self.root.quad( - source=self.source, - top="top", - bottom="bottom", - left="left", - right="right", - fill_color="#aaaaaa", - line_color="#aaaaaa", - fill_alpha=0.1, - line_alpha=0.3, - ) - self.root.quad( - source=self.source, - top="top", - bottom="bottom", - left="left", - right="released-loc", - fill_color="color", - line_color="color", - fill_alpha=0.6, - ) - self.root.quad( - source=self.source, - top="top", - bottom="bottom", - left="released-loc", - right="memory-loc", - fill_color="color", - line_color="color", - fill_alpha=1.0, - ) - self.root.quad( - source=self.source, - top="top", - bottom="bottom", - left="memory-loc", - right="erred-loc", - fill_color="black", - fill_alpha=0.5, - line_alpha=0, - ) - self.root.quad( - source=self.source, - top="top", - bottom="bottom", - left="erred-loc", - right="processing-loc", - fill_color="gray", - fill_alpha=0.35, - line_alpha=0, - ) - self.root.text( - source=self.source, - text="show-name", - y="bottom", - x="left", - x_offset=5, - text_font_size=value("10pt"), - ) - self.root.text( - source=self.source, - text="done", - y="bottom", - x="right", - x_offset=-5, - text_align="right", - text_font_size=value("10pt"), - ) - self.root.ygrid.visible = False - self.root.yaxis.minor_tick_line_alpha = 0 - self.root.yaxis.visible = False - self.root.xgrid.visible = False - self.root.xaxis.minor_tick_line_alpha = 0 - self.root.xaxis.visible = False - - hover = HoverTool( - point_policy="follow_mouse", - tooltips=""" -
          - Name:  - @name -
          -
          - All:  - @all -
          -
          - Memory:  - @memory -
          -
          - Erred:  - @erred -
          -
          - Ready:  - @processing -
          - """, - ) - self.root.add_tools(hover) - @without_property_validation - def update(self): - with log_errors(): - state = {"all": valmap(len, self.plugin.all), "nbytes": self.plugin.nbytes} - for k in ["memory", "erred", "released", "processing", "waiting"]: - state[k] = valmap(len, self.plugin.state[k]) - if not state["all"] and not len(self.source.data["all"]): - return - - d = progress_quads(state) - - update(self.source, d) - - totals = { - k: sum(state[k].values()) - for k in ["all", "memory", "erred", "released", "waiting"] - } - totals["processing"] = totals["all"] - sum( - v for k, v in totals.items() if k != "all" - ) - - self.root.title.text = ( - "Progress -- total: %(all)s, " - "in-memory: %(memory)s, processing: %(processing)s, " - "waiting: %(waiting)s, " - "erred: %(erred)s" % totals - ) +class CountsJSON(RequestHandler): + def get(self): + scheduler = self.server + erred = 0 + nbytes = 0 + nthreads = 0 + memory = 0 + processing = 0 + released = 0 + waiting = 0 + waiting_data = 0 + + for ts in scheduler.tasks.values(): + if ts.exception_blame is not None: + erred += 1 + elif ts.state == "released": + released += 1 + if ts.waiting_on: + waiting += 1 + if ts.waiters: + waiting_data += 1 + for ws in scheduler.workers.values(): + nthreads += ws.nthreads + memory += len(ws.has_what) + nbytes += ws.nbytes + processing += len(ws.processing) + + response = { + "bytes": nbytes, + "clients": len(scheduler.clients), + "cores": nthreads, + "erred": erred, + "hosts": len(scheduler.host_info), + "idle": len(scheduler.idle), + "memory": memory, + "processing": processing, + "released": released, + "saturated": len(scheduler.saturated), + "tasks": len(scheduler.tasks), + "unrunnable": len(scheduler.unrunnable), + "waiting": waiting, + "waiting_data": waiting_data, + "workers": len(scheduler.workers), + } + self.write(response) -class MemoryUse(DashboardComponent): - """ The memory usage across the cluster, grouped by task type """ - - def __init__(self, scheduler, **kwargs): - self.scheduler = scheduler - ps = [p for p in scheduler.plugins if isinstance(p, AllProgress)] - if ps: - self.plugin = ps[0] - else: - self.plugin = AllProgress(scheduler) - - self.source = ColumnDataSource( - data=dict( - name=[], - left=[], - right=[], - center=[], - color=[], - percent=[], - MB=[], - text=[], - ) - ) - - self.root = Plot( - id="bk-nbytes-plot", - x_range=DataRange1d(), - y_range=DataRange1d(), - toolbar_location=None, - outline_line_color=None, - **kwargs - ) - self.root.add_glyph( - self.source, - Quad( - top=1, - bottom=0, - left="left", - right="right", - fill_color="color", - fill_alpha=1, - ), - ) +class IdentityJSON(RequestHandler): + def get(self): + self.write(self.server.identity()) - self.root.add_layout(LinearAxis(), "left") - self.root.add_layout(LinearAxis(), "below") - - hover = HoverTool( - point_policy="follow_mouse", - tooltips=""" -
          - Name:  - @name -
          -
          - Percent:  - @percent -
          -
          - MB:  - @MB -
          - """, - ) - self.root.add_tools(hover) - @without_property_validation - def update(self): +class IndexJSON(RequestHandler): + def get(self): with log_errors(): - nb = nbytes_bar(self.plugin.nbytes) - update(self.source, nb) - self.root.title.text = "Memory Use: %0.2f MB" % ( - sum(self.plugin.nbytes.values()) / 1e6 + r = [url for url, _ in routes if url.endswith(".json")] + self.render( + "json-index.html", routes=r, title="Index of JSON routes", **self.extra ) -class WorkerTable(DashboardComponent): - """ Status of the current workers - - This is two plots, a text-based table for each host and a thin horizontal - plot laying out hosts by their current memory use. - """ - - excluded_names = {"executing", "in_flight", "in_memory", "ready", "time"} - - def __init__(self, scheduler, width=800, **kwargs): - self.scheduler = scheduler - self.names = [ - "name", - "address", - "nthreads", - "cpu", - "memory", - "memory_limit", - "memory_percent", - "num_fds", - "read_bytes", - "write_bytes", - "cpu_fraction", - ] - workers = self.scheduler.workers.values() - self.extra_names = sorted( - { - m - for ws in workers - for m, v in ws.metrics.items() - if m not in self.names and isinstance(v, (str, int, float)) - } - - self.excluded_names - ) - - table_names = [ - "name", - "address", - "nthreads", - "cpu", - "memory", - "memory_limit", - "memory_percent", - "num_fds", - "read_bytes", - "write_bytes", - ] - - self.source = ColumnDataSource({k: [] for k in self.names}) - - columns = { - name: TableColumn(field=name, title=name.replace("_percent", " %")) - for name in table_names +class IndividualPlots(RequestHandler): + def get(self): + bokeh_server = self.server.services["dashboard"] + result = { + uri.strip("/").replace("-", " ").title(): uri + for uri in bokeh_server.apps + if uri.lstrip("/").startswith("individual-") and not uri.endswith(".json") } + self.write(result) - formatters = { - "cpu": NumberFormatter(format="0.0 %"), - "memory_percent": NumberFormatter(format="0.0 %"), - "memory": NumberFormatter(format="0 b"), - "memory_limit": NumberFormatter(format="0 b"), - "read_bytes": NumberFormatter(format="0 b"), - "write_bytes": NumberFormatter(format="0 b"), - "num_fds": NumberFormatter(format="0"), - "nthreads": NumberFormatter(format="0"), - } - if BOKEH_VERSION < "0.12.15": - dt_kwargs = {"row_headers": False} - else: - dt_kwargs = {"index_position": None} - - table = DataTable( - source=self.source, - columns=[columns[n] for n in table_names], - reorderable=True, - sortable=True, - width=width, - **dt_kwargs - ) +class _PrometheusCollector(object): + def __init__(self, server): + self.server = server - for name in table_names: - if name in formatters: - table.columns[table_names.index(name)].formatter = formatters[name] + def collect(self): + from prometheus_client.core import GaugeMetricFamily - extra_names = ["name", "address"] + self.extra_names - extra_columns = { - name: TableColumn(field=name, title=name.replace("_percent", "%")) - for name in extra_names - } - - extra_table = DataTable( - source=self.source, - columns=[extra_columns[n] for n in extra_names], - reorderable=True, - sortable=True, - width=width, - **dt_kwargs - ) - - hover = HoverTool( - point_policy="follow_mouse", - tooltips=""" -
          - @worker: - @memory_percent -
          - """, - ) - - mem_plot = figure( - title="Memory Use (%)", - toolbar_location=None, - x_range=(0, 1), - y_range=(-0.1, 0.1), - height=60, - width=width, - tools="", - **kwargs - ) - mem_plot.circle( - source=self.source, x="memory_percent", y=0, size=10, fill_alpha=0.5 + yield GaugeMetricFamily( + "dask_scheduler_workers", + "Number of workers connected.", + value=len(self.server.workers), ) - mem_plot.ygrid.visible = False - mem_plot.yaxis.minor_tick_line_alpha = 0 - mem_plot.xaxis.visible = False - mem_plot.yaxis.visible = False - mem_plot.add_tools(hover, BoxSelectTool()) - - hover = HoverTool( - point_policy="follow_mouse", - tooltips=""" -
          - @worker: - @cpu -
          - """, + yield GaugeMetricFamily( + "dask_scheduler_clients", + "Number of clients connected.", + value=len(self.server.clients), ) - - cpu_plot = figure( - title="CPU Use (%)", - toolbar_location=None, - x_range=(0, 1), - y_range=(-0.1, 0.1), - height=60, - width=width, - tools="", - **kwargs + yield GaugeMetricFamily( + "dask_scheduler_received_tasks", + "Number of tasks received at scheduler", + value=len(self.server.tasks), ) - cpu_plot.circle( - source=self.source, x="cpu_fraction", y=0, size=10, fill_alpha=0.5 + yield GaugeMetricFamily( + "dask_scheduler_unrunnable_tasks", + "Number of unrunnable tasks at scheduler", + value=len(self.server.unrunnable), ) - cpu_plot.ygrid.visible = False - cpu_plot.yaxis.minor_tick_line_alpha = 0 - cpu_plot.xaxis.visible = False - cpu_plot.yaxis.visible = False - cpu_plot.add_tools(hover, BoxSelectTool()) - self.cpu_plot = cpu_plot - - if "sizing_mode" in kwargs: - sizing_mode = {"sizing_mode": kwargs["sizing_mode"]} - else: - sizing_mode = {} - - components = [cpu_plot, mem_plot, table] - if self.extra_names: - components.append(extra_table) - - self.root = column(*components, id="bk-worker-table", **sizing_mode) - - @without_property_validation - def update(self): - data = {name: [] for name in self.names + self.extra_names} - for i, (addr, ws) in enumerate( - sorted(self.scheduler.workers.items(), key=lambda kv: kv[1].name) - ): - for name in self.names + self.extra_names: - data[name].append(ws.metrics.get(name, None)) - data["name"][-1] = ws.name if ws.name is not None else i - data["address"][-1] = ws.address - if ws.memory_limit: - data["memory_percent"][-1] = ws.metrics["memory"] / ws.memory_limit - else: - data["memory_percent"][-1] = "" - data["memory_limit"][-1] = ws.memory_limit - data["cpu"][-1] = ws.metrics["cpu"] / 100.0 - data["cpu_fraction"][-1] = ws.metrics["cpu"] / 100.0 / ws.nthreads - data["nthreads"][-1] = ws.nthreads - - self.source.data.update(data) - - -def systemmonitor_doc(scheduler, extra, doc): - with log_errors(): - sysmon = SystemMonitor(scheduler, sizing_mode="stretch_both") - doc.title = "Dask: Scheduler System Monitor" - add_periodic_callback(doc, sysmon, 500) - - for subdoc in sysmon.root.children: - doc.add_root(subdoc) - doc.template = env.get_template("system.html") - doc.template_variables.update(extra) - doc.theme = BOKEH_THEME - - -def stealing_doc(scheduler, extra, doc): - with log_errors(): - occupancy = Occupancy(scheduler, height=200, sizing_mode="scale_width") - stealing_ts = StealingTimeSeries(scheduler, sizing_mode="scale_width") - stealing_events = StealingEvents(scheduler, sizing_mode="scale_width") - stealing_events.root.x_range = stealing_ts.root.x_range - doc.title = "Dask: Work Stealing" - add_periodic_callback(doc, occupancy, 500) - add_periodic_callback(doc, stealing_ts, 500) - add_periodic_callback(doc, stealing_events, 500) - - doc.add_root( - column( - occupancy.root, - stealing_ts.root, - stealing_events.root, - sizing_mode="scale_width", - ) - ) - - doc.template = env.get_template("simple.html") - doc.template_variables.update(extra) - doc.theme = BOKEH_THEME - - -def events_doc(scheduler, extra, doc): - with log_errors(): - events = Events(scheduler, "all", height=250) - events.update() - add_periodic_callback(doc, events, 500) - doc.title = "Dask: Scheduler Events" - doc.add_root(column(events.root, sizing_mode="scale_width")) - doc.template = env.get_template("simple.html") - doc.template_variables.update(extra) - doc.theme = BOKEH_THEME - - -def workers_doc(scheduler, extra, doc): - with log_errors(): - table = WorkerTable(scheduler) - table.update() - add_periodic_callback(doc, table, 500) - doc.title = "Dask: Workers" - doc.add_root(table.root) - doc.template = env.get_template("simple.html") - doc.template_variables.update(extra) - doc.theme = BOKEH_THEME - - -def tasks_doc(scheduler, extra, doc): - with log_errors(): - ts = TaskStream( - scheduler, - n_rectangles=dask.config.get( - "distributed.scheduler.dashboard.tasks.task-stream-length" - ), - clear_interval="60s", - sizing_mode="stretch_both", - ) - ts.update() - add_periodic_callback(doc, ts, 5000) - doc.title = "Dask: Task Stream" - doc.add_root(ts.root) - doc.template = env.get_template("simple.html") - doc.template_variables.update(extra) - doc.theme = BOKEH_THEME - - -def graph_doc(scheduler, extra, doc): - with log_errors(): - graph = GraphPlot(scheduler, sizing_mode="stretch_both") - doc.title = "Dask: Task Graph" - graph.update() - add_periodic_callback(doc, graph, 200) - doc.add_root(graph.root) - - doc.template = env.get_template("simple.html") - doc.template_variables.update(extra) - doc.theme = BOKEH_THEME - - -def status_doc(scheduler, extra, doc): - with log_errors(): - task_stream = TaskStream( - scheduler, - n_rectangles=dask.config.get( - "distributed.scheduler.dashboard.status.task-stream-length" - ), - clear_interval="10s", - sizing_mode="stretch_both", - ) - task_stream.update() - add_periodic_callback(doc, task_stream, 100) - - task_progress = TaskProgress(scheduler, sizing_mode="stretch_both") - task_progress.update() - add_periodic_callback(doc, task_progress, 100) - - if len(scheduler.workers) < 50: - current_load = CurrentLoad(scheduler, sizing_mode="stretch_both") - current_load.update() - add_periodic_callback(doc, current_load, 100) - doc.add_root(current_load.nbytes_figure) - doc.add_root(current_load.processing_figure) - else: - nbytes_hist = NBytesHistogram(scheduler, sizing_mode="stretch_both") - nbytes_hist.update() - processing_hist = ProcessingHistogram(scheduler, sizing_mode="stretch_both") - processing_hist.update() - add_periodic_callback(doc, nbytes_hist, 100) - add_periodic_callback(doc, processing_hist, 100) - current_load_fig = row( - nbytes_hist.root, processing_hist.root, sizing_mode="stretch_both" - ) - - doc.add_root(nbytes_hist.root) - doc.add_root(processing_hist.root) - - doc.title = "Dask: Status" - doc.add_root(task_progress.root) - doc.add_root(task_stream.root) - doc.theme = BOKEH_THEME - doc.template = env.get_template("status.html") - doc.template_variables.update(extra) - doc.theme = BOKEH_THEME - - -def individual_task_stream_doc(scheduler, extra, doc): - task_stream = TaskStream( - scheduler, n_rectangles=1000, clear_interval="10s", sizing_mode="stretch_both" - ) - task_stream.update() - add_periodic_callback(doc, task_stream, 100) - doc.add_root(task_stream.root) - doc.theme = BOKEH_THEME -def individual_nbytes_doc(scheduler, extra, doc): - current_load = CurrentLoad(scheduler, sizing_mode="stretch_both") - current_load.update() - add_periodic_callback(doc, current_load, 100) - doc.add_root(current_load.nbytes_figure) - doc.theme = BOKEH_THEME - +class PrometheusHandler(RequestHandler): + _initialized = False -def individual_cpu_doc(scheduler, extra, doc): - current_load = CurrentLoad(scheduler, sizing_mode="stretch_both") - current_load.update() - add_periodic_callback(doc, current_load, 100) - doc.add_root(current_load.cpu_figure) - doc.theme = BOKEH_THEME + def __init__(self, *args, **kwargs): + import prometheus_client + super(PrometheusHandler, self).__init__(*args, **kwargs) -def individual_nprocessing_doc(scheduler, extra, doc): - current_load = CurrentLoad(scheduler, sizing_mode="stretch_both") - current_load.update() - add_periodic_callback(doc, current_load, 100) - doc.add_root(current_load.processing_figure) - doc.theme = BOKEH_THEME - - -def individual_progress_doc(scheduler, extra, doc): - task_progress = TaskProgress(scheduler, height=160, sizing_mode="stretch_both") - task_progress.update() - add_periodic_callback(doc, task_progress, 100) - doc.add_root(task_progress.root) - doc.theme = BOKEH_THEME - - -def individual_graph_doc(scheduler, extra, doc): - with log_errors(): - graph = GraphPlot(scheduler, sizing_mode="stretch_both") - graph.update() - - add_periodic_callback(doc, graph, 200) - doc.add_root(graph.root) - doc.theme = BOKEH_THEME - - -def individual_profile_doc(scheduler, extra, doc): - with log_errors(): - prof = ProfileTimePlot(scheduler, sizing_mode="scale_width", doc=doc) - doc.add_root(prof.root) - prof.trigger_update() - doc.theme = BOKEH_THEME - - -def individual_profile_server_doc(scheduler, extra, doc): - with log_errors(): - prof = ProfileServer(scheduler, sizing_mode="scale_width", doc=doc) - doc.add_root(prof.root) - prof.trigger_update() - doc.theme = BOKEH_THEME - - -def individual_workers_doc(scheduler, extra, doc): - with log_errors(): - table = WorkerTable(scheduler) - table.update() - add_periodic_callback(doc, table, 500) - doc.add_root(table.root) - doc.theme = BOKEH_THEME + if PrometheusHandler._initialized: + return + prometheus_client.REGISTRY.register(_PrometheusCollector(self.server)) -def individual_bandwidth_types(scheduler, extra, doc): - with log_errors(): - bw = BandwidthTypes(scheduler, sizing_mode="stretch_both") - bw.update() - add_periodic_callback(doc, bw, 500) - doc.add_root(bw.fig) - doc.theme = BOKEH_THEME + PrometheusHandler._initialized = True + def get(self): + import prometheus_client -def individual_bandwidth_workers(scheduler, extra, doc): - with log_errors(): - bw = BandwidthWorkers(scheduler, sizing_mode="stretch_both") - bw.update() - add_periodic_callback(doc, bw, 500) - doc.add_root(bw.fig) - doc.theme = BOKEH_THEME + self.write(prometheus_client.generate_latest()) + self.set_header("Content-Type", "text/plain; version=0.0.4") -def profile_doc(scheduler, extra, doc): - with log_errors(): - doc.title = "Dask: Profile" - prof = ProfileTimePlot(scheduler, sizing_mode="scale_width", doc=doc) - doc.add_root(prof.root) - doc.template = env.get_template("simple.html") - doc.template_variables.update(extra) - doc.theme = BOKEH_THEME +class HealthHandler(RequestHandler): + def get(self): + self.write("ok") + self.set_header("Content-Type", "text/plain") - prof.trigger_update() +routes = [ + (r"info", redirect("info/main/workers.html")), + (r"info/main/workers.html", Workers), + (r"info/worker/(.*).html", Worker), + (r"info/task/(.*).html", Task), + (r"info/main/logs.html", Logs), + (r"info/call-stacks/(.*).html", WorkerCallStacks), + (r"info/call-stack/(.*).html", TaskCallStack), + (r"info/logs/(.*).html", WorkerLogs), + (r"json/counts.json", CountsJSON), + (r"json/identity.json", IdentityJSON), + (r"json/index.html", IndexJSON), + (r"individual-plots.json", IndividualPlots), + (r"metrics", PrometheusHandler), + (r"health", HealthHandler), + (r"proxy/(\d+)/(.*?)/(.*)", GlobalProxyHandler), +] -def profile_server_doc(scheduler, extra, doc): - with log_errors(): - doc.title = "Dask: Profile of Event Loop" - prof = ProfileServer(scheduler, sizing_mode="scale_width", doc=doc) - doc.add_root(prof.root) - doc.template = env.get_template("simple.html") - doc.template_variables.update(extra) - doc.theme = BOKEH_THEME - prof.trigger_update() +def get_handlers(server): + return [(url, cls, {"server": server}) for url, cls in routes] class BokehScheduler(BokehServer): @@ -1853,8 +361,6 @@ def my_server(self): def listen(self, *args, **kwargs): super(BokehScheduler, self).listen(*args, **kwargs) - from .scheduler_html import routes - handlers = [ ( self.prefix + "/" + url, @@ -1884,11 +390,12 @@ def listen(self, *args, **kwargs): "/individual-profile": individual_profile_doc, "/individual-profile-server": individual_profile_server_doc, "/individual-nbytes": individual_nbytes_doc, + "/individual-memory-use": individual_memory_use_doc, "/individual-cpu": individual_cpu_doc, "/individual-nprocessing": individual_nprocessing_doc, "/individual-workers": individual_workers_doc, - "/individual-bandwidth-types": individual_bandwidth_types, - "/individual-bandwidth-workers": individual_bandwidth_workers, + "/individual-bandwidth-types": individual_bandwidth_types_doc, + "/individual-bandwidth-workers": individual_bandwidth_workers_doc, } try: @@ -1896,4 +403,7 @@ def listen(self, *args, **kwargs): except ImportError: pass else: - from . import nvml # noqa: 1708 + from .components.nvml import gpu_memory_doc, gpu_utilization_doc # noqa: 1708 + + applications["/individual-gpu-memory"] = gpu_memory_doc + applications["/individual-gpu-utilization"] = gpu_utilization_doc diff --git a/distributed/dashboard/scheduler_html.py b/distributed/dashboard/scheduler_html.py deleted file mode 100644 index 6ac13915523..00000000000 --- a/distributed/dashboard/scheduler_html.py +++ /dev/null @@ -1,269 +0,0 @@ -from datetime import datetime - -from dask.utils import format_bytes -import toolz -from tornado import escape - -from ..utils import log_errors, format_time -from .proxy import GlobalProxyHandler -from .utils import RequestHandler, redirect - -ns = { - func.__name__: func for func in [format_bytes, format_time, datetime.fromtimestamp] -} - -rel_path_statics = {"rel_path_statics": "../../"} - - -class Workers(RequestHandler): - def get(self): - with log_errors(): - self.render( - "workers.html", - title="Workers", - scheduler=self.server, - **toolz.merge(self.server.__dict__, ns, self.extra, rel_path_statics), - ) - - -class Worker(RequestHandler): - def get(self, worker): - worker = escape.url_unescape(worker) - if worker not in self.server.workers: - self.send_error(404) - return - with log_errors(): - self.render( - "worker.html", - title="Worker: " + worker, - scheduler=self.server, - Worker=worker, - **toolz.merge(self.server.__dict__, ns, self.extra, rel_path_statics), - ) - - -class Task(RequestHandler): - def get(self, task): - task = escape.url_unescape(task) - if task not in self.server.tasks: - self.send_error(404) - return - with log_errors(): - self.render( - "task.html", - title="Task: " + task, - Task=task, - scheduler=self.server, - **toolz.merge(self.server.__dict__, ns, self.extra, rel_path_statics), - ) - - -class Logs(RequestHandler): - def get(self): - with log_errors(): - logs = self.server.get_logs() - self.render( - "logs.html", - title="Logs", - logs=logs, - **toolz.merge(self.extra, rel_path_statics), - ) - - -class WorkerLogs(RequestHandler): - async def get(self, worker): - with log_errors(): - worker = escape.url_unescape(worker) - logs = await self.server.get_worker_logs(workers=[worker]) - logs = logs[worker] - self.render( - "logs.html", - title="Logs: " + worker, - logs=logs, - **toolz.merge(self.extra, rel_path_statics), - ) - - -class WorkerCallStacks(RequestHandler): - async def get(self, worker): - with log_errors(): - worker = escape.url_unescape(worker) - keys = self.server.processing[worker] - call_stack = await self.server.get_call_stack(keys=keys) - self.render( - "call-stack.html", - title="Call Stacks: " + worker, - call_stack=call_stack, - **toolz.merge(self.extra, rel_path_statics), - ) - - -class TaskCallStack(RequestHandler): - async def get(self, key): - with log_errors(): - key = escape.url_unescape(key) - call_stack = await self.server.get_call_stack(keys=[key]) - if not call_stack: - self.write( - "

          Task not actively running. " - "It may be finished or not yet started

          " - ) - else: - self.render( - "call-stack.html", - title="Call Stack: " + key, - call_stack=call_stack, - **toolz.merge(self.extra, rel_path_statics), - ) - - -class CountsJSON(RequestHandler): - def get(self): - scheduler = self.server - erred = 0 - nbytes = 0 - nthreads = 0 - memory = 0 - processing = 0 - released = 0 - waiting = 0 - waiting_data = 0 - - for ts in scheduler.tasks.values(): - if ts.exception_blame is not None: - erred += 1 - elif ts.state == "released": - released += 1 - if ts.waiting_on: - waiting += 1 - if ts.waiters: - waiting_data += 1 - for ws in scheduler.workers.values(): - nthreads += ws.nthreads - memory += len(ws.has_what) - nbytes += ws.nbytes - processing += len(ws.processing) - - response = { - "bytes": nbytes, - "clients": len(scheduler.clients), - "cores": nthreads, - "erred": erred, - "hosts": len(scheduler.host_info), - "idle": len(scheduler.idle), - "memory": memory, - "processing": processing, - "released": released, - "saturated": len(scheduler.saturated), - "tasks": len(scheduler.tasks), - "unrunnable": len(scheduler.unrunnable), - "waiting": waiting, - "waiting_data": waiting_data, - "workers": len(scheduler.workers), - } - self.write(response) - - -class IdentityJSON(RequestHandler): - def get(self): - self.write(self.server.identity()) - - -class IndexJSON(RequestHandler): - def get(self): - with log_errors(): - r = [url for url, _ in routes if url.endswith(".json")] - self.render( - "json-index.html", routes=r, title="Index of JSON routes", **self.extra - ) - - -class IndividualPlots(RequestHandler): - def get(self): - bokeh_server = self.server.services["dashboard"] - result = { - uri.strip("/").replace("-", " ").title(): uri - for uri in bokeh_server.apps - if uri.lstrip("/").startswith("individual-") and not uri.endswith(".json") - } - self.write(result) - - -class _PrometheusCollector(object): - def __init__(self, server): - self.server = server - - def collect(self): - from prometheus_client.core import GaugeMetricFamily - - yield GaugeMetricFamily( - "dask_scheduler_workers", - "Number of workers connected.", - value=len(self.server.workers), - ) - yield GaugeMetricFamily( - "dask_scheduler_clients", - "Number of clients connected.", - value=len(self.server.clients), - ) - yield GaugeMetricFamily( - "dask_scheduler_received_tasks", - "Number of tasks received at scheduler", - value=len(self.server.tasks), - ) - yield GaugeMetricFamily( - "dask_scheduler_unrunnable_tasks", - "Number of unrunnable tasks at scheduler", - value=len(self.server.unrunnable), - ) - - -class PrometheusHandler(RequestHandler): - _initialized = False - - def __init__(self, *args, **kwargs): - import prometheus_client - - super(PrometheusHandler, self).__init__(*args, **kwargs) - - if PrometheusHandler._initialized: - return - - prometheus_client.REGISTRY.register(_PrometheusCollector(self.server)) - - PrometheusHandler._initialized = True - - def get(self): - import prometheus_client - - self.write(prometheus_client.generate_latest()) - self.set_header("Content-Type", "text/plain; version=0.0.4") - - -class HealthHandler(RequestHandler): - def get(self): - self.write("ok") - self.set_header("Content-Type", "text/plain") - - -routes = [ - (r"info", redirect("info/main/workers.html")), - (r"info/main/workers.html", Workers), - (r"info/worker/(.*).html", Worker), - (r"info/task/(.*).html", Task), - (r"info/main/logs.html", Logs), - (r"info/call-stacks/(.*).html", WorkerCallStacks), - (r"info/call-stack/(.*).html", TaskCallStack), - (r"info/logs/(.*).html", WorkerLogs), - (r"json/counts.json", CountsJSON), - (r"json/identity.json", IdentityJSON), - (r"json/index.html", IndexJSON), - (r"individual-plots.json", IndividualPlots), - (r"metrics", PrometheusHandler), - (r"health", HealthHandler), - (r"proxy/(\d+)/(.*?)/(.*)", GlobalProxyHandler), -] - - -def get_handlers(server): - return [(url, cls, {"server": server}) for url, cls in routes] diff --git a/distributed/dashboard/tests/test_components.py b/distributed/dashboard/tests/test_components.py index 5e96d788e45..195c947bdec 100644 --- a/distributed/dashboard/tests/test_components.py +++ b/distributed/dashboard/tests/test_components.py @@ -6,9 +6,7 @@ from tornado import gen from distributed.utils_test import slowinc, gen_cluster - -from distributed.dashboard.components import ( - TaskStream, +from distributed.dashboard.components.shared import ( MemoryUsage, Processing, ProfilePlot, @@ -16,7 +14,7 @@ ) -@pytest.mark.parametrize("Component", [TaskStream, MemoryUsage, Processing]) +@pytest.mark.parametrize("Component", [MemoryUsage, Processing]) def test_basic(Component): c = Component() assert isinstance(c.source, ColumnDataSource) diff --git a/distributed/dashboard/tests/test_scheduler_bokeh.py b/distributed/dashboard/tests/test_scheduler_bokeh.py index 1e48a3addec..875f1064503 100644 --- a/distributed/dashboard/tests/test_scheduler_bokeh.py +++ b/distributed/dashboard/tests/test_scheduler_bokeh.py @@ -16,10 +16,10 @@ from distributed.client import wait from distributed.metrics import time from distributed.utils_test import gen_cluster, inc, dec, slowinc, div, get_cert -from distributed.dashboard.worker import Counters, BokehWorker -from distributed.dashboard.scheduler import ( - applications, - BokehScheduler, +from distributed.dashboard.worker import BokehWorker +from distributed.dashboard.components.worker import Counters +from distributed.dashboard.scheduler import applications, BokehScheduler +from distributed.dashboard.components.scheduler import ( SystemMonitor, Occupancy, StealingTimeSeries, @@ -32,7 +32,7 @@ ProcessingHistogram, NBytesHistogram, WorkerTable, - GraphPlot, + TaskGraph, ProfileServer, ) @@ -70,7 +70,7 @@ def test_simple(c, s, a, b): @gen_cluster(client=True, worker_kwargs=dict(services={"dashboard": BokehWorker})) def test_basic(c, s, a, b): - for component in [SystemMonitor, Occupancy, StealingTimeSeries]: + for component in [TaskStream, SystemMonitor, Occupancy, StealingTimeSeries]: ss = component(s) ss.update() @@ -443,8 +443,8 @@ def metric(worker): @gen_cluster(client=True) -def test_GraphPlot(c, s, a, b): - gp = GraphPlot(s) +def test_TaskGraph(c, s, a, b): + gp = TaskGraph(s) futures = c.map(inc, range(5)) total = c.submit(sum, futures) yield total @@ -483,8 +483,8 @@ def test_GraphPlot(c, s, a, b): @gen_cluster(client=True) -def test_GraphPlot_clear(c, s, a, b): - gp = GraphPlot(s) +def test_TaskGraph_clear(c, s, a, b): + gp = TaskGraph(s) futures = c.map(inc, range(5)) total = c.submit(sum, futures) yield total @@ -507,9 +507,9 @@ def test_GraphPlot_clear(c, s, a, b): @gen_cluster(client=True, timeout=30) -def test_GraphPlot_complex(c, s, a, b): +def test_TaskGraph_complex(c, s, a, b): da = pytest.importorskip("dask.array") - gp = GraphPlot(s) + gp = TaskGraph(s) x = da.random.random((2000, 2000), chunks=(1000, 1000)) y = ((x + x.T) - x.mean(axis=0)).persist() yield wait(y) @@ -538,12 +538,12 @@ def test_GraphPlot_complex(c, s, a, b): @gen_cluster(client=True) -def test_GraphPlot_order(c, s, a, b): +def test_TaskGraph_order(c, s, a, b): x = c.submit(inc, 1) y = c.submit(div, 1, 0) yield wait(y) - gp = GraphPlot(s) + gp = TaskGraph(s) gp.update() assert gp.node_source.data["state"][gp.layout.index[y.key]] == "erred" diff --git a/distributed/dashboard/tests/test_worker_bokeh.py b/distributed/dashboard/tests/test_worker_bokeh.py index c490c825ab4..b33fc3ba185 100644 --- a/distributed/dashboard/tests/test_worker_bokeh.py +++ b/distributed/dashboard/tests/test_worker_bokeh.py @@ -14,8 +14,8 @@ from distributed.metrics import time from distributed.utils_test import gen_cluster, inc, dec from distributed.dashboard.scheduler import BokehScheduler -from distributed.dashboard.worker import ( - BokehWorker, +from distributed.dashboard.worker import BokehWorker +from distributed.dashboard.components.worker import ( StateTable, CrossFilter, CommunicatingStream, diff --git a/distributed/dashboard/utils.py b/distributed/dashboard/utils.py index 285f6a5772a..b47cb75d6b0 100644 --- a/distributed/dashboard/utils.py +++ b/distributed/dashboard/utils.py @@ -1,14 +1,30 @@ from distutils.version import LooseVersion import os +from numbers import Number import bokeh +from bokeh.io import curdoc from tornado import web from toolz import partition +try: + import numpy as np +except ImportError: + np = False + + +try: + from cytoolz.curried import first +except ImportError: + from toolz.curried import first + BOKEH_VERSION = LooseVersion(bokeh.__version__) dirname = os.path.dirname(__file__) +PROFILING = False + + if BOKEH_VERSION >= "1.0.0": # This decorator is only available in bokeh >= 1.0.0, and doesn't work for # callbacks in Python 2, since the signature introspection won't line up. @@ -48,3 +64,33 @@ def get(self): self.redirect(path) return Redirect + + +@without_property_validation +def update(source, data): + """ Update source with data + + This checks a few things first + + 1. If the data is the same, then don't update + 2. If numpy is available and the data is numeric, then convert to numpy + arrays + 3. If profiling then perform the update in another callback + """ + if not np or not any(isinstance(v, np.ndarray) for v in source.data.values()): + if source.data == data: + return + if np and len(data[first(data)]) > 10: + d = {} + for k, v in data.items(): + if type(v) is not np.ndarray and isinstance(v[0], Number): + d[k] = np.array(v) + else: + d[k] = v + else: + d = data + + if PROFILING: + curdoc().add_next_tick_callback(lambda: source.data.update(d)) + else: + source.data.update(d) diff --git a/distributed/dashboard/worker.py b/distributed/dashboard/worker.py index 402d3fd0a70..4d635388512 100644 --- a/distributed/dashboard/worker.py +++ b/distributed/dashboard/worker.py @@ -1,40 +1,20 @@ from functools import partial import logging -import math import os -from bokeh.layouts import row, column, widgetbox -from bokeh.models import ( - ColumnDataSource, - DataRange1d, - HoverTool, - BoxZoomTool, - ResetTool, - PanTool, - WheelZoomTool, - NumeralTickFormatter, - Select, -) - -from bokeh.models.widgets import DataTable, TableColumn -from bokeh.plotting import figure -from bokeh.palettes import RdBu from bokeh.themes import Theme -from dask.utils import format_bytes -from toolz import merge, partition_all - -from .components import ( - DashboardComponent, - ProfileTimePlot, - ProfileServer, - add_periodic_callback, +from toolz import merge + +from .components.worker import ( + status_doc, + crossfilter_doc, + systemmonitor_doc, + counters_doc, + profile_doc, + profile_server_doc, ) from .core import BokehServer -from .utils import transpose, without_property_validation -from ..compatibility import WINDOWS -from ..diagnostics.progress_stream import color_of -from ..metrics import time -from ..utils import log_errors, key_split, format_time +from .utils import RequestHandler, redirect logger = logging.getLogger(__name__) @@ -53,708 +33,110 @@ template_variables = {"pages": ["status", "system", "profile", "crossfilter"]} -class StateTable(DashboardComponent): - """ Currently running tasks """ - - def __init__(self, worker): - self.worker = worker - - names = ["Stored", "Executing", "Ready", "Waiting", "Connections", "Serving"] - self.source = ColumnDataSource({name: [] for name in names}) - - columns = {name: TableColumn(field=name, title=name) for name in names} - - table = DataTable( - source=self.source, columns=[columns[n] for n in names], height=70 - ) - self.root = table - - @without_property_validation - def update(self): - with log_errors(): - w = self.worker - d = { - "Stored": [len(w.data)], - "Executing": ["%d / %d" % (len(w.executing), w.nthreads)], - "Ready": [len(w.ready)], - "Waiting": [len(w.waiting_for_data)], - "Connections": [len(w.in_flight_workers)], - "Serving": [len(w._comms)], - } - self.source.data.update(d) - - -class CommunicatingStream(DashboardComponent): - def __init__(self, worker, height=300, **kwargs): - with log_errors(): - self.worker = worker - names = [ - "start", - "stop", - "middle", - "duration", - "who", - "y", - "hover", - "alpha", - "bandwidth", - "total", - ] - - self.incoming = ColumnDataSource({name: [] for name in names}) - self.outgoing = ColumnDataSource({name: [] for name in names}) - - x_range = DataRange1d(range_padding=0) - y_range = DataRange1d(range_padding=0) - - fig = figure( - title="Peer Communications", - x_axis_type="datetime", - x_range=x_range, - y_range=y_range, - height=height, - tools="", - **kwargs - ) - - fig.rect( - source=self.incoming, - x="middle", - y="y", - width="duration", - height=0.9, - color="red", - alpha="alpha", - ) - fig.rect( - source=self.outgoing, - x="middle", - y="y", - width="duration", - height=0.9, - color="blue", - alpha="alpha", +class _PrometheusCollector(object): + def __init__(self, server): + self.worker = server + self.logger = logging.getLogger("distributed.dask_worker") + self.crick_available = True + try: + import crick # noqa: F401 + except ImportError: + self.crick_available = False + self.logger.info( + "Not all prometheus metrics available are exported. Digest-based metrics require crick to be installed" ) - hover = HoverTool(point_policy="follow_mouse", tooltips="""@hover""") - fig.add_tools( - hover, - ResetTool(), - PanTool(dimensions="width"), - WheelZoomTool(dimensions="width"), - ) + def collect(self): + from prometheus_client.core import GaugeMetricFamily - self.root = fig - - self.last_incoming = 0 - self.last_outgoing = 0 - self.who = dict() - - @without_property_validation - def update(self): - with log_errors(): - outgoing = self.worker.outgoing_transfer_log - n = self.worker.outgoing_count - self.last_outgoing - outgoing = [outgoing[-i].copy() for i in range(1, n + 1)] - self.last_outgoing = self.worker.outgoing_count - - incoming = self.worker.incoming_transfer_log - n = self.worker.incoming_count - self.last_incoming - incoming = [incoming[-i].copy() for i in range(1, n + 1)] - self.last_incoming = self.worker.incoming_count - - for [msgs, source] in [ - [incoming, self.incoming], - [outgoing, self.outgoing], - ]: - - for msg in msgs: - if "compressed" in msg: - del msg["compressed"] - del msg["keys"] - - bandwidth = msg["total"] / (msg["duration"] or 0.5) - bw = max(min(bandwidth / 500e6, 1), 0.3) - msg["alpha"] = bw - try: - msg["y"] = self.who[msg["who"]] - except KeyError: - self.who[msg["who"]] = len(self.who) - msg["y"] = self.who[msg["who"]] - - msg["hover"] = "%s / %s = %s/s" % ( - format_bytes(msg["total"]), - format_time(msg["duration"]), - format_bytes(msg["total"] / msg["duration"]), - ) - - for k in ["middle", "duration", "start", "stop"]: - msg[k] = msg[k] * 1000 - - if msgs: - msgs = transpose(msgs) - if ( - len(source.data["stop"]) - and min(msgs["start"]) > source.data["stop"][-1] + 10000 - ): - source.data.update(msgs) - else: - source.stream(msgs, rollover=10000) - - -class CommunicatingTimeSeries(DashboardComponent): - def __init__(self, worker, **kwargs): - self.worker = worker - self.source = ColumnDataSource({"x": [], "in": [], "out": []}) - - x_range = DataRange1d(follow="end", follow_interval=20000, range_padding=0) - - fig = figure( - title="Communication History", - x_axis_type="datetime", - y_range=[-0.1, worker.total_out_connections + 0.5], - height=150, - tools="", - x_range=x_range, - **kwargs + tasks = GaugeMetricFamily( + "dask_worker_tasks", "Number of tasks at worker.", labels=["state"] ) - fig.line(source=self.source, x="x", y="in", color="red") - fig.line(source=self.source, x="x", y="out", color="blue") - - fig.add_tools( - ResetTool(), PanTool(dimensions="width"), WheelZoomTool(dimensions="width") + tasks.add_metric(["stored"], len(self.worker.data)) + tasks.add_metric(["ready"], len(self.worker.ready)) + tasks.add_metric(["waiting"], len(self.worker.waiting_for_data)) + tasks.add_metric(["serving"], len(self.worker._comms)) + yield tasks + + yield GaugeMetricFamily( + "dask_worker_connections", + "Number of task connections to other workers.", + value=len(self.worker.in_flight_workers), ) - self.root = fig - - @without_property_validation - def update(self): - with log_errors(): - self.source.stream( - { - "x": [time() * 1000], - "out": [len(self.worker._comms)], - "in": [len(self.worker.in_flight_workers)], - }, - 10000, - ) - - -class ExecutingTimeSeries(DashboardComponent): - def __init__(self, worker, **kwargs): - self.worker = worker - self.source = ColumnDataSource({"x": [], "y": []}) - - x_range = DataRange1d(follow="end", follow_interval=20000, range_padding=0) - - fig = figure( - title="Executing History", - x_axis_type="datetime", - y_range=[-0.1, worker.nthreads + 0.1], - height=150, - tools="", - x_range=x_range, - **kwargs + yield GaugeMetricFamily( + "dask_worker_threads", + "Number of worker threads.", + value=self.worker.nthreads, ) - fig.line(source=self.source, x="x", y="y") - fig.add_tools( - ResetTool(), PanTool(dimensions="width"), WheelZoomTool(dimensions="width") + yield GaugeMetricFamily( + "dask_worker_latency_seconds", + "Latency of worker connection.", + value=self.worker.latency, ) - self.root = fig - - @without_property_validation - def update(self): - with log_errors(): - self.source.stream( - {"x": [time() * 1000], "y": [len(self.worker.executing)]}, 1000 + # all metrics using digests require crick to be installed + # the following metrics will export NaN, if the corresponding digests are None + if self.crick_available: + yield GaugeMetricFamily( + "dask_worker_tick_duration_median_seconds", + "Median tick duration at worker.", + value=self.worker.digests["tick-duration"].components[1].quantile(50), ) - -class CrossFilter(DashboardComponent): - def __init__(self, worker, **kwargs): - with log_errors(): - self.worker = worker - - quantities = ["nbytes", "duration", "bandwidth", "count", "start", "stop"] - colors = ["inout-color", "type-color", "key-color"] - - # self.source = ColumnDataSource({name: [] for name in names}) - self.source = ColumnDataSource( - { - "nbytes": [1, 2], - "duration": [0.01, 0.02], - "bandwidth": [0.01, 0.02], - "count": [1, 2], - "type": ["int", "str"], - "inout-color": ["blue", "red"], - "type-color": ["blue", "red"], - "key": ["add", "inc"], - "start": [1, 2], - "stop": [1, 2], - } + yield GaugeMetricFamily( + "dask_worker_task_duration_median_seconds", + "Median task runtime at worker.", + value=self.worker.digests["task-duration"].components[1].quantile(50), ) - self.x = Select(title="X-Axis", value="nbytes", options=quantities) - self.x.on_change("value", self.update_figure) - - self.y = Select(title="Y-Axis", value="bandwidth", options=quantities) - self.y.on_change("value", self.update_figure) - - self.size = Select( - title="Size", value="None", options=["None"] + quantities + yield GaugeMetricFamily( + "dask_worker_transfer_bandwidth_median_bytes", + "Bandwidth for transfer at worker in Bytes.", + value=self.worker.digests["transfer-bandwidth"] + .components[1] + .quantile(50), ) - self.size.on_change("value", self.update_figure) - self.color = Select( - title="Color", value="inout-color", options=["black"] + colors - ) - self.color.on_change("value", self.update_figure) - if "sizing_mode" in kwargs: - kw = {"sizing_mode": kwargs["sizing_mode"]} - else: - kw = {} +class PrometheusHandler(RequestHandler): + _initialized = False - self.control = widgetbox( - [self.x, self.y, self.size, self.color], width=200, **kw - ) + def __init__(self, *args, **kwargs): + import prometheus_client - self.last_outgoing = 0 - self.last_incoming = 0 - self.kwargs = kwargs - - self.layout = row(self.control, self.create_figure(**self.kwargs), **kw) - - self.root = self.layout - - @without_property_validation - def update(self): - with log_errors(): - outgoing = self.worker.outgoing_transfer_log - n = self.worker.outgoing_count - self.last_outgoing - n = min(n, 1000) - outgoing = [outgoing[-i].copy() for i in range(1, n)] - self.last_outgoing = self.worker.outgoing_count - - incoming = self.worker.incoming_transfer_log - n = self.worker.incoming_count - self.last_incoming - n = min(n, 1000) - incoming = [incoming[-i].copy() for i in range(1, n)] - self.last_incoming = self.worker.incoming_count - - out = [] - - for msg in incoming: - if msg["keys"]: - d = self.process_msg(msg) - d["inout-color"] = "red" - out.append(d) - - for msg in outgoing: - if msg["keys"]: - d = self.process_msg(msg) - d["inout-color"] = "blue" - out.append(d) - - if out: - out = transpose(out) - if ( - len(self.source.data["stop"]) - and min(out["start"]) > self.source.data["stop"][-1] + 10 - ): - self.source.data.update(out) - else: - self.source.stream(out, rollover=1000) - - def create_figure(self, **kwargs): - with log_errors(): - fig = figure(title="", tools="", **kwargs) - - size = self.size.value - if size == "None": - size = 1 - - fig.circle( - source=self.source, - x=self.x.value, - y=self.y.value, - color=self.color.value, - size=10, - alpha=0.5, - hover_alpha=1, - ) - fig.xaxis.axis_label = self.x.value - fig.yaxis.axis_label = self.y.value - - fig.add_tools( - # self.hover, - ResetTool(), - PanTool(), - WheelZoomTool(), - BoxZoomTool(), - ) - return fig + super(PrometheusHandler, self).__init__(*args, **kwargs) - @without_property_validation - def update_figure(self, attr, old, new): - with log_errors(): - fig = self.create_figure(**self.kwargs) - self.layout.children[1] = fig + if PrometheusHandler._initialized: + return - def process_msg(self, msg): - try: + prometheus_client.REGISTRY.register(_PrometheusCollector(self.server)) - def func(k): - return msg["keys"].get(k, 0) - - status_key = max(msg["keys"], key=func) - typ = self.worker.types.get(status_key, object).__name__ - keyname = key_split(status_key) - d = { - "nbytes": msg["total"], - "duration": msg["duration"], - "bandwidth": msg["bandwidth"], - "count": len(msg["keys"]), - "type": typ, - "type-color": color_of(typ), - "key": keyname, - "key-color": color_of(keyname), - "start": msg["start"], - "stop": msg["stop"], - } - return d - except Exception as e: - logger.exception(e) - raise - - -class SystemMonitor(DashboardComponent): - def __init__(self, worker, height=150, **kwargs): - self.worker = worker + PrometheusHandler._initialized = True - names = worker.monitor.quantities - self.last = 0 - self.source = ColumnDataSource({name: [] for name in names}) - self.source.data.update(self.get_data()) + def get(self): + import prometheus_client - x_range = DataRange1d(follow="end", follow_interval=20000, range_padding=0) + self.write(prometheus_client.generate_latest()) + self.set_header("Content-Type", "text/plain; version=0.0.4") - tools = "reset,xpan,xwheel_zoom" - self.cpu = figure( - title="CPU", - x_axis_type="datetime", - height=height, - tools=tools, - x_range=x_range, - **kwargs - ) - self.cpu.line(source=self.source, x="time", y="cpu") - self.cpu.yaxis.axis_label = "Percentage" - self.mem = figure( - title="Memory", - x_axis_type="datetime", - height=height, - tools=tools, - x_range=x_range, - **kwargs - ) - self.mem.line(source=self.source, x="time", y="memory") - self.mem.yaxis.axis_label = "Bytes" - self.bandwidth = figure( - title="Bandwidth", - x_axis_type="datetime", - height=height, - x_range=x_range, - tools=tools, - **kwargs - ) - self.bandwidth.line(source=self.source, x="time", y="read_bytes", color="red") - self.bandwidth.line(source=self.source, x="time", y="write_bytes", color="blue") - self.bandwidth.yaxis.axis_label = "Bytes / second" - - # self.cpu.yaxis[0].formatter = NumeralTickFormatter(format='0%') - self.bandwidth.yaxis[0].formatter = NumeralTickFormatter(format="0.0b") - self.mem.yaxis[0].formatter = NumeralTickFormatter(format="0.0b") - - plots = [self.cpu, self.mem, self.bandwidth] - - if not WINDOWS: - self.num_fds = figure( - title="Number of File Descriptors", - x_axis_type="datetime", - height=height, - x_range=x_range, - tools=tools, - **kwargs - ) - - self.num_fds.line(source=self.source, x="time", y="num_fds") - plots.append(self.num_fds) - - if "sizing_mode" in kwargs: - kw = {"sizing_mode": kwargs["sizing_mode"]} - else: - kw = {} - - if not WINDOWS: - self.num_fds.y_range.start = 0 - self.mem.y_range.start = 0 - self.cpu.y_range.start = 0 - self.bandwidth.y_range.start = 0 - - self.root = column(*plots, **kw) - self.worker.monitor.update() - - def get_data(self): - d = self.worker.monitor.range_query(start=self.last) - d["time"] = [x * 1000 for x in d["time"]] - self.last = self.worker.monitor.count - return d - - @without_property_validation - def update(self): - with log_errors(): - self.source.stream(self.get_data(), 1000) - - -class Counters(DashboardComponent): - def __init__(self, server, sizing_mode="stretch_both", **kwargs): - self.server = server - self.counter_figures = {} - self.counter_sources = {} - self.digest_figures = {} - self.digest_sources = {} - self.sizing_mode = sizing_mode - - if self.server.digests: - for name in self.server.digests: - self.add_digest_figure(name) - for name in self.server.counters: - self.add_counter_figure(name) - - figures = merge(self.digest_figures, self.counter_figures) - figures = [figures[k] for k in sorted(figures)] - - if len(figures) <= 5: - self.root = column(figures, sizing_mode=sizing_mode) - else: - self.root = column( - *[ - row(*pair, sizing_mode=sizing_mode) - for pair in partition_all(2, figures) - ], - sizing_mode=sizing_mode - ) - - def add_digest_figure(self, name): - with log_errors(): - n = len(self.server.digests[name].intervals) - sources = {i: ColumnDataSource({"x": [], "y": []}) for i in range(n)} - - kwargs = {} - if name.endswith("duration"): - kwargs["x_axis_type"] = "datetime" - - fig = figure( - title=name, tools="", height=150, sizing_mode=self.sizing_mode, **kwargs - ) - fig.yaxis.visible = False - fig.ygrid.visible = False - if name.endswith("bandwidth") or name.endswith("bytes"): - fig.xaxis[0].formatter = NumeralTickFormatter(format="0.0b") - - for i in range(n): - alpha = 0.3 + 0.3 * (n - i) / n - fig.line( - source=sources[i], - x="x", - y="y", - alpha=alpha, - color=RdBu[max(n, 3)][-i], - ) - - fig.xaxis.major_label_orientation = math.pi / 12 - fig.toolbar.logo = None - self.digest_sources[name] = sources - self.digest_figures[name] = fig - return fig - - def add_counter_figure(self, name): - with log_errors(): - n = len(self.server.counters[name].intervals) - sources = { - i: ColumnDataSource({"x": [], "y": [], "y-center": [], "counts": []}) - for i in range(n) - } - - fig = figure( - title=name, - tools="", - height=150, - sizing_mode=self.sizing_mode, - x_range=sorted(map(str, self.server.counters[name].components[0])), - ) - fig.ygrid.visible = False - - for i in range(n): - width = 0.5 + 0.4 * i / n - fig.rect( - source=sources[i], - x="x", - y="y-center", - width=width, - height="y", - alpha=0.3, - color=RdBu[max(n, 3)][-i], - ) - hover = HoverTool( - point_policy="follow_mouse", tooltips="""@x : @counts""" - ) - fig.add_tools(hover) - fig.xaxis.major_label_orientation = math.pi / 12 - - fig.toolbar.logo = None - - self.counter_sources[name] = sources - self.counter_figures[name] = fig - return fig - - @without_property_validation - def update(self): - with log_errors(): - for name, fig in self.digest_figures.items(): - digest = self.server.digests[name] - d = {} - for i, d in enumerate(digest.components): - if d.size(): - ys, xs = d.histogram(100) - xs = xs[1:] - if name.endswith("duration"): - xs *= 1000 - self.digest_sources[name][i].data.update({"x": xs, "y": ys}) - fig.title.text = "%s: %d" % (name, digest.size()) - - for name, fig in self.counter_figures.items(): - counter = self.server.counters[name] - d = {} - for i, d in enumerate(counter.components): - if d: - xs = sorted(d) - factor = counter.intervals[0] / counter.intervals[i] - counts = [d[x] for x in xs] - ys = [factor * c for c in counts] - y_centers = [y / 2 for y in ys] - xs = list(map(str, xs)) - d = {"x": xs, "y": ys, "y-center": y_centers, "counts": counts} - self.counter_sources[name][i].data.update(d) - fig.title.text = "%s: %d" % (name, counter.size()) - fig.x_range.factors = list(map(str, xs)) - - -from bokeh.application.handlers.function import FunctionHandler -from bokeh.application import Application - - -def status_doc(worker, extra, doc): - with log_errors(): - statetable = StateTable(worker) - executing_ts = ExecutingTimeSeries(worker, sizing_mode="scale_width") - communicating_ts = CommunicatingTimeSeries(worker, sizing_mode="scale_width") - communicating_stream = CommunicatingStream(worker, sizing_mode="scale_width") - - xr = executing_ts.root.x_range - communicating_ts.root.x_range = xr - communicating_stream.root.x_range = xr - - doc.title = "Dask Worker Internal Monitor" - add_periodic_callback(doc, statetable, 200) - add_periodic_callback(doc, executing_ts, 200) - add_periodic_callback(doc, communicating_ts, 200) - add_periodic_callback(doc, communicating_stream, 200) - doc.add_root( - column( - statetable.root, - executing_ts.root, - communicating_ts.root, - communicating_stream.root, - sizing_mode="scale_width", - ) - ) - doc.template = env.get_template("simple.html") - doc.template_variables["active_page"] = "status" - doc.template_variables.update(extra) - doc.theme = BOKEH_THEME - - -def crossfilter_doc(worker, extra, doc): - with log_errors(): - statetable = StateTable(worker) - crossfilter = CrossFilter(worker) - - doc.title = "Dask Worker Cross-filter" - add_periodic_callback(doc, statetable, 500) - add_periodic_callback(doc, crossfilter, 500) - - doc.add_root(column(statetable.root, crossfilter.root)) - doc.template = env.get_template("simple.html") - doc.template_variables["active_page"] = "crossfilter" - doc.template_variables.update(extra) - doc.theme = BOKEH_THEME +class HealthHandler(RequestHandler): + def get(self): + self.write("ok") + self.set_header("Content-Type", "text/plain") -def systemmonitor_doc(worker, extra, doc): - with log_errors(): - sysmon = SystemMonitor(worker, sizing_mode="scale_width") - doc.title = "Dask Worker Monitor" - add_periodic_callback(doc, sysmon, 500) +routes = [ + (r"metrics", PrometheusHandler), + (r"health", HealthHandler), + (r"main", redirect("/status")), +] - doc.add_root(sysmon.root) - doc.template = env.get_template("simple.html") - doc.template_variables["active_page"] = "system" - doc.template_variables.update(extra) - doc.theme = BOKEH_THEME - -def counters_doc(server, extra, doc): - with log_errors(): - doc.title = "Dask Worker Counters" - counter = Counters(server, sizing_mode="stretch_both") - add_periodic_callback(doc, counter, 500) - - doc.add_root(counter.root) - doc.template = env.get_template("simple.html") - doc.template_variables["active_page"] = "counters" - doc.template_variables.update(extra) - doc.theme = BOKEH_THEME - - -def profile_doc(server, extra, doc): - with log_errors(): - doc.title = "Dask Worker Profile" - profile = ProfileTimePlot(server, sizing_mode="scale_width", doc=doc) - profile.trigger_update() - - doc.add_root(profile.root) - doc.template = env.get_template("simple.html") - doc.template_variables["active_page"] = "profile" - doc.template_variables.update(extra) - doc.theme = BOKEH_THEME - - -def profile_server_doc(server, extra, doc): - with log_errors(): - doc.title = "Dask: Profile of Event Loop" - prof = ProfileServer(server, sizing_mode="scale_width", doc=doc) - doc.add_root(prof.root) - doc.template = env.get_template("simple.html") - # doc.template_variables['active_page'] = '' - doc.template_variables.update(extra) - doc.theme = BOKEH_THEME - - prof.trigger_update() +def get_handlers(server): + return [(url, cls, {"server": server}) for url, cls in routes] class BokehWorker(BokehServer): @@ -768,31 +150,15 @@ def __init__(self, worker, io_loop=None, prefix="", **kwargs): prefix = "/" + prefix self.prefix = prefix - extra = {"prefix": prefix} - - extra.update(template_variables) - - status = Application(FunctionHandler(partial(status_doc, worker, extra))) - crossfilter = Application( - FunctionHandler(partial(crossfilter_doc, worker, extra)) - ) - systemmonitor = Application( - FunctionHandler(partial(systemmonitor_doc, worker, extra)) - ) - counters = Application(FunctionHandler(partial(counters_doc, worker, extra))) - profile = Application(FunctionHandler(partial(profile_doc, worker, extra))) - profile_server = Application( - FunctionHandler(partial(profile_server_doc, worker, extra)) - ) - self.apps = { - "/status": status, - "/counters": counters, - "/crossfilter": crossfilter, - "/system": systemmonitor, - "/profile": profile, - "/profile-server": profile_server, + "/status": status_doc, + "/counters": counters_doc, + "/crossfilter": crossfilter_doc, + "/system": systemmonitor_doc, + "/profile": profile_doc, + "/profile-server": profile_server_doc, } + self.apps = {k: partial(v, worker, self.extra) for k, v in self.apps.items()} self.loop = io_loop or worker.loop self.server = None @@ -808,8 +174,6 @@ def my_server(self): def listen(self, *args, **kwargs): super(BokehWorker, self).listen(*args, **kwargs) - from .worker_html import routes - handlers = [ ( self.prefix + "/" + url, diff --git a/distributed/dashboard/worker_html.py b/distributed/dashboard/worker_html.py deleted file mode 100644 index 27e1f9fe9d2..00000000000 --- a/distributed/dashboard/worker_html.py +++ /dev/null @@ -1,108 +0,0 @@ -import logging -from .utils import RequestHandler, redirect - - -class _PrometheusCollector(object): - def __init__(self, server): - self.worker = server - self.logger = logging.getLogger("distributed.dask_worker") - self.crick_available = True - try: - import crick # noqa: F401 - except ImportError: - self.crick_available = False - self.logger.info( - "Not all prometheus metrics available are exported. Digest-based metrics require crick to be installed" - ) - - def collect(self): - from prometheus_client.core import GaugeMetricFamily - - tasks = GaugeMetricFamily( - "dask_worker_tasks", "Number of tasks at worker.", labels=["state"] - ) - tasks.add_metric(["stored"], len(self.worker.data)) - tasks.add_metric(["ready"], len(self.worker.ready)) - tasks.add_metric(["waiting"], len(self.worker.waiting_for_data)) - tasks.add_metric(["serving"], len(self.worker._comms)) - yield tasks - - yield GaugeMetricFamily( - "dask_worker_connections", - "Number of task connections to other workers.", - value=len(self.worker.in_flight_workers), - ) - - yield GaugeMetricFamily( - "dask_worker_threads", - "Number of worker threads.", - value=self.worker.nthreads, - ) - - yield GaugeMetricFamily( - "dask_worker_latency_seconds", - "Latency of worker connection.", - value=self.worker.latency, - ) - - # all metrics using digests require crick to be installed - # the following metrics will export NaN, if the corresponding digests are None - if self.crick_available: - yield GaugeMetricFamily( - "dask_worker_tick_duration_median_seconds", - "Median tick duration at worker.", - value=self.worker.digests["tick-duration"].components[1].quantile(50), - ) - - yield GaugeMetricFamily( - "dask_worker_task_duration_median_seconds", - "Median task runtime at worker.", - value=self.worker.digests["task-duration"].components[1].quantile(50), - ) - - yield GaugeMetricFamily( - "dask_worker_transfer_bandwidth_median_bytes", - "Bandwidth for transfer at worker in Bytes.", - value=self.worker.digests["transfer-bandwidth"] - .components[1] - .quantile(50), - ) - - -class PrometheusHandler(RequestHandler): - _initialized = False - - def __init__(self, *args, **kwargs): - import prometheus_client - - super(PrometheusHandler, self).__init__(*args, **kwargs) - - if PrometheusHandler._initialized: - return - - prometheus_client.REGISTRY.register(_PrometheusCollector(self.server)) - - PrometheusHandler._initialized = True - - def get(self): - import prometheus_client - - self.write(prometheus_client.generate_latest()) - self.set_header("Content-Type", "text/plain; version=0.0.4") - - -class HealthHandler(RequestHandler): - def get(self): - self.write("ok") - self.set_header("Content-Type", "text/plain") - - -routes = [ - (r"metrics", PrometheusHandler), - (r"health", HealthHandler), - (r"main", redirect("/status")), -] - - -def get_handlers(server): - return [(url, cls, {"server": server}) for url, cls in routes] diff --git a/distributed/diagnostics/graph_layout.py b/distributed/diagnostics/graph_layout.py index c81c6edcafe..a348d2e04ee 100644 --- a/distributed/diagnostics/graph_layout.py +++ b/distributed/diagnostics/graph_layout.py @@ -7,7 +7,7 @@ class GraphLayout(SchedulerPlugin): This assigns (x, y) locations to all tasks quickly and dynamically as new tasks are added. This scales to a few thousand nodes. - It is commonly used with distributed/bokeh/scheduler.py::GraphPlot, which + It is commonly used with distributed/bokeh/scheduler.py::TaskGraph, which is rendered at /graph on the diagnostic dashboard. """ @@ -113,7 +113,7 @@ def transition(self, key, start, finish, *args, **kwargs): def reset_index(self): """ Reset the index and refill new and new_edges - From time to time GraphPlot wants to remove invisible nodes and reset + From time to time TaskGraph wants to remove invisible nodes and reset all of its indices. This helps. """ self.new = [] From 159e6c2eba15b2a6bba9cabccbc994bcb81a8be7 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Tue, 15 Oct 2019 21:01:54 -0500 Subject: [PATCH 0506/1550] bump version to 2.6.0 --- docs/source/changelog.rst | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index 5b6288885fc..41496953f66 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -1,6 +1,21 @@ Changelog ========= +2.6.0 - 2019-10-15 +------------------ + +- Refactor dashboard module (:pr:`3138`) `Jacob Tomlinson`_ +- Use ``setuptools.find_packages`` in ``setup.py`` (:pr:`3150`) `Matthew Rocklin`_ +- Move death timeout logic up to ``Node.start`` (:pr:`3115`) `Matthew Rocklin`_ +- Only include metric in ``WorkerTable`` if it is a scalar (:pr:`3140`) `Matthew Rocklin`_ +- Add ``Nanny(config={...})`` keyword (:pr:`3134`) `Matthew Rocklin`_ +- Xfail ``test_worksapce_concurrency`` on Python 3.6 (:pr:`3132`) `Matthew Rocklin`_ +- Extend Worker plugin API with transition method (:pr:`2994`) `matthieubulte`_ +- Raise exception if the user passes in unused keywords to ``Client`` (:pr:`3117`) `Jonathan De Troye`_ +- Move new ``SSHCluster`` to top level (:pr:`3128`) `Matthew Rocklin`_ +- Bump dask dependency (:pr:`3124`) `Jim Crist`_ + + 2.5.2 - 2019-10-04 ------------------ @@ -1315,3 +1330,5 @@ significantly without many new features. .. _`Arpit Solanki`: https://github.com/arpit1997 .. _`Gil Forsyth`: https://github.com/gforsyth .. _`Philipp Rudiger`: https://github.com/philippjfr +.. _`Jonathan De Troye`: https://github.com/detroyejr +.. _`matthieubulte`: https://github.com/matthieubulte From e7a2e6d41e0b719866769713d8f41cb5fcfbf6e8 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Wed, 16 Oct 2019 10:59:43 -0500 Subject: [PATCH 0507/1550] Adds badges to README.rst [skip ci] (#3152) --- README.rst | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/README.rst b/README.rst index b6f0edd604f..3d9c02915fc 100644 --- a/README.rst +++ b/README.rst @@ -1,7 +1,20 @@ Distributed =========== -A library for distributed computation. See documentation_ for more details. +|Build Status| |Doc Status| |Gitter| |Version Status| |NumFOCUS| +A library for distributed computation. See documentation_ for more details. -.. _documentation: https://distributed.readthedocs.io/en/latest +.. _documentation: https://distributed.dask.org +.. |Build Status| image:: https://travis-ci.org/dask/distributed.svg?branch=master + :target: https://travis-ci.org/dask/distributed +.. |Doc Status| image:: https://readthedocs.org/projects/distributed/badge/?version=latest + :target: https://distributed.dask.org + :alt: Documentation Status +.. |Gitter| image:: https://badges.gitter.im/Join%20Chat.svg + :alt: Join the chat at https://gitter.im/dask/dask + :target: https://gitter.im/dask/dask?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge +.. |Version Status| image:: https://img.shields.io/pypi/v/distributed.svg + :target: https://pypi.python.org/pypi/distributed/ +.. |NumFOCUS| image:: https://img.shields.io/badge/powered%20by-NumFOCUS-orange.svg?style=flat&colorA=E1523D&colorB=007D8A + :target: https://www.numfocus.org/ From 8261e93dd98d233c8a5b262f4389b365c42697ee Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Thu, 17 Oct 2019 08:44:07 -0400 Subject: [PATCH 0508/1550] Don't overwrite `self.address` if it is present (#3153) Testing out the new `SSHCluster` I was unable to connect to an IP that turned out to be `None` -- the call to `super().__init__()` was overwriting `self.address` with `None` by default. Quick one-line fix to check if the attribute has already been declared in a child class. --- distributed/deploy/spec.py | 2 +- distributed/tests/test_spec.py | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) create mode 100644 distributed/tests/test_spec.py diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index 11ff3b44322..72cae01e85c 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -29,7 +29,7 @@ class ProcessInterface: """ def __init__(self, scheduler=None, name=None): - self.address = None + self.address = getattr(self, "address", None) self.external_address = None self.lock = asyncio.Lock() self.status = "created" diff --git a/distributed/tests/test_spec.py b/distributed/tests/test_spec.py new file mode 100644 index 00000000000..38719661fc9 --- /dev/null +++ b/distributed/tests/test_spec.py @@ -0,0 +1,18 @@ +from distributed.deploy.spec import ProcessInterface + + +def test_address_default_none(): + p = ProcessInterface() + assert p.address is None + + +def test_child_address_persists(): + class Child(ProcessInterface): + def __init__(self, address=None): + self.address = address + super().__init__() + + c = Child() + assert c.address is None + c = Child("localhost") + assert c.address == "localhost" From 4a7d16cf4ee6f528bbd1fd79dfbf550cdb3792f4 Mon Sep 17 00:00:00 2001 From: darindf Date: Mon, 21 Oct 2019 13:22:54 -0700 Subject: [PATCH 0509/1550] Removed outdated references to debug scheduler and worker bokeh pages. (#3160) --- docs/source/web.rst | 21 +-------------------- 1 file changed, 1 insertion(+), 20 deletions(-) diff --git a/docs/source/web.rst b/docs/source/web.rst index c73838c13dd..cfef4902c35 100644 --- a/docs/source/web.rst +++ b/docs/source/web.rst @@ -19,31 +19,12 @@ a normal web page in real time. This web interface is launched by default wherever the scheduler is launched if the scheduler machine has Bokeh_ installed (``conda install bokeh -c bokeh``). -List of Servers ---------------- - -There are a few sets of diagnostic pages served at different ports: +These diagnostic pages are: * Main Scheduler pages at ``http://scheduler-address:8787``. These pages, particularly the ``/status`` page are the main page that most people associate with Dask. These pages are served from a separate standalone Bokeh server application running in a separate process. -* Debug Scheduler pages at ``http://scheduler-address:8788``. These pages - have more detailed diagnostic information about the scheduler. They are - more often used by developers than by users, but may still be of interest - to the performance-conscious. These pages run from inside the scheduler - process, and so compete for resources with the main scheduler. -* Debug Worker pages for each worker at ``http://worker-address:8789``. - These pages have detailed diagnostic information about the worker. Like the - diagnostic scheduler pages they are of more utility to developers or to - people looking to understand the performance of their underlying cluster. If - port 8789 is unavailable (for example it is in use by another worker) then a - random port is chosen. A list of all ports can be obtained from looking at - the service ports for each worker in the result of calling - ``client.scheduler_info()`` - -The rest of this document will be about the main pages at -``http://scheduler-address:8787``. The available pages are ``http://scheduler-address:8787//`` where ```` is one of From 7cb76f57701db615ecea18976388c9110009e3a1 Mon Sep 17 00:00:00 2001 From: Jacob Tomlinson Date: Tue, 22 Oct 2019 13:19:13 +0100 Subject: [PATCH 0510/1550] Update CONTRIBUTING.md (#3159) Ref dask/community#17 --- CONTRIBUTING.md | 30 ++---------------------------- 1 file changed, 2 insertions(+), 28 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index cd35ad7c572..ab4175a59fe 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,29 +1,3 @@ -For more information, see https://docs.dask.org/en/latest/develop.html#contributing-to-code +Dask is a community maintained project. We welcome contributions in the form of bug reports, documentation, code, design proposals, and more. - -## Style -Distributed conforms with the [flake8] and [black] styles. To make sure your -code conforms with these styles, run - -``` shell -$ pip install black flake8 -$ cd path/to/distributed -$ black distributed -$ flake8 distributed -``` - -[flake8]:http://flake8.pycqa.org/en/latest/ -[black]:https://github.com/python/black - -## Docstrings - -Dask Distributed roughly follows the [numpydoc] standard. More information is -available at https://docs.dask.org/en/latest/develop.html#docstrings. - -[numpydoc]:https://github.com/numpy/numpy/blob/master/doc/HOWTO_DOCUMENT.rst.txt - -## Tests - -Dask employs extensive unit tests to ensure correctness of code both for today -and for the future. Test coverage is expected for all code contributions. More -detail is at https://docs.dask.org/en/latest/develop.html#test +For general information on how to contribute see https://docs.dask.org/en/latest/develop.html. From 70bed6af661610abf3693cce66ee70b614aa50d2 Mon Sep 17 00:00:00 2001 From: darindf Date: Tue, 22 Oct 2019 16:04:22 -0700 Subject: [PATCH 0511/1550] Add Prometheus metric for a worker's executing tasks count (#3163) --- distributed/dashboard/worker.py | 1 + 1 file changed, 1 insertion(+) diff --git a/distributed/dashboard/worker.py b/distributed/dashboard/worker.py index 4d635388512..99b27557694 100644 --- a/distributed/dashboard/worker.py +++ b/distributed/dashboard/worker.py @@ -53,6 +53,7 @@ def collect(self): "dask_worker_tasks", "Number of tasks at worker.", labels=["state"] ) tasks.add_metric(["stored"], len(self.worker.data)) + tasks.add_metric(["executing"], len(self.worker.executing)) tasks.add_metric(["ready"], len(self.worker.ready)) tasks.add_metric(["waiting"], len(self.worker.waiting_for_data)) tasks.add_metric(["serving"], len(self.worker._comms)) From 21e1dcdba13b1c2a4cb78ef3022d6acbe8bdc243 Mon Sep 17 00:00:00 2001 From: darindf Date: Wed, 23 Oct 2019 12:08:06 -0700 Subject: [PATCH 0512/1550] Updated Prometheus documentation (#3165) * Updated Prometheus documentation * Changed formatting * fixup --- docs/source/prometheus.rst | 58 ++++++++++++++++++++++---------------- 1 file changed, 33 insertions(+), 25 deletions(-) diff --git a/docs/source/prometheus.rst b/docs/source/prometheus.rst index 5000858a045..8d6759ad1fa 100644 --- a/docs/source/prometheus.rst +++ b/docs/source/prometheus.rst @@ -8,29 +8,37 @@ scheduler and worker metrics in a prometheus text based format. Metrics are avai Available metrics are as following -+---------------------------------------------+----------------------------------------------+ -| Metric name | Description | -+=========================+===================+==============================================+ -| dask_scheduler_workers | Number of workers connected. | -+---------------------------------------------+----------------------------------------------+ -| dask_scheduler_clients | Number of clients connected. | -+---------------------------------------------+----------------------------------------------+ -| dask_scheduler_received_tasks | Number of tasks received at scheduler | -+---------------------------------------------+----------------------------------------------+ -| dask_scheduler_unrunnable_tasks | Number of unrunnable tasks at scheduler | -+---------------------------------------------+----------------------------------------------+ -| dask_worker_tasks | Number of tasks at worker. | -+---------------------------------------------+----------------------------------------------+ -| dask_worker_connections | Number of task connections to other workers. | -+---------------------------------------------+----------------------------------------------+ -| dask_worker_threads | Number of worker threads. | -+---------------------------------------------+----------------------------------------------+ -| dask_worker_latency_seconds | Latency of worker connection. | -+---------------------------------------------+----------------------------------------------+ -| dask_worker_tick_duration_median_seconds | Median tick duration at worker. | -+---------------------------------------------+----------------------------------------------+ -| dask_worker_task_duration_median_seconds | Median task runtime at worker. | -+---------------------------------------------+----------------------------------------------+ -| dask_worker_transfer_bandwidth_median_bytes | Bandwidth for transfer at worker in Bytes. | -+---------------------------------------------+----------------------------------------------+ ++---------------------------------------------+------------------------------------------------+-----------+--------+ +| Metric name | Description | Scheduler | Worker | ++=========================+===================+================================================+===========+========+ +| python_gc_objects_collected_total | Objects collected during gc. | Yes | Yes | ++---------------------------------------------+------------------------------------------------+-----------+--------+ +| python_gc_objects_uncollectable_total | Uncollectable object found during GC. | Yes | Yes | ++---------------------------------------------+------------------------------------------------+-----------+--------+ +| python_gc_collections_total | Number of times this generation was collected. | Yes | Yes | ++---------------------------------------------+------------------------------------------------+-----------+--------+ +| python_info | Python platform information. | Yes | Yes | ++---------------------------------------------+------------------------------------------------+-----------+--------+ +| dask_scheduler_workers | Number of workers connected. | Yes | | ++---------------------------------------------+------------------------------------------------+-----------+--------+ +| dask_scheduler_clients | Number of clients connected. | Yes | | ++---------------------------------------------+------------------------------------------------+-----------+--------+ +| dask_scheduler_received_tasks | Number of tasks received at scheduler. | Yes | | ++---------------------------------------------+------------------------------------------------+-----------+--------+ +| dask_scheduler_unrunnable_tasks | Number of unrunnable tasks at scheduler. | Yes | | ++---------------------------------------------+------------------------------------------------+-----------+--------+ +| dask_worker_tasks | Number of tasks at worker. | | Yes | ++---------------------------------------------+------------------------------------------------+-----------+--------+ +| dask_worker_connections | Number of task connections to other workers. | | Yes | ++---------------------------------------------+------------------------------------------------+-----------+--------+ +| dask_worker_threads | Number of worker threads. | | Yes | ++---------------------------------------------+------------------------------------------------+-----------+--------+ +| dask_worker_latency_seconds | Latency of worker connection. | | Yes | ++---------------------------------------------+------------------------------------------------+-----------+--------+ +| dask_worker_tick_duration_median_seconds | Median tick duration at worker. | | Yes | ++---------------------------------------------+------------------------------------------------+-----------+--------+ +| dask_worker_task_duration_median_seconds | Median task runtime at worker. | | Yes | ++---------------------------------------------+------------------------------------------------+-----------+--------+ +| dask_worker_transfer_bandwidth_median_bytes | Bandwidth for transfer at worker in Bytes. | | Yes | ++---------------------------------------------+------------------------------------------------+-----------+--------+ From 876bca0384ba4e0eef8bbf600cf9b89ac99faa99 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Thu, 24 Oct 2019 13:56:10 +0200 Subject: [PATCH 0513/1550] Fix Numba serialization when strides is None (#3166) --- distributed/protocol/numba.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/distributed/protocol/numba.py b/distributed/protocol/numba.py index ddf43adc182..9b33660e2bd 100644 --- a/distributed/protocol/numba.py +++ b/distributed/protocol/numba.py @@ -18,9 +18,18 @@ def serialize_numba_ndarray(x): @cuda_deserialize.register(numba.cuda.devicearray.DeviceNDArray) def deserialize_numba_ndarray(header, frames): (frame,) = frames + shape = header["shape"] + strides = header["strides"] + + # Starting with __cuda_array_interface__ version 2, strides can be None, + # meaning the array is C-contiguous, so we have to calculate it. + if strides is None: + itemsize = np.dtype(header["typestr"]).itemsize + strides = tuple((np.cumprod((1,) + shape[:0:-1]) * itemsize).tolist()) + arr = numba.cuda.devicearray.DeviceNDArray( - header["shape"], - header["strides"], + shape, + strides, np.dtype(header["typestr"]), gpu_data=numba.cuda.as_cuda_array(frame).gpu_data, ) From 97fbaae0bbcfc91fccaaa43db4952bff312bcc12 Mon Sep 17 00:00:00 2001 From: SimonBoothroyd Date: Fri, 25 Oct 2019 10:35:41 -0400 Subject: [PATCH 0514/1550] Await cluster in Adaptive.recommendations (#3168) Fixes #3154 This fixes adaptive scaling when we've lost all workers in some cases. --- distributed/deploy/adaptive.py | 8 +++++ distributed/deploy/adaptive_core.py | 1 + distributed/deploy/tests/test_spec_cluster.py | 30 +++++++++++++++++++ 3 files changed, 39 insertions(+) diff --git a/distributed/deploy/adaptive.py b/distributed/deploy/adaptive.py index b8c3429a505..f173e36a396 100644 --- a/distributed/deploy/adaptive.py +++ b/distributed/deploy/adaptive.py @@ -113,6 +113,14 @@ async def target(self): target_duration=self.target_duration ) + async def recommendations(self, target: int) -> dict: + if len(self.plan) != len(self.requested): + # Ensure that the number of planned and requested workers + # are in sync before making recommendations. + await self.cluster + + return await super(Adaptive, self).recommendations(target) + async def workers_to_close(self, target: int): """ Determine which, if any, workers should potentially be removed from diff --git a/distributed/deploy/adaptive_core.py b/distributed/deploy/adaptive_core.py index db50f109ce3..44a708aca38 100644 --- a/distributed/deploy/adaptive_core.py +++ b/distributed/deploy/adaptive_core.py @@ -178,6 +178,7 @@ async def adapt(self) -> None: self._adapting = True try: + target = await self.safe_target() recommendations = await self.recommendations(target) diff --git a/distributed/deploy/tests/test_spec_cluster.py b/distributed/deploy/tests/test_spec_cluster.py index 1c8a01e98ce..db78b66269e 100644 --- a/distributed/deploy/tests/test_spec_cluster.py +++ b/distributed/deploy/tests/test_spec_cluster.py @@ -1,5 +1,6 @@ import asyncio import re +from time import sleep import dask from dask.distributed import SpecCluster, Worker, Client, Scheduler, Nanny @@ -131,6 +132,35 @@ async def test_scale(cleanup): assert len(cluster.workers) == 2 +@pytest.mark.slow +@pytest.mark.asyncio +async def test_adaptive_killed_worker(cleanup): + with dask.config.set({"distributed.deploy.lost-worker-timeout": 0.1}): + + async with SpecCluster( + asynchronous=True, + worker={"cls": Nanny, "options": {"nthreads": 1}}, + scheduler={"cls": Scheduler, "options": {"port": 0}}, + ) as cluster: + + async with Client(cluster, asynchronous=True) as client: + + cluster.adapt(minimum=1, maximum=1) + + # Scale up a cluster with 1 worker. + while len(cluster.workers) != 1: + await asyncio.sleep(0.01) + + future = client.submit(sleep, 0.1) + + # Kill the only worker. + [worker_id] = cluster.workers + await cluster.workers[worker_id].kill() + + # Wait for the worker to re-spawn and finish sleeping. + await future.result(timeout=5) + + @pytest.mark.asyncio async def test_unexpected_closed_worker(cleanup): worker = {"cls": Worker, "options": {"nthreads": 1}} From 888675a2451f0908f2eb18a07ae48a42a9f4fe0d Mon Sep 17 00:00:00 2001 From: Jim Crist Date: Fri, 25 Oct 2019 12:41:41 -0500 Subject: [PATCH 0515/1550] Support automatic TLS (#3164) This adds support for automatically securing cluster communication with TLS. This can be useful for situations where you don't have existing credentials. This required the following changes: - ``Security`` objects now support either paths or contents for all `key`/`cert`/`ca` fields. Due to limitations in Python's ``ssl`` module, if contents are provided they must be written to a temporary directory before being loaded back in. We make sure to use secure methods for doing this. We also change the ``__repr__`` to not show the raw cert values in the case they're stored in memory. - ``Security`` objects now have a classmethod ``temporary``, which can be used to create temporary credentials using self-signed certs. This requires ``cryptography`` to be installed. Most environments will already have ``cryptography``, so this isn't a huge new dependency. ```python >>> sec = Security.temporary() ``` - Both ``Client`` and ``LocalCluster`` now support passing in ``security=True``, which will generate temporary credentials automatically for use with that cluster. This api could be supported by other cluster managers, but for now we restrict to ``LocalCluster`` only. ```python >>> client = Client(security=True) >>> client ``` --- distributed/client.py | 17 +++- distributed/deploy/local.py | 15 +++- distributed/deploy/tests/test_local.py | 31 +++++-- distributed/security.py | 112 +++++++++++++++++++++++-- distributed/tests/test_security.py | 39 +++++++++ 5 files changed, 192 insertions(+), 22 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index ca7ca431c90..9fe32c58df0 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -523,8 +523,10 @@ class Client(Node): Claim this scheduler as the global dask scheduler scheduler_file: string (optional) Path to a file with scheduler information if available - security: (optional) - Optional security information + security: Security or bool, optional + Optional security information. If creating a local cluster can also + pass in ``True``, in which case temporary self-signed credentials will + be created automatically. asynchronous: bool (False by default) Set to True if using this client within async/await functions or within Tornado gen.coroutines. Otherwise this should remain False for normal @@ -659,8 +661,15 @@ def __init__( if security is None: security = getattr(self.cluster, "security", None) - self.security = security or Security() - assert isinstance(self.security, Security) + if security is None: + security = Security() + elif security is True: + security = Security.temporary() + self._startup_kwargs["security"] = security + elif not isinstance(security, Security): + raise TypeError("security must be a Security object") + + self.security = security if name == "worker": self.connection_args = self.security.get_connection_args("worker") diff --git a/distributed/deploy/local.py b/distributed/deploy/local.py index 29c344f6719..fd1430baa21 100644 --- a/distributed/deploy/local.py +++ b/distributed/deploy/local.py @@ -56,7 +56,10 @@ class LocalCluster(SpecCluster): like ``['feed', 'run_function']`` service_kwargs: Dict[str, Dict] Extra keywords to hand to the running services - security : Security + security : Security or bool, optional + Configures communication security in this cluster. Can be a security + object, or True. If True, temporary self-signed credentials will + be created automatically. protocol: str (optional) Protocol to use like ``tcp://``, ``tls://``, ``inproc://`` This defaults to sensible choice given other keyword arguments like @@ -122,7 +125,15 @@ def __init__( self.status = None self.processes = processes - security = security or Security() + + if security is None: + # Falsey values load the default configuration + security = Security() + elif security is True: + # True indicates self-signed temporary credentials should be used + security = Security.temporary() + elif not isinstance(security, Security): + raise TypeError("security must be a Security object") if protocol is None: if host and "://" in host: diff --git a/distributed/deploy/tests/test_local.py b/distributed/deploy/tests/test_local.py index 7a340a9c6f8..452a5795ad7 100644 --- a/distributed/deploy/tests/test_local.py +++ b/distributed/deploy/tests/test_local.py @@ -246,6 +246,15 @@ def test_Client_twice(loop): assert c.cluster.scheduler.port != f.cluster.scheduler.port +@pytest.mark.asyncio +async def test_client_constructor_with_temporary_security(cleanup): + async with Client( + security=True, silence_logs=False, dashboard_address=None, asynchronous=True + ) as c: + assert c.cluster.scheduler_address.startswith("tls") + assert c.security == c.cluster.security + + @pytest.mark.asyncio async def test_defaults(cleanup): async with LocalCluster( @@ -695,10 +704,12 @@ def test_adapt_then_manual(loop): assert time() < start + 5 -def test_local_tls(loop): - from distributed.utils_test import tls_only_security - - security = tls_only_security() +@pytest.mark.parametrize("temporary", [True, False]) +def test_local_tls(loop, temporary): + if temporary: + security = True + else: + security = tls_only_security() with LocalCluster( n_workers=0, scheduler_port=8786, @@ -712,7 +723,7 @@ def test_local_tls(loop): loop, assert_can_connect_from_everywhere_4, c.scheduler.port, - connection_args=security.get_connection_args("client"), + connection_args=c.security.get_connection_args("client"), protocol="tls", timeout=3, ) @@ -722,7 +733,7 @@ def test_local_tls(loop): loop, assert_cannot_connect, addr="tcp://127.0.0.1:%d" % c.scheduler.port, - connection_args=security.get_connection_args("client"), + connection_args=c.security.get_connection_args("client"), exception_class=RuntimeError, ) @@ -977,8 +988,12 @@ async def test_repr(cleanup): @pytest.mark.asyncio -async def test_capture_security(cleanup): - security = tls_only_security() +@pytest.mark.parametrize("temporary", [True, False]) +async def test_capture_security(cleanup, temporary): + if temporary: + security = True + else: + security = tls_only_security() async with LocalCluster( n_workers=0, silence_logs=False, diff --git a/distributed/security.py b/distributed/security.py index a42cbeef646..6b7d87b2715 100644 --- a/distributed/security.py +++ b/distributed/security.py @@ -1,3 +1,7 @@ +import datetime +import tempfile +import os + try: import ssl except ImportError: @@ -76,6 +80,67 @@ def __init__(self, **kwargs): self._set_field(kwargs, "tls_worker_key", "distributed.comm.tls.worker.key") self._set_field(kwargs, "tls_worker_cert", "distributed.comm.tls.worker.cert") + @classmethod + def temporary(cls): + """Create a new temporary Security object. + + This creates a new self-signed key/cert pair suitable for securing + communication for all roles in a Dask cluster. These keys/certs exist + only in memory, and are stored in this object. + + This method requires the library ``cryptography`` be installed. + """ + try: + from cryptography import x509 + from cryptography.hazmat.backends import default_backend + from cryptography.hazmat.primitives import hashes + from cryptography.hazmat.primitives import serialization + from cryptography.hazmat.primitives.asymmetric import rsa + from cryptography.x509.oid import NameOID + except ImportError: + raise ImportError( + "Using `Security.temporary` requires `cryptography`, please " + "install it using either pip or conda" + ) + key = rsa.generate_private_key( + public_exponent=65537, key_size=2048, backend=default_backend() + ) + key_contents = key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ).decode() + + dask_internal = x509.Name( + [x509.NameAttribute(NameOID.COMMON_NAME, "dask-internal")] + ) + altnames = x509.SubjectAlternativeName([x509.DNSName("dask-internal")]) + now = datetime.datetime.utcnow() + cert = ( + x509.CertificateBuilder() + .subject_name(dask_internal) + .issuer_name(dask_internal) + .add_extension(altnames, critical=False) + .public_key(key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(now) + .not_valid_after(now + datetime.timedelta(days=365)) + .sign(key, hashes.SHA256(), default_backend()) + ) + + cert_contents = cert.public_bytes(serialization.Encoding.PEM).decode() + + return cls( + require_encryption=True, + tls_ca_file=cert_contents, + tls_client_key=key_contents, + tls_client_cert=cert_contents, + tls_scheduler_key=key_contents, + tls_scheduler_cert=cert_contents, + tls_worker_key=key_contents, + tls_worker_cert=cert_contents, + ) + def _set_field(self, kwargs, field, config_name): if field in kwargs: out = kwargs[field] @@ -84,12 +149,16 @@ def _set_field(self, kwargs, field, config_name): setattr(self, field, out) def __repr__(self): - items = sorted((k, getattr(self, k)) for k in self.__slots__) - return ( - "Security(" - + ", ".join("%s=%r" % (k, v) for k, v in items if v is not None) - + ")" - ) + keys = sorted(self.__slots__) + items = [] + for k in keys: + val = getattr(self, k) + if val is not None: + if isinstance(val, str) and "\n" in val: + items.append((k, "...")) + else: + items.append((k, repr(val))) + return "Security(" + ", ".join("%s=%s" % (k, v) for k, v in items) + ")" def get_tls_config_for_role(self, role): """ @@ -106,14 +175,41 @@ def get_tls_config_for_role(self, role): def _get_tls_context(self, tls, purpose): if tls.get("ca_file") and tls.get("cert"): - ctx = ssl.create_default_context(purpose=purpose, cafile=tls["ca_file"]) + ca = tls["ca_file"] + cert_path = cert = tls["cert"] + key_path = key = tls.get("key") + + if "\n" in ca: + ctx = ssl.create_default_context(purpose=purpose, cadata=ca) + else: + ctx = ssl.create_default_context(purpose=purpose, cafile=ca) + + cert_in_memory = "\n" in cert + key_in_memory = key is not None and "\n" in key + if cert_in_memory or key_in_memory: + with tempfile.TemporaryDirectory() as tempdir: + if cert_in_memory: + cert_path = os.path.join(tempdir, "dask.crt") + with open(cert_path, "w") as f: + f.write(cert) + if key_in_memory: + key_path = os.path.join(tempdir, "dask.pem") + with open(key_path, "w") as f: + f.write(key) + ctx.load_cert_chain(cert_path, key_path) + else: + ctx.load_cert_chain(cert_path, key_path) + + # Bidirectional authentication ctx.verify_mode = ssl.CERT_REQUIRED + # We expect a dedicated CA for the cluster and people using # IP addresses rather than hostnames ctx.check_hostname = False - ctx.load_cert_chain(tls["cert"], tls.get("key")) + if tls.get("ciphers"): ctx.set_ciphers(tls.get("ciphers")) + return ctx def get_connection_args(self, role): diff --git a/distributed/tests/test_security.py b/distributed/tests/test_security.py index bfc8358acf1..7496c037ae7 100644 --- a/distributed/tests/test_security.py +++ b/distributed/tests/test_security.py @@ -379,3 +379,42 @@ def check_encryption_error(): handle_comm, connection_args=sec2.get_listen_args("scheduler"), ) + + +def test_temporary_credentials(): + sec = Security.temporary() + sec_repr = repr(sec) + fields = ["tls_ca_file"] + fields.extend( + "tls_%s_%s" % (role, kind) + for role in ["client", "scheduler", "worker"] + for kind in ["key", "cert"] + ) + for f in fields: + val = getattr(sec, f) + assert "\n" in val + assert val not in sec_repr + + +@gen_test() +def test_tls_temporary_credentials_functional(): + pytest.importorskip("cryptography") + + @gen.coroutine + def handle_comm(comm): + peer_addr = comm.peer_address + assert peer_addr.startswith("tls://") + yield comm.write("hello") + yield comm.close() + + sec = Security.temporary() + + with listen( + "tls://", handle_comm, connection_args=sec.get_listen_args("scheduler") + ) as listener: + comm = yield connect( + listener.contact_address, connection_args=sec.get_connection_args("worker") + ) + msg = yield comm.read() + assert msg == "hello" + comm.abort() From c45023af13591e8502de1c68d4b07834eec8273d Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Mon, 28 Oct 2019 14:21:05 -0500 Subject: [PATCH 0516/1550] Avoid swamping high-memory workers with data requests (#3071) Reduces memory pressure on high-memory workers when many other workers are requesting data from the high-memory worker. We throttle the number of responses we'll fulfill simultaneously. --- distributed/nanny.py | 1 + distributed/tests/test_nanny.py | 56 ++++++++++++++++++++++++++++++++- distributed/worker.py | 16 +++++++++- 3 files changed, 71 insertions(+), 2 deletions(-) diff --git a/distributed/nanny.py b/distributed/nanny.py index 83ca2ebbf80..b21974d0257 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -357,6 +357,7 @@ def memory_monitor(self): except (ProcessLookupError, psutil.NoSuchProcess, psutil.AccessDenied): return frac = memory / self.memory_limit + if self.memory_terminate_fraction and frac > self.memory_terminate_fraction: logger.warning( "Worker exceeded %d%% memory budget. Restarting", diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index d54cf4e3b14..952d9cb8c52 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -15,7 +15,7 @@ import dask from distributed.diagnostics import SchedulerPlugin -from distributed import Nanny, rpc, Scheduler, Worker, Client +from distributed import Nanny, rpc, Scheduler, Worker, Client, wait from distributed.core import CommClosedError from distributed.metrics import time from distributed.protocol.pickle import dumps @@ -290,6 +290,60 @@ def leak(): assert "memory" in out.lower() +@gen_cluster( + nthreads=[("127.0.0.1", 1)] * 8, + client=True, + Worker=Nanny, + worker_kwargs={"memory_limit": 2e8}, + timeout=20, + clean_kwargs={"threads": False}, +) +async def test_nanny_throttle(c, s, *workers): + # Verify that get_data requests are throttled when the worker + # with the data is at high-memory by + # 1. Allocation some data on a worker + # 2. Pausing that worker + # 3. Requesting data from that worker from many other workers + a = workers[0] + proc = a.process.pid + size = 1000 + + def data(size): + return b"0" * size + + def patch(dask_worker): + # Patch paused and memory_monitor on the one worker + # This is is very fragile, since a refactor of memory_monitor to + # remove _memory_monitoring will break this test. + dask_worker._memory_monitoring = True + dask_worker.paused = True + + def check(dask_worker): + return dask_worker.paused + + futures = [ + c.submit(data, size, workers=[a.worker_address], pure=False) for i in range(4) + ] + await wait(futures) + await c.run(patch, workers=[a.worker_address]) + paused = await c.run(check, workers=[a.worker_address]) + assert paused[a.worker_address] + + await c.run(lambda: logging.getLogger("distributed.worker").setLevel(logging.DEBUG)) + # Cluster is in the correct state, now for the test. + n = len(workers) + result = c.map( + lambda x, i: x[i], + [futures[0]] * n, + range(n), + workers=[w.worker_address for w in workers[1:]], + ) + await result[0] + wlogs = await c.get_worker_logs(workers=[a.worker_address]) + wlogs = "\n".join(x[1] for x in wlogs[a.worker_address]) + assert "throttling" in wlogs.lower() + + @gen_cluster(nthreads=[], client=True) def test_avoid_memory_monitor_if_zero_limit(c, s): nanny = yield Nanny(s.address, loop=s.loop, memory_limit=0) diff --git a/distributed/worker.py b/distributed/worker.py index 12dfe3fe178..822e9677cf1 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1156,10 +1156,24 @@ async def get_data( ): max_connections = max_connections * 2 + if self.paused: + max_connections = 1 + throttle_msg = " Throttling outgoing connections because worker is paused." + else: + throttle_msg = "" + if ( max_connections is not False - and self.outgoing_current_count > max_connections + and self.outgoing_current_count >= max_connections ): + logger.debug( + "Worker %s has too many open connections to respond to data request from %s (%d/%d).%s", + self.address, + who, + self.outgoing_current_count, + max_connections, + throttle_msg, + ) return {"status": "busy"} self.outgoing_current_count += 1 From ec1ffaa6086171ff21acd3ed8d879f293d2dd9b0 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Tue, 29 Oct 2019 16:20:40 +0100 Subject: [PATCH 0517/1550] Update UCX variables to use sockcm by default (#3177) --- distributed/comm/ucx.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index 77f65c661e8..cb3b93fbced 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -21,7 +21,8 @@ os.environ.setdefault("UCX_RNDV_SCHEME", "put_zcopy") os.environ.setdefault("UCX_MEMTYPE_CACHE", "n") -os.environ.setdefault("UCX_TLS", "tcp,rc,cuda_copy,cuda_ipc") +os.environ.setdefault("UCX_TLS", "tcp,sockcm,rc,cuda_copy,cuda_ipc") +os.environ.setdefault("UCX_SOCKADDR_TLS_PRIORITY", "sockcm") logger = logging.getLogger(__name__) MAX_MSG_LOG = 23 From 40d58b2a51f61a89a65db5af3f262d80bc80948f Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Wed, 30 Oct 2019 00:12:17 +0100 Subject: [PATCH 0518/1550] Get protocol in Nanny/Worker from scheduler address (#3175) --- distributed/nanny.py | 5 +++++ distributed/tests/test_worker.py | 14 ++++++++++++++ distributed/worker.py | 5 +++++ 3 files changed, 24 insertions(+) diff --git a/distributed/nanny.py b/distributed/nanny.py index b21974d0257..6e58271c33a 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -106,6 +106,11 @@ def __init__( else: self.scheduler_addr = coerce_to_address((scheduler_ip, scheduler_port)) + if protocol is None: + protocol_address = self.scheduler_addr.split("://") + if len(protocol_address) == 2: + protocol = protocol_address[0] + if ncores is not None: warnings.warn("the ncores= parameter has moved to nthreads=") nthreads = ncores diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 53aac46216a..4b9c1ace01f 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1497,6 +1497,20 @@ async def test_interface_async(loop, Worker): assert all("127.0.0.1" == d["host"] for d in info["workers"].values()) +@pytest.mark.asyncio +@pytest.mark.parametrize("Worker", [Worker, Nanny]) +async def test_protocol_from_scheduler_address(Worker): + ucp = pytest.importorskip("ucp") + + async with Scheduler(protocol="ucx") as s: + assert s.address.startswith("ucx://") + async with Worker(s.address) as w: + assert w.address.startswith("ucx://") + async with Client(s.address, asynchronous=True) as c: + info = c.scheduler_info() + assert info["address"].startswith("ucx://") + + @pytest.mark.asyncio @pytest.mark.parametrize("Worker", [Worker, Nanny]) async def test_worker_listens_on_same_interface_by_default(Worker): diff --git a/distributed/worker.py b/distributed/worker.py index 822e9677cf1..fec0444ba4e 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -440,6 +440,11 @@ def __init__( scheduler_addr = coerce_to_address((scheduler_ip, scheduler_port)) self.contact_address = contact_address + if protocol is None: + protocol_address = scheduler_addr.split("://") + if len(protocol_address) == 2: + protocol = protocol_address[0] + # Target interface on which we contact the scheduler by default # TODO: it is unfortunate that we special-case inproc here if not host and not interface and not scheduler_addr.startswith("inproc://"): From 9429ffe735e3767b427fd0b18aaa29d6a7da8513 Mon Sep 17 00:00:00 2001 From: darindf Date: Tue, 29 Oct 2019 16:13:33 -0700 Subject: [PATCH 0519/1550] Add worker and tasks state for Prometheus data collection (#3174) --- distributed/dashboard/scheduler.py | 30 +++++++++++++++++------------- docs/source/prometheus.rst | 4 +--- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/distributed/dashboard/scheduler.py b/distributed/dashboard/scheduler.py index 1117fe7bd72..8928d468c5e 100644 --- a/distributed/dashboard/scheduler.py +++ b/distributed/dashboard/scheduler.py @@ -239,26 +239,30 @@ def __init__(self, server): def collect(self): from prometheus_client.core import GaugeMetricFamily - yield GaugeMetricFamily( - "dask_scheduler_workers", - "Number of workers connected.", - value=len(self.server.workers), - ) yield GaugeMetricFamily( "dask_scheduler_clients", "Number of clients connected.", value=len(self.server.clients), ) - yield GaugeMetricFamily( - "dask_scheduler_received_tasks", - "Number of tasks received at scheduler", - value=len(self.server.tasks), + + tasks = GaugeMetricFamily( + "dask_scheduler_workers", + "Number of workers known by scheduler.", + labels=["state"], ) - yield GaugeMetricFamily( - "dask_scheduler_unrunnable_tasks", - "Number of unrunnable tasks at scheduler", - value=len(self.server.unrunnable), + tasks.add_metric(["connected"], len(self.server.workers)) + tasks.add_metric(["saturated"], len(self.server.saturated)) + tasks.add_metric(["idle"], len(self.server.idle)) + yield tasks + + tasks = GaugeMetricFamily( + "dask_scheduler_tasks", + "Number of tasks known by scheduler.", + labels=["state"], ) + tasks.add_metric(["received"], len(self.server.tasks)) + tasks.add_metric(["unrunnable"], len(self.server.unrunnable)) + yield tasks class PrometheusHandler(RequestHandler): diff --git a/docs/source/prometheus.rst b/docs/source/prometheus.rst index 8d6759ad1fa..097335ee0d7 100644 --- a/docs/source/prometheus.rst +++ b/docs/source/prometheus.rst @@ -23,9 +23,7 @@ Available metrics are as following +---------------------------------------------+------------------------------------------------+-----------+--------+ | dask_scheduler_clients | Number of clients connected. | Yes | | +---------------------------------------------+------------------------------------------------+-----------+--------+ -| dask_scheduler_received_tasks | Number of tasks received at scheduler. | Yes | | -+---------------------------------------------+------------------------------------------------+-----------+--------+ -| dask_scheduler_unrunnable_tasks | Number of unrunnable tasks at scheduler. | Yes | | +| dask_scheduler_tasks | Number of tasks at scheduler. | Yes | | +---------------------------------------------+------------------------------------------------+-----------+--------+ | dask_worker_tasks | Number of tasks at worker. | | Yes | +---------------------------------------------+------------------------------------------------+-----------+--------+ From 8b7c47d7e93ee74d543f29ada1c7863bbad78310 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Fri, 1 Nov 2019 05:30:28 +0100 Subject: [PATCH 0520/1550] Use async def functions for offload to/from_frames (#3171) --- distributed/client.py | 2 +- distributed/comm/tcp.py | 7 +++--- distributed/comm/ucx.py | 3 +-- distributed/comm/utils.py | 16 +++++-------- distributed/core.py | 18 ++++++++------ distributed/deploy/cluster.py | 2 +- distributed/deploy/tests/test_local.py | 16 ++++++++++++- distributed/scheduler.py | 3 +-- distributed/tests/test_core.py | 5 ++-- distributed/tests/test_nanny.py | 2 +- distributed/tests/test_scheduler.py | 16 +++++++++++++ distributed/tests/test_worker.py | 33 +++++++++++--------------- distributed/utils.py | 6 ++--- distributed/utils_comm.py | 4 ++-- distributed/worker.py | 2 +- 15 files changed, 79 insertions(+), 56 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 9fe32c58df0..11aaacdf044 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -1298,7 +1298,7 @@ async def _close(self, fast=False): with ignoring(TimeoutError): await gen.with_timeout(timedelta(seconds=2), list(coroutines)) with ignoring(AttributeError): - self.scheduler.close_rpc() + await self.scheduler.close_rpc() self.scheduler = None self.status = "closed" diff --git a/distributed/comm/tcp.py b/distributed/comm/tcp.py index d0322e151d7..f0a24fe4fb7 100644 --- a/distributed/comm/tcp.py +++ b/distributed/comm/tcp.py @@ -213,14 +213,13 @@ async def read(self, deserializers=None): raise CommClosedError("aborted stream on truncated data") return msg - @gen.coroutine - def write(self, msg, serializers=None, on_error="message"): + async def write(self, msg, serializers=None, on_error="message"): stream = self.stream bytes_since_last_yield = 0 if stream is None: raise CommClosedError - frames = yield to_frames( + frames = await to_frames( msg, serializers=serializers, on_error=on_error, @@ -247,7 +246,7 @@ def write(self, msg, serializers=None, on_error="message"): future = stream.write(frame) bytes_since_last_yield += nbytes(frame) if bytes_since_last_yield > 32e6: - yield future + await future bytes_since_last_yield = 0 except StreamClosedError as e: stream = None diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index cb3b93fbced..fede1c91371 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -221,7 +221,6 @@ async def connect(self, address: str, deserialize=True, **connection_args) -> UC class UCXListener(Listener): - # MAX_LISTENERS 256 in ucx-py prefix = UCXConnector.prefix comm_class = UCXConnector.comm_class encrypted = UCXConnector.encrypted @@ -251,7 +250,7 @@ async def serve_forever(client_ep): ucx = UCX( client_ep, local_addr=self.address, - peer_addr=self.address, # TODO: https://github.com/Akshay-Venkatesh/ucx-py/issues/111 + peer_addr=self.address, deserialize=self.deserialize, ) if self.comm_handler: diff --git a/distributed/comm/utils.py b/distributed/comm/utils.py index 70cd2b4cd27..80e1f163785 100644 --- a/distributed/comm/utils.py +++ b/distributed/comm/utils.py @@ -1,8 +1,6 @@ import logging import socket -from tornado import gen - from .. import protocol from ..utils import get_ip, get_ipv6, nbytes, offload @@ -16,8 +14,7 @@ FRAME_OFFLOAD_THRESHOLD = 10 * 1024 ** 2 # 10 MB -@gen.coroutine -def to_frames(msg, serializers=None, on_error="message", context=None): +async def to_frames(msg, serializers=None, on_error="message", context=None): """ Serialize a message into a list of Distributed protocol frames. """ @@ -34,13 +31,12 @@ def _to_frames(): logger.exception(e) raise - res = yield offload(_to_frames) + res = await offload(_to_frames) - raise gen.Return(res) + return res -@gen.coroutine -def from_frames(frames, deserialize=True, deserializers=None): +async def from_frames(frames, deserialize=True, deserializers=None): """ Unserialize a list of Distributed protocol frames. """ @@ -61,11 +57,11 @@ def _from_frames(): raise if deserialize and size > FRAME_OFFLOAD_THRESHOLD: - res = yield offload(_from_frames) + res = await offload(_from_frames) else: res = _from_frames() - raise gen.Return(res) + return res def get_tcp_server_address(tcp_server): diff --git a/distributed/core.py b/distributed/core.py index 32b509dc170..716a7b035e2 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -641,25 +641,29 @@ async def live_comm(self): return comm def close_comms(self): - @gen.coroutine - def _close_comm(comm): + async def _close_comm(comm): # Make sure we tell the peer to close try: if not comm.closed(): - yield comm.write({"op": "close", "reply": False}) - yield comm.close() + await comm.write({"op": "close", "reply": False}) + await comm.close() except EnvironmentError: comm.abort() + tasks = [] for comm in list(self.comms): if comm and not comm.closed(): # IOLoop.current().add_callback(_close_comm, comm) task = asyncio.ensure_future(_close_comm(comm)) + tasks.append(task) for comm in list(self._created): if comm and not comm.closed(): # IOLoop.current().add_callback(_close_comm, comm) task = asyncio.ensure_future(_close_comm(comm)) + tasks.append(task) + self.comms.clear() + return tasks def __getattr__(self, key): async def send_recv_from_rpc(**kwargs): @@ -685,13 +689,13 @@ def close_rpc(self): if self.status != "closed": rpc.active.discard(self) self.status = "closed" - self.close_comms() + return asyncio.gather(*self.close_comms()) def __enter__(self): return self def __exit__(self, *args): - self.close_rpc() + asyncio.ensure_future(self.close_rpc()) def __del__(self): if self.status != "closed": @@ -744,7 +748,7 @@ async def send_recv_from_rpc(**kwargs): return send_recv_from_rpc - def close_rpc(self): + async def close_rpc(self): pass # For compatibility with rpc() diff --git a/distributed/deploy/cluster.py b/distributed/deploy/cluster.py index 033f6877684..2631fb502df 100644 --- a/distributed/deploy/cluster.py +++ b/distributed/deploy/cluster.py @@ -74,7 +74,7 @@ async def _close(self): for pc in self.periodic_callbacks.values(): pc.stop() - self.scheduler_comm.close_rpc() + await self.scheduler_comm.close_rpc() self.status = "closed" diff --git a/distributed/deploy/tests/test_local.py b/distributed/deploy/tests/test_local.py index 452a5795ad7..e73ccd4721f 100644 --- a/distributed/deploy/tests/test_local.py +++ b/distributed/deploy/tests/test_local.py @@ -1,3 +1,4 @@ +import asyncio from functools import partial import gc import subprocess @@ -455,7 +456,7 @@ def test_silent_startup(): if __name__ == "__main__": with LocalCluster(1, dashboard_address=None, scheduler_port=0): - sleep(1.5) + sleep(.1) """ out = subprocess.check_output( @@ -1004,3 +1005,16 @@ async def test_capture_security(cleanup, temporary): ) as cluster: async with Client(cluster, asynchronous=True) as client: assert client.security == cluster.security + + +@pytest.mark.asyncio +@pytest.mark.skipif( + sys.version_info < (3, 7), reason="asyncio.all_tasks not implemented" +) +async def test_no_danglng_asyncio_tasks(cleanup): + start = asyncio.all_tasks() + async with LocalCluster(asynchronous=True, processes=False): + await asyncio.sleep(0.01) + + tasks = asyncio.all_tasks() + assert tasks == start diff --git a/distributed/scheduler.py b/distributed/scheduler.py index f3d3fc92ea8..4319584735e 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -2642,8 +2642,7 @@ async def restart(self, client=None, timeout=3): "timeout. Continuuing with restart process" ) finally: - for nanny in nannies: - nanny.close_rpc() + await asyncio.gather(*[nanny.close_rpc() for nanny in nannies]) await self.start() diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index cbde7ac240b..99c07226c48 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -299,13 +299,14 @@ def test_rpc_inproc(): yield check_rpc("inproc://", None) -def test_rpc_inputs(): +@pytest.mark.asyncio +async def test_rpc_inputs(): L = [rpc("127.0.0.1:8884"), rpc(("127.0.0.1", 8884)), rpc("tcp://127.0.0.1:8884")] assert all(r.address == "tcp://127.0.0.1:8884" for r in L), L for r in L: - r.close_rpc() + await r.close_rpc() async def check_rpc_message_lifetime(*listen_args): diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index 952d9cb8c52..70497bf7909 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -115,7 +115,7 @@ def test_nanny_process_failure(c, s): assert not os.path.exists(second_dir) assert not os.path.exists(first_dir) assert first_dir != n.worker_dir - ww.close_rpc() + yield ww.close_rpc() s.stop() diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 4b56a4d084f..f57dbfb9e07 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -1,3 +1,4 @@ +import asyncio import cloudpickle import pickle from collections import defaultdict @@ -1688,3 +1689,18 @@ def test_get_task_duration(): assert s.get_task_duration(ts_pref2_2) == 0.5 # default assert len(s.unknown_durations) == 1 assert len(s.unknown_durations["prefix_2"]) == 2 + + +@pytest.mark.asyncio +@pytest.mark.skipif( + sys.version_info < (3, 7), reason="asyncio.all_tasks not implemented" +) +async def test_no_danglng_asyncio_tasks(cleanup): + start = asyncio.all_tasks() + async with Scheduler(port=0) as s: + async with Worker(s.address, name="0") as a: + async with Client(s.address, asynchronous=True) as c: + await asyncio.sleep(0.01) + + tasks = asyncio.all_tasks() + assert tasks == start diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 4b9c1ace01f..8b81e3afbe6 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -173,7 +173,7 @@ def dont_test_delete_data_with_missing_worker(c, a, b): assert not c.has_what[bad] assert not c.has_what[a.address] - cc.close_rpc() + yield cc.close_rpc() @gen_cluster(client=True) @@ -998,32 +998,27 @@ def test_worker_fds(s): @gen_cluster(nthreads=[]) -def test_service_hosts_match_worker(s): +async def test_service_hosts_match_worker(s): pytest.importorskip("bokeh") from distributed.dashboard import BokehWorker - services = {("dashboard", ":0"): BokehWorker} - - w = yield Worker( + async with Worker( s.address, services={("dashboard", ":0"): BokehWorker}, host="tcp://0.0.0.0" - ) - sock = first(w.services["dashboard"].server._http._sockets.values()) - assert sock.getsockname()[0] in ("::", "0.0.0.0") - yield w.close() + ) as w: + sock = first(w.services["dashboard"].server._http._sockets.values()) + assert sock.getsockname()[0] in ("::", "0.0.0.0") - w = yield Worker( + async with Worker( s.address, services={("dashboard", ":0"): BokehWorker}, host="tcp://127.0.0.1" - ) - sock = first(w.services["dashboard"].server._http._sockets.values()) - assert sock.getsockname()[0] in ("::", "0.0.0.0") - yield w.close() + ) as w: + sock = first(w.services["dashboard"].server._http._sockets.values()) + assert sock.getsockname()[0] in ("::", "0.0.0.0") - w = yield Worker( + async with Worker( s.address, services={("dashboard", 0): BokehWorker}, host="tcp://127.0.0.1" - ) - sock = first(w.services["dashboard"].server._http._sockets.values()) - assert sock.getsockname()[0] == "127.0.0.1" - yield w.close() + ) as w: + sock = first(w.services["dashboard"].server._http._sockets.values()) + assert sock.getsockname()[0] == "127.0.0.1" @gen_cluster(nthreads=[]) diff --git a/distributed/utils.py b/distributed/utils.py index fbac950df43..c7ab77bcbd8 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -1365,6 +1365,6 @@ def is_valid_xml(text): weakref.finalize(_offload_executor, _offload_executor.shutdown) -@gen.coroutine -def offload(fn, *args, **kwargs): - return (yield _offload_executor.submit(fn, *args, **kwargs)) +async def offload(fn, *args, **kwargs): + loop = asyncio.get_event_loop() + return await loop.run_in_executor(_offload_executor, fn, *args, **kwargs) diff --git a/distributed/utils_comm.py b/distributed/utils_comm.py index 53504d11939..e2072189be0 100644 --- a/distributed/utils_comm.py +++ b/distributed/utils_comm.py @@ -84,7 +84,7 @@ async def gather_from_workers(who_has, rpc, close=True, serializers=None, who=No response.update(r["data"]) finally: for r in rpcs.values(): - r.close_rpc() + await r.close_rpc() bad_addresses |= {v for k, v in rev.items() if k not in response} results.update(response) @@ -148,7 +148,7 @@ async def scatter_to_workers(nthreads, data, rpc=rpc, report=True, serializers=N ) finally: for r in rpcs.values(): - r.close_rpc() + await r.close_rpc() nbytes = merge(o["nbytes"] for o in out) diff --git a/distributed/worker.py b/distributed/worker.py index fec0444ba4e..d1c35f68eb5 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1068,7 +1068,7 @@ async def close( address=self.contact_address, safe=safe ), ) - self.scheduler.close_rpc() + await self.scheduler.close_rpc() self._workdir.release() for k, v in self.services.items(): From 5025b124bc7627837844e50d6f9c4b6df7ee36af Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Tue, 5 Nov 2019 16:48:19 +0100 Subject: [PATCH 0521/1550] Subprocesses inherit the global dask config (#3192) --- distributed/cli/dask_scheduler.py | 7 +++++++ distributed/cli/dask_worker.py | 6 ++++++ distributed/deploy/ssh.py | 17 ++++++++++++--- distributed/deploy/tests/test_ssh.py | 19 +++++++++++++++++ distributed/process.py | 16 ++++++++++++-- distributed/tests/test_client.py | 12 +++++++++++ distributed/utils.py | 31 ++++++++++++++++++++++++++++ 7 files changed, 103 insertions(+), 5 deletions(-) diff --git a/distributed/cli/dask_scheduler.py b/distributed/cli/dask_scheduler.py index 29de26d7b4d..0951b8c3d27 100755 --- a/distributed/cli/dask_scheduler.py +++ b/distributed/cli/dask_scheduler.py @@ -9,6 +9,7 @@ import warnings import click +import dask from tornado.ioloop import IOLoop @@ -16,6 +17,7 @@ from distributed.preloading import validate_preload_argv from distributed.security import Security from distributed.cli.utils import check_python_3, install_signal_handlers +from distributed.utils import deserialize_for_cli from distributed.proctitle import ( enable_proctitle_on_children, enable_proctitle_on_current, @@ -174,6 +176,11 @@ def main( } ) + if "DASK_INTERNAL_INHERIT_CONFIG" in os.environ: + config = deserialize_for_cli(os.environ["DASK_INTERNAL_INHERIT_CONFIG"]) + # Update the global config given priority to the existing global config + dask.config.update(dask.config.global_config, config, priority="old") + if not host and (tls_ca_file or tls_cert or tls_key): host = "tls://" diff --git a/distributed/cli/dask_worker.py b/distributed/cli/dask_worker.py index fb32fc2e882..0f307398e04 100755 --- a/distributed/cli/dask_worker.py +++ b/distributed/cli/dask_worker.py @@ -20,6 +20,7 @@ enable_proctitle_on_children, enable_proctitle_on_current, ) +from distributed.utils import deserialize_for_cli from toolz import valmap from tornado.ioloop import IOLoop, TimeoutError @@ -359,6 +360,11 @@ def del_pid_file(): with ignoring(TypeError, ValueError): name = int(name) + if "DASK_INTERNAL_INHERIT_CONFIG" in os.environ: + config = deserialize_for_cli(os.environ["DASK_INTERNAL_INHERIT_CONFIG"]) + # Update the global config given priority to the existing global config + dask.config.update(dask.config.global_config, config, priority="old") + nannies = [ t( scheduler, diff --git a/distributed/deploy/ssh.py b/distributed/deploy/ssh.py index 8aa3cc17d97..673cb7ba717 100644 --- a/distributed/deploy/ssh.py +++ b/distributed/deploy/ssh.py @@ -4,10 +4,13 @@ import warnings import weakref +import dask + from .spec import SpecCluster, ProcessInterface from ..utils import cli_keywords from ..scheduler import Scheduler as _Scheduler from ..worker import Worker as _Worker +from ..utils import serialize_for_cli logger = logging.getLogger(__name__) @@ -86,6 +89,8 @@ async def start(self): self.proc = await self.connection.create_process( " ".join( [ + 'DASK_INTERNAL_INHERIT_CONFIG="%s"' + % serialize_for_cli(dask.config.global_config), sys.executable, "-m", self.worker_module, @@ -112,7 +117,7 @@ async def start(self): class Scheduler(Process): - """ A Remote Dask Scheduler controled by SSH + """ A Remote Dask Scheduler controlled by SSH Parameters ---------- @@ -141,7 +146,13 @@ async def start(self): self.proc = await self.connection.create_process( " ".join( - [sys.executable, "-m", "distributed.cli.dask_scheduler"] + [ + 'DASK_INTERNAL_INHERIT_CONFIG="%s"' + % serialize_for_cli(dask.config.global_config), + sys.executable, + "-m", + "distributed.cli.dask_scheduler", + ] + cli_keywords(self.kwargs, cls=_Scheduler) ) ) @@ -191,7 +202,7 @@ def SSHCluster( The SSHCluster function deploys a Dask Scheduler and Workers for you on a set of machine addresses that you provide. The first address will be used for the scheduler while the rest will be used for the workers (feel free to - repeat the first hostname if you want to have the scheudler and worker + repeat the first hostname if you want to have the scheduler and worker co-habitate one machine.) You may configure the scheduler and workers by passing diff --git a/distributed/deploy/tests/test_ssh.py b/distributed/deploy/tests/test_ssh.py index 3124af4f177..376b0eae3a4 100644 --- a/distributed/deploy/tests/test_ssh.py +++ b/distributed/deploy/tests/test_ssh.py @@ -2,6 +2,7 @@ pytest.importorskip("asyncssh") +import dask from dask.distributed import Client from distributed.deploy.ssh import SSHCluster from distributed.utils_test import loop # noqa: F401 @@ -55,3 +56,21 @@ def test_defer_to_old(loop): from distributed.deploy.old_ssh import SSHCluster as OldSSHCluster assert isinstance(c, OldSSHCluster) + + +@pytest.mark.asyncio +async def test_config_inherited_by_subprocess(loop): + def f(x): + return dask.config.get("foo") + 1 + + with dask.config.set(foo=100): + async with SSHCluster( + ["127.0.0.1"] * 2, + connect_options=dict(known_hosts=None), + asynchronous=True, + scheduler_options={"port": 0, "idle_timeout": "5s"}, + worker_options={"death_timeout": "5s"}, + ) as cluster: + async with Client(cluster, asynchronous=True) as client: + result = await client.submit(f, 1) + assert result == 101 diff --git a/distributed/process.py b/distributed/process.py index 889787fe0bf..38527ecd9ab 100644 --- a/distributed/process.py +++ b/distributed/process.py @@ -6,6 +6,7 @@ import re import threading import weakref +import dask from .utils import mp_context @@ -71,7 +72,14 @@ def __init__(self, loop=None, target=None, name=None, args=(), kwargs={}): self._process = mp_context.Process( target=self._run, name=name, - args=(target, args, kwargs, parent_alive_pipe, self._keep_child_alive), + args=( + target, + args, + kwargs, + parent_alive_pipe, + self._keep_child_alive, + dask.config.global_config, + ), ) _dangling.add(self._process) self._name = self._process.name @@ -163,7 +171,9 @@ def reset_logger_locks(): handler.createLock() @classmethod - def _run(cls, target, args, kwargs, parent_alive_pipe, _keep_child_alive): + def _run( + cls, target, args, kwargs, parent_alive_pipe, _keep_child_alive, inherit_config + ): # On Python 2 with the fork method, we inherit the _keep_child_alive fd, # whether it is passed or not. Therefore, pass it unconditionally and # close it here, so that there are no other references to the pipe lying @@ -176,6 +186,8 @@ def _run(cls, target, args, kwargs, parent_alive_pipe, _keep_child_alive): cls._immediate_exit_when_closed(parent_alive_pipe) threading.current_thread().name = "MainThread" + # Update the global config given priority to the existing global config + dask.config.update(dask.config.global_config, inherit_config, priority="old") target(*args, **kwargs) @classmethod diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index c4ebe92b4f1..29aad876ab3 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -5654,5 +5654,17 @@ async def test_shutdown_localcluster(cleanup): assert lc.scheduler.status == "closed" +@pytest.mark.asyncio +async def test_config_inherited_by_subprocess(cleanup): + def f(x): + return dask.config.get("foo") + 1 + + with dask.config.set(foo=100): + async with LocalCluster(n_workers=1, asynchronous=True, processes=True) as lc: + async with Client(lc, asynchronous=True) as c: + result = await c.submit(f, 1) + assert result == 101 + + if sys.version_info >= (3, 5): from distributed.tests.py3_test_client import * # noqa F401 diff --git a/distributed/utils.py b/distributed/utils.py index c7ab77bcbd8..251e1110be8 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -24,6 +24,7 @@ import warnings import weakref import pkgutil +import base64 import tblib.pickling_support import xml.etree.ElementTree @@ -1368,3 +1369,33 @@ def is_valid_xml(text): async def offload(fn, *args, **kwargs): loop = asyncio.get_event_loop() return await loop.run_in_executor(_offload_executor, fn, *args, **kwargs) + + +def serialize_for_cli(data): + """ Serialize data into a string that can be passthrough cli + + Parameters + ---------- + data: json-serializable object + The data to serialize + Returns + ------- + serialized_data: str + The serialized data as a string + """ + return base64.urlsafe_b64encode(json.dumps(data).encode()).decode() + + +def deserialize_for_cli(data): + """ De-serialize data into the original object + + Parameters + ---------- + data: str + String serialied by serialize_for_cli() + Returns + ------- + deserialized_data: obj + The de-serialized data + """ + return json.loads(base64.urlsafe_b64decode(data.encode()).decode()) From 4687879b0aebaa7435b26e80eb0669031ff2861e Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 5 Nov 2019 09:41:01 -0800 Subject: [PATCH 0522/1550] XFail test_open_close_many_workers (#3194) This test is causing intermittent failures, and unfortunately no one is available to resolve it. --- distributed/tests/test_client.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 29aad876ab3..00abbcd25ec 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -3576,10 +3576,7 @@ def test_reconnect_timeout(c, s): @pytest.mark.slow @pytest.mark.skipif(WINDOWS, reason="num_fds not supported on windows") -@pytest.mark.skipif( - sys.version_info[0] == 2, reason="Semaphore.acquire doesn't support timeout option" -) -# @pytest.mark.xfail(reason="TODO: intermittent failures") +@pytest.mark.xfail(reason="TODO: intermittent failures") @pytest.mark.parametrize("worker,count,repeat", [(Worker, 100, 5), (Nanny, 10, 20)]) def test_open_close_many_workers(loop, worker, count, repeat): psutil = pytest.importorskip("psutil") From e4a0404f0ed3f853a86adf79c48a60620f3e9fb6 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Tue, 5 Nov 2019 11:49:47 -0600 Subject: [PATCH 0523/1550] Drop Python 3.5 (#3179) --- .travis.yml | 3 +- appveyor.yml | 6 +-- .../setup_conda_environment.cmd | 1 - distributed/core.py | 6 +++ distributed/tests/test_client.py | 5 +- distributed/tests/test_nanny.py | 51 +++++++++---------- distributed/tests/test_security.py | 8 +-- distributed/utils_test.py | 2 +- docs/source/install.rst | 10 ++-- setup.py | 3 +- 10 files changed, 44 insertions(+), 51 deletions(-) diff --git a/.travis.yml b/.travis.yml index 1726cffd4f1..5d3cbf0ec0b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -6,8 +6,7 @@ dist: trusty env: matrix: - - PYTHON=3.5.4 TESTS=true COVERAGE=true PACKAGES="python-blosc lz4" CRICK=true - - PYTHON=3.6 TESTS=true PACKAGES="scikit-learn lz4" TORNADO=5 + - PYTHON=3.6 TESTS=true COVERAGE=true PACKAGES="scikit-learn lz4" TORNADO=5 CRICK=true - PYTHON=3.7 TESTS=true PACKAGES="scikit-learn python-snappy python-blosc" TORNADO=6 matrix: diff --git a/appveyor.yml b/appveyor.yml index 496640b3f30..e32c48f105a 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -11,14 +11,14 @@ environment: matrix: # Since appveyor is quite slow, we only use a single configuration - - PYTHON: "3.5" + - PYTHON: "3.6" ARCH: "64" CONDA_ENV: testenv init: # Use AppVeyor's provided Miniconda: https://www.appveyor.com/docs/installed-software#python - - if "%ARCH%" == "64" set MINICONDA=C:\Miniconda35-x64 - - if "%ARCH%" == "32" set MINICONDA=C:\Miniconda35 + - if "%ARCH%" == "64" set MINICONDA=C:\Miniconda36-x64 + - if "%ARCH%" == "32" set MINICONDA=C:\Miniconda36 - set PATH=%MINICONDA%;%MINICONDA%/Scripts;%MINICONDA%/Library/bin;%PATH% install: diff --git a/continuous_integration/setup_conda_environment.cmd b/continuous_integration/setup_conda_environment.cmd index 87e37751548..5efc7358dbe 100644 --- a/continuous_integration/setup_conda_environment.cmd +++ b/continuous_integration/setup_conda_environment.cmd @@ -24,7 +24,6 @@ call deactivate cloudpickle ^ dask ^ dill ^ - futures ^ lz4 ^ ipykernel ^ ipywidgets ^ diff --git a/distributed/core.py b/distributed/core.py index 716a7b035e2..66205dbfd85 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -697,6 +697,12 @@ def __enter__(self): def __exit__(self, *args): asyncio.ensure_future(self.close_rpc()) + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + await self.close_rpc() + def __del__(self): if self.status != "closed": rpc.active.discard(self) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 00abbcd25ec..3d582c4fe8f 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -3322,10 +3322,7 @@ def test_get_foo_lost_keys(c, s, u, v, w): @pytest.mark.slow @gen_cluster( - client=True, - Worker=Nanny, - worker_kwargs={"death_timeout": "500ms"}, - clean_kwargs={"threads": False, "processes": False}, + client=True, Worker=Nanny, clean_kwargs={"threads": False, "processes": False} ) def test_bad_tasks_fail(c, s, a, b): f = c.submit(sys.exit, 0) diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index 70497bf7909..fccfd2efde6 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -30,33 +30,30 @@ @gen_cluster(nthreads=[]) -def test_nanny(s): - n = yield Nanny(s.address, nthreads=2, loop=s.loop) - - with rpc(n.address) as nn: - assert n.is_alive() - assert s.nthreads[n.worker_address] == 2 - assert s.workers[n.worker_address].nanny == n.address - - yield nn.kill() - assert not n.is_alive() - assert n.worker_address not in s.nthreads - assert n.worker_address not in s.workers - - yield nn.kill() - assert not n.is_alive() - assert n.worker_address not in s.nthreads - assert n.worker_address not in s.workers - - yield nn.instantiate() - assert n.is_alive() - assert s.nthreads[n.worker_address] == 2 - assert s.workers[n.worker_address].nanny == n.address - - yield nn.terminate() - assert not n.is_alive() - - yield n.close() +async def test_nanny(s): + async with Nanny(s.address, nthreads=2, loop=s.loop) as n: + async with rpc(n.address) as nn: + assert n.is_alive() + assert s.nthreads[n.worker_address] == 2 + assert s.workers[n.worker_address].nanny == n.address + + await nn.kill() + assert not n.is_alive() + assert n.worker_address not in s.nthreads + assert n.worker_address not in s.workers + + await nn.kill() + assert not n.is_alive() + assert n.worker_address not in s.nthreads + assert n.worker_address not in s.workers + + await nn.instantiate() + assert n.is_alive() + assert s.nthreads[n.worker_address] == 2 + assert s.workers[n.worker_address].nanny == n.address + + await nn.terminate() + assert not n.is_alive() @gen_cluster(nthreads=[]) diff --git a/distributed/tests/test_security.py b/distributed/tests/test_security.py index 7496c037ae7..167abc762ae 100644 --- a/distributed/tests/test_security.py +++ b/distributed/tests/test_security.py @@ -196,9 +196,9 @@ def many_ciphers(ctx): basic_checks(ctx) if sys.version_info >= (3, 6): supported_ciphers = ctx.get_ciphers() - tls_12_ciphers = [c for c in supported_ciphers if c["protocol"] == "TLSv1.2"] + tls_12_ciphers = [c for c in supported_ciphers if "TLSv1.2" in c["description"]] assert len(tls_12_ciphers) == 1 - tls_13_ciphers = [c for c in supported_ciphers if c["protocol"] == "TLSv1.3"] + tls_13_ciphers = [c for c in supported_ciphers if "TLSv1.3" in c["description"]] if len(tls_13_ciphers): assert len(tls_13_ciphers) == 3 @@ -249,9 +249,9 @@ def many_ciphers(ctx): basic_checks(ctx) if sys.version_info >= (3, 6): supported_ciphers = ctx.get_ciphers() - tls_12_ciphers = [c for c in supported_ciphers if c["protocol"] == "TLSv1.2"] + tls_12_ciphers = [c for c in supported_ciphers if "TLSv1.2" in c["description"]] assert len(tls_12_ciphers) == 1 - tls_13_ciphers = [c for c in supported_ciphers if c["protocol"] == "TLSv1.3"] + tls_13_ciphers = [c for c in supported_ciphers if "TLSv1.3" in c["description"]] if len(tls_13_ciphers): assert len(tls_13_ciphers) == 3 diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 97cbe783318..81b196639ac 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -865,7 +865,7 @@ def test_foo(scheduler, worker1, worker2): nthreads = ncores worker_kwargs = merge( - {"memory_limit": system.MEMORY_LIMIT, "death_timeout": 5}, worker_kwargs + {"memory_limit": system.MEMORY_LIMIT, "death_timeout": 10}, worker_kwargs ) def _(func): diff --git a/docs/source/install.rst b/docs/source/install.rst index 2ca74bb3d5f..7cf4199eecd 100644 --- a/docs/source/install.rst +++ b/docs/source/install.rst @@ -35,13 +35,9 @@ Notes ----- **Note for Macports users:** There `is a known issue -`_. with python from macports that +`_. with Python from macports that makes executables be placed in a location that is not available by default. A simple solution is to extend the ``PATH`` environment variable to the location -where python from macports install the binaries:: +where Python from macports install the binaries. For example, for Python 3.6:: - $ export PATH=/opt/local/Library/Frameworks/Python.framework/Versions/3.5/bin:$PATH - - or - - $ export PATH=/opt/local/Library/Frameworks/Python.framework/Versions/2.7/bin:$PATH + $ export PATH=/opt/local/Library/Frameworks/Python.framework/Versions/3.6/bin:$PATH diff --git a/setup.py b/setup.py index 5d900199256..310d5322e98 100755 --- a/setup.py +++ b/setup.py @@ -26,7 +26,7 @@ url="https://distributed.dask.org", maintainer="Matthew Rocklin", maintainer_email="mrocklin@gmail.com", - python_requires=">=3.5", + python_requires=">=3.6", license="BSD", package_data={ "": ["templates/index.html", "template.html"], @@ -46,7 +46,6 @@ "License :: OSI Approved :: BSD License", "Operating System :: OS Independent", "Programming Language :: Python", - "Programming Language :: Python :: 3.5", "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", "Topic :: Scientific/Engineering", From bc04d7642c0439bf3757b34bd1eb802b9679ba12 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Tue, 5 Nov 2019 21:14:42 +0100 Subject: [PATCH 0524/1550] UCX: avoid double init after fork (#3178) --- distributed/comm/tests/test_comms.py | 2 +- distributed/comm/tests/test_ucx.py | 3 - distributed/comm/ucx.py | 122 +++++++++++++++------------ 3 files changed, 70 insertions(+), 57 deletions(-) diff --git a/distributed/comm/tests/test_comms.py b/distributed/comm/tests/test_comms.py index 620d4b89c94..301cb4f013f 100644 --- a/distributed/comm/tests/test_comms.py +++ b/distributed/comm/tests/test_comms.py @@ -527,7 +527,7 @@ def client_communicate(key, delay=0): @gen_test() def test_ucx_client_server(): pytest.importorskip("distributed.comm.ucx") - import ucp + ucp = pytest.importorskip("ucp") addr = ucp.get_address() yield check_client_server("ucx://" + addr) diff --git a/distributed/comm/tests/test_ucx.py b/distributed/comm/tests/test_ucx.py index 28348369899..e5f4a4ab79b 100644 --- a/distributed/comm/tests/test_ucx.py +++ b/distributed/comm/tests/test_ucx.py @@ -261,9 +261,6 @@ async def test_ping_pong_numba(): @pytest.mark.parametrize("processes", [True, False]) def test_ucx_localcluster(loop, processes): - if processes: - pytest.skip("Known bug, processes=True doesn't work currently") - with LocalCluster( protocol="ucx", dashboard_address=None, diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index fede1c91371..d6ab704ab0f 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -6,6 +6,11 @@ .. _UCX: https://github.com/openucx/ucx """ import logging +import concurrent +import os + +import dask +import numpy as np from .addressing import parse_host_port, unparse_host_port from .core import Comm, Connector, Listener, CommClosedError @@ -13,41 +18,48 @@ from .utils import ensure_concrete_host, to_frames, from_frames from ..utils import ensure_ip, get_ip, get_ipv6, nbytes, log_errors -from tornado.ioloop import IOLoop -import ucp -import numpy as np - -import os os.environ.setdefault("UCX_RNDV_SCHEME", "put_zcopy") os.environ.setdefault("UCX_MEMTYPE_CACHE", "n") -os.environ.setdefault("UCX_TLS", "tcp,sockcm,rc,cuda_copy,cuda_ipc") +os.environ.setdefault("UCX_TLS", "all") os.environ.setdefault("UCX_SOCKADDR_TLS_PRIORITY", "sockcm") logger = logging.getLogger(__name__) -MAX_MSG_LOG = 23 -# ---------------------------------------------------------------------------- -# Comm Interface -# ---------------------------------------------------------------------------- +# In order to avoid double init when forking/spawning new processes (multiprocess), +# we make sure only to import and initialize UCX once at first use. +ucp = None +cuda_array = None + + +def init_once(): + global ucp, cuda_array + if ucp is not None: + return + + import ucp as _ucp -# Let's find the function, `cuda_array`, to use when allocating new CUDA arrays -try: - import rmm + ucp = _ucp + options = dask.config.get("ucx", default={}) + ucp.init(options=options) - cuda_array = lambda n: rmm.device_array(n, dtype=np.uint8) -except ImportError: + # Find the function, `cuda_array()`, to use when allocating new CUDA arrays try: - import numba.cuda + import rmm - cuda_array = lambda n: numba.cuda.device_array((n,), dtype=np.uint8) + cuda_array = lambda n: rmm.device_array(n, dtype=np.uint8) except ImportError: + try: + import numba.cuda - def cuda_array(n): - raise RuntimeError( - "In order to send/recv CUDA arrays, Numba or RMM is required" - ) + cuda_array = lambda n: numba.cuda.device_array((n,), dtype=np.uint8) + except ImportError: + + def cuda_array(n): + raise RuntimeError( + "In order to send/recv CUDA arrays, Numba or RMM is required" + ) class UCX(Comm): @@ -84,9 +96,7 @@ class UCX(Comm): 4. Read all the data frames. """ - def __init__( - self, ep: ucp.Endpoint, local_addr: str, peer_addr: str, deserialize=True - ): + def __init__(self, ep, local_addr: str, peer_addr: str, deserialize=True): Comm.__init__(self) self._ep = ep if local_addr: @@ -115,26 +125,33 @@ async def write( with log_errors(): if self.closed(): raise CommClosedError("Endpoint is closed -- unable to send message") + try: + if serializers is None: + serializers = ("cuda", "dask", "pickle", "error") + # msg can also be a list of dicts when sending batched messages + frames = await to_frames( + msg, serializers=serializers, on_error=on_error + ) - if serializers is None: - serializers = ("cuda", "dask", "pickle", "error") - # msg can also be a list of dicts when sending batched messages - frames = await to_frames(msg, serializers=serializers, on_error=on_error) - - # Send meta data - await self.ep.send(np.array([len(frames)], dtype=np.uint64)) - await self.ep.send( - np.array( - [hasattr(f, "__cuda_array_interface__") for f in frames], - dtype=np.bool, + # Send meta data + await self.ep.send(np.array([len(frames)], dtype=np.uint64)) + await self.ep.send( + np.array( + [hasattr(f, "__cuda_array_interface__") for f in frames], + dtype=np.bool, + ) ) - ) - await self.ep.send(np.array([nbytes(f) for f in frames], dtype=np.uint64)) - # Send frames - for frame in frames: - if nbytes(frame) > 0: - await self.ep.send(frame) - return sum(map(nbytes, frames)) + await self.ep.send( + np.array([nbytes(f) for f in frames], dtype=np.uint64) + ) + # Send frames + for frame in frames: + if nbytes(frame) > 0: + await self.ep.send(frame) + return sum(map(nbytes, frames)) + except (ucp.exceptions.UCXBaseException): + self.abort() + raise CommClosedError("While writing, the connection was closed") async def read(self, deserializers=("cuda", "dask", "pickle", "error")): with log_errors(): @@ -152,12 +169,12 @@ async def read(self, deserializers=("cuda", "dask", "pickle", "error")): await self.ep.recv(is_cudas) sizes = np.empty(nframes[0], dtype=np.uint64) await self.ep.recv(sizes) - except (ucp.exceptions.UCXCanceled, ucp.exceptions.UCXCloseError): - if self._ep is not None and not self._ep.closed(): - await self._ep.shutdown() - self._ep.close() - self._ep = None - raise CommClosedError("While reading, the connection was canceled") + except ( + ucp.exceptions.UCXBaseException, + concurrent.futures._base.CancelledError, + ): + self.abort() + raise CommClosedError("While reading, the connection was closed") else: # Recv frames frames = [] @@ -181,15 +198,12 @@ async def read(self, deserializers=("cuda", "dask", "pickle", "error")): async def close(self): if self._ep is not None: - if not self._ep.closed(): - await self._ep.signal_shutdown() - self._ep.close() + await self._ep.close() self._ep = None def abort(self): if self._ep is not None: - logger.debug("Destroyed UCX endpoint") - IOLoop.current().add_callback(self._ep.signal_shutdown) + self._ep.abort() self._ep = None @property @@ -211,6 +225,7 @@ class UCXConnector(Connector): async def connect(self, address: str, deserialize=True, **connection_args) -> UCX: logger.debug("UCXConnector.connect: %s", address) ip, port = parse_host_port(address) + init_once() ep = await ucp.create_endpoint(ip, port) return self.comm_class( ep, @@ -256,6 +271,7 @@ async def serve_forever(client_ep): if self.comm_handler: await self.comm_handler(ucx) + init_once() self.ucp_server = ucp.create_listener(serve_forever, port=self._input_port) def stop(self): From 2a3d0072cd906369a6bc5a933094f4d29712783c Mon Sep 17 00:00:00 2001 From: "James A. Bednar" Date: Wed, 6 Nov 2019 17:56:50 -0600 Subject: [PATCH 0525/1550] Silenced warning when importing while offline (#3203) --- distributed/comm/inproc.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/distributed/comm/inproc.py b/distributed/comm/inproc.py index 5235b7535fd..e9bed986ea0 100644 --- a/distributed/comm/inproc.py +++ b/distributed/comm/inproc.py @@ -4,6 +4,7 @@ import os import threading import weakref +import warnings from tornado import locks from tornado.concurrent import Future @@ -31,7 +32,11 @@ class Manager(object): def __init__(self): self.listeners = weakref.WeakValueDictionary() self.addr_suffixes = itertools.count(1) - self.ip = get_ip() + with warnings.catch_warnings(): + # Avoid immediate warning for unreachable network + # (will still warn for other get_ip() calls when actually used) + warnings.simplefilter("ignore") + self.ip = get_ip() self.lock = threading.Lock() def add_listener(self, addr, listener): From 7b46b92e46358d13e8c9926d4d99f24bb2566fee Mon Sep 17 00:00:00 2001 From: IPetrik Date: Thu, 7 Nov 2019 11:25:19 -0800 Subject: [PATCH 0526/1550] Adds docs to Client methods for resources, actors, and traverse (#2851) --- distributed/client.py | 49 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) diff --git a/distributed/client.py b/distributed/client.py index 11aaacdf044..c7f921ee348 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -1429,6 +1429,15 @@ def submit( Higher priorities take precedence fifo_timeout: str timedelta (default '100ms') Allowed amount of time between calls to consider the same priority + resources: dict (defaults to {}) + Defines the `resources` this job requires on the worker; e.g. + ``{'GPU': 2}``. See :doc:`worker resources ` for details + on defining resources. + actor: bool (default False) + Whether this task should exist on the worker as a stateful actor. + See :doc:`actors` for additional details. + actors: bool (default False) + Alias for `actor` Examples -------- @@ -1531,6 +1540,9 @@ def map( workers: set, iterable of sets A set of worker hostnames on which computations may be performed. Leave empty to default to all workers (common case) + allow_other_workers: bool (defaults to False) + Used with `workers`. Indicates whether or not the computations + may be performed on workers that are not in the `workers` set(s). retries: int (default to 0) Number of allowed automatic retries if a task fails priority: Number @@ -1538,6 +1550,16 @@ def map( Higher priorities take precedence fifo_timeout: str timedelta (default '100ms') Allowed amount of time between calls to consider the same priority + resources: dict (defaults to {}) + Defines the `resources` each instance of this mapped task requires + on the worker; e.g. ``{'GPU': 2}``. See + :doc:`worker resources ` for details on defining + resources. + actor: bool (default False) + Whether these tasks should exist on the worker as stateful actors. + See :doc:`actors` for additional details. + actors: bool (default False) + Alias for `actor` **kwargs: dict Extra keywords to send to the function. Large values will be included explicitly in the task graph. @@ -2656,6 +2678,21 @@ def compute( Higher priorities take precedence fifo_timeout: timedelta str (defaults to '60s') Allowed amount of time between calls to consider the same priority + traverse: bool (defaults to True) + By default dask traverses builtin python collections looking for + dask objects passed to ``compute``. For large collections this can + be expensive. If none of the arguments contain any dask objects, + set ``traverse=False`` to avoid doing this traversal. + resources: dict (defaults to {}) + Defines the `resources` these tasks require on the worker. Can + specify global resources (``{'GPU': 2}``), or per-task resources + (``{'x': {'GPU': 1}, 'y': {'SSD': 4}}``), but not both. + See :doc:`worker resources ` for details on defining + resources. + actors: bool or dict (default None) + Whether these tasks should exist on the worker as stateful actors. + Specified on a global (True/False) or per-task (``{'x': True, + 'y': False}``) basis. See :doc:`actors` for additional details. **kwargs: Options to pass to the graph optimize calls @@ -2791,7 +2828,17 @@ def persist( Higher priorities take precedence fifo_timeout: timedelta str (defaults to '60s') Allowed amount of time between calls to consider the same priority - kwargs: + resources: dict (defaults to {}) + Defines the `resources` these tasks require on the worker. Can + specify global resources (``{'GPU': 2}``), or per-task resources + (``{'x': {'GPU': 1}, 'y': {'SSD': 4}}``), but not both. + See :doc:`worker resources ` for details on defining + resources. + actors: bool or dict (default None) + Whether these tasks should exist on the worker as stateful actors. + Specified on a global (True/False) or per-task (``{'x': True, + 'y': False}``) basis. See :doc:`actors` for additional details. + **kwargs: Options to pass to the graph optimize calls Returns From 0766d78327a4d09e0f5e015570264a451a6607fa Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 7 Nov 2019 11:25:49 -0800 Subject: [PATCH 0527/1550] Add failing test for concurrent scatter operations (#2244) --- distributed/tests/test_client.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 3d582c4fe8f..4d747d8f3ec 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -5436,6 +5436,15 @@ def test_tuple_keys(c, s, a, b): assert (yield future) == 3 +@gen_cluster(client=True) +def test_multiple_scatter(c, s, a, b): + for i in range(5): + x = c.scatter(1, direct=True) + + x = yield x + x = yield x + + @gen_cluster(client=True) def test_map_large_kwargs_in_graph(c, s, a, b): np = pytest.importorskip("numpy") From d575ea0bd0784da2368d486f07914ae669824a07 Mon Sep 17 00:00:00 2001 From: Dave Hirschfeld Date: Fri, 8 Nov 2019 05:57:56 +1000 Subject: [PATCH 0528/1550] Expand async docs (#2293) --- .gitignore | 1 + docs/source/asynchronous.rst | 18 ++++++++++++------ 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index 2d70b7ebd7f..86ee425adff 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,7 @@ continuous_integration/hdfs-initialized *.lock .#* .idea/ +.vscode/ .pytest_cache/ dask-worker-space/ .vscode/ diff --git a/docs/source/asynchronous.rst b/docs/source/asynchronous.rst index 1ffc8d4f5c3..a49788e1fbe 100644 --- a/docs/source/asynchronous.rst +++ b/docs/source/asynchronous.rst @@ -44,18 +44,24 @@ received information from the scheduler should now be ``await``'ed. result = await client.gather(future) -If you want to reuse the same client in asynchronous and synchronous -environments you can apply the ``asynchronous=True`` keyword at each method -call. + +If you want to use an asynchronous function with a synchronous ``Client`` +(one made without the ``asynchronous=True`` keyword) then you can apply the +``asynchronous=True`` keyword at each method call and use the ``Client.sync`` +function to run the asynchronous function: .. code-block:: python + from dask.distributed import Client + client = Client() # normal blocking client async def f(): - futures = client.map(func, L) - results = await client.gather(futures, asynchronous=True) - return results + future = client.submit(lambda x: x + 1, 10) + result = await client.gather(future, asynchronous=True) + return result + + client.sync(f) Python 2 Compatibility From 2f46ab18372e747dcf0179341ffcdaa5900aa522 Mon Sep 17 00:00:00 2001 From: "Richard (Rick) Zamora" Date: Fri, 8 Nov 2019 09:28:05 -0600 Subject: [PATCH 0529/1550] Adding PatchedDeviceArray to drop stride attribute for cupy<7.0 (#3198) --- distributed/protocol/cupy.py | 13 +++++++++++++ distributed/protocol/tests/test_cupy.py | 17 +++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/distributed/protocol/cupy.py b/distributed/protocol/cupy.py index d85f37d8a1e..d15c719359c 100644 --- a/distributed/protocol/cupy.py +++ b/distributed/protocol/cupy.py @@ -5,6 +5,17 @@ from .cuda import cuda_serialize, cuda_deserialize +class PatchedCudaArrayInterface(object): + # TODO: This class wont be necessary + # once Cupy<7.0 is no longer supported + def __init__(self, ary): + cai = ary.__cuda_array_interface__ + cai_cupy_vsn = cupy.ndarray(0).__cuda_array_interface__["version"] + if cai.get("strides") is None and cai_cupy_vsn < 2: + cai.pop("strides", None) + self.__cuda_array_interface__ = cai + + @cuda_serialize.register(cupy.ndarray) def serialize_cupy_ndarray(x): # Making sure `x` is behaving @@ -18,6 +29,8 @@ def serialize_cupy_ndarray(x): @cuda_deserialize.register(cupy.ndarray) def deserialize_cupy_array(header, frames): (frame,) = frames + if not isinstance(frame, cupy.ndarray): + frame = PatchedCudaArrayInterface(frame) arr = cupy.ndarray( header["shape"], dtype=header["typestr"], memptr=cupy.asarray(frame).data ) diff --git a/distributed/protocol/tests/test_cupy.py b/distributed/protocol/tests/test_cupy.py index 10335d14338..57d26ae679b 100644 --- a/distributed/protocol/tests/test_cupy.py +++ b/distributed/protocol/tests/test_cupy.py @@ -1,4 +1,5 @@ from distributed.protocol import serialize, deserialize +import pickle import pytest cupy = pytest.importorskip("cupy") @@ -12,3 +13,19 @@ def test_serialize_cupy(size, dtype): y = deserialize(header, frames, deserializers=("cuda", "dask", "pickle", "error")) assert (x == y).all() + + +@pytest.mark.parametrize("dtype", ["u1", "u4", "u8", "f4"]) +def test_serialize_cupy_from_numba(dtype): + numba = pytest.importorskip("numba") + np = pytest.importorskip("numpy") + + size = 10 + x_np = np.arange(size, dtype=dtype) + x = numba.cuda.to_device(x_np) + header, frames = serialize(x, serializers=("cuda", "dask", "pickle")) + header["type-serialized"] = pickle.dumps(cupy.ndarray) + + y = deserialize(header, frames, deserializers=("cuda", "dask", "pickle", "error")) + + assert (x_np == cupy.asnumpy(y)).all() From f9ae2243e203195dd4069bda577bd1b8edb5ebda Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 8 Nov 2019 13:07:19 -0800 Subject: [PATCH 0530/1550] bump version to 2.7.0 --- docs/source/changelog.rst | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index 41496953f66..3cd8e3d287d 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -1,6 +1,35 @@ Changelog ========= +2.7.0 - 2019-11-08 +------------------ + +This release drops support for Python 3.5 + +- Adds badges to README.rst [skip ci] (:pr:`3152`) `James Bourbeau`_ +- Don't overwrite `self.address` if it is present (:pr:`3153`) `Gil Forsyth`_ +- Remove outdated references to debug scheduler and worker bokeh pages. (:pr:`3160`) `darindf`_ +- Update CONTRIBUTING.md (:pr:`3159`) `Jacob Tomlinson`_ +- Add Prometheus metric for a worker's executing tasks count (:pr:`3163`) `darindf`_ +- Update Prometheus documentation (:pr:`3165`) `darindf`_ +- Fix Numba serialization when strides is None (:pr:`3166`) `Peter Andreas Entschev`_ +- Await cluster in Adaptive.recommendations (:pr:`3168`) `Simon Boothroyd`_ +- Support automatic TLS (:pr:`3164`) `Jim Crist`_ +- Avoid swamping high-memory workers with data requests (:pr:`3071`) `Tom Augspurger`_ +- Update UCX variables to use sockcm by default (:pr:`3177`) `Peter Andreas Entschev`_ +- Get protocol in Nanny/Worker from scheduler address (:pr:`3175`) `Peter Andreas Entschev`_ +- Add worker and tasks state for Prometheus data collection (:pr:`3174`) `darindf`_ +- Use async def functions for offload to/from_frames (:pr:`3171`) `Mads R. B. Kristensen`_ +- Subprocesses inherit the global dask config (:pr:`3192`) `Mads R. B. Kristensen`_ +- XFail test_open_close_many_workers (:pr:`3194`) `Matthew Rocklin`_ +- Drop Python 3.5 (:pr:`3179`) `James Bourbeau`_ +- UCX: avoid double init after fork (:pr:`3178`) `Mads R. B. Kristensen`_ +- Silence warning when importing while offline (:pr:`3203`) `James A. Bednar`_ +- Adds docs to Client methods for resources, actors, and traverse (:pr:`2851`) `IPetrik`_ +- Add test for concurrent scatter operations (:pr:`2244`) `Matthew Rocklin`_ +- Expand async docs (:pr:`2293`) `Dave Hirschfeld`_ +- Add PatchedDeviceArray to drop stride attribute for cupy<7.0 (:pr:`3198`) `Richard J Zamora`_ + 2.6.0 - 2019-10-15 ------------------ @@ -1332,3 +1361,7 @@ significantly without many new features. .. _`Philipp Rudiger`: https://github.com/philippjfr .. _`Jonathan De Troye`: https://github.com/detroyejr .. _`matthieubulte`: https://github.com/matthieubulte +.. _`darindf`: https://github.com/darindf +.. _`James A. Bednar`: https://github.com/jbednar +.. _`IPetrik`: https://github.com/IPetrik +.. _`Simon Boothroyd`: https://github.com/SimonBoothroyd From 763a649e3e272b657d394409e74584505510f064 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 8 Nov 2019 13:18:29 -0800 Subject: [PATCH 0531/1550] Add UCX config values (#3135) --- distributed/comm/tests/test_ucx.py | 52 ++++++++++++++++++++++-------- distributed/comm/ucx.py | 9 +++--- distributed/distributed.yaml | 2 ++ 3 files changed, 45 insertions(+), 18 deletions(-) diff --git a/distributed/comm/tests/test_ucx.py b/distributed/comm/tests/test_ucx.py index e5f4a4ab79b..4c2d2d0782e 100644 --- a/distributed/comm/tests/test_ucx.py +++ b/distributed/comm/tests/test_ucx.py @@ -3,19 +3,22 @@ ucp = pytest.importorskip("ucp") -from distributed import Client +from distributed import Client, Worker, Scheduler, wait from distributed.comm import ucx, listen, connect from distributed.comm.registry import backends, get_backend from distributed.comm import ucx, parse_address from distributed.protocol import to_serialize from distributed.deploy.local import LocalCluster from dask.dataframe.utils import assert_eq -from distributed.utils_test import gen_test, loop, inc # noqa: 401 +from distributed.utils_test import gen_test, loop, inc, cleanup # noqa: 401 from .test_comms import check_deserialize -HOST = ucp.get_address() +try: + HOST = ucp.get_address() +except Exception: + HOST = "127.0.0.1" def test_registered(): @@ -225,7 +228,7 @@ async def test_ping_pong_cupy(shape): ), ], ) -async def test_large_cupy(n): +async def test_large_cupy(n, cleanup): cupy = pytest.importorskip("cupy") com, serv_com = await get_comm_pair() @@ -242,7 +245,7 @@ async def test_large_cupy(n): @pytest.mark.asyncio -async def test_ping_pong_numba(): +async def test_ping_pong_numba(cleanup): np = pytest.importorskip("numpy") numba = pytest.importorskip("numba") import numba.cuda @@ -260,18 +263,19 @@ async def test_ping_pong_numba(): @pytest.mark.parametrize("processes", [True, False]) -def test_ucx_localcluster(loop, processes): - with LocalCluster( +@pytest.mark.asyncio +async def test_ucx_localcluster(processes, cleanup): + async with LocalCluster( protocol="ucx", dashboard_address=None, n_workers=2, threads_per_worker=1, processes=processes, - loop=loop, + asynchronous=True, ) as cluster: - with Client(cluster) as client: + async with Client(cluster, asynchronous=True) as client: x = client.submit(inc, 1) - x.result() + await x.result() assert x.key in cluster.scheduler.tasks if not processes: assert any(w.data == {x.key: 2} for w in cluster.workers.values()) @@ -280,9 +284,8 @@ def test_ucx_localcluster(loop, processes): @pytest.mark.slow @pytest.mark.asyncio -async def test_stress(): - import dask.array as da - from distributed import wait +async def test_stress(cleanup): + da = pytest.importorskip("dask.array") chunksize = "10 MB" @@ -300,3 +303,26 @@ async def test_stress(): x = x.rechunk((-1, chunksize)) x = x.persist() await wait(x) + + +@pytest.mark.asyncio +async def test_simple(cleanup): + async with Scheduler(protocol="ucx") as s: + async with Worker(s.address) as a: + async with Client(s.address, asynchronous=True) as c: + result = await c.submit(lambda x: x + 1, 10) + assert result == 11 + + +@pytest.mark.asyncio +async def test_transpose(cleanup): + da = pytest.importorskip("dask.array") + + async with Scheduler(protocol="ucx") as s: + async with Worker(s.address) as a, Worker(s.address) as b: + async with Client(s.address, asynchronous=True) as c: + x = da.ones((10000, 10000), chunks=(1000, 1000)).persist() + await x + + y = (x + x.T).sum() + await y diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index d6ab704ab0f..bcf9b19f412 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -7,7 +7,6 @@ """ import logging import concurrent -import os import dask import numpy as np @@ -18,11 +17,11 @@ from .utils import ensure_concrete_host, to_frames, from_frames from ..utils import ensure_ip, get_ip, get_ipv6, nbytes, log_errors +import dask +import ucp +import numpy as np -os.environ.setdefault("UCX_RNDV_SCHEME", "put_zcopy") -os.environ.setdefault("UCX_MEMTYPE_CACHE", "n") -os.environ.setdefault("UCX_TLS", "all") -os.environ.setdefault("UCX_SOCKADDR_TLS_PRIORITY", "sockcm") +ucp.init(options=dask.config.get("ucx"), env_takes_precedence=True) logger = logging.getLogger(__name__) diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index 48484be12a6..92b7c15e157 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -122,3 +122,5 @@ distributed: log-length: 10000 # default length of logs to keep in memory log-format: '%(name)s - %(levelname)s - %(message)s' pdb-on-err: False # enter debug mode on scheduling error + +ucx: {} From 690363b9515230e152d0507b8f785cd598249d68 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 8 Nov 2019 18:03:23 -0800 Subject: [PATCH 0532/1550] Relax test_MultiWorker (#3210) The repr of a class only updates periodically, so we need to relax a check a bit. --- distributed/deploy/tests/test_spec_cluster.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/distributed/deploy/tests/test_spec_cluster.py b/distributed/deploy/tests/test_spec_cluster.py index db78b66269e..19e162ca67b 100644 --- a/distributed/deploy/tests/test_spec_cluster.py +++ b/distributed/deploy/tests/test_spec_cluster.py @@ -430,7 +430,9 @@ async def test_MultiWorker(cleanup): assert len(cluster.worker_spec) == 2 await client.wait_for_workers(4) - assert "workers=4" in repr(cluster) + while "workers=4" not in repr(cluster): + await asyncio.sleep(0.1) + workers_line = re.search("(Workers.+)", cluster._widget_status()).group(1) assert re.match("Workers.*4", workers_line) From 559db6721e3f32d17606aaafce70eabdbf162198 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 8 Nov 2019 18:03:32 -0800 Subject: [PATCH 0533/1550] Avoid ucp.init at import time (#3211) --- distributed/comm/ucx.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index bcf9b19f412..0a4580eacaa 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -21,7 +21,6 @@ import ucp import numpy as np -ucp.init(options=dask.config.get("ucx"), env_takes_precedence=True) logger = logging.getLogger(__name__) @@ -40,8 +39,7 @@ def init_once(): import ucp as _ucp ucp = _ucp - options = dask.config.get("ucx", default={}) - ucp.init(options=options) + ucp.init(options=dask.config.get("ucx"), env_takes_precedence=True) # Find the function, `cuda_array()`, to use when allocating new CUDA arrays try: From dfc703eb3baa3dbfdc89f1f96fdba7db75964917 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sat, 9 Nov 2019 08:03:45 -0800 Subject: [PATCH 0534/1550] Clean up rpc to avoid intermittent test failure (#3215) --- distributed/tests/test_core.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index 99c07226c48..78fbd1211d7 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -255,7 +255,7 @@ async def check_rpc(listen_addr, rpc_addr=None, listen_args=None, connection_arg if rpc_addr is None: rpc_addr = server.address - with rpc(rpc_addr, connection_args=connection_args) as remote: + async with rpc(rpc_addr, connection_args=connection_args) as remote: response = await remote.ping() assert response == b"pong" assert remote.comms @@ -324,7 +324,7 @@ async def check_rpc_message_lifetime(*listen_args): await asyncio.sleep(0.01) assert time() < start + 1 - with rpc(server.address) as remote: + async with rpc(server.address) as remote: obj = CountedObject() res = await remote.echo(x=to_serialize(obj)) assert isinstance(res["result"], CountedObject) @@ -366,7 +366,7 @@ async def g(): server = Server({"ping": pingpong}) server.listen(listen_arg) - with rpc(server.address) as remote: + async with rpc(server.address) as remote: for i in range(10): await g() @@ -392,15 +392,14 @@ async def check_large_packets(listen_arg): server.listen(listen_arg) data = b"0" * int(200e6) # slightly more than 100MB - conn = rpc(server.address) - result = await conn.echo(x=data) - assert result == data + async with rpc(server.address) as conn: + result = await conn.echo(x=data) + assert result == data - d = {"x": data} - result = await conn.echo(x=d) - assert result == d + d = {"x": data} + result = await conn.echo(x=d) + assert result == d - conn.close_comms() server.stop() @@ -419,7 +418,7 @@ async def check_identity(listen_arg): server = Server({}) server.listen(listen_arg) - with rpc(server.address) as remote: + async with rpc(server.address) as remote: a = await remote.identity() b = await remote.identity() assert a["type"] == "Server" @@ -707,11 +706,11 @@ async def f(): server = Server({"echo": echo_serialize}) server.listen("tcp://") - with rpc(server.address, serializers=["msgpack"]) as r: + async with rpc(server.address, serializers=["msgpack"]) as r: with pytest.raises(TypeError): await r.echo(x=to_serialize(inc)) - with rpc(server.address, serializers=["msgpack", "pickle"]) as r: + async with rpc(server.address, serializers=["msgpack", "pickle"]) as r: result = await r.echo(x=to_serialize(inc)) assert result == {"result": inc} From 2936803ef9bd5b434eb6f260b23edd62476715c9 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sat, 9 Nov 2019 13:08:51 -0800 Subject: [PATCH 0535/1550] Respect protocol if given to Scheduler (#3212) Previously calling the following dask-scheduler --protocol ucx Without specifying an interface would result in a tcp:// protocol Now we are a bit more forceful about respecting a protocol if given. --- distributed/comm/addressing.py | 4 ++-- distributed/comm/tests/test_ucx.py | 7 +++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/distributed/comm/addressing.py b/distributed/comm/addressing.py index 35d5e1c3407..f0c18b9fbda 100644 --- a/distributed/comm/addressing.py +++ b/distributed/comm/addressing.py @@ -241,7 +241,7 @@ def address_from_user_args( else: addr = "" - if protocol and "://" not in addr: - addr = protocol.rstrip("://") + "://" + addr + if protocol: + addr = protocol.rstrip("://") + "://" + addr.split("://")[-1] return addr diff --git a/distributed/comm/tests/test_ucx.py b/distributed/comm/tests/test_ucx.py index 4c2d2d0782e..a9207e72e7a 100644 --- a/distributed/comm/tests/test_ucx.py +++ b/distributed/comm/tests/test_ucx.py @@ -326,3 +326,10 @@ async def test_transpose(cleanup): y = (x + x.T).sum() await y + + +@pytest.mark.asyncio +@pytest.mark.parametrize("port", [0, 1234]) +async def test_ucx_protocol(cleanup, port): + async with Scheduler(protocol="ucx", port=port) as s: + assert s.address.startswith("ucx://") From 86bd07a0c58daf3a8249d8231251ed9c9fba600c Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 11 Nov 2019 08:04:20 -0800 Subject: [PATCH 0536/1550] Use legend_field= keyword in bokeh plots (#3218) The old legend= keyword has been deprecated --- distributed/dashboard/components/scheduler.py | 40 +++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/distributed/dashboard/components/scheduler.py b/distributed/dashboard/components/scheduler.py index 5e94a034cb0..77f095bce0d 100644 --- a/distributed/dashboard/components/scheduler.py +++ b/distributed/dashboard/components/scheduler.py @@ -113,7 +113,7 @@ def __init__(self, scheduler, **kwargs): tools="", id="bk-occupancy-plot", x_axis_type="datetime", - **kwargs + **kwargs, ) rect = fig.rect( source=self.source, x="x", width="ms", y="y", height=1, color="color" @@ -193,7 +193,7 @@ def __init__(self, scheduler, **kwargs): name="processing_hist", y_axis_label="frequency", tools="", - **kwargs + **kwargs, ) self.root.xaxis.minor_tick_line_alpha = 0 @@ -236,7 +236,7 @@ def __init__(self, scheduler, **kwargs): id="bk-nbytes-histogram-plot", y_axis_label="frequency", tools="", - **kwargs + **kwargs, ) self.root.xaxis[0].formatter = NumeralTickFormatter(format="0.0 b") @@ -290,7 +290,7 @@ def __init__(self, scheduler, **kwargs): id="bk-bandwidth-type-plot", name="bandwidth_type_histogram", y_range=["a", "b"], - **kwargs + **kwargs, ) rect = fig.rect( source=self.source, @@ -364,7 +364,7 @@ def __init__(self, scheduler, **kwargs): name="bandwidth_worker_heatmap", x_range=["a", "b"], y_range=["a", "b"], - **kwargs + **kwargs, ) fig.xaxis.major_label_orientation = -math.pi / 12 rect = fig.rect( @@ -458,7 +458,7 @@ def __init__(self, scheduler, width=600, **kwargs): id="bk-nprocessing-plot", name="processing_hist", width=int(width / 2), - **kwargs + **kwargs, ) rect = processing.rect( source=self.source, @@ -477,7 +477,7 @@ def __init__(self, scheduler, width=600, **kwargs): id="bk-nbytes-worker-plot", width=int(width / 2), name="nbytes_hist", - **kwargs + **kwargs, ) rect = nbytes.rect( source=self.source, @@ -495,7 +495,7 @@ def __init__(self, scheduler, width=600, **kwargs): id="bk-cpu-worker-plot", width=int(width / 2), name="cpu_hist", - **kwargs + **kwargs, ) rect = cpu.rect( source=self.source, @@ -637,7 +637,7 @@ def __init__(self, scheduler, **kwargs): height=150, tools="", x_range=x_range, - **kwargs + **kwargs, ) fig.line(source=self.source, x="time", y="idle", color="red") fig.line(source=self.source, x="time", y="saturated", color="green") @@ -691,7 +691,7 @@ def __init__(self, scheduler, **kwargs): height=250, tools="", x_range=x_range, - **kwargs + **kwargs, ) fig.circle( @@ -787,7 +787,7 @@ def __init__(self, scheduler, name, height=150, **kwargs): height=height, tools="", x_range=x_range, - **kwargs + **kwargs, ) fig.circle( @@ -797,7 +797,7 @@ def __init__(self, scheduler, name, height=150, **kwargs): color="color", size=50, alpha=0.5, - legend="action", + **{"legend_field" if BOKEH_VERSION >= "1.4" else "legend": "action"}, ) fig.yaxis.axis_label = "Action" fig.legend.location = "top_left" @@ -958,7 +958,7 @@ def task_stream_figure(clear_interval="20s", **kwargs): x_axis_type="datetime", min_border_right=35, tools="", - **kwargs + **kwargs, ) rect = root.rect( @@ -1058,7 +1058,7 @@ def __init__(self, scheduler, **kwargs): color=node_colors, source=self.node_source, view=node_view, - legend="state", + **{"legend_field" if BOKEH_VERSION >= "1.4" else "legend": "state"}, ) self.root.xgrid.grid_line_color = None self.root.ygrid.grid_line_color = None @@ -1202,7 +1202,7 @@ def __init__(self, scheduler, **kwargs): y_range=y_range, toolbar_location=None, tools="", - **kwargs + **kwargs, ) self.root.line( # just to define early ranges x=[0, 0.9], y=[-1, 0], line_color="#FFFFFF", alpha=0.0 @@ -1368,7 +1368,7 @@ def __init__(self, scheduler, **kwargs): y_range=DataRange1d(), toolbar_location=None, outline_line_color=None, - **kwargs + **kwargs, ) self.root.add_glyph( @@ -1492,7 +1492,7 @@ def __init__(self, scheduler, width=800, **kwargs): reorderable=True, sortable=True, width=width, - **dt_kwargs + **dt_kwargs, ) for name in table_names: @@ -1511,7 +1511,7 @@ def __init__(self, scheduler, width=800, **kwargs): reorderable=True, sortable=True, width=width, - **dt_kwargs + **dt_kwargs, ) hover = HoverTool( @@ -1532,7 +1532,7 @@ def __init__(self, scheduler, width=800, **kwargs): height=60, width=width, tools="", - **kwargs + **kwargs, ) mem_plot.circle( source=self.source, x="memory_percent", y=0, size=10, fill_alpha=0.5 @@ -1561,7 +1561,7 @@ def __init__(self, scheduler, width=800, **kwargs): height=60, width=width, tools="", - **kwargs + **kwargs, ) cpu_plot.circle( source=self.source, x="cpu_fraction", y=0, size=10, fill_alpha=0.5 From 2a25c5f3f16bf238787f8ca223442fdb79f89e0f Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 11 Nov 2019 08:04:39 -0800 Subject: [PATCH 0537/1550] Cache psutil.Process object in Nanny (#3207) --- distributed/nanny.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/distributed/nanny.py b/distributed/nanny.py index 6e58271c33a..39bffbd80c0 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -349,6 +349,19 @@ async def _(): else: return "OK" + @property + def _psutil_process(self): + pid = self.process.process.pid + try: + proc = self._psutil_process_obj + except AttributeError: + self._psutil_process_obj = psutil.Process(pid) + + if self._psutil_process_obj.pid != pid: + self._psutil_process_obj = psutil.Process(pid) + + return self._psutil_process_obj + def memory_monitor(self): """ Track worker's memory. Restart if it goes above terminate fraction """ if self.status != "running": @@ -357,7 +370,7 @@ def memory_monitor(self): if process is None: return try: - proc = psutil.Process(process.pid) + proc = self._psutil_process memory = proc.memory_info().rss except (ProcessLookupError, psutil.NoSuchProcess, psutil.AccessDenied): return From 81df14f144d5358ed9d37a065c5f5220280a6f78 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 11 Nov 2019 09:23:12 -0800 Subject: [PATCH 0538/1550] Replace gen.sleep with asyncio.sleep (#3208) --- distributed/client.py | 16 ++++++++-------- distributed/comm/core.py | 3 ++- distributed/comm/tests/test_comms.py | 4 ++-- distributed/core.py | 6 +++--- distributed/diagnostics/progress.py | 6 +++--- distributed/nanny.py | 5 +++-- distributed/scheduler.py | 8 ++++---- distributed/utils_test.py | 10 +++++----- distributed/worker.py | 6 +++--- 9 files changed, 33 insertions(+), 31 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index c7f921ee348..6093bc9ec03 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -942,7 +942,7 @@ async def _start(self, timeout=no_default, **kwargs): address = self.cluster.scheduler_address elif self.scheduler_file is not None: while not os.path.exists(self.scheduler_file): - await gen.sleep(0.01) + await asyncio.sleep(0.01) for i in range(10): try: with open(self.scheduler_file) as f: @@ -950,7 +950,7 @@ async def _start(self, timeout=no_default, **kwargs): address = cfg["address"] break except (ValueError, KeyError): # JSON file not yet flushed - await gen.sleep(0.01) + await asyncio.sleep(0.01) elif self._start_arg is None: from .deploy import LocalCluster @@ -976,7 +976,7 @@ async def _start(self, timeout=no_default, **kwargs): while not self.cluster.workers or len(self.cluster.scheduler.workers) < len( self.cluster.workers ): - await gen.sleep(0.01) + await asyncio.sleep(0.01) address = self.cluster.scheduler_address @@ -1017,7 +1017,7 @@ async def _reconnect(self): break except EnvironmentError: # Wait a bit before retrying - await gen.sleep(0.1) + await asyncio.sleep(0.1) timeout = deadline - self.loop.time() else: logger.error( @@ -1092,7 +1092,7 @@ async def _update_scheduler_info(self): async def _wait_for_workers(self, n_workers=0): info = await self.scheduler.identity() while n_workers and len(info["workers"]) < n_workers: - await gen.sleep(0.1) + await asyncio.sleep(0.1) info = await self.scheduler.identity() def wait_for_workers(self, n_workers=0): @@ -1946,7 +1946,7 @@ async def _scatter( start = time() while not nthreads: if nthreads is not None: - await gen.sleep(0.1) + await asyncio.sleep(0.1) if time() > start + timeout: raise gen.TimeoutError("No valid workers found") nthreads = await self.scheduler.ncores(workers=workers) @@ -2280,7 +2280,7 @@ def run_on_scheduler(self, function, *args, **kwargs): >>> async def print_state(dask_scheduler): # doctest: +SKIP ... while True: ... print(dask_scheduler.status) - ... await gen.sleep(1) + ... await asyncio.sleep(1) >>> c.run(print_state, wait=False) # doctest: +SKIP @@ -2370,7 +2370,7 @@ def run(self, function, *args, **kwargs): >>> async def print_state(dask_worker): # doctest: +SKIP ... while True: ... print(dask_worker.status) - ... await gen.sleep(1) + ... await asyncio.sleep(1) >>> c.run(print_state, wait=False) # doctest: +SKIP """ diff --git a/distributed/comm/core.py b/distributed/comm/core.py index 256a17de3a5..39a8b123cd3 100644 --- a/distributed/comm/core.py +++ b/distributed/comm/core.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod, abstractproperty +import asyncio from datetime import timedelta import logging import weakref @@ -224,7 +225,7 @@ def _raise(error): except EnvironmentError as e: error = str(e) if time() < deadline: - await gen.sleep(0.01) + await asyncio.sleep(0.01) logger.debug("sleeping on connect") else: _raise(error) diff --git a/distributed/comm/tests/test_comms.py b/distributed/comm/tests/test_comms.py index 301cb4f013f..5839c3e8871 100644 --- a/distributed/comm/tests/test_comms.py +++ b/distributed/comm/tests/test_comms.py @@ -683,7 +683,7 @@ def handle_comm(comm): # Sanity check comm = yield connect( - listener.contact_address, timeout=0.5, connection_args={"ssl_context": cli_ctx} + listener.contact_address, timeout=2, connection_args={"ssl_context": cli_ctx} ) yield comm.close() @@ -696,7 +696,7 @@ def handle_comm(comm): with pytest.raises(EnvironmentError) as excinfo: yield connect( listener.contact_address, - timeout=0.5, + timeout=2, connection_args={"ssl_context": cli_ctx}, ) # The wrong error is reported on Python 2, see https://github.com/tornadoweb/tornado/pull/2028 diff --git a/distributed/core.py b/distributed/core.py index 66205dbfd85..d096e0b2274 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -464,7 +464,7 @@ async def handle_stream(self, comm, extra=None, every_cycle=[]): handler(**merge(extra, msg)) else: logger.error("odd message %s", msg) - await gen.sleep(0) + await asyncio.sleep(0) for func in every_cycle: func() @@ -492,7 +492,7 @@ def close(self): if not self._comms: break else: - yield gen.sleep(0.05) + yield asyncio.sleep(0.05) yield [comm.close() for comm in self._comms] # then forcefully close for cb in self._ongoing_coroutines: cb.cancel() @@ -500,7 +500,7 @@ def close(self): if all(cb.cancelled() for c in self._ongoing_coroutines): break else: - yield gen.sleep(0.01) + yield asyncio.sleep(0.01) self._event_finished.set() diff --git a/distributed/diagnostics/progress.py b/distributed/diagnostics/progress.py index 4136fd17a5c..48f26570980 100644 --- a/distributed/diagnostics/progress.py +++ b/distributed/diagnostics/progress.py @@ -1,9 +1,9 @@ +import asyncio from collections import defaultdict import logging from timeit import default_timer from toolz import groupby, valmap -from tornado import gen from .plugin import SchedulerPlugin from ..utils import key_split, key_split_group, log_errors, tokey @@ -76,7 +76,7 @@ async def setup(self): keys = self.keys while not keys.issubset(self.scheduler.tasks): - await gen.sleep(0.05) + await asyncio.sleep(0.05) tasks = [self.scheduler.tasks[k] for k in keys] @@ -164,7 +164,7 @@ async def setup(self): keys = self.keys while not keys.issubset(self.scheduler.tasks): - await gen.sleep(0.05) + await asyncio.sleep(0.05) tasks = [self.scheduler.tasks[k] for k in keys] diff --git a/distributed/nanny.py b/distributed/nanny.py index 39bffbd80c0..58491da154c 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -1,3 +1,4 @@ +import asyncio from datetime import timedelta import logging from multiprocessing.queues import Empty @@ -608,7 +609,7 @@ async def kill(self, timeout=2, executor_wait=True): self.child_stop_q.close() while process.is_alive() and loop.time() < deadline: - await gen.sleep(0.05) + await asyncio.sleep(0.05) if process.is_alive(): logger.warning( @@ -627,7 +628,7 @@ async def _wait_until_connected(self, uid): try: msg = self.init_result_q.get_nowait() except Empty: - await gen.sleep(delay) + await asyncio.sleep(delay) continue if msg["uid"] != uid: # ensure that we didn't cross queues diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 4319584735e..de04804370a 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1271,7 +1271,7 @@ async def close(self, comm=None, fast=False, close_workers=False): self.worker_send(worker, {"op": "close"}) for i in range(20): # wait a second for send signals to clear if self.workers: - await gen.sleep(0.05) + await asyncio.sleep(0.05) else: break @@ -2494,7 +2494,7 @@ async def scatter( """ start = time() while not self.workers: - await gen.sleep(0.2) + await asyncio.sleep(0.2) if time() > start + timeout: raise gen.TimeoutError("No workers found") @@ -2649,7 +2649,7 @@ async def restart(self, client=None, timeout=3): self.log_event([client, "all"], {"action": "restart", "client": client}) start = time() while time() < start + 10 and len(self.workers) < n_workers: - await gen.sleep(0.01) + await asyncio.sleep(0.01) self.report({"op": "restart"}) @@ -3292,7 +3292,7 @@ async def feed( else: response = function(self, state) await comm.write(response) - await gen.sleep(interval) + await asyncio.sleep(interval) except (EnvironmentError, CommClosedError): pass finally: diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 81b196639ac..aef7bde8eee 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -394,7 +394,7 @@ def apply(func, *args, **kwargs): async def geninc(x, delay=0.02): - await gen.sleep(delay) + await asyncio.sleep(delay) return x + 1 @@ -410,7 +410,7 @@ def compile_snippet(code, dedent=True): compile_snippet( """ async def asyncinc(x, delay=0.02): - await gen.sleep(delay) + await asyncio.sleep(delay) return x + 1 """ ) @@ -813,7 +813,7 @@ async def start_cluster( while len(s.workers) < len(nthreads) or any( comm.comm is None for comm in s.stream_comms.values() ): - await gen.sleep(0.01) + await asyncio.sleep(0.01) if time() - start > 5: await asyncio.gather(*[w.close(timeout=1) for w in workers]) await s.close(fast=True) @@ -939,7 +939,7 @@ async def coro(): if all(c.closed() for c in Comm._instances): break else: - await gen.sleep(0.05) + await asyncio.sleep(0.05) else: L = [c for c in Comm._instances if not c.closed()] Comm._instances.clear() @@ -1063,7 +1063,7 @@ def wait_for(predicate, timeout, fail_func=None, period=0.001): async def async_wait_for(predicate, timeout, fail_func=None, period=0.001): deadline = time() + timeout while not predicate(): - await gen.sleep(period) + await asyncio.sleep(period) if time() > deadline: if fail_func is not None: fail_func() diff --git a/distributed/worker.py b/distributed/worker.py index d1c35f68eb5..c0836077a6c 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -829,7 +829,7 @@ async def _register_with_scheduler(self): break except EnvironmentError: logger.info("Waiting to connect to: %26s", self.scheduler.address) - await gen.sleep(0.1) + await asyncio.sleep(0.1) except gen.TimeoutError: logger.info("Timed out when connecting to scheduler") if response["status"] != "OK": @@ -1997,7 +1997,7 @@ async def gather_dep(self, worker, dep, deps, total_nbytes, cause=None): else: # Exponential backoff to avoid hammering scheduler/worker self.repetitively_busy += 1 - await gen.sleep(0.100 * 1.5 ** self.repetitively_busy) + await asyncio.sleep(0.100 * 1.5 ** self.repetitively_busy) # See if anyone new has the data await self.query_who_has(dep) @@ -2586,7 +2586,7 @@ async def memory_monitor(self): del k, v total += weight count += 1 - await gen.sleep(0) + await asyncio.sleep(0) memory = proc.memory_info().rss if total > need and memory > target: # Issue a GC to ensure that the evicted data is actually From 8572887824cf438817e5e85ca7460f5f1d5870ac Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 11 Nov 2019 12:49:27 -0800 Subject: [PATCH 0539/1550] Avoid offloading serialization for small messages (#3224) This was causing a fair amount of either cost or noise when profiling. Fixes https://github.com/dask/distributed/issues/3223 --- distributed/comm/utils.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/distributed/comm/utils.py b/distributed/comm/utils.py index 80e1f163785..5b15d5c798c 100644 --- a/distributed/comm/utils.py +++ b/distributed/comm/utils.py @@ -1,6 +1,8 @@ import logging import socket +from dask.sizeof import sizeof + from .. import protocol from ..utils import get_ip, get_ipv6, nbytes, offload @@ -31,9 +33,10 @@ def _to_frames(): logger.exception(e) raise - res = await offload(_to_frames) - - return res + if sizeof(msg) > FRAME_OFFLOAD_THRESHOLD: + return await offload(_to_frames) + else: + return _to_frames() async def from_frames(frames, deserialize=True, deserializers=None): From aff524ba2286e614ed9d5cc259bcab5cc921cc4a Mon Sep 17 00:00:00 2001 From: Gabriel Sailer Date: Mon, 11 Nov 2019 23:09:44 +0100 Subject: [PATCH 0540/1550] Add desired_workers metric (#3221) --- distributed/dashboard/scheduler.py | 8 ++++++++ distributed/scheduler.py | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/distributed/dashboard/scheduler.py b/distributed/dashboard/scheduler.py index 8928d468c5e..ecc413b5f0d 100644 --- a/distributed/dashboard/scheduler.py +++ b/distributed/dashboard/scheduler.py @@ -171,6 +171,7 @@ def get(self): released = 0 waiting = 0 waiting_data = 0 + desired_workers = scheduler.adaptive_target() for ts in scheduler.tasks.values(): if ts.exception_blame is not None: @@ -203,6 +204,7 @@ def get(self): "waiting": waiting, "waiting_data": waiting_data, "workers": len(scheduler.workers), + "desired_workers": desired_workers, } self.write(response) @@ -245,6 +247,12 @@ def collect(self): value=len(self.server.clients), ) + yield GaugeMetricFamily( + "dask_scheduler_desired_workers", + "Number of workers scheduler needs for task graph.", + value=self.server.adaptive_target(), + ) + tasks = GaugeMetricFamily( "dask_scheduler_workers", "Number of workers known by scheduler.", diff --git a/distributed/scheduler.py b/distributed/scheduler.py index de04804370a..6724524344d 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -3062,7 +3062,7 @@ def _key(group): result = [getattr(ws, attribute) for g in to_close for ws in groups[g]] if result: - logger.info("Suggest closing workers: %s", result) + logger.debug("Suggest closing workers: %s", result) return result From 43b2ed77218861c34dd2c98d88ec5153bd35f5c4 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 12 Nov 2019 11:57:43 -0800 Subject: [PATCH 0541/1550] Fail fast when importing distributed.comm.ucx (#3228) Fixes https://github.com/dask/dask/issues/5572 If we don't have UCX installed then we shouldn't try anything further. In this commit we move the `import ucp` line to the top of the file. --- distributed/comm/ucx.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index 0a4580eacaa..58b16eaaf7f 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -5,6 +5,8 @@ .. _UCX: https://github.com/openucx/ucx """ +import ucp + import logging import concurrent @@ -18,7 +20,6 @@ from ..utils import ensure_ip, get_ip, get_ipv6, nbytes, log_errors import dask -import ucp import numpy as np From 2060805bbe3074e050ef7ac81c485201caf0618e Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 13 Nov 2019 07:29:57 -0800 Subject: [PATCH 0542/1550] Add module name to Future repr (#3231) --- distributed/client.py | 12 ++++-------- distributed/tests/test_client.py | 11 ++++++++++- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 6093bc9ec03..6f5bf786367 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -383,16 +383,12 @@ def __del__(self): def __repr__(self): if self.type: try: - typ = self.type.__name__ + typ = self.type.__module__.split(".")[0] + "." + self.type.__name__ except AttributeError: typ = str(self.type) - return "" % ( - self.status, - typ, - self.key, - ) + return "" % (self.status, typ, self.key) else: - return "" % (self.status, self.key) + return "" % (self.status, self.key) def _repr_html_(self): text = "Future: %s " % html.escape(key_split(self.key)) @@ -405,7 +401,7 @@ def _repr_html_(self): } if self.type: try: - typ = self.type.__name__ + typ = self.type.__module__.split(".")[0] + "." + self.type.__name__ except AttributeError: typ = str(self.type) text += 'type: %s, ' % typ diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 4d747d8f3ec..9f8e48298fa 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -334,13 +334,22 @@ def test_retries_dask_array(c, s, a, b): @gen_cluster(client=True) -def test_future_repr(c, s, a, b): +async def test_future_repr(c, s, a, b): + pd = pytest.importorskip("pandas") x = c.submit(inc, 10) + y = c.submit(pd.DataFrame, {"x": [1, 2, 3]}) + await x + await y + for func in [repr, lambda x: x._repr_html_()]: assert str(x.key) in func(x) assert str(x.status) in func(x) assert str(x.status) in repr(c.futures[x.key]) + assert "int" in func(x) + assert "pandas" in func(y) + assert "DataFrame" in func(y) + @gen_cluster(client=True) def test_future_tuple_repr(c, s, a, b): From e0fe1caa9b44919f1c59383d04e593bfc3f9ecda Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 14 Nov 2019 07:59:02 -0800 Subject: [PATCH 0543/1550] Add name to Pub/Sub repr (#3235) --- distributed/pubsub.py | 10 ++++++++++ distributed/tests/test_pubsub.py | 10 ++++++++++ 2 files changed, 20 insertions(+) diff --git a/distributed/pubsub.py b/distributed/pubsub.py index 0a4053191eb..5ed631c46ec 100644 --- a/distributed/pubsub.py +++ b/distributed/pubsub.py @@ -343,6 +343,11 @@ def put(self, msg): """ Publish a message to all subscribers of this topic """ self.loop.add_callback(self._put, msg) + def __repr__(self): + return "".format(self.name) + + __str__ = __repr__ + class Sub(object): """ Subscribe to a Publish/Subscribe topic @@ -426,3 +431,8 @@ def __aiter__(self): def _put(self, msg): self.buffer.append(msg) self.condition.notify() + + def __repr__(self): + return "".format(self.name) + + __str__ = __repr__ diff --git a/distributed/tests/test_pubsub.py b/distributed/tests/test_pubsub.py index 9d2b30dab6f..555afb71a73 100644 --- a/distributed/tests/test_pubsub.py +++ b/distributed/tests/test_pubsub.py @@ -125,5 +125,15 @@ def test_timeouts(c, s, a, b): assert stop - start < 1 +@gen_cluster(client=True) +async def test_repr(c, s, a, b): + pub = Pub("my-topic") + sub = Sub("my-topic") + assert "my-topic" in str(pub) + assert "Pub" in str(pub) + assert "my-topic" in str(sub) + assert "Sub" in str(sub) + + if sys.version_info >= (3, 5): from distributed.tests.py3_test_pubsub import * # noqa: F401, F403 From ddbec38ba1ec6de913ccbfcd090f1c85eea1b032 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Thu, 14 Nov 2019 10:14:10 -0600 Subject: [PATCH 0544/1550] Import CPU_COUNT from dask.system (#3199) --- distributed/cli/dask_worker.py | 2 +- distributed/deploy/local.py | 10 +++--- distributed/deploy/tests/test_local.py | 3 +- distributed/nanny.py | 2 +- distributed/system.py | 44 +---------------------- distributed/tests/test_system.py | 48 +------------------------- distributed/tests/test_worker.py | 6 ++-- distributed/worker.py | 5 +-- 8 files changed, 17 insertions(+), 103 deletions(-) diff --git a/distributed/cli/dask_worker.py b/distributed/cli/dask_worker.py index 0f307398e04..9070024c430 100755 --- a/distributed/cli/dask_worker.py +++ b/distributed/cli/dask_worker.py @@ -10,9 +10,9 @@ import click import dask from dask.utils import ignoring +from dask.system import CPU_COUNT from distributed import Nanny, Worker from distributed.security import Security -from distributed.system import CPU_COUNT from distributed.cli.utils import check_python_3, install_signal_handlers from distributed.comm import get_address_host_port from distributed.preloading import validate_preload_argv diff --git a/distributed/deploy/local.py b/distributed/deploy/local.py index fd1430baa21..8eb55c54997 100644 --- a/distributed/deploy/local.py +++ b/distributed/deploy/local.py @@ -5,8 +5,8 @@ import weakref from dask.utils import factors +from dask.system import CPU_COUNT -from .. import system from .spec import SpecCluster from ..nanny import Nanny from ..scheduler import Scheduler @@ -157,12 +157,12 @@ def __init__( n_workers, threads_per_worker = nprocesses_nthreads() else: n_workers = 1 - threads_per_worker = system.CPU_COUNT + threads_per_worker = CPU_COUNT if n_workers is None and threads_per_worker is not None: - n_workers = max(1, system.CPU_COUNT // threads_per_worker) + n_workers = max(1, CPU_COUNT // threads_per_worker) if n_workers and threads_per_worker is None: # Overcommit threads per worker, rather than undercommit - threads_per_worker = max(1, int(math.ceil(system.CPU_COUNT / n_workers))) + threads_per_worker = max(1, int(math.ceil(CPU_COUNT / n_workers))) if n_workers and "memory_limit" not in worker_kwargs: worker_kwargs["memory_limit"] = parse_memory_limit("auto", 1, n_workers) @@ -217,7 +217,7 @@ def start_worker(self, *args, **kwargs): ) -def nprocesses_nthreads(n=system.CPU_COUNT): +def nprocesses_nthreads(n=CPU_COUNT): """ The default breakdown of processes and threads for a given number of cores diff --git a/distributed/deploy/tests/test_local.py b/distributed/deploy/tests/test_local.py index e73ccd4721f..9b39d8f81f5 100644 --- a/distributed/deploy/tests/test_local.py +++ b/distributed/deploy/tests/test_local.py @@ -12,10 +12,11 @@ from tornado import gen import pytest +from dask.system import CPU_COUNT from distributed import Client, Worker, Nanny, get_client from distributed.deploy.local import LocalCluster, nprocesses_nthreads from distributed.metrics import time -from distributed.system import CPU_COUNT, MEMORY_LIMIT +from distributed.system import MEMORY_LIMIT from distributed.utils_test import ( # noqa: F401 clean, cleanup, diff --git a/distributed/nanny.py b/distributed/nanny.py index 58491da154c..11cf0157c10 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -11,6 +11,7 @@ import weakref import dask +from dask.system import CPU_COUNT from tornado import gen from tornado.ioloop import IOLoop, TimeoutError from tornado.locks import Event @@ -23,7 +24,6 @@ from .process import AsyncProcess from .proctitle import enable_proctitle_on_children from .security import Security -from .system import CPU_COUNT from .utils import ( get_ip, mp_context, diff --git a/distributed/system.py b/distributed/system.py index 291248ddded..2b032a34024 100644 --- a/distributed/system.py +++ b/distributed/system.py @@ -1,10 +1,8 @@ -import math -import os import sys import psutil -__all__ = ("memory_limit", "cpu_count", "MEMORY_LIMIT", "CPU_COUNT") +__all__ = ("memory_limit", "MEMORY_LIMIT") def memory_limit(): @@ -41,44 +39,4 @@ def memory_limit(): return limit -def cpu_count(): - """Get the available CPU count for this system. - - Takes the minimum value from the following locations: - - - Total system cpus available on the host. - - CPU Affinity (if set) - - Cgroups limit (if set) - """ - count = os.cpu_count() - - # Check CPU affinity if available - try: - affinity_count = len(psutil.Process().cpu_affinity()) - if affinity_count > 0: - count = min(count, affinity_count) - except Exception: - pass - - # Check cgroups if available - if sys.platform == "linux": - # The directory name isn't standardized across linux distros, check both - for dirname in ["cpuacct,cpu", "cpu,cpuacct"]: - try: - with open("/sys/fs/cgroup/%s/cpu.cfs_quota_us" % dirname) as f: - quota = int(f.read()) - with open("/sys/fs/cgroup/%s/cpu.cfs_period_us" % dirname) as f: - period = int(f.read()) - # We round up on fractional CPUs - cgroups_count = math.ceil(quota / period) - if cgroups_count > 0: - count = min(count, cgroups_count) - break - except Exception: - pass - - return count - - MEMORY_LIMIT = memory_limit() -CPU_COUNT = cpu_count() diff --git a/distributed/tests/test_system.py b/distributed/tests/test_system.py index d276613b520..3d44efe781d 100644 --- a/distributed/tests/test_system.py +++ b/distributed/tests/test_system.py @@ -1,57 +1,11 @@ import builtins import io -import os import sys import psutil import pytest -from distributed.system import cpu_count, memory_limit - - -def test_cpu_count(): - count = cpu_count() - assert isinstance(count, int) - assert count <= os.cpu_count() - assert count >= 1 - - -@pytest.mark.parametrize("dirname", ["cpuacct,cpu", "cpu,cpuacct", None]) -def test_cpu_count_cgroups(dirname, monkeypatch): - def mycpu_count(): - # Absurdly high, unlikely to match real value - return 250 - - monkeypatch.setattr(os, "cpu_count", mycpu_count) - - class MyProcess(object): - def cpu_affinity(self): - # No affinity set - return [] - - monkeypatch.setattr(psutil, "Process", MyProcess) - - if dirname: - paths = { - "/sys/fs/cgroup/%s/cpu.cfs_quota_us" % dirname: io.StringIO("2005"), - "/sys/fs/cgroup/%s/cpu.cfs_period_us" % dirname: io.StringIO("10"), - } - builtin_open = builtins.open - - def myopen(path, *args, **kwargs): - if path in paths: - return paths.get(path) - return builtin_open(path, *args, **kwargs) - - monkeypatch.setattr(builtins, "open", myopen) - monkeypatch.setattr(sys, "platform", "linux") - - count = cpu_count() - if dirname: - # Rounds up - assert count == 201 - else: - assert count == 250 +from distributed.system import memory_limit def test_memory_limit(): diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 8b81e3afbe6..ff523342243 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -14,6 +14,7 @@ import dask from dask import delayed from dask.utils import format_bytes +from dask.system import CPU_COUNT import pytest from toolz import pluck, sliding_window, first import tornado @@ -29,7 +30,6 @@ get_worker, Reschedule, wait, - system, ) from distributed.compatibility import WINDOWS from distributed.core import rpc @@ -62,7 +62,7 @@ def test_worker_nthreads(): w = Worker("127.0.0.1", 8019) try: - assert w.executor._max_workers == system.CPU_COUNT + assert w.executor._max_workers == CPU_COUNT finally: shutil.rmtree(w.local_directory) @@ -516,7 +516,7 @@ def test_memory_limit_auto(): assert isinstance(a.memory_limit, Number) assert isinstance(b.memory_limit, Number) - if system.CPU_COUNT > 1: + if CPU_COUNT > 1: assert a.memory_limit < b.memory_limit assert c.memory_limit == d.memory_limit diff --git a/distributed/worker.py b/distributed/worker.py index c0836077a6c..3533d460299 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -18,6 +18,7 @@ from dask.core import istask from dask.compatibility import apply from dask.utils import format_bytes, funcname +from dask.system import CPU_COUNT try: from cytoolz import pluck, partial, merge, first, keymap @@ -462,7 +463,7 @@ def __init__( warnings.warn("the ncores= parameter has moved to nthreads=") nthreads = ncores - self.nthreads = nthreads or system.CPU_COUNT + self.nthreads = nthreads or CPU_COUNT self.total_resources = resources or {} self.available_resources = (resources or {}).copy() self.death_timeout = parse_timedelta(death_timeout) @@ -3071,7 +3072,7 @@ class Reschedule(Exception): pass -def parse_memory_limit(memory_limit, nthreads, total_cores=system.CPU_COUNT): +def parse_memory_limit(memory_limit, nthreads, total_cores=CPU_COUNT): if memory_limit is None: return None From fd98e30ee129f07367ce4828523f9fd9eb4f7f58 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Thu, 14 Nov 2019 10:56:23 -0600 Subject: [PATCH 0545/1550] Efficiently serialize zero strided NumPy arrays (#3180) --- distributed/protocol/numpy.py | 15 ++++++++++++ distributed/protocol/tests/test_numpy.py | 29 ++++++++++++++++++++++++ distributed/tests/test_sizeof.py | 23 +++++++++++++++++++ 3 files changed, 67 insertions(+) create mode 100644 distributed/tests/test_sizeof.py diff --git a/distributed/protocol/numpy.py b/distributed/protocol/numpy.py index c7e48e63b1a..c5061a8f802 100644 --- a/distributed/protocol/numpy.py +++ b/distributed/protocol/numpy.py @@ -44,6 +44,13 @@ def serialize_numpy_ndarray(x): else: dt = (0, x.dtype.str) + # Only serialize non-broadcasted data for arrays with zero strided axes + if 0 in x.strides: + broadcast_to = (x.shape, x.flags.writeable) + x = x[tuple(slice(None) if s != 0 else slice(1) for s in x.strides)] + else: + broadcast_to = None + if not x.shape: # 0d array strides = x.strides @@ -68,6 +75,9 @@ def serialize_numpy_ndarray(x): header = {"dtype": dt, "shape": x.shape, "strides": strides} + if broadcast_to is not None: + header["broadcast_to"] = broadcast_to + if x.nbytes > 1e5: frames = frame_split_size([data]) else: @@ -97,6 +107,11 @@ def deserialize_numpy_ndarray(header, frames): header["shape"], dtype=dt, buffer=frames[0], strides=header["strides"] ) + if header.get("broadcast_to"): + shape, writeable = header["broadcast_to"] + x = np.broadcast_to(x, shape) + x.setflags(write=writeable) + return x diff --git a/distributed/protocol/tests/test_numpy.py b/distributed/protocol/tests/test_numpy.py index 4fb20d58631..6e4712272d8 100644 --- a/distributed/protocol/tests/test_numpy.py +++ b/distributed/protocol/tests/test_numpy.py @@ -71,6 +71,7 @@ def test_serialize(): np.zeros((1, 1000, 1000)), np.arange(12)[::2], # non-contiguous array np.ones(shape=(5, 6)).astype(dtype=[("total", " Date: Thu, 14 Nov 2019 08:57:31 -0800 Subject: [PATCH 0546/1550] Cache function deserialization in workers (#3234) This is particularly useful for numba.cuda.jit compiled functions where there is some accumulated state. --- distributed/worker.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/distributed/worker.py b/distributed/worker.py index 3533d460299..2a705320baa 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -3143,11 +3143,18 @@ async def get_data_from_worker( job_counter = [0] +import functools + + +@functools.lru_cache(100) +def cached_function_deserialization(func): + return pickle.loads(func) + def _deserialize(function=None, args=None, kwargs=None, task=no_value): """ Deserialize task inputs and regularize to func, args, kwargs """ if function is not None: - function = pickle.loads(function) + function = cached_function_deserialization(function) if args: args = pickle.loads(args) if kwargs: From 27b30fa2965d510d3afddb04807106ecb37f8d8d Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 14 Nov 2019 11:32:05 -0800 Subject: [PATCH 0547/1550] Respect ordering of futures in futures_of (#3236) Previously we used to drop ordering and use only a set. Now we maintain a list of when we first see each future. --- distributed/client.py | 28 ++++++++++++++++++++++++---- distributed/tests/test_client.py | 9 +++++++++ 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 6f5bf786367..4f315e89b6d 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -4385,9 +4385,27 @@ def redict_collection(c, dsk): def futures_of(o, client=None): - """ Future objects in a collection """ + """ Future objects in a collection + + Parameters + ---------- + o: collection + A possibly nested collection of Dask objects + + Examples + -------- + >>> futures_of(my_dask_dataframe) + [, + ] + + Returns + ------- + futures : List[Future] + A list of futures held by those collections + """ stack = [o] - futures = set() + seen = set() + futures = list() while stack: x = stack.pop() if type(x) in (tuple, set, list): @@ -4397,7 +4415,9 @@ def futures_of(o, client=None): elif type(x) is SubgraphCallable: stack.extend(x.dsk.values()) elif isinstance(x, Future): - futures.add(x) + if x not in seen: + seen.add(x) + futures.append(x) elif dask.is_dask_collection(x): stack.extend(x.__dask_graph__().values()) @@ -4406,7 +4426,7 @@ def futures_of(o, client=None): if bad: raise CancelledError(bad) - return list(futures) + return futures[::-1] def fire_and_forget(obj): diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 9f8e48298fa..73906517045 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -5678,5 +5678,14 @@ def f(x): assert result == 101 +@gen_cluster(client=True) +async def test_futures_of_sorted(c, s, a, b): + pytest.importorskip("dask.dataframe") + df = await dask.datasets.timeseries(dtypes={"x": int}).persist() + futures = futures_of(df) + for k, f in zip(df.__dask_keys__(), futures): + assert str(k) in str(f) + + if sys.version_info >= (3, 5): from distributed.tests.py3_test_client import * # noqa F401 From 81821ad057968c813b2d0f112c6cee735cafb5a4 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Thu, 14 Nov 2019 15:10:12 -0600 Subject: [PATCH 0548/1550] Bump dask dependency to 2.7.0 (#3237) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index d1335d0b3b1..b17e4620be6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ click >= 6.6 cloudpickle >= 0.2.2 -dask >= 2.5.2 +dask >= 2.7.0 msgpack psutil >= 5.0 sortedcontainers !=2.0.0, !=2.0.1 From 3670cc2ad9d1d885ddedeae3ac98cd1593d5a2d9 Mon Sep 17 00:00:00 2001 From: rockwellw Date: Thu, 14 Nov 2019 13:10:45 -0800 Subject: [PATCH 0549/1550] Avoid setting inf x_range (#3229) --- distributed/dashboard/components/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/dashboard/components/scheduler.py b/distributed/dashboard/components/scheduler.py index 77f095bce0d..cb8f6e6f20e 100644 --- a/distributed/dashboard/components/scheduler.py +++ b/distributed/dashboard/components/scheduler.py @@ -585,7 +585,7 @@ def update(self): or inf ) - if limit > max_limit: + if limit > max_limit and limit != inf: max_limit = limit if nb > limit: From 886189aee41a18a08b15b4f48a65da91620e3cb7 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 14 Nov 2019 13:11:48 -0800 Subject: [PATCH 0550/1550] Clear task stream based on recent behavior (#3200) --- distributed/dashboard/components/scheduler.py | 36 +++++++++++++------ 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/distributed/dashboard/components/scheduler.py b/distributed/dashboard/components/scheduler.py index cb8f6e6f20e..a53a0e744e8 100644 --- a/distributed/dashboard/components/scheduler.py +++ b/distributed/dashboard/components/scheduler.py @@ -1,7 +1,7 @@ import logging import math from numbers import Number -from operator import add +import operator import os from bokeh.layouts import column, row @@ -873,6 +873,7 @@ def __init__(self, scheduler, n_rectangles=1000, clear_interval="20s", **kwargs) clear_interval = parse_timedelta(clear_interval, default="ms") self.clear_interval = clear_interval self.last = 0 + self.last_seen = 0 self.source, self.root = task_stream_figure(clear_interval, **kwargs) @@ -899,16 +900,31 @@ def update(self): if not rectangles["start"]: return - # If there has been a significant delay then clear old rectangles - first_end = min(map(add, rectangles["start"], rectangles["duration"])) - if first_end > self.last: - last = self.last - self.last = first_end - if first_end > last + self.clear_interval * 1000: - self.offset = min(rectangles["start"]) - self.source.data.update({k: [] for k in rectangles}) + # If it has been a while since we've updated the plot + if time() > self.last_seen + self.clear_interval: + new_start = min(rectangles["start"]) - self.offset + old_start = min(self.source.data["start"]) + old_end = max( + map( + operator.add, + self.source.data["start"], + self.source.data["duration"], + ) + ) + + density = ( + sum(self.source.data["duration"]) + / len(self.workers) + / (old_end - old_start) + ) + + # If whitespace is more than 3x the old width + if (new_start - old_end) > (old_end - old_start) * 2 or density < 0.05: + self.source.data.update({k: [] for k in rectangles}) # clear + self.offset = min(rectangles["start"]) # redefine offset rectangles["start"] = [x - self.offset for x in rectangles["start"]] + self.last_seen = time() # Convert to numpy for serialization speed if n >= 10 and np: @@ -1707,7 +1723,7 @@ def status_doc(scheduler, extra, doc): n_rectangles=dask.config.get( "distributed.scheduler.dashboard.status.task-stream-length" ), - clear_interval="10s", + clear_interval="5s", sizing_mode="stretch_both", ) task_stream.update() From f8aca16c5dc5910618f0a7e1acfb1f284fd49759 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 14 Nov 2019 14:23:37 -0800 Subject: [PATCH 0551/1550] Use the percentage field for profile plots (#3238) Previously we populated this data, but didn't use it correctly in the tooltip --- distributed/profile.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/distributed/profile.py b/distributed/profile.py index 274dfcd1d20..1bef6450974 100644 --- a/distributed/profile.py +++ b/distributed/profile.py @@ -230,7 +230,7 @@ def traverse(state, start, stop, height): x += width traverse(state, 0, 1, 0) - percentages = ["{:.2f}%".format(100 * w) for w in widths] + percentages = ["{:.1f}%".format(100 * w) for w in widths] return { "left": starts, "right": stops, @@ -423,7 +423,7 @@ def plot_figure(data, **kwargs):
          Percentage:  - @width + @percentage
          """, ) From 4d0d58aade4460fab6e7e85a3548353671036d2c Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 14 Nov 2019 14:58:28 -0800 Subject: [PATCH 0552/1550] bump version to 2.8.0 --- docs/source/changelog.rst | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index 3cd8e3d287d..9187d0a1579 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -1,6 +1,31 @@ Changelog ========= +2.8.0 - 2019-11-14 +------------------ + +- Add UCX config values (:pr:`3135`) `Matthew Rocklin`_ +- Relax test_MultiWorker (:pr:`3210`) `Matthew Rocklin`_ +- Avoid ucp.init at import time (:pr:`3211`) `Matthew Rocklin`_ +- Clean up rpc to avoid intermittent test failure (:pr:`3215`) `Matthew Rocklin`_ +- Respect protocol if given to Scheduler (:pr:`3212`) `Matthew Rocklin`_ +- Use legend_field= keyword in bokeh plots (:pr:`3218`) `Matthew Rocklin`_ +- Cache psutil.Process object in Nanny (:pr:`3207`) `Matthew Rocklin`_ +- Replace gen.sleep with asyncio.sleep (:pr:`3208`) `Matthew Rocklin`_ +- Avoid offloading serialization for small messages (:pr:`3224`) `Matthew Rocklin`_ +- Add desired_workers metric (:pr:`3221`) `Gabriel Sailer`_ +- Fail fast when importing distributed.comm.ucx (:pr:`3228`) `Matthew Rocklin`_ +- Add module name to Future repr (:pr:`3231`) `Matthew Rocklin`_ +- Add name to Pub/Sub repr (:pr:`3235`) `Matthew Rocklin`_ +- Import CPU_COUNT from dask.system (:pr:`3199`) `James Bourbeau`_ +- Efficiently serialize zero strided NumPy arrays (:pr:`3180`) `James Bourbeau`_ +- Cache function deserialization in workers (:pr:`3234`) `Matthew Rocklin`_ +- Respect ordering of futures in futures_of (:pr:`3236`) `Matthew Rocklin`_ +- Bump dask dependency to 2.7.0 (:pr:`3237`) `James Bourbeau`_ +- Avoid setting inf x_range (:pr:`3229`) `rockwellw`_ +- Clear task stream based on recent behavior (:pr:`3200`) `Matthew Rocklin`_ +- Use the percentage field for profile plots (:pr:`3238`) `Matthew Rocklin`_ + 2.7.0 - 2019-11-08 ------------------ @@ -1365,3 +1390,4 @@ significantly without many new features. .. _`James A. Bednar`: https://github.com/jbednar .. _`IPetrik`: https://github.com/IPetrik .. _`Simon Boothroyd`: https://github.com/SimonBoothroyd +.. _`rockwellw`: https://github.com/rockwellw From 8eeb603e7518e692fa2aec940ca59572ed00e24a Mon Sep 17 00:00:00 2001 From: Jim Crist-Harif Date: Fri, 15 Nov 2019 11:08:18 -0600 Subject: [PATCH 0553/1550] Use inspect.isawaitable where relevant (#3241) Noticed a few places in distributed where `hasattr(x, '__await__')` was checked. This should use `inspect.isawaitable` instead for readability/correctness. Updated accordingly. --- distributed/core.py | 3 ++- distributed/deploy/adaptive.py | 3 ++- distributed/scheduler.py | 3 ++- distributed/worker.py | 9 +++++---- 4 files changed, 11 insertions(+), 7 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index d096e0b2274..3e1a3b47cc3 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -2,6 +2,7 @@ from collections import defaultdict, deque from concurrent.futures import CancelledError from functools import partial +from inspect import isawaitable import logging import threading import traceback @@ -397,7 +398,7 @@ async def handle_comm(self, comm, shutting_down=shutting_down): logger.debug("Calling into handler %s", handler.__name__) try: result = handler(comm, **msg) - if hasattr(result, "__await__"): + if isawaitable(result): result = asyncio.ensure_future(result) self._ongoing_coroutines.add(result) result = await result diff --git a/distributed/deploy/adaptive.py b/distributed/deploy/adaptive.py index f173e36a396..0d295200018 100644 --- a/distributed/deploy/adaptive.py +++ b/distributed/deploy/adaptive.py @@ -1,3 +1,4 @@ +from inspect import isawaitable import logging import math @@ -158,7 +159,7 @@ async def scale_down(self, workers): # close workers more forcefully logger.info("Retiring workers %s", workers) f = self.cluster.scale_down(workers) - if hasattr(f, "__await__"): + if isawaitable(f): await f async def scale_up(self, n): diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 6724524344d..5cb2267759c 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -3,6 +3,7 @@ from collections.abc import Mapping, Set from datetime import timedelta from functools import partial +from inspect import isawaitable import itertools import json import logging @@ -3283,7 +3284,7 @@ async def feed( if teardown: teardown = pickle.loads(teardown) state = setup(self) if setup else None - if hasattr(state, "__await__"): + if isawaitable(state): state = await state try: while self.status == "running": diff --git a/distributed/worker.py b/distributed/worker.py index 2a705320baa..e3ef6b260fa 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -4,6 +4,7 @@ from collections.abc import MutableMapping from datetime import timedelta import heapq +from inspect import isawaitable import logging import os from pickle import PicklingError @@ -748,7 +749,7 @@ async def get_metrics(self): for k, metric in self.metrics.items(): try: result = metric(self) - if hasattr(result, "__await__"): + if isawaitable(result): result = await result custom[k] = result except Exception: # TODO: log error once @@ -761,7 +762,7 @@ async def get_startup_information(self): for k, f in self.startup_information.items(): try: v = f(self) - if hasattr(v, "__await__"): + if isawaitable(v): v = await v result[k] = v except Exception: # TODO: log error once @@ -1057,7 +1058,7 @@ async def close( if hasattr(plugin, "teardown") ] - await asyncio.gather(*[td for td in teardowns if hasattr(td, "__await__")]) + await asyncio.gather(*[td for td in teardowns if isawaitable(td)]) for pc in self.periodic_callbacks.values(): pc.stop() @@ -2301,7 +2302,7 @@ async def plugin_add(self, comm=None, plugin=None, name=None): if hasattr(plugin, "setup"): try: result = plugin.setup(worker=self) - if hasattr(result, "__await__"): + if isawaitable(result): result = await result except Exception as e: msg = error_message(e) From 1a46657af40106fc1fdcc1f2ec6b7f48f5e98566 Mon Sep 17 00:00:00 2001 From: Jim Crist-Harif Date: Fri, 15 Nov 2019 13:54:23 -0600 Subject: [PATCH 0554/1550] Remove `gen.coroutine` usage in scheduler (#3242) Use `async`/`await` and `asyncio` idioms throughout. --- distributed/scheduler.py | 116 +++++++++++++++++++-------------------- 1 file changed, 58 insertions(+), 58 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 5cb2267759c..e013bbee9a9 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -25,7 +25,6 @@ from toolz import frequencies, merge, pluck, merge_sorted, first from toolz import valmap, second, compose, groupby from tornado import gen -from tornado.gen import Return from tornado.ioloop import IOLoop import dask @@ -2008,8 +2007,8 @@ def cancel_key(self, key, client, retries=5, force=False): return if ts is None or not ts.who_wants: # no key yet, lets try again in a moment if retries: - self.loop.add_future( - gen.sleep(0.2), lambda _: self.cancel_key(key, client, retries - 1) + self.loop.call_later( + 0.2, lambda: self.cancel_key(key, client, retries - 1) ) return if force or ts.who_wants == {cs}: # no one else wants this key @@ -2701,8 +2700,7 @@ async def proxy(self, comm=None, msg=None, worker=None, serializers=None): ) return d[worker] - @gen.coroutine - def rebalance(self, comm=None, keys=None, workers=None): + async def rebalance(self, comm=None, keys=None, workers=None): """ Rebalance keys so that each worker stores roughly equal bytes **Policy** @@ -2778,9 +2776,9 @@ def rebalance(self, comm=None, keys=None, workers=None): to_recipients[recipient.address][ts.key].append(sender.address) to_senders[sender.address].append(ts.key) - result = yield { - r: self.rpc(addr=r).gather(who_has=v) for r, v in to_recipients.items() - } + result = await asyncio.gather( + *(self.rpc(addr=r).gather(who_has=v) for r, v in to_recipients.items()) + ) for r, v in to_recipients.items(): self.log_event(r, {"action": "rebalance", "who_has": v}) @@ -2795,13 +2793,11 @@ def rebalance(self, comm=None, keys=None, workers=None): }, ) - if not all(r["status"] == "OK" for r in result.values()): - raise Return( - { - "status": "missing-data", - "keys": sum([r["keys"] for r in result if "keys" in r], []), - } - ) + if not all(r["status"] == "OK" for r in result): + return { + "status": "missing-data", + "keys": sum([r["keys"] for r in result if "keys" in r], []), + } for sender, recipient, ts in msgs: assert ts.state == "memory" @@ -2812,20 +2808,21 @@ def rebalance(self, comm=None, keys=None, workers=None): ("rebalance", ts.key, time(), sender.address, recipient.address) ) - result = yield { - r: self.rpc(addr=r).delete_data(keys=v, report=False) - for r, v in to_senders.items() - } + await asyncio.gather( + *( + self.rpc(addr=r).delete_data(keys=v, report=False) + for r, v in to_senders.items() + ) + ) for sender, recipient, ts in msgs: ts.who_has.remove(sender) sender.has_what.remove(ts) sender.nbytes -= ts.get_nbytes() - raise Return({"status": "OK"}) + return {"status": "OK"} - @gen.coroutine - def replicate( + async def replicate( self, comm=None, keys=None, @@ -2868,7 +2865,7 @@ def replicate( tasks = {self.tasks[k] for k in keys} missing_data = [ts.key for ts in tasks if not ts.who_has] if missing_data: - raise Return({"status": "missing-data", "keys": missing_data}) + return {"status": "missing-data", "keys": missing_data} # Delete extraneous data if delete: @@ -2879,12 +2876,14 @@ def replicate( for ws in random.sample(del_candidates, len(del_candidates) - n): del_worker_tasks[ws].add(ts) - yield [ - self.rpc(addr=ws.address).delete_data( - keys=[ts.key for ts in tasks], report=False + await asyncio.gather( + *( + self.rpc(addr=ws.address).delete_data( + keys=[ts.key for ts in tasks], report=False + ) + for ws, tasks in del_worker_tasks.items() ) - for ws, tasks in del_worker_tasks.items() - ] + ) for ws, tasks in del_worker_tasks.items(): ws.has_what -= tasks @@ -2912,11 +2911,13 @@ def replicate( for ws in random.sample(workers - ts.who_has, count): gathers[ws.address][ts.key] = [wws.address for wws in ts.who_has] - results = yield { - w: self.rpc(addr=w).gather(who_has=who_has) - for w, who_has in gathers.items() - } - for w, v in results.items(): + results = await asyncio.gather( + *( + self.rpc(addr=w).gather(who_has=who_has) + for w, who_has in gathers.items() + ) + ) + for w, v in zip(gathers, results): if v["status"] == "OK": self.add_keys(worker=w, keys=list(gathers[w])) else: @@ -3349,8 +3350,7 @@ def get_ncores(self, comm=None, workers=None): else: return {w: ws.nthreads for w, ws in self.workers.items()} - @gen.coroutine - def get_call_stack(self, comm=None, keys=None): + async def get_call_stack(self, comm=None, keys=None): if keys is not None: stack = list(keys) processing = set() @@ -3370,14 +3370,13 @@ def get_call_stack(self, comm=None, keys=None): workers = {w: None for w in self.workers} if not workers: - raise gen.Return({}) + return {} - else: - response = yield { - w: self.rpc(w).call_stack(keys=v) for w, v in workers.items() - } - response = {k: v for k, v in response.items() if v} - raise gen.Return(response) + results = await asyncio.gather( + *(self.rpc(w).call_stack(keys=v) for w, v in workers.items()) + ) + response = {w: r for w, r in zip(workers, results) if r} + return response def get_nbytes(self, comm=None, keys=None, summary=True): with log_errors(): @@ -4614,8 +4613,7 @@ def worker_objective(self, ts, ws): else: return (start_time, ws.nbytes) - @gen.coroutine - def get_profile( + async def get_profile( self, comm=None, workers=None, @@ -4628,15 +4626,17 @@ def get_profile( workers = self.workers else: workers = set(self.workers) & set(workers) - result = yield { - w: self.rpc(w).profile(start=start, stop=stop, key=key) for w in workers - } + results = await asyncio.gather( + *(self.rpc(w).profile(start=start, stop=stop, key=key) for w in workers) + ) + if merge_workers: - result = profile.merge(*result.values()) - raise gen.Return(result) + response = profile.merge(*results) + else: + response = dict(zip(workers, results)) + return response - @gen.coroutine - def get_profile_metadata( + async def get_profile_metadata( self, comm=None, workers=None, @@ -4654,22 +4654,22 @@ def get_profile_metadata( workers = self.workers else: workers = set(self.workers) & set(workers) - result = yield { - w: self.rpc(w).profile_metadata(start=start, stop=stop) for w in workers - } + results = await asyncio.gather( + *(self.rpc(w).profile_metadata(start=start, stop=stop) for w in workers) + ) - counts = [v["counts"] for v in result.values()] + counts = [v["counts"] for v in results] counts = itertools.groupby(merge_sorted(*counts), lambda t: t[0] // dt * dt) counts = [(time, sum(pluck(1, group))) for time, group in counts] keys = set() - for v in result.values(): + for v in results: for t, d in v["keys"]: for k in d: keys.add(k) keys = {k: [] for k in keys} - groups1 = [v["keys"] for v in result.values()] + groups1 = [v["keys"] for v in results] groups2 = list(merge_sorted(*groups1, key=first)) last = 0 @@ -4682,7 +4682,7 @@ def get_profile_metadata( for k, v in d.items(): keys[k][-1][1] += v - raise gen.Return({"counts": counts, "keys": keys}) + return {"counts": counts, "keys": keys} async def get_worker_logs(self, comm=None, n=None, workers=None, nanny=False): results = await self.broadcast( From 93d631ad84ee8517865a2c761e97ed84c4272ec8 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Fri, 15 Nov 2019 22:57:46 +0100 Subject: [PATCH 0555/1550] Fixed cupy array going out of scope (#3240) --- distributed/protocol/cupy.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/distributed/protocol/cupy.py b/distributed/protocol/cupy.py index d15c719359c..26a5accc6af 100644 --- a/distributed/protocol/cupy.py +++ b/distributed/protocol/cupy.py @@ -6,14 +6,34 @@ class PatchedCudaArrayInterface(object): - # TODO: This class wont be necessary - # once Cupy<7.0 is no longer supported + """This class do two things: + 1) Makes sure that __cuda_array_interface__['strides'] + behaves as specified in the protocol. + 2) Makes sure that the cuda context is active + when deallocating the base cuda array. + Notice, this is only needed when the array to deserialize + isn't a native cupy array. + """ + def __init__(self, ary): cai = ary.__cuda_array_interface__ cai_cupy_vsn = cupy.ndarray(0).__cuda_array_interface__["version"] if cai.get("strides") is None and cai_cupy_vsn < 2: cai.pop("strides", None) self.__cuda_array_interface__ = cai + # Save a ref to ary so it won't go out of scope + self.base = ary + + def __del__(self): + # Making sure that the cuda context is active + # when deallocating the base cuda array + try: + import numba.cuda + + numba.cuda.current_context() + except ImportError: + pass + del self.base @cuda_serialize.register(cupy.ndarray) From 029ed174dce249ea1493fe60f1427827322aa2f5 Mon Sep 17 00:00:00 2001 From: He Jia Date: Sat, 16 Nov 2019 12:31:41 +0800 Subject: [PATCH 0556/1550] Fixed typos in pubsub.py (#3244) --- distributed/pubsub.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/distributed/pubsub.py b/distributed/pubsub.py index 5ed631c46ec..0a5d82897fd 100644 --- a/distributed/pubsub.py +++ b/distributed/pubsub.py @@ -74,7 +74,7 @@ def remove_publisher(self, comm=None, name=None, worker=None): def remove_subscriber(self, comm=None, name=None, worker=None, client=None): if worker: - logger.debug("Add worker subscriber: %s %s", name, worker) + logger.debug("Remove worker subscriber: %s %s", name, worker) self.subscribers[name].remove(worker) for pub in self.publishers[name]: self.scheduler.worker_send( @@ -82,7 +82,7 @@ def remove_subscriber(self, comm=None, name=None, worker=None, client=None): {"op": "pubsub-remove-subscriber", "address": worker, "name": name}, ) elif client: - logger.debug("Add client subscriber: %s %s", name, client) + logger.debug("Remove client subscriber: %s %s", name, client) self.client_subscribers[name].remove(client) if not self.client_subscribers[name]: del self.client_subscribers[name] From 25976b4b8b592388f8b7503dcc4151157553ebaf Mon Sep 17 00:00:00 2001 From: Jed Brown Date: Mon, 18 Nov 2019 21:44:38 -0700 Subject: [PATCH 0557/1550] docs: fix array.shape() -> array.shape (#3247) --- docs/source/efficiency.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/efficiency.rst b/docs/source/efficiency.rst index 39ca16d93d5..94a603ea9a3 100644 --- a/docs/source/efficiency.rst +++ b/docs/source/efficiency.rst @@ -31,7 +31,7 @@ shape we might choose one of the following options: .. code-block:: python - >>> x.result().shape() # Slow from lots of data transfer + >>> x.result().shape # Slow from lots of data transfer (1000, 1000) **Fast** From 35551998d7350cd5ae6a5c24970d8437fd8d521d Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 19 Nov 2019 09:24:32 -0700 Subject: [PATCH 0558/1550] Add new dashboard plot for memory use by key (#3243) --- distributed/dashboard/components/scheduler.py | 88 ++++++++++++++++++- distributed/dashboard/scheduler.py | 2 + .../dashboard/tests/test_scheduler_bokeh.py | 19 ++++ 3 files changed, 108 insertions(+), 1 deletion(-) diff --git a/distributed/dashboard/components/scheduler.py b/distributed/dashboard/components/scheduler.py index a53a0e744e8..7f226bf7ddc 100644 --- a/distributed/dashboard/components/scheduler.py +++ b/distributed/dashboard/components/scheduler.py @@ -1,3 +1,4 @@ +from collections import defaultdict import logging import math from numbers import Number @@ -36,7 +37,7 @@ from bokeh.transform import factor_cmap, linear_cmap from bokeh.io import curdoc import dask -from dask.utils import format_bytes +from dask.utils import format_bytes, key_split from toolz import pipe from tornado import escape @@ -428,6 +429,82 @@ def update(self): update(self.source, result) +class MemoryByKey(DashboardComponent): + """ Bar chart showing memory use by key prefix""" + + def __init__(self, scheduler, **kwargs): + with log_errors(): + self.last = 0 + self.scheduler = scheduler + self.source = ColumnDataSource( + { + "name": ["a", "b"], + "nbytes": [100, 1000], + "count": [1, 2], + "color": ["blue", "blue"], + } + ) + + fig = figure( + title="Memory Use", + tools="", + id="bk-memory-by-key-plot", + name="memory_by_key", + x_range=["a", "b"], + **kwargs, + ) + rect = fig.vbar( + source=self.source, x="name", top="nbytes", width=0.9, color="color" + ) + fig.yaxis[0].formatter = NumeralTickFormatter(format="0.0 b") + fig.xaxis.major_label_orientation = -math.pi / 12 + rect.nonselection_glyph = None + + fig.xaxis.minor_tick_line_alpha = 0 + fig.ygrid.visible = False + + fig.toolbar.logo = None + fig.toolbar_location = None + + hover = HoverTool() + hover.tooltips = "@name: @nbytes_text" + hover.tooltips = """ +
          +

          Name: @name

          +

          Bytes: @nbytes_text

          +

          Count: @count objects

          +
          + """ + hover.point_policy = "follow_mouse" + fig.add_tools(hover) + + self.fig = fig + + @without_property_validation + def update(self): + with log_errors(): + counts = defaultdict(int) + nbytes = defaultdict(int) + for ws in self.scheduler.workers.values(): + for ts in ws.has_what: + ks = key_split(ts.key) + counts[ks] += 1 + nbytes[ks] += ts.nbytes + + names = list(sorted(counts)) + self.fig.x_range.factors = names + result = { + "name": names, + "count": [counts[name] for name in names], + "nbytes": [nbytes[name] for name in names], + "nbytes_text": [format_bytes(nbytes[name]) for name in names], + "color": [color_of(name) for name in names], + } + self.fig.title.text = "Total Use: " + format_bytes(sum(nbytes.values())) + + update(self.source, result) + + class CurrentLoad(DashboardComponent): """ How many tasks are on each worker """ @@ -1865,6 +1942,15 @@ def individual_bandwidth_workers_doc(scheduler, extra, doc): doc.theme = BOKEH_THEME +def individual_memory_by_key_doc(scheduler, extra, doc): + with log_errors(): + component = MemoryByKey(scheduler, sizing_mode="stretch_both") + component.update() + add_periodic_callback(doc, component, 500) + doc.add_root(component.fig) + doc.theme = BOKEH_THEME + + def profile_doc(scheduler, extra, doc): with log_errors(): doc.title = "Dask: Profile" diff --git a/distributed/dashboard/scheduler.py b/distributed/dashboard/scheduler.py index ecc413b5f0d..67d08c50bd0 100644 --- a/distributed/dashboard/scheduler.py +++ b/distributed/dashboard/scheduler.py @@ -36,6 +36,7 @@ individual_workers_doc, individual_bandwidth_types_doc, individual_bandwidth_workers_doc, + individual_memory_by_key_doc, ) from .core import BokehServer from .worker import counters_doc @@ -408,6 +409,7 @@ def listen(self, *args, **kwargs): "/individual-workers": individual_workers_doc, "/individual-bandwidth-types": individual_bandwidth_types_doc, "/individual-bandwidth-workers": individual_bandwidth_workers_doc, + "/individual-memory-by-key": individual_memory_by_key_doc, } try: diff --git a/distributed/dashboard/tests/test_scheduler_bokeh.py b/distributed/dashboard/tests/test_scheduler_bokeh.py index 875f1064503..14d055baa7f 100644 --- a/distributed/dashboard/tests/test_scheduler_bokeh.py +++ b/distributed/dashboard/tests/test_scheduler_bokeh.py @@ -11,6 +11,7 @@ from tornado import gen from tornado.httpclient import AsyncHTTPClient, HTTPRequest +import dask from dask.core import flatten from distributed.utils import tokey, format_dashboard_link from distributed.client import wait @@ -34,6 +35,7 @@ WorkerTable, TaskGraph, ProfileServer, + MemoryByKey, ) from distributed.dashboard import scheduler @@ -690,3 +692,20 @@ def test_https_support(c, s, a, b): body = response.body.decode() assert "bokeh" in body.lower() assert not re.search("href=./", body) # no absolute links + + +@gen_cluster( + client=True, scheduler_kwargs={"services": {("dashboard", 0): BokehScheduler}} +) +async def test_memory_by_key(c, s, a, b): + mbk = MemoryByKey(s) + + da = pytest.importorskip("dask.array") + x = (da.random.random((20, 20), chunks=(10, 10)) + 1).persist(optimize_graph=False) + await x + + y = await dask.delayed(inc)(1).persist() + + mbk.update() + assert mbk.source.data["name"] == ["add", "inc"] + assert mbk.source.data["nbytes"] == [x.nbytes, sys.getsizeof(1)] From be4e9661edc01dc3040871803da8b671c2188066 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Thu, 21 Nov 2019 15:04:15 +0100 Subject: [PATCH 0559/1550] Skip numba.cuda tests if CUDA is not available (#3255) --- distributed/protocol/tests/test_numba.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/distributed/protocol/tests/test_numba.py b/distributed/protocol/tests/test_numba.py index 794db58b3c9..69ea73310d4 100644 --- a/distributed/protocol/tests/test_numba.py +++ b/distributed/protocol/tests/test_numba.py @@ -7,6 +7,9 @@ @pytest.mark.parametrize("dtype", ["u1", "u4", "u8", "f4"]) def test_serialize_cupy(dtype): + if not cuda.is_available(): + pytest.skip("CUDA is not available") + ary = np.arange(100, dtype=dtype) x = cuda.to_device(ary) header, frames = serialize(x, serializers=("cuda", "dask", "pickle")) From 5b33d54cd9afd174f827bcd139a581aa215f95a7 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Fri, 22 Nov 2019 06:13:07 -0600 Subject: [PATCH 0560/1550] Fix NumPy writeable serialization bug (#3253) * Add failing test * Pass broadcasted shape to np.ndarray --- distributed/protocol/numpy.py | 14 ++++++-------- distributed/protocol/tests/test_numpy.py | 19 +++++++++++++++++++ 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/distributed/protocol/numpy.py b/distributed/protocol/numpy.py index c5061a8f802..9a1f493c333 100644 --- a/distributed/protocol/numpy.py +++ b/distributed/protocol/numpy.py @@ -46,7 +46,7 @@ def serialize_numpy_ndarray(x): # Only serialize non-broadcasted data for arrays with zero strided axes if 0 in x.strides: - broadcast_to = (x.shape, x.flags.writeable) + broadcast_to = x.shape x = x[tuple(slice(None) if s != 0 else slice(1) for s in x.strides)] else: broadcast_to = None @@ -103,14 +103,12 @@ def deserialize_numpy_ndarray(header, frames): else: dt = np.dtype(dt) - x = np.ndarray( - header["shape"], dtype=dt, buffer=frames[0], strides=header["strides"] - ) - if header.get("broadcast_to"): - shape, writeable = header["broadcast_to"] - x = np.broadcast_to(x, shape) - x.setflags(write=writeable) + shape = header["broadcast_to"] + else: + shape = header["shape"] + + x = np.ndarray(shape, dtype=dt, buffer=frames[0], strides=header["strides"]) return x diff --git a/distributed/protocol/tests/test_numpy.py b/distributed/protocol/tests/test_numpy.py index 6e4712272d8..432b749e27e 100644 --- a/distributed/protocol/tests/test_numpy.py +++ b/distributed/protocol/tests/test_numpy.py @@ -288,3 +288,22 @@ def test_non_zero_strided_array(): header, frames = serialize(x) assert "broadcast_to" not in header assert sum(map(nbytes, frames)) == x.nbytes + + +def test_serialize_writeable_array_readonly_base_object(): + # Regression test for https://github.com/dask/distributed/issues/3252 + + x = np.arange(3) + # Create array which doesn't own it's own memory + y = np.broadcast_to(x, (3, 3)) + + # Make y writeable and it's base object (x) read-only + y.setflags(write=True) + x.setflags(write=False) + + # Serialize / deserialize y + z = deserialize(*serialize(y)) + np.testing.assert_equal(z, y) + + # Ensure z and y have the same flags (including WRITEABLE) + assert z.flags == y.flags From 42a51ce086e9d8badbe68e139db471ce7b28befa Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Fri, 22 Nov 2019 15:41:15 -0600 Subject: [PATCH 0561/1550] Fix hanging worker when the scheduler leaves (#3250) Closes https://github.com/dask/distributed/issues/2880 --- distributed/tests/test_nanny.py | 15 +++++++++++++++ distributed/worker.py | 6 ++++++ 2 files changed, 21 insertions(+) diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index fccfd2efde6..cacd98477e0 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -1,3 +1,4 @@ +import asyncio import gc import logging import os @@ -129,6 +130,20 @@ def test_run(s): yield n.close() +@pytest.mark.slow +@gen_cluster(config={"distributed.comm.timeouts.connect": "1s"}) +async def test_no_hang_when_scheduler_closes(s, a, b): + # https://github.com/dask/distributed/issues/2880 + with captured_logger("tornado.application", logging.ERROR) as logger: + await s.close() + await asyncio.sleep(1.2) + assert a.status == "closed" + assert b.status == "closed" + + out = logger.getvalue() + assert "Timed out trying to connect" not in out + + @pytest.mark.slow @gen_cluster( Worker=Nanny, nthreads=[("127.0.0.1", 1)], worker_kwargs={"reconnect": False} diff --git a/distributed/worker.py b/distributed/worker.py index e3ef6b260fa..e5666aca561 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -882,6 +882,12 @@ async def heartbeat(self): ) self.bandwidth_workers.clear() self.bandwidth_types.clear() + except IOError as e: + # Scheduler is gone. Respect distributed.comm.timeouts.connect + if "Timed out trying to connect" in str(e): + await self.close(report=False) + else: + raise e except CommClosedError: logger.warning("Heartbeat to scheduler failed") finally: From 507659d79434845e50d48c247ff42d5efd336686 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Fri, 22 Nov 2019 22:46:55 -0600 Subject: [PATCH 0562/1550] bump version to 2.8.1 --- docs/source/changelog.rst | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index 9187d0a1579..21a57806533 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -1,6 +1,20 @@ Changelog ========= +2.8.1 - 2019-11-22 +------------------ + +- Fix hanging worker when the scheduler leaves (:pr:`3250`) `Tom Augspurger`_ +- Fix NumPy writeable serialization bug (:pr:`3253`) `James Bourbeau`_ +- Skip ``numba.cuda`` tests if CUDA is not available (:pr:`3255`) `Peter Andreas Entschev`_ +- Add new dashboard plot for memory use by key (:pr:`3243`) `Matthew Rocklin`_ +- Fix ``array.shape()`` -> ``array.shape`` (:pr:`3247`) `Jed Brown`_ +- Fixed typos in ``pubsub.py`` (:pr:`3244`) `He Jia`_ +- Fixed cupy array going out of scope (:pr:`3240`) `Mads R. B. Kristensen`_ +- Remove ``gen.coroutine`` usage in scheduler (:pr:`3242`) `Jim Crist-Harif`_ +- Use ``inspect.isawaitable`` where relevant (:pr:`3241`) `Jim Crist-Harif`_ + + 2.8.0 - 2019-11-14 ------------------ @@ -1391,3 +1405,6 @@ significantly without many new features. .. _`IPetrik`: https://github.com/IPetrik .. _`Simon Boothroyd`: https://github.com/SimonBoothroyd .. _`rockwellw`: https://github.com/rockwellw +.. _`Jed Brown`: https://github.com/jedbrown +.. _`He Jia`: https://github.com/HerculesJack +.. _`Jim Crist-Harif`: https://github.com/jcrist From bacfa51bf95dddcb75beaa62141e487b52e604cf Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sat, 23 Nov 2019 14:57:22 -0700 Subject: [PATCH 0563/1550] Add validate options to configuraation (#3258) --- distributed/distributed.yaml | 2 ++ distributed/scheduler.py | 4 +++- distributed/worker.py | 4 +++- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index 92b7c15e157..59bf6f8dc1c 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -24,6 +24,7 @@ distributed: preload: [] preload-argv: [] default-task-durations: {} # How long we expect function names to run ("1h", "1s") (helps for long tasks) + validate: False # Check scheduler state at every step for debugging dashboard: status: task-stream-length: 1000 @@ -44,6 +45,7 @@ distributed: preload: [] preload-argv: [] daemon: True + validate: False # Check worker state at every step for debugging lifetime: duration: null # Time after which to gracefully shutdown the worker stagger: 0 seconds # Random amount by which to stagger lifetimes diff --git a/distributed/scheduler.py b/distributed/scheduler.py index e013bbee9a9..94bb5ea756f 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -840,7 +840,7 @@ def __init__( service_kwargs=None, allowed_failures=None, extensions=None, - validate=False, + validate=None, scheduler_file=None, security=None, worker_ttl=None, @@ -860,6 +860,8 @@ def __init__( if allowed_failures is None: allowed_failures = dask.config.get("distributed.scheduler.allowed-failures") self.allowed_failures = allowed_failures + if validate is None: + validate = dask.config.get("distributed.scheduler.validate") self.validate = validate self.status = None self.proc = psutil.Process() diff --git a/distributed/worker.py b/distributed/worker.py index e5666aca561..76ee7a723d5 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -320,7 +320,7 @@ def __init__( nanny=None, plugins=(), low_level_profiler=dask.config.get("distributed.worker.profile.low-level"), - validate=False, + validate=None, profile_cycle_interval=None, lifetime=None, lifetime_stagger=None, @@ -386,6 +386,8 @@ def __init__( self.target_message_size = 50e6 # 50 MB self.log = deque(maxlen=100000) + if validate is None: + validate = dask.config.get("distributed.scheduler.validate") self.validate = validate self._transitions = { From d2b43fbfed1fa0385c054987a170b09f32af4b9a Mon Sep 17 00:00:00 2001 From: Elliott Sales de Andrade Date: Mon, 25 Nov 2019 07:17:48 -0500 Subject: [PATCH 0564/1550] Fix dev requirements for pytest. (#3264) The minimum version specified in `setup.cfg` is newer than `dev-requirements.txt`. --- dev-requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index 3c4cf7954a3..cd79b3e4317 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -7,6 +7,6 @@ pyzmq >= 16.0.2 ipython >= 5.0.0 jupyter_client >= 4.4.0 ipykernel >= 4.5.2 -pytest >= 3.0.5 +pytest >= 3.2 prometheus_client >= 0.6.0 jupyter-server-proxy >= 1.1.0 From a285267d3c4042cc3aada675d6dabb793e43c500 Mon Sep 17 00:00:00 2001 From: jakirkham Date: Mon, 25 Nov 2019 15:55:53 -0500 Subject: [PATCH 0565/1550] Use `DeviceBuffer` from newer RMM releases (#3261) If a newer version of RMM is around, use `DeviceBuffer` instead. Otherwise fallback to `device_array`. This is significantly faster to allocate. --- distributed/comm/ucx.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index 58b16eaaf7f..2bdcff5e958 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -46,7 +46,10 @@ def init_once(): try: import rmm - cuda_array = lambda n: rmm.device_array(n, dtype=np.uint8) + if hasattr(rmm, "DeviceBuffer"): + cuda_array = lambda n: rmm.DeviceBuffer(size=n) + else: # pre-0.12.0 + cuda_array = lambda n: rmm.device_array(n, dtype=np.uint8) except ImportError: try: import numba.cuda From 1d9aaac6c67e6a6e7b49d3d1ec3216a4994482f1 Mon Sep 17 00:00:00 2001 From: fjetter Date: Tue, 26 Nov 2019 16:27:34 +0100 Subject: [PATCH 0566/1550] Robust gather in case of connection failures (#3246) --- distributed/comm/core.py | 2 +- distributed/scheduler.py | 14 ++- distributed/tests/test_scheduler.py | 145 ++++++++++++++++++++++++++- distributed/tests/test_utils_comm.py | 43 +++++++- distributed/worker.py | 59 +++++++---- 5 files changed, 232 insertions(+), 31 deletions(-) diff --git a/distributed/comm/core.py b/distributed/comm/core.py index 39a8b123cd3..0befb36d712 100644 --- a/distributed/comm/core.py +++ b/distributed/comm/core.py @@ -59,7 +59,7 @@ def read(self, deserializers=None): """ @abstractmethod - def write(self, msg, on_error=None): + def write(self, msg, serializers=None, on_error=None): """ Write a message (a Python object). diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 94bb5ea756f..26b4cf1c67e 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -2547,7 +2547,7 @@ async def gather(self, comm=None, keys=None, serializers=None): (self.tasks[key].state if key in self.tasks else None) for key in missing_keys ] - logger.debug( + logger.exception( "Couldn't gather keys %s state: %s workers: %s", missing_keys, missing_states, @@ -2555,17 +2555,21 @@ async def gather(self, comm=None, keys=None, serializers=None): ) result = {"status": "error", "keys": missing_keys} with log_errors(): + # Remove suspicious workers from the scheduler but allow them to + # reconnect. for worker in missing_workers: - self.remove_worker(address=worker) # this is extreme + self.remove_worker(address=worker, close=False) for key, workers in missing_keys.items(): - if not workers: - continue - ts = self.tasks[key] + # Task may already be gone if it was held by a + # `missing_worker` + ts = self.tasks.get(key) logger.exception( "Workers don't have promised key: %s, %s", str(workers), str(key), ) + if not workers or ts is None: + continue for worker in workers: ws = self.workers.get(worker) if ws is not None and ts in ws.has_what: diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index f57dbfb9e07..4e0e9a8710c 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -17,7 +17,8 @@ import pytest from distributed import Nanny, Worker, Client, wait, fire_and_forget -from distributed.core import connect, rpc +from distributed.comm import Comm +from distributed.core import connect, rpc, ConnectionPool from distributed.scheduler import Scheduler, TaskState from distributed.client import wait from distributed.metrics import time @@ -1704,3 +1705,145 @@ async def test_no_danglng_asyncio_tasks(cleanup): tasks = asyncio.all_tasks() assert tasks == start + + +class BrokenComm(Comm): + peer_address = None + local_address = None + + def close(self): + pass + + def closed(self): + pass + + def abort(self): + pass + + def read(self, deserializers=None): + raise EnvironmentError + + def write(self, msg, serializers=None, on_error=None): + raise EnvironmentError + + +class FlakyConnectionPool(ConnectionPool): + def __init__(self, *args, failing_connections=0, **kwargs): + self.cnn_count = 0 + self.failing_connections = failing_connections + super(FlakyConnectionPool, self).__init__(*args, **kwargs) + + async def connect(self, *args, **kwargs): + self.cnn_count += 1 + if self.cnn_count > self.failing_connections: + return await super(FlakyConnectionPool, self).connect(*args, **kwargs) + else: + return BrokenComm() + + +@gen_cluster(client=True) +async def test_gather_failing_cnn_recover(c, s, a, b): + orig_rpc = s.rpc + x = await c.scatter({"x": 1}, workers=a.address) + + s.rpc = FlakyConnectionPool(failing_connections=1) + res = await s.gather(keys=["x"]) + assert res["status"] == "OK" + + +@gen_cluster(client=True) +async def test_gather_failing_cnn_error(c, s, a, b): + orig_rpc = s.rpc + x = await c.scatter({"x": 1}, workers=a.address) + + s.rpc = FlakyConnectionPool(failing_connections=10) + res = await s.gather(keys=["x"]) + assert res["status"] == "error" + assert list(res["keys"]) == ["x"] + + +@gen_cluster(client=True) +async def test_gather_no_workers(c, s, a, b): + await asyncio.sleep(1) + x = await c.scatter({"x": 1}, workers=a.address) + + await a.close() + await b.close() + + res = await s.gather(keys=["x"]) + assert res["status"] == "error" + assert list(res["keys"]) == ["x"] + + +@gen_cluster(client=True, client_kwargs={"direct_to_workers": False}) +async def test_gather_allow_worker_reconnect(c, s, a, b): + """ + Test that client resubmissions allow failed workers to reconnect and re-use + their results. Failure scenario would be a connection issue during result + gathering. + Upon connection failure, the worker is flagged as suspicious and removed + from the scheduler. If the worker is healthy and reconnencts we want to use + its results instead of recomputing them. + """ + # GH3246 + ALREADY_CALCULATED = [] + + import time + + def inc_slow(x): + # Once the graph below is rescheduled this computation runs again. We + # need to sleep for at least 0.5 seconds to give the worker a chance to + # reconnect (Heartbeat timing) + if x in ALREADY_CALCULATED: + time.sleep(0.5) + ALREADY_CALCULATED.append(x) + return x + 1 + + x = c.submit(inc_slow, 1) + y = c.submit(inc_slow, 2) + + def reducer(x, y): + return x + y + + z = c.submit(reducer, x, y) + + s.rpc = FlakyConnectionPool(failing_connections=4) + + with captured_logger(logging.getLogger("distributed.scheduler")) as sched_logger: + with captured_logger(logging.getLogger("distributed.client")) as client_logger: + with captured_logger( + logging.getLogger("distributed.worker") + ) as worker_logger: + # Gather using the client (as an ordinary user would) + # Upon a missing key, the client will reschedule the computations + res = await c.gather(z) + + assert res == 5 + + sched_logger = sched_logger.getvalue() + client_logger = client_logger.getvalue() + worker_logger = worker_logger.getvalue() + + # Ensure that the communication was done via the scheduler, i.e. we actually hit a bad connection + assert s.rpc.cnn_count > 0 + + assert "Encountered connection issue during data collection" in worker_logger + + # The reducer task was actually not found upon first collection. The client will reschedule the graph + assert "Couldn't gather 1 keys, rescheduling" in client_logger + # There will also be a `Unexpected worker completed task` message but this + # is rather an artifact and not the intention + assert "Workers don't have promised key" in sched_logger + + # Once the worker reconnects, it will also submit the keys it holds such + # that the scheduler again knows about the result. + # The final reduce step should then be used from the re-connected worker + # instead of recomputing it. + + starts = [] + finish_processing_transitions = 0 + for transition in s.transition_log: + key, start, finish, recommendations, timestamp = transition + if "reducer" in key and finish == "processing": + finish_processing_transitions += 1 + assert finish_processing_transitions == 1 diff --git a/distributed/tests/test_utils_comm.py b/distributed/tests/test_utils_comm.py index 224b4b7f181..f66d3ba62d5 100644 --- a/distributed/tests/test_utils_comm.py +++ b/distributed/tests/test_utils_comm.py @@ -1,6 +1,5 @@ -import pytest - -from distributed.core import rpc +from distributed.core import ConnectionPool +from distributed.comm import Comm from distributed.utils_test import gen_cluster from distributed.utils_comm import pack_data, gather_from_workers @@ -12,9 +11,9 @@ def test_pack_data(): assert pack_data({"a": ["x"], "b": "y"}, data) == {"a": [1], "b": "y"} -@pytest.mark.xfail(reason="rpc now needs to be a connection pool") @gen_cluster(client=True) def test_gather_from_workers_permissive(c, s, a, b): + rpc = ConnectionPool() x = yield c.scatter({"x": 1}, workers=a.address) data, missing, bad_workers = yield gather_from_workers( @@ -23,3 +22,39 @@ def test_gather_from_workers_permissive(c, s, a, b): assert data == {"x": 1} assert list(missing) == ["y"] + + +class BrokenComm(Comm): + peer_address = None + local_address = None + + def close(self): + pass + + def closed(self): + pass + + def abort(self): + pass + + def read(self, deserializers=None): + raise EnvironmentError + + def write(self, msg, serializers=None, on_error=None): + raise EnvironmentError + + +class BrokenConnectionPool(ConnectionPool): + async def connect(self, *args, **kwargs): + return BrokenComm() + + +@gen_cluster(client=True) +def test_gather_from_workers_permissive_flaky(c, s, a, b): + x = yield c.scatter({"x": 1}, workers=a.address) + + rpc = BrokenConnectionPool() + data, missing, bad_workers = yield gather_from_workers({"x": [a.address]}, rpc=rpc) + + assert missing == {"x": [a.address]} + assert bad_workers == [a.address] diff --git a/distributed/worker.py b/distributed/worker.py index 76ee7a723d5..fb46dc2391c 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -3125,27 +3125,46 @@ async def get_data_from_worker( if deserializers is None: deserializers = rpc.deserializers - comm = await rpc.connect(worker) - comm.name = "Ephemeral Worker->Worker for gather" - try: - response = await send_recv( - comm, - serializers=serializers, - deserializers=deserializers, - op="get_data", - keys=keys, - who=who, - max_connections=max_connections, - ) + retry_count = 0 + max_retries = 3 + + while True: + comm = await rpc.connect(worker) + comm.name = "Ephemeral Worker->Worker for gather" try: - status = response["status"] - except KeyError: - raise ValueError("Unexpected response", response) - else: - if status == "OK": - await comm.write("OK") - finally: - rpc.reuse(worker, comm) + response = await send_recv( + comm, + serializers=serializers, + deserializers=deserializers, + op="get_data", + keys=keys, + who=who, + max_connections=max_connections, + ) + try: + status = response["status"] + except KeyError: + raise ValueError("Unexpected response", response) + else: + if status == "OK": + await comm.write("OK") + break + except (EnvironmentError, CommClosedError): + if retry_count < max_retries: + await asyncio.sleep(0.1 * (2 ** retry_count)) + retry_count += 1 + logger.info( + "Encountered connection issue during data collection of keys %s on worker %s. Retrying (%s / %s)", + keys, + worker, + retry_count, + max_retries, + ) + continue + else: + raise + finally: + rpc.reuse(worker, comm) return response From ec29c04a317cf5d0c418ee9803723839e04a16d2 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 26 Nov 2019 08:17:08 -0800 Subject: [PATCH 0567/1550] Use base-2 values for byte-valued axes in dashboard (#3267) --- distributed/dashboard/components/scheduler.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/distributed/dashboard/components/scheduler.py b/distributed/dashboard/components/scheduler.py index 7f226bf7ddc..a02ca73b63d 100644 --- a/distributed/dashboard/components/scheduler.py +++ b/distributed/dashboard/components/scheduler.py @@ -24,6 +24,7 @@ LinearAxis, NumeralTickFormatter, BoxZoomTool, + AdaptiveTicker, BasicTicker, NumberFormatter, BoxSelectTool, @@ -241,6 +242,7 @@ def __init__(self, scheduler, **kwargs): ) self.root.xaxis[0].formatter = NumeralTickFormatter(format="0.0 b") + self.root.xaxis.ticker = AdaptiveTicker(mantissas=[1, 256, 512], base=1024) self.root.xaxis.major_label_orientation = -math.pi / 12 self.root.xaxis.minor_tick_line_alpha = 0 @@ -303,6 +305,7 @@ def __init__(self, scheduler, **kwargs): ) fig.x_range.start = 0 fig.xaxis[0].formatter = NumeralTickFormatter(format="0.0 b") + fig.xaxis.ticker = AdaptiveTicker(mantissas=[1, 256, 512], base=1024) rect.nonselection_glyph = None fig.xaxis.minor_tick_line_alpha = 0 @@ -457,6 +460,7 @@ def __init__(self, scheduler, **kwargs): source=self.source, x="name", top="nbytes", width=0.9, color="color" ) fig.yaxis[0].formatter = NumeralTickFormatter(format="0.0 b") + fig.yaxis.ticker = AdaptiveTicker(mantissas=[1, 256, 512], base=1024) fig.xaxis.major_label_orientation = -math.pi / 12 rect.nonselection_glyph = None From 6af261770072313392b0a4cac90d01bfb68f078c Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 26 Nov 2019 08:17:41 -0800 Subject: [PATCH 0568/1550] Set x_range in CPU plot based on the number of threads (#3266) --- distributed/dashboard/components/scheduler.py | 19 +++++++++---------- .../dashboard/tests/test_scheduler_bokeh.py | 2 ++ 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/distributed/dashboard/components/scheduler.py b/distributed/dashboard/components/scheduler.py index a02ca73b63d..d5aa2c1d338 100644 --- a/distributed/dashboard/components/scheduler.py +++ b/distributed/dashboard/components/scheduler.py @@ -19,7 +19,6 @@ Range1d, Plot, Quad, - Span, value, LinearAxis, NumeralTickFormatter, @@ -576,6 +575,7 @@ def __init__(self, scheduler, width=600, **kwargs): id="bk-cpu-worker-plot", width=int(width / 2), name="cpu_hist", + x_range=(0, None), **kwargs, ) rect = cpu.rect( @@ -587,21 +587,13 @@ def __init__(self, scheduler, width=600, **kwargs): color="blue", ) rect.nonselection_glyph = None - hundred_span = Span( - location=100, - dimension="height", - line_color="gray", - line_dash="dashed", - line_width=3, - ) - cpu.add_layout(hundred_span) nbytes.axis[0].ticker = BasicTicker(mantissas=[1, 256, 512], base=1024) nbytes.xaxis[0].formatter = NumeralTickFormatter(format="0.0 b") nbytes.xaxis.major_label_orientation = -math.pi / 12 nbytes.x_range.start = 0 - for fig in [processing, nbytes]: + for fig in [processing, nbytes, cpu]: fig.xaxis.minor_tick_line_alpha = 0 fig.yaxis.visible = False fig.ygrid.visible = False @@ -698,6 +690,13 @@ def update(self): sum(nbytes) ) self.nbytes_figure.x_range.end = max_limit + if self.scheduler.workers: + self.cpu_figure.x_range.end = ( + max(ws.nthreads or 1 for ws in self.scheduler.workers.values()) + * 100 + ) + else: + self.cpu_figure.x_range.end = 100 update(self.source, result) diff --git a/distributed/dashboard/tests/test_scheduler_bokeh.py b/distributed/dashboard/tests/test_scheduler_bokeh.py index 14d055baa7f..fe92c805efa 100644 --- a/distributed/dashboard/tests/test_scheduler_bokeh.py +++ b/distributed/dashboard/tests/test_scheduler_bokeh.py @@ -274,6 +274,8 @@ def test_CurrentLoad(c, s, a, b): assert all(len(L) == 2 for L in d.values()) assert all(d["nbytes"]) + assert cl.cpu_figure.x_range.end == 200 + @gen_cluster(client=True) def test_ProcessingHistogram(c, s, a, b): From 856bba7ca913163493d406a87205629f1d021a23 Mon Sep 17 00:00:00 2001 From: Jacob Tomlinson Date: Tue, 26 Nov 2019 17:07:11 +0000 Subject: [PATCH 0569/1550] Fix layout scaling on profile plots (#3268) --- distributed/dashboard/components/scheduler.py | 4 ++-- distributed/dashboard/components/shared.py | 9 +++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/distributed/dashboard/components/scheduler.py b/distributed/dashboard/components/scheduler.py index d5aa2c1d338..5b01a99f7a6 100644 --- a/distributed/dashboard/components/scheduler.py +++ b/distributed/dashboard/components/scheduler.py @@ -1957,7 +1957,7 @@ def individual_memory_by_key_doc(scheduler, extra, doc): def profile_doc(scheduler, extra, doc): with log_errors(): doc.title = "Dask: Profile" - prof = ProfileTimePlot(scheduler, sizing_mode="scale_width", doc=doc) + prof = ProfileTimePlot(scheduler, sizing_mode="stretch_both", doc=doc) doc.add_root(prof.root) doc.template = env.get_template("simple.html") doc.template_variables.update(extra) @@ -1969,7 +1969,7 @@ def profile_doc(scheduler, extra, doc): def profile_server_doc(scheduler, extra, doc): with log_errors(): doc.title = "Dask: Profile of Event Loop" - prof = ProfileServer(scheduler, sizing_mode="scale_width", doc=doc) + prof = ProfileServer(scheduler, sizing_mode="stretch_both", doc=doc) doc.add_root(prof.root) doc.template = env.get_template("simple.html") doc.template_variables.update(extra) diff --git a/distributed/dashboard/components/shared.py b/distributed/dashboard/components/shared.py index 882db411434..ec3a207fc8b 100644 --- a/distributed/dashboard/components/shared.py +++ b/distributed/dashboard/components/shared.py @@ -312,12 +312,12 @@ def cb(attr, old, new): self.ts_source = ColumnDataSource({"time": [], "count": []}) self.ts_plot = figure( title="Activity over time", - height=100, + height=150, x_axis_type="datetime", active_drag="xbox_select", y_range=[0, 1 / profile_interval], tools="xpan,xwheel_zoom,xbox_select,reset", - **kwargs + sizing_mode="stretch_width", ) self.ts_plot.line("time", "count", source=self.ts_source) self.ts_plot.circle( @@ -367,6 +367,7 @@ def select_cb(attr, old, new): self.reset_button, self.update_button, sizing_mode="scale_width", + height=250, ), self.profile_plot, self.ts_plot, @@ -464,12 +465,12 @@ def cb(attr, old, new): self.ts_source = ColumnDataSource({"time": [], "count": []}) self.ts_plot = figure( title="Activity over time", - height=100, + height=150, x_axis_type="datetime", active_drag="xbox_select", y_range=[0, 1 / profile_interval], tools="xpan,xwheel_zoom,xbox_select,reset", - **kwargs + sizing_mode="stretch_width", ) self.ts_plot.line("time", "count", source=self.ts_source) self.ts_plot.circle( From b17a1ad04168dd12dac3ce5f31a13226d3d7e52a Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 26 Nov 2019 14:45:14 -0800 Subject: [PATCH 0570/1550] Add offload size to configuration (#3270) --- distributed/comm/utils.py | 10 +++++++--- distributed/distributed.yaml | 1 + 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/distributed/comm/utils.py b/distributed/comm/utils.py index 5b15d5c798c..4862aace207 100644 --- a/distributed/comm/utils.py +++ b/distributed/comm/utils.py @@ -1,7 +1,9 @@ import logging import socket +import dask from dask.sizeof import sizeof +from dask.utils import parse_bytes from .. import protocol from ..utils import get_ip, get_ipv6, nbytes, offload @@ -13,7 +15,9 @@ # Offload (de)serializing large frames to improve event loop responsiveness. # We use at most 4 threads to allow for parallel processing of large messages. -FRAME_OFFLOAD_THRESHOLD = 10 * 1024 ** 2 # 10 MB +FRAME_OFFLOAD_THRESHOLD = dask.config.get("distributed.comm.offload") +if isinstance(FRAME_OFFLOAD_THRESHOLD, str): + FRAME_OFFLOAD_THRESHOLD = parse_bytes(FRAME_OFFLOAD_THRESHOLD) async def to_frames(msg, serializers=None, on_error="message", context=None): @@ -33,7 +37,7 @@ def _to_frames(): logger.exception(e) raise - if sizeof(msg) > FRAME_OFFLOAD_THRESHOLD: + if FRAME_OFFLOAD_THRESHOLD and sizeof(msg) > FRAME_OFFLOAD_THRESHOLD: return await offload(_to_frames) else: return _to_frames() @@ -59,7 +63,7 @@ def _from_frames(): logger.error("truncated data stream (%d bytes): %s", size, datastr) raise - if deserialize and size > FRAME_OFFLOAD_THRESHOLD: + if deserialize and FRAME_OFFLOAD_THRESHOLD and size > FRAME_OFFLOAD_THRESHOLD: res = await offload(_from_frames) else: res = _from_frames() diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index 59bf6f8dc1c..ae42162bb2f 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -73,6 +73,7 @@ distributed: comm: compression: auto + offload: 10MiB # Size after which we choose to offload serialization to another thread default-scheme: tcp socket-backlog: 2048 recent-messages-log-length: 0 # number of messages to keep for debugging From 71c998d23b879c6fd8543c3937847318a23c5447 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 26 Nov 2019 14:45:48 -0800 Subject: [PATCH 0571/1550] Remove memory use plot (#3269) This was replaced by another more traditional bar chart --- distributed/dashboard/components/__init__.py | 1 - distributed/dashboard/components/scheduler.py | 90 +------------------ distributed/dashboard/components/shared.py | 77 ---------------- distributed/dashboard/scheduler.py | 2 - .../dashboard/tests/test_components.py | 3 +- .../dashboard/tests/test_scheduler_bokeh.py | 14 --- distributed/diagnostics/progress_stream.py | 46 ---------- .../diagnostics/tests/test_progress_stream.py | 24 +---- 8 files changed, 3 insertions(+), 254 deletions(-) diff --git a/distributed/dashboard/components/__init__.py b/distributed/dashboard/components/__init__.py index 12f57b352b1..a66be2eced6 100644 --- a/distributed/dashboard/components/__init__.py +++ b/distributed/dashboard/components/__init__.py @@ -29,7 +29,6 @@ import toolz from distributed.dashboard.utils import without_property_validation, BOKEH_VERSION -from distributed.diagnostics.progress_stream import nbytes_bar from distributed import profile from distributed.utils import log_errors, parse_timedelta diff --git a/distributed/dashboard/components/scheduler.py b/distributed/dashboard/components/scheduler.py index 5b01a99f7a6..4049860bed5 100644 --- a/distributed/dashboard/components/scheduler.py +++ b/distributed/dashboard/components/scheduler.py @@ -17,10 +17,7 @@ TapTool, OpenURL, Range1d, - Plot, - Quad, value, - LinearAxis, NumeralTickFormatter, BoxZoomTool, AdaptiveTicker, @@ -62,7 +59,7 @@ ) from distributed.metrics import time from distributed.utils import log_errors, format_time, parse_timedelta -from distributed.diagnostics.progress_stream import color_of, progress_quads, nbytes_bar +from distributed.diagnostics.progress_stream import color_of, progress_quads from distributed.diagnostics.progress import AllProgress from distributed.diagnostics.graph_layout import GraphLayout from distributed.diagnostics.task_stream import TaskStreamPlugin @@ -1434,83 +1431,6 @@ def update(self): ) -class MemoryUse(DashboardComponent): - """ The memory usage across the cluster, grouped by task type """ - - def __init__(self, scheduler, **kwargs): - self.scheduler = scheduler - ps = [p for p in scheduler.plugins if isinstance(p, AllProgress)] - if ps: - self.plugin = ps[0] - else: - self.plugin = AllProgress(scheduler) - - self.source = ColumnDataSource( - data=dict( - name=[], - left=[], - right=[], - center=[], - color=[], - percent=[], - MB=[], - text=[], - ) - ) - - self.root = Plot( - id="bk-nbytes-plot", - x_range=DataRange1d(), - y_range=DataRange1d(), - toolbar_location=None, - outline_line_color=None, - **kwargs, - ) - - self.root.add_glyph( - self.source, - Quad( - top=1, - bottom=0, - left="left", - right="right", - fill_color="color", - fill_alpha=1, - ), - ) - - self.root.add_layout(LinearAxis(), "left") - self.root.add_layout(LinearAxis(), "below") - - hover = HoverTool( - point_policy="follow_mouse", - tooltips=""" -
          - Name:  - @name -
          -
          - Percent:  - @percent -
          -
          - MB:  - @MB -
          - """, - ) - self.root.add_tools(hover) - - @without_property_validation - def update(self): - with log_errors(): - nb = nbytes_bar(self.plugin.nbytes) - update(self.source, nb) - self.root.title.text = "Memory Use: %0.2f MB" % ( - sum(self.plugin.nbytes.values()) / 1e6 - ) - - class WorkerTable(DashboardComponent): """ Status of the current workers @@ -1860,14 +1780,6 @@ def individual_nbytes_doc(scheduler, extra, doc): doc.theme = BOKEH_THEME -def individual_memory_use_doc(scheduler, extra, doc): - memory_use = MemoryUse(scheduler, sizing_mode="stretch_both") - memory_use.update() - add_periodic_callback(doc, memory_use, 100) - doc.add_root(memory_use.root) - doc.theme = BOKEH_THEME - - def individual_cpu_doc(scheduler, extra, doc): current_load = CurrentLoad(scheduler, sizing_mode="stretch_both") current_load.update() diff --git a/distributed/dashboard/components/shared.py b/distributed/dashboard/components/shared.py index ec3a207fc8b..d7554e6bb30 100644 --- a/distributed/dashboard/components/shared.py +++ b/distributed/dashboard/components/shared.py @@ -4,12 +4,9 @@ from bokeh.layouts import row, column from bokeh.models import ( ColumnDataSource, - Plot, DataRange1d, - LinearAxis, HoverTool, Range1d, - Quad, Button, Select, NumeralTickFormatter, @@ -26,7 +23,6 @@ BOKEH_VERSION, update, ) -from distributed.diagnostics.progress_stream import nbytes_bar from distributed import profile from distributed.utils import log_errors, parse_timedelta from distributed.compatibility import WINDOWS @@ -41,79 +37,6 @@ profile_interval = parse_timedelta(profile_interval, default="ms") -class MemoryUsage(DashboardComponent): - """ The memory usage across the cluster, grouped by task type """ - - def __init__(self, **kwargs): - self.source = ColumnDataSource( - data=dict( - name=[], - left=[], - right=[], - center=[], - color=[], - percent=[], - MB=[], - text=[], - ) - ) - - self.root = Plot( - id="bk-nbytes-plot", - x_range=DataRange1d(), - y_range=DataRange1d(), - toolbar_location=None, - outline_line_color=None, - **kwargs - ) - - self.root.add_glyph( - self.source, - Quad( - top=1, - bottom=0, - left="left", - right="right", - fill_color="color", - fill_alpha=1, - ), - ) - - self.root.add_layout(LinearAxis(), "left") - self.root.add_layout(LinearAxis(), "below") - - hover = HoverTool( - point_policy="follow_mouse", - tooltips=""" -
          - Name:  - @name -
          -
          - Percent:  - @percent -
          -
          - MB:  - @MB -
          - """, - ) - self.root.add_tools(hover) - - @without_property_validation - def update(self, messages): - with log_errors(): - msg = messages["progress"] - if not msg: - return - nb = nbytes_bar(msg["nbytes"]) - update(self.source, nb) - self.root.title.text = "Memory Use: %0.2f MB" % ( - sum(msg["nbytes"].values()) / 1e6 - ) - - class Processing(DashboardComponent): """ Processing and distribution per core diff --git a/distributed/dashboard/scheduler.py b/distributed/dashboard/scheduler.py index 67d08c50bd0..2c1ec38c4f3 100644 --- a/distributed/dashboard/scheduler.py +++ b/distributed/dashboard/scheduler.py @@ -30,7 +30,6 @@ individual_profile_doc, individual_profile_server_doc, individual_nbytes_doc, - individual_memory_use_doc, individual_cpu_doc, individual_nprocessing_doc, individual_workers_doc, @@ -403,7 +402,6 @@ def listen(self, *args, **kwargs): "/individual-profile": individual_profile_doc, "/individual-profile-server": individual_profile_server_doc, "/individual-nbytes": individual_nbytes_doc, - "/individual-memory-use": individual_memory_use_doc, "/individual-cpu": individual_cpu_doc, "/individual-nprocessing": individual_nprocessing_doc, "/individual-workers": individual_workers_doc, diff --git a/distributed/dashboard/tests/test_components.py b/distributed/dashboard/tests/test_components.py index 195c947bdec..3e6a696cc6b 100644 --- a/distributed/dashboard/tests/test_components.py +++ b/distributed/dashboard/tests/test_components.py @@ -7,14 +7,13 @@ from distributed.utils_test import slowinc, gen_cluster from distributed.dashboard.components.shared import ( - MemoryUsage, Processing, ProfilePlot, ProfileTimePlot, ) -@pytest.mark.parametrize("Component", [MemoryUsage, Processing]) +@pytest.mark.parametrize("Component", [Processing]) def test_basic(Component): c = Component() assert isinstance(c.source, ColumnDataSource) diff --git a/distributed/dashboard/tests/test_scheduler_bokeh.py b/distributed/dashboard/tests/test_scheduler_bokeh.py index fe92c805efa..0f262ec5809 100644 --- a/distributed/dashboard/tests/test_scheduler_bokeh.py +++ b/distributed/dashboard/tests/test_scheduler_bokeh.py @@ -28,7 +28,6 @@ Events, TaskStream, TaskProgress, - MemoryUse, CurrentLoad, ProcessingHistogram, NBytesHistogram, @@ -248,19 +247,6 @@ def test_TaskProgress_empty(c, s, a, b): assert not any(len(v) for v in tp.source.data.values()) -@gen_cluster(client=True) -def test_MemoryUse(c, s, a, b): - mu = MemoryUse(s) - - futures = c.map(slowinc, range(10), delay=0.001) - yield wait(futures) - - mu.update() - d = dict(mu.source.data) - assert all(len(L) == 1 for L in d.values()) - assert d["name"] == ["slowinc"] - - @gen_cluster(client=True) def test_CurrentLoad(c, s, a, b): cl = CurrentLoad(s) diff --git a/distributed/diagnostics/progress_stream.py b/distributed/diagnostics/progress_stream.py index 038237b89e2..d127ecfeb7e 100644 --- a/distributed/diagnostics/progress_stream.py +++ b/distributed/diagnostics/progress_stream.py @@ -55,52 +55,6 @@ async def progress_stream(address, interval): return comm -def nbytes_bar(nbytes): - """ Convert nbytes message into rectangle placements - - >>> nbytes_bar({'inc': 1000, 'dec': 3000}) # doctest: +NORMALIZE_WHITESPACE - {'names': ['dec', 'inc'], - 'left': [0, 0.75], - 'center': [0.375, 0.875], - 'right': [0.75, 1.0]} - """ - total = sum(nbytes.values()) - names = sorted(nbytes) - - d = { - "name": [], - "text": [], - "left": [], - "right": [], - "center": [], - "color": [], - "percent": [], - "MB": [], - } - - if not total: - return d - - right = 0 - for name in names: - left = right - right = nbytes[name] / total + left - center = (right + left) / 2 - d["MB"].append(nbytes[name] / 1000000) - d["percent"].append(round(nbytes[name] / total * 100, 2)) - d["left"].append(left) - d["right"].append(right) - d["center"].append(center) - d["color"].append(color_of(name)) - d["name"].append(name) - if right - left > 0.1: - d["text"].append(name) - else: - d["text"].append("") - - return d - - def progress_quads(msg, nrows=8, ncols=3): """ diff --git a/distributed/diagnostics/tests/test_progress_stream.py b/distributed/diagnostics/tests/test_progress_stream.py index 56da9e974c1..77b3922a42e 100644 --- a/distributed/diagnostics/tests/test_progress_stream.py +++ b/distributed/diagnostics/tests/test_progress_stream.py @@ -4,11 +4,7 @@ from dask import delayed from distributed.client import wait -from distributed.diagnostics.progress_stream import ( - progress_quads, - nbytes_bar, - progress_stream, -) +from distributed.diagnostics.progress_stream import progress_quads, progress_stream from distributed.utils_test import div, gen_cluster, inc @@ -88,24 +84,6 @@ def test_progress_stream(c, s, a, b): yield comm.close() -def test_nbytes_bar(): - nbytes = {"inc": 1000, "dec": 3000} - expected = { - "name": ["dec", "inc"], - "left": [0, 0.75], - "center": [0.375, 0.875], - "right": [0.75, 1.0], - "percent": [75, 25], - "MB": [0.003, 0.001], - "text": ["dec", "inc"], - } - - result = nbytes_bar(nbytes) - color = result.pop("color") - assert len(set(color)) == 2 - assert result == expected - - def test_progress_quads_many_functions(): funcnames = ["fn%d" % i for i in range(1000)] msg = { From 0b68318112b13d70a9cdd741e5db00da2ec6a8f5 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 26 Nov 2019 14:46:07 -0800 Subject: [PATCH 0572/1550] Enable saving profile information from server threads (#3271) Previously we could only save information from the worker threads and had to rely ont the dashboard to effectively operate on the administrative threads of the scheduler and workers. Now we can do both. --- distributed/client.py | 16 ++++++++++++++++ distributed/scheduler.py | 11 ++++++++++- distributed/tests/test_client.py | 16 ++++++++++++++++ distributed/worker.py | 7 +++++-- 4 files changed, 47 insertions(+), 3 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 4f315e89b6d..3b2eb6e2863 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -3259,6 +3259,8 @@ def profile( merge_workers=True, plot=False, filename=None, + server=False, + scheduler=False, ): """ Collect statistical profiling information about recent work @@ -3271,6 +3273,14 @@ def profile( stop: time workers: list List of workers to restrict profile information + server : bool + If true, return the profile of the worker's administrative thread + rather than the worker threads. + This is useful when profiling Dask itself, rather than user code. + scheduler: bool + If true, return the profile information from the scheduler's + administrative thread rather than the workers. + This is useful when profiling Dask's scheduling itself. plot: boolean or string Whether or not to return a plot object filename: str @@ -3293,6 +3303,8 @@ def profile( stop=stop, plot=plot, filename=filename, + server=server, + scheduler=scheduler, ) async def _profile( @@ -3304,6 +3316,8 @@ async def _profile( merge_workers=True, plot=False, filename=None, + server=False, + scheduler=False, ): if isinstance(workers, (str, Number)): workers = [workers] @@ -3314,6 +3328,8 @@ async def _profile( merge_workers=merge_workers, start=start, stop=stop, + server=server, + scheduler=scheduler, ) if filename: diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 26b4cf1c67e..26b6bffa970 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4623,6 +4623,8 @@ async def get_profile( self, comm=None, workers=None, + scheduler=False, + server=False, merge_workers=True, start=None, stop=None, @@ -4632,8 +4634,15 @@ async def get_profile( workers = self.workers else: workers = set(self.workers) & set(workers) + + if scheduler: + return profile.get_profile(self.io_loop.profile, start=start, stop=stop) + results = await asyncio.gather( - *(self.rpc(w).profile(start=start, stop=stop, key=key) for w in workers) + *( + self.rpc(w).profile(start=start, stop=stop, key=key, server=server) + for w in workers + ) ) if merge_workers: diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 73906517045..e97cb023a48 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -5687,5 +5687,21 @@ async def test_futures_of_sorted(c, s, a, b): assert str(k) in str(f) +@gen_cluster(client=True, worker_kwargs={"profile_cycle_interval": "10ms"}) +async def test_profile_server(c, s, a, b): + x = c.map(slowinc, range(10), delay=0.01, workers=a.address) + await wait(x) + + await asyncio.gather( + c.run(slowinc, 1, delay=0.5), c.run_on_scheduler(slowdec, 1, delay=0.5) + ) + + p = await c.profile(server=True) # All worker servers + assert "slowinc" in str(p) + + p = await c.profile(scheduler=True) # Scheduler + assert "slowdec" in str(p) + + if sys.version_info >= (3, 5): from distributed.tests.py3_test_client import * # noqa F401 diff --git a/distributed/worker.py b/distributed/worker.py index fb46dc2391c..8509f317a61 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -2655,12 +2655,15 @@ def trigger_profile(self): if self.digests is not None: self.digests["profile-duration"].add(stop - start) - def get_profile(self, comm=None, start=None, stop=None, key=None): + def get_profile(self, comm=None, start=None, stop=None, key=None, server=False): now = time() + self.scheduler_delay - if key is None: + if server: + history = self.io_loop.profile + elif key is None: history = self.profile_history else: history = [(t, d[key]) for t, d in self.profile_keys_history if key in d] + if start is None: istart = 0 else: From 9a937535379fe786f08ce936962310307d5d99ae Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 27 Nov 2019 09:27:57 -0800 Subject: [PATCH 0573/1550] Make profile coroutines consistent between Scheduler and Worker (#3277) Previously scheduler get_profile(_metadata) methods were asynchronous while worker methods were synchronous. This caused consistency issues with dashboard plots. Now we make the worker methods asynchronous (needlessly) for consistency. --- distributed/dashboard/components/shared.py | 4 ++-- distributed/tests/test_worker.py | 10 +++++----- distributed/worker.py | 6 ++++-- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/distributed/dashboard/components/shared.py b/distributed/dashboard/components/shared.py index d7554e6bb30..611d281dd5e 100644 --- a/distributed/dashboard/components/shared.py +++ b/distributed/dashboard/components/shared.py @@ -321,11 +321,11 @@ def update(self, state, metadata=None): def trigger_update(self, update_metadata=True): async def cb(): with log_errors(): - prof = self.server.get_profile( + prof = await self.server.get_profile( key=self.key, start=self.start, stop=self.stop ) if update_metadata: - metadata = self.server.get_profile_metadata() + metadata = await self.server.get_profile_metadata() else: metadata = None if isinstance(prof, gen.Future): diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index ff523342243..f67701f671a 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1079,7 +1079,7 @@ def test_statistical_profiling_2(c, s, a, b): y = (x + x * 2) - x.sum().persist() yield wait(y) - profile = a.get_profile() + profile = yield a.get_profile() text = str(profile) if profile["count"] and "sum" in text and "random" in text: break @@ -1165,15 +1165,15 @@ def test_statistical_profiling_cycle(c, s, a, b): end = time() assert len(a.profile_history) > 3 - x = a.get_profile(start=time() + 10, stop=time() + 20) + x = yield a.get_profile(start=time() + 10, stop=time() + 20) assert not x["count"] - x = a.get_profile(start=0, stop=time()) + x = yield a.get_profile(start=0, stop=time()) actual = sum(p["count"] for _, p in a.profile_history) + a.profile_recent["count"] - x2 = a.get_profile(start=0, stop=time()) + x2 = yield a.get_profile(start=0, stop=time()) assert x["count"] <= actual <= x2["count"] - y = a.get_profile(start=end - 0.300, stop=time()) + y = yield a.get_profile(start=end - 0.300, stop=time()) assert 0 < y["count"] <= x["count"] diff --git a/distributed/worker.py b/distributed/worker.py index 8509f317a61..1eb8252534d 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -2655,7 +2655,9 @@ def trigger_profile(self): if self.digests is not None: self.digests["profile-duration"].add(stop - start) - def get_profile(self, comm=None, start=None, stop=None, key=None, server=False): + async def get_profile( + self, comm=None, start=None, stop=None, key=None, server=False + ): now = time() + self.scheduler_delay if server: history = self.io_loop.profile @@ -2696,7 +2698,7 @@ def get_profile(self, comm=None, start=None, stop=None, key=None, server=False): return prof - def get_profile_metadata(self, comm=None, start=0, stop=None): + async def get_profile_metadata(self, comm=None, start=0, stop=None): if stop is None: add_recent = True now = time() + self.scheduler_delay From 664d1339cbe159ed06065c211da1ca5f5bd8675f Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 27 Nov 2019 11:10:53 -0800 Subject: [PATCH 0574/1550] Improve bandwidth workers plot (#3273) * Use new maximum for bandwidth workers at every step * Use base 2 values for bandwidth workers plot * Use worker names in bandwidth plot * add an additional decimal place to workerbandwidth colorbar We often have values like 1GB and 2GB, which are too coarse, and so can be repeated * switch around order of workers on y-axis This matches intuition from matrices where the diagonal goes from upper left to lower right --- distributed/dashboard/components/scheduler.py | 20 ++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/distributed/dashboard/components/scheduler.py b/distributed/dashboard/components/scheduler.py index 4049860bed5..2b27c708111 100644 --- a/distributed/dashboard/components/scheduler.py +++ b/distributed/dashboard/components/scheduler.py @@ -383,7 +383,10 @@ def __init__(self, scheduler, **kwargs): border_line_color=None, location=(0, 0), ) - color_bar.formatter = NumeralTickFormatter(format="0 b") + color_bar.formatter = NumeralTickFormatter(format="0.0 b") + color_bar.ticker = AdaptiveTicker( + mantissas=[1, 64, 128, 256, 512], base=1024 + ) fig.add_layout(color_bar, "right") fig.toolbar.logo = None @@ -408,14 +411,21 @@ def update(self): bw = self.scheduler.bandwidth_workers if not bw: return - x, y, value = zip(*[(a, b, c) for (a, b), c in bw.items()]) - if self.color_map.high < max(value): - self.color_map.high = max(value) + def name(address): + ws = self.scheduler.workers[address] + if ws.name is not None: + return str(ws.name) + else: + return address + + x, y, value = zip(*[(name(a), name(b), c) for (a, b), c in bw.items()]) + + self.color_map.high = max(value) factors = list(sorted(set(x + y))) self.fig.x_range.factors = factors - self.fig.y_range.factors = factors + self.fig.y_range.factors = factors[::-1] result = { "source": x, From 447b2c2ac2458090a0b8baf2f11774b89b6d78f9 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 27 Nov 2019 11:11:04 -0800 Subject: [PATCH 0575/1550] Worker profile server (#3274) * Add profile-server button to Worker dashboard navbar * Set worker's profile server plots to stretch_both layout --- distributed/dashboard/components/worker.py | 4 ++-- distributed/dashboard/worker.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/distributed/dashboard/components/worker.py b/distributed/dashboard/components/worker.py index 9dc2b2ec82f..440e7279e3b 100644 --- a/distributed/dashboard/components/worker.py +++ b/distributed/dashboard/components/worker.py @@ -638,7 +638,7 @@ def counters_doc(server, extra, doc): def profile_doc(server, extra, doc): with log_errors(): doc.title = "Dask Worker Profile" - profile = ProfileTimePlot(server, sizing_mode="scale_width", doc=doc) + profile = ProfileTimePlot(server, sizing_mode="stretch_both", doc=doc) profile.trigger_update() doc.add_root(profile.root) @@ -651,7 +651,7 @@ def profile_doc(server, extra, doc): def profile_server_doc(server, extra, doc): with log_errors(): doc.title = "Dask: Profile of Event Loop" - prof = ProfileServer(server, sizing_mode="scale_width", doc=doc) + prof = ProfileServer(server, sizing_mode="stretch_both", doc=doc) doc.add_root(prof.root) doc.template = env.get_template("simple.html") # doc.template_variables['active_page'] = '' diff --git a/distributed/dashboard/worker.py b/distributed/dashboard/worker.py index 99b27557694..5a34a261bf1 100644 --- a/distributed/dashboard/worker.py +++ b/distributed/dashboard/worker.py @@ -30,7 +30,9 @@ BOKEH_THEME = Theme(os.path.join(os.path.dirname(__file__), "theme.yaml")) -template_variables = {"pages": ["status", "system", "profile", "crossfilter"]} +template_variables = { + "pages": ["status", "system", "profile", "crossfilter", "profile-server"] +} class _PrometheusCollector(object): From f7976d9397ab7ed4ce1224779914b18dc7e0f2a9 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 27 Nov 2019 13:51:37 -0800 Subject: [PATCH 0576/1550] Remove dask-submit and dask-remote (#3280) These were seemingly unused for years, and have gotten far out of date. --- distributed/cli/dask_remote.py | 23 ----- distributed/cli/dask_submit.py | 32 ------- distributed/cli/tests/test_dask_remote.py | 15 --- distributed/cli/tests/test_dask_submit.py | 15 --- distributed/submit.py | 94 ------------------- distributed/tests/test_submit_cli.py | 56 ----------- .../tests/test_submit_remote_client.py | 63 ------------- docs/source/index.rst | 1 - docs/source/submitting-applications.rst | 71 -------------- setup.py | 2 - 10 files changed, 372 deletions(-) delete mode 100644 distributed/cli/dask_remote.py delete mode 100644 distributed/cli/dask_submit.py delete mode 100644 distributed/cli/tests/test_dask_remote.py delete mode 100644 distributed/cli/tests/test_dask_submit.py delete mode 100644 distributed/submit.py delete mode 100644 distributed/tests/test_submit_cli.py delete mode 100644 distributed/tests/test_submit_remote_client.py delete mode 100644 docs/source/submitting-applications.rst diff --git a/distributed/cli/dask_remote.py b/distributed/cli/dask_remote.py deleted file mode 100644 index 9fcfe7f3763..00000000000 --- a/distributed/cli/dask_remote.py +++ /dev/null @@ -1,23 +0,0 @@ -import click -from distributed.cli.utils import check_python_3, install_signal_handlers -from distributed.submit import _remote - - -@click.command() -@click.option("--host", type=str, default=None, help="IP or hostname of this server") -@click.option( - "--port", type=int, default=8788, show_default=True, help="Remote Client Port" -) -@click.version_option() -def main(host, port): - _remote(host, port) - - -def go(): - install_signal_handlers() - check_python_3() - main() - - -if __name__ == "__main__": - go() diff --git a/distributed/cli/dask_submit.py b/distributed/cli/dask_submit.py deleted file mode 100644 index 071dd5bbe32..00000000000 --- a/distributed/cli/dask_submit.py +++ /dev/null @@ -1,32 +0,0 @@ -import sys -import click -from tornado import gen -from tornado.ioloop import IOLoop -from distributed.cli.utils import check_python_3, install_signal_handlers -from distributed.submit import _submit - - -@click.command() -@click.argument("remote_client_address", type=str, required=True) -@click.argument("filepath", type=str, required=True) -@click.version_option() -def main(remote_client_address, filepath): - @gen.coroutine - def f(): - stdout, stderr = yield _submit(remote_client_address, filepath) - if stdout: - sys.stdout.write(str(stdout)) - if stderr: - sys.stderr.write(str(stderr)) - - IOLoop.instance().run_sync(f) - - -def go(): - install_signal_handlers() - check_python_3() - main() - - -if __name__ == "__main__": - go() diff --git a/distributed/cli/tests/test_dask_remote.py b/distributed/cli/tests/test_dask_remote.py deleted file mode 100644 index 14da80f949c..00000000000 --- a/distributed/cli/tests/test_dask_remote.py +++ /dev/null @@ -1,15 +0,0 @@ -from click.testing import CliRunner -from distributed.cli.dask_remote import main - - -def test_dask_remote(): - runner = CliRunner() - result = runner.invoke(main, ["--help"]) - assert "--host TEXT IP or hostname of this server" in result.output - assert result.exit_code == 0 - - -def test_version_option(): - runner = CliRunner() - result = runner.invoke(main, ["--version"]) - assert result.exit_code == 0 diff --git a/distributed/cli/tests/test_dask_submit.py b/distributed/cli/tests/test_dask_submit.py deleted file mode 100644 index 8f5f961ea96..00000000000 --- a/distributed/cli/tests/test_dask_submit.py +++ /dev/null @@ -1,15 +0,0 @@ -from click.testing import CliRunner -from distributed.cli.dask_submit import main - - -def test_submit_runs_as_a_cli(): - runner = CliRunner() - result = runner.invoke(main, ["--help"]) - assert result.exit_code == 0 - assert "Usage: main [OPTIONS] REMOTE_CLIENT_ADDRESS FILEPATH" in result.output - - -def test_version_option(): - runner = CliRunner() - result = runner.invoke(main, ["--version"]) - assert result.exit_code == 0 diff --git a/distributed/submit.py b/distributed/submit.py deleted file mode 100644 index 4cd7fb197a9..00000000000 --- a/distributed/submit.py +++ /dev/null @@ -1,94 +0,0 @@ -import logging -import os -import socket -import subprocess -import tempfile -import sys - -from tornado import gen - -from tornado.ioloop import IOLoop - -from .core import rpc, Server -from .security import Security -from .utils import get_ip - - -logger = logging.getLogger("distributed.remote") - - -class RemoteClient(Server): - def __init__( - self, - ip=None, - local_dir=tempfile.mkdtemp(prefix="client-"), - loop=None, - security=None, - **kwargs - ): - self.ip = ip or get_ip() - self.loop = loop or IOLoop.current() - self.local_dir = local_dir - handlers = {"upload_file": self.upload_file, "execute": self.execute} - - self.security = security or Security() - assert isinstance(self.security, Security) - self.listen_args = self.security.get_listen_args("scheduler") - - super(RemoteClient, self).__init__(handlers, io_loop=self.loop, **kwargs) - - @gen.coroutine - def _start(self, port=0): - self.listen(port, listen_args=self.listen_args) - - def start(self, port=0): - self.loop.add_callback(self._start, port) - logger.info("Remote Client is running at {0}:{1}".format(self.ip, port)) - - @gen.coroutine - def execute(self, stream=None, filename=None): - script_path = os.path.join(self.local_dir, filename) - cmd = "{0} {1}".format(sys.executable, script_path) - process = subprocess.Popen( - cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE - ) - out, err = process.communicate() - return_code = process.returncode - raise gen.Return({"stdout": out, "stderr": err, "returncode": return_code}) - - def upload_file(self, stream, filename=None, file_payload=None): - out_filename = os.path.join(self.local_dir, filename) - if isinstance(file_payload, str): - file_payload = file_payload.encode() - with open(out_filename, "wb") as f: - f.write(file_payload) - return {"status": "OK", "nbytes": len(file_payload)} - - @gen.coroutine - def _close(self): - self.stop() - - -def _remote(host, port, loop=IOLoop.current(), client=RemoteClient): - host = host or get_ip() - if ":" in host and port == 8788: - host, port = host.rsplit(":", 1) - port = int(port) - ip = socket.gethostbyname(host) - remote_client = client(ip=ip, loop=loop) - remote_client.start(port=port) - loop.start() - loop.close() - remote_client.stop() - logger.info("End remote client at %s:%d", host, port) - - -@gen.coroutine -def _submit(remote_client_address, filepath, connection_args=None): - rc = rpc(remote_client_address, connection_args=connection_args) - remote_file = os.path.basename(filepath) - with open(filepath, "rb") as f: - bytes_read = f.read() - yield rc.upload_file(filename=remote_file, file_payload=bytes_read) - result = yield rc.execute(filename=remote_file) - raise gen.Return((result["stdout"], result["stderr"])) diff --git a/distributed/tests/test_submit_cli.py b/distributed/tests/test_submit_cli.py deleted file mode 100644 index edc16e0a61e..00000000000 --- a/distributed/tests/test_submit_cli.py +++ /dev/null @@ -1,56 +0,0 @@ -from unittest.mock import Mock - -from tornado import gen -from tornado.ioloop import IOLoop -from distributed.submit import RemoteClient, _submit, _remote -from distributed.utils_test import ( # noqa: F401 - valid_python_script, - invalid_python_script, - loop, -) - - -def test_dask_submit_cli_writes_result_to_stdout(loop, tmpdir, valid_python_script): - @gen.coroutine - def test(): - remote_client = RemoteClient(ip="127.0.0.1", local_dir=str(tmpdir)) - yield remote_client._start() - - out, err = yield _submit( - "127.0.0.1:{0}".format(remote_client.port), str(valid_python_script) - ) - assert b"hello world!" in out - yield remote_client._close() - - loop.run_sync(test, timeout=5) - - -def test_dask_submit_cli_writes_traceback_to_stdout( - loop, tmpdir, invalid_python_script -): - @gen.coroutine - def test(): - remote_client = RemoteClient(ip="127.0.0.1", local_dir=str(tmpdir)) - yield remote_client._start() - - out, err = yield _submit( - "127.0.0.1:{0}".format(remote_client.port), str(invalid_python_script) - ) - assert b"Traceback" in err - yield remote_client._close() - - loop.run_sync(test, timeout=5) - - -def test_cli_runs_remote_client(): - mock_remote_client = Mock(spec=RemoteClient) - mock_ioloop = Mock(spec=IOLoop.current()) - - _remote("127.0.0.1:8799", 8788, loop=mock_ioloop, client=mock_remote_client) - - mock_remote_client.assert_called_once_with(ip="127.0.0.1", loop=mock_ioloop) - mock_remote_client().start.assert_called_once_with(port=8799) - - assert mock_ioloop.start.called - assert mock_ioloop.close.called - assert mock_remote_client().stop.called diff --git a/distributed/tests/test_submit_remote_client.py b/distributed/tests/test_submit_remote_client.py deleted file mode 100644 index e6527d8319b..00000000000 --- a/distributed/tests/test_submit_remote_client.py +++ /dev/null @@ -1,63 +0,0 @@ -import os - -from tornado import gen - -from distributed import rpc -from distributed.submit import RemoteClient -from distributed.utils_test import ( # noqa: F401 - loop, - valid_python_script, - invalid_python_script, -) - - -def test_remote_client_uploads_a_file(loop, tmpdir): - @gen.coroutine - def test(): - remote_client = RemoteClient(ip="127.0.0.1", local_dir=str(tmpdir)) - yield remote_client._start(0) - remote_process = rpc(remote_client.address) - upload = yield remote_process.upload_file( - filename="script.py", file_payload="x=1" - ) - - assert upload == {"status": "OK", "nbytes": 3} - assert tmpdir.join("script.py").read() == "x=1" - - yield remote_client._close() - - loop.run_sync(test, timeout=5) - - -def test_remote_client_execution_outputs_to_stdout(loop, tmpdir): - @gen.coroutine - def test(): - remote_client = RemoteClient(ip="127.0.0.1", local_dir=str(tmpdir)) - yield remote_client._start(0) - rr = rpc(remote_client.address) - yield rr.upload_file(filename="script.py", file_payload='print("hello world!")') - - message = yield rr.execute(filename="script.py") - assert message["stdout"] == b"hello world!" + os.linesep.encode() - assert message["returncode"] == 0 - - yield remote_client._close() - - loop.run_sync(test, timeout=5) - - -def test_remote_client_execution_outputs_stderr(loop, tmpdir, invalid_python_script): - @gen.coroutine - def test(): - remote_client = RemoteClient(ip="127.0.0.1", local_dir=str(tmpdir)) - yield remote_client._start(0) - rr = rpc(remote_client.address) - yield rr.upload_file(filename="script.py", file_payload="a+1") - - message = yield rr.execute(filename="script.py") - assert b"'a' is not defined" in message["stderr"] - assert message["returncode"] == 1 - - yield remote_client._close() - - loop.run_sync(test, timeout=5) diff --git a/docs/source/index.rst b/docs/source/index.rst index ee32738f826..47419e014ec 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -116,7 +116,6 @@ Contents publish queues resources - submitting-applications task-launch tls web diff --git a/docs/source/submitting-applications.rst b/docs/source/submitting-applications.rst deleted file mode 100644 index 8b5ab1d61c8..00000000000 --- a/docs/source/submitting-applications.rst +++ /dev/null @@ -1,71 +0,0 @@ -Submitting Applications -======================= - -The ``dask-submit`` cli can be used to submit an application to the dask cluster -running remotely. If your code depends on resources that can only be access -from cluster running dask, ``dask-submit`` provides a mechanism to send the script -to the cluster for execution from a different machine. - -For example, S3 buckets could not be visible from your local machine and hence any -attempt to create a dask graph from local machine may not work. - - -Submitting dask Applications with ``dask-submit`` -------------------------------------------------- - -In order to remotely submit scripts to the cluster from a local machine or a CI/CD -environment, we need to run a remote client on the same machine as the scheduler:: - - #scheduler machine - dask-remote --port 8788 - - -After making sure the ``dask-remote`` is running, you can submit a script by:: - - #local machine - dask-submit : - - -Some of the commonly used arguments are: - -- ``REMOTE_CLIENT_ADDRESS``: host name where ``dask-remote`` client is running -- ``FILEPATH``: Local path to file containing dask application - -For example, given the following dask application saved in a file called -``script.py``: - -.. code-block:: python - - # script.py - from distributed import Client - - def inc(x): - return x + 1 - - if __name__=='__main__': - client = Client('127.0.0.1:8786') - x = client.submit(inc, 10) - print(x.result()) - - -We can submit this application from a local machine by running:: - - dask-submit : script.py - - -CLI Options ------------ - -.. note:: - - The command line documentation here may differ depending on your installed - version. We recommend referring to the output of ``dask-remote --help`` - and ``dask-submit --help``. - -.. click:: distributed.cli.dask_remote:main - :prog: dask-remote - :show-nested: - -.. click:: distributed.cli.dask_submit:main - :prog: dask-submit - :show-nested: \ No newline at end of file diff --git a/setup.py b/setup.py index 310d5322e98..e8c419cb147 100755 --- a/setup.py +++ b/setup.py @@ -54,8 +54,6 @@ entry_points=""" [console_scripts] dask-ssh=distributed.cli.dask_ssh:go - dask-submit=distributed.cli.dask_submit:go - dask-remote=distributed.cli.dask_remote:go dask-scheduler=distributed.cli.dask_scheduler:go dask-worker=distributed.cli.dask_worker:go """, From 1eb3a9dbdaae96aebb33d9be843906c684ca52f4 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 27 Nov 2019 17:54:10 -0800 Subject: [PATCH 0577/1550] Make Listener.start asynchronous (#3278) Previously the Listener.start method was synchronous, expecting comms to be able to start themselves and get an address immediately. This was a challenge when we tried adding asyncio comms, which needed to await the server creation. Now we make Listener.start asynchronous, and make sure that we await all calls to it within the codebase. --- distributed/comm/core.py | 8 +- distributed/comm/inproc.py | 5 +- distributed/comm/tcp.py | 2 +- distributed/comm/tests/test_comms.py | 627 +++++++++++++-------------- distributed/comm/ucx.py | 2 +- distributed/core.py | 4 +- distributed/nanny.py | 2 +- distributed/scheduler.py | 2 +- distributed/tests/test_batched.py | 174 ++++---- distributed/tests/test_core.py | 330 +++++++------- distributed/tests/test_security.py | 62 ++- distributed/worker.py | 2 +- 12 files changed, 598 insertions(+), 622 deletions(-) diff --git a/distributed/comm/core.py b/distributed/comm/core.py index 0befb36d712..11f74a1aba8 100644 --- a/distributed/comm/core.py +++ b/distributed/comm/core.py @@ -131,7 +131,7 @@ def __repr__(self): class Listener(ABC): @abstractmethod - def start(self): + async def start(self): """ Start listening for incoming connections. """ @@ -157,11 +157,11 @@ def contact_address(self): address such as 'tcp://0.0.0.0:123'. """ - def __enter__(self): - self.start() + async def __aenter__(self): + await self.start() return self - def __exit__(self, *exc): + async def __aexit__(self, *exc): self.stop() diff --git a/distributed/comm/inproc.py b/distributed/comm/inproc.py index e9bed986ea0..e46c2804ed1 100644 --- a/distributed/comm/inproc.py +++ b/distributed/comm/inproc.py @@ -1,3 +1,4 @@ +import asyncio from collections import deque, namedtuple import itertools import logging @@ -265,9 +266,9 @@ async def _listen(self): def connect_threadsafe(self, conn_req): self.loop.add_callback(self.listen_q.put_nowait, conn_req) - def start(self): + async def start(self): self.loop = IOLoop.current() - self.loop.add_callback(self._listen) + self._listen_future = asyncio.ensure_future(self._listen()) self.manager.add_listener(self.address, self) def stop(self): diff --git a/distributed/comm/tcp.py b/distributed/comm/tcp.py index f0a24fe4fb7..40b6e8104b3 100644 --- a/distributed/comm/tcp.py +++ b/distributed/comm/tcp.py @@ -395,7 +395,7 @@ def __init__( self.tcp_server = None self.bound_address = None - def start(self): + async def start(self): self.tcp_server = TCPServer(max_buffer_size=MAX_BUFFER_SIZE, **self.server_args) self.tcp_server.handle_stream = self._handle_stream backlog = int(dask.config.get("distributed.comm.socket-backlog")) diff --git a/distributed/comm/tests/test_comms.py b/distributed/comm/tests/test_comms.py index 5839c3e8871..465be11c7a4 100644 --- a/distributed/comm/tests/test_comms.py +++ b/distributed/comm/tests/test_comms.py @@ -1,3 +1,4 @@ +import asyncio from functools import partial import os import sys @@ -6,13 +7,12 @@ import pytest -from tornado import gen, ioloop, locks, queues +from tornado import ioloop, locks, queues from tornado.concurrent import Future from distributed.metrics import time from distributed.utils import get_ip, get_ipv6 from distributed.utils_test import ( - gen_test, requires_ipv6, has_ipv6, get_cert, @@ -73,21 +73,21 @@ def check_tls_extra(info): ) -@gen.coroutine -def get_comm_pair(listen_addr, listen_args=None, connect_args=None, **kwargs): +@pytest.mark.asyncio +async def get_comm_pair(listen_addr, listen_args=None, connect_args=None, **kwargs): q = queues.Queue() def handle_comm(comm): q.put(comm) listener = listen(listen_addr, handle_comm, connection_args=listen_args, **kwargs) - listener.start() + await listener.start() - comm = yield connect( + comm = await connect( listener.contact_address, connection_args=connect_args, **kwargs ) - serv_comm = yield q.get() - raise gen.Return((comm, serv_comm)) + serv_comm = await q.get() + return (comm, serv_comm) def get_tcp_comm_pair(**kwargs): @@ -103,15 +103,14 @@ def get_inproc_comm_pair(**kwargs): return get_comm_pair("inproc://", **kwargs) -@gen.coroutine -def debug_loop(): +async def debug_loop(): """ Debug helper """ while True: loop = ioloop.IOLoop.current() print(".", loop, loop._handlers) - yield gen.sleep(0.50) + await asyncio.sleep(0.50) # @@ -205,23 +204,22 @@ def test_get_local_address_for(): # -@gen_test() -def test_tcp_specific(): +@pytest.mark.asyncio +async def test_tcp_specific(): """ Test concrete TCP API. """ - @gen.coroutine - def handle_comm(comm): + async def handle_comm(comm): assert comm.peer_address.startswith("tcp://" + host) assert comm.extra_info == {} - msg = yield comm.read() + msg = await comm.read() msg["op"] = "pong" - yield comm.write(msg) - yield comm.close() + await comm.write(msg) + await comm.close() listener = tcp.TCPListener("localhost", handle_comm) - listener.start() + await listener.start() host, port = listener.get_host_port() assert host in ("localhost", "127.0.0.1", "::1") assert port > 0 @@ -229,49 +227,47 @@ def handle_comm(comm): connector = tcp.TCPConnector() l = [] - @gen.coroutine - def client_communicate(key, delay=0): + async def client_communicate(key, delay=0): addr = "%s:%d" % (host, port) - comm = yield connector.connect(addr) + comm = await connector.connect(addr) assert comm.peer_address == "tcp://" + addr assert comm.extra_info == {} - yield comm.write({"op": "ping", "data": key}) + await comm.write({"op": "ping", "data": key}) if delay: - yield gen.sleep(delay) - msg = yield comm.read() + await asyncio.sleep(delay) + msg = await comm.read() assert msg == {"op": "pong", "data": key} l.append(key) - yield comm.close() + await comm.close() - yield client_communicate(key=1234) + await client_communicate(key=1234) # Many clients at once N = 100 futures = [client_communicate(key=i, delay=0.05) for i in range(N)] - yield futures + await asyncio.gather(*futures) assert set(l) == {1234} | set(range(N)) -@gen_test() -def test_tls_specific(): +@pytest.mark.asyncio +async def test_tls_specific(): """ Test concrete TLS API. """ - @gen.coroutine - def handle_comm(comm): + async def handle_comm(comm): assert comm.peer_address.startswith("tls://" + host) check_tls_extra(comm.extra_info) - msg = yield comm.read() + msg = await comm.read() msg["op"] = "pong" - yield comm.write(msg) - yield comm.close() + await comm.write(msg) + await comm.close() server_ctx = get_server_ssl_context() client_ctx = get_client_ssl_context() listener = tcp.TLSListener("localhost", handle_comm, ssl_context=server_ctx) - listener.start() + await listener.start() host, port = listener.get_host_port() assert host in ("localhost", "127.0.0.1", "::1") assert port > 0 @@ -279,31 +275,30 @@ def handle_comm(comm): connector = tcp.TLSConnector() l = [] - @gen.coroutine - def client_communicate(key, delay=0): + async def client_communicate(key, delay=0): addr = "%s:%d" % (host, port) - comm = yield connector.connect(addr, ssl_context=client_ctx) + comm = await connector.connect(addr, ssl_context=client_ctx) assert comm.peer_address == "tls://" + addr check_tls_extra(comm.extra_info) - yield comm.write({"op": "ping", "data": key}) + await comm.write({"op": "ping", "data": key}) if delay: - yield gen.sleep(delay) - msg = yield comm.read() + await asyncio.sleep(delay) + msg = await comm.read() assert msg == {"op": "pong", "data": key} l.append(key) - yield comm.close() + await comm.close() - yield client_communicate(key=1234) + await client_communicate(key=1234) # Many clients at once N = 100 futures = [client_communicate(key=i, delay=0.05) for i in range(N)] - yield futures + await asyncio.gather(*futures) assert set(l) == {1234} | set(range(N)) -@gen_test() -def test_comm_failure_threading(): +@pytest.mark.asyncio +async def test_comm_failure_threading(): """ When we fail to connect, make sure we don't make a lot of threads. @@ -312,40 +307,38 @@ def test_comm_failure_threading(): set for python 3. See github PR #2403 discussion for info. """ - @gen.coroutine - def sleep_for_60ms(): + async def sleep_for_60ms(): max_thread_count = 0 for x in range(60): - yield gen.sleep(0.001) + await asyncio.sleep(0.001) thread_count = threading.active_count() if thread_count > max_thread_count: max_thread_count = thread_count - raise gen.Return(max_thread_count) + return max_thread_count original_thread_count = threading.active_count() # tcp.TCPConnector() sleep_future = sleep_for_60ms() with pytest.raises(IOError): - yield connect("tcp://localhost:28400", 0.052) - max_thread_count = yield sleep_future + await connect("tcp://localhost:28400", 0.052) + max_thread_count = await sleep_future # 2 is the number set by BaseTCPConnector.executor (ThreadPoolExecutor) assert max_thread_count <= 2 + original_thread_count # tcp.TLSConnector() sleep_future = sleep_for_60ms() with pytest.raises(IOError): - yield connect( + await connect( "tls://localhost:28400", 0.052, connection_args={"ssl_context": get_client_ssl_context()}, ) - max_thread_count = yield sleep_future + max_thread_count = await sleep_future assert max_thread_count <= 2 + original_thread_count -@gen.coroutine -def check_inproc_specific(run_client): +async def check_inproc_specific(run_client): """ Test concrete InProc API. """ @@ -356,18 +349,17 @@ def check_inproc_specific(run_client): N_MSGS = 3 - @gen.coroutine - def handle_comm(comm): + async def handle_comm(comm): assert comm.peer_address.startswith("inproc://" + addr_head) client_addresses.add(comm.peer_address) for i in range(N_MSGS): - msg = yield comm.read() + msg = await comm.read() msg["op"] = "pong" - yield comm.write(msg) - yield comm.close() + await comm.write(msg) + await comm.close() listener = inproc.InProcListener(listener_addr, handle_comm) - listener.start() + await listener.start() assert ( listener.listen_address == listener.contact_address @@ -377,29 +369,28 @@ def handle_comm(comm): connector = inproc.InProcConnector(inproc.global_manager) l = [] - @gen.coroutine - def client_communicate(key, delay=0): - comm = yield connector.connect(listener_addr) + async def client_communicate(key, delay=0): + comm = await connector.connect(listener_addr) assert comm.peer_address == "inproc://" + listener_addr for i in range(N_MSGS): - yield comm.write({"op": "ping", "data": key}) + await comm.write({"op": "ping", "data": key}) if delay: - yield gen.sleep(delay) - msg = yield comm.read() + await asyncio.sleep(delay) + msg = await comm.read() assert msg == {"op": "pong", "data": key} l.append(key) with pytest.raises(CommClosedError): - yield comm.read() - yield comm.close() + await comm.read() + await comm.close() client_communicate = partial(run_client, client_communicate) - yield client_communicate(key=1234) + await client_communicate(key=1234) # Many clients at once N = 20 futures = [client_communicate(key=i, delay=0.001) for i in range(N)] - yield futures + await asyncio.gather(*futures) assert set(l) == {1234} | set(range(N)) assert len(client_addresses) == N + 1 @@ -430,14 +421,14 @@ def run(): return fut -@gen_test() -def test_inproc_specific_same_thread(): - yield check_inproc_specific(run_coro) +@pytest.mark.asyncio +async def test_inproc_specific_same_thread(): + await check_inproc_specific(run_coro) -@gen_test() -def test_inproc_specific_different_threads(): - yield check_inproc_specific(run_coro_in_thread) +@pytest.mark.asyncio +async def test_inproc_specific_different_threads(): + await check_inproc_specific(run_coro_in_thread) # @@ -445,8 +436,7 @@ def test_inproc_specific_different_threads(): # -@gen.coroutine -def check_client_server( +async def check_client_server( addr, check_listen_addr=None, check_contact_addr=None, @@ -457,27 +447,26 @@ def check_client_server( Abstract client / server test. """ - @gen.coroutine - def handle_comm(comm): + async def handle_comm(comm): scheme, loc = parse_address(comm.peer_address) assert scheme == bound_scheme - msg = yield comm.read() + msg = await comm.read() assert msg["op"] == "ping" msg["op"] = "pong" - yield comm.write(msg) + await comm.write(msg) - msg = yield comm.read() + msg = await comm.read() assert msg["op"] == "foobar" - yield comm.close() + await comm.close() # Arbitrary connection args should be ignored listen_args = listen_args or {"xxx": "bar"} connect_args = connect_args or {"xxx": "foo"} listener = listen(addr, handle_comm, connection_args=listen_args) - listener.start() + await listener.start() # Check listener properties bound_addr = listener.listen_address @@ -500,37 +489,36 @@ def handle_comm(comm): # Check client <-> server comms l = [] - @gen.coroutine - def client_communicate(key, delay=0): - comm = yield connect(listener.contact_address, connection_args=connect_args) + async def client_communicate(key, delay=0): + comm = await connect(listener.contact_address, connection_args=connect_args) assert comm.peer_address == listener.contact_address - yield comm.write({"op": "ping", "data": key}) - yield comm.write({"op": "foobar"}) + await comm.write({"op": "ping", "data": key}) + await comm.write({"op": "foobar"}) if delay: - yield gen.sleep(delay) - msg = yield comm.read() + await asyncio.sleep(delay) + msg = await comm.read() assert msg == {"op": "pong", "data": key} l.append(key) - yield comm.close() + await comm.close() - yield client_communicate(key=1234) + await client_communicate(key=1234) # Many clients at once futures = [client_communicate(key=i, delay=0.05) for i in range(20)] - yield futures + await asyncio.gather(*futures) assert set(l) == {1234} | set(range(20)) listener.stop() -@gen_test() -def test_ucx_client_server(): +@pytest.mark.asyncio +async def test_ucx_client_server(): pytest.importorskip("distributed.comm.ucx") ucp = pytest.importorskip("ucp") addr = ucp.get_address() - yield check_client_server("ucx://" + addr) + await check_client_server("ucx://" + addr) def tcp_eq(expected_host, expected_port=None): @@ -560,79 +548,79 @@ def checker(loc): return checker -@gen_test() -def test_default_client_server_ipv4(): +@pytest.mark.asyncio +async def test_default_client_server_ipv4(): # Default scheme is (currently) TCP - yield check_client_server("127.0.0.1", tcp_eq("127.0.0.1")) - yield check_client_server("127.0.0.1:3201", tcp_eq("127.0.0.1", 3201)) - yield check_client_server("0.0.0.0", tcp_eq("0.0.0.0"), tcp_eq(EXTERNAL_IP4)) - yield check_client_server( + await check_client_server("127.0.0.1", tcp_eq("127.0.0.1")) + await check_client_server("127.0.0.1:3201", tcp_eq("127.0.0.1", 3201)) + await check_client_server("0.0.0.0", tcp_eq("0.0.0.0"), tcp_eq(EXTERNAL_IP4)) + await check_client_server( "0.0.0.0:3202", tcp_eq("0.0.0.0", 3202), tcp_eq(EXTERNAL_IP4, 3202) ) # IPv4 is preferred for the bound address - yield check_client_server("", tcp_eq("0.0.0.0"), tcp_eq(EXTERNAL_IP4)) - yield check_client_server( + await check_client_server("", tcp_eq("0.0.0.0"), tcp_eq(EXTERNAL_IP4)) + await check_client_server( ":3203", tcp_eq("0.0.0.0", 3203), tcp_eq(EXTERNAL_IP4, 3203) ) @requires_ipv6 -@gen_test() -def test_default_client_server_ipv6(): - yield check_client_server("[::1]", tcp_eq("::1")) - yield check_client_server("[::1]:3211", tcp_eq("::1", 3211)) - yield check_client_server("[::]", tcp_eq("::"), tcp_eq(EXTERNAL_IP6)) - yield check_client_server( +@pytest.mark.asyncio +async def test_default_client_server_ipv6(): + await check_client_server("[::1]", tcp_eq("::1")) + await check_client_server("[::1]:3211", tcp_eq("::1", 3211)) + await check_client_server("[::]", tcp_eq("::"), tcp_eq(EXTERNAL_IP6)) + await check_client_server( "[::]:3212", tcp_eq("::", 3212), tcp_eq(EXTERNAL_IP6, 3212) ) -@gen_test() -def test_tcp_client_server_ipv4(): - yield check_client_server("tcp://127.0.0.1", tcp_eq("127.0.0.1")) - yield check_client_server("tcp://127.0.0.1:3221", tcp_eq("127.0.0.1", 3221)) - yield check_client_server("tcp://0.0.0.0", tcp_eq("0.0.0.0"), tcp_eq(EXTERNAL_IP4)) - yield check_client_server( +@pytest.mark.asyncio +async def test_tcp_client_server_ipv4(): + await check_client_server("tcp://127.0.0.1", tcp_eq("127.0.0.1")) + await check_client_server("tcp://127.0.0.1:3221", tcp_eq("127.0.0.1", 3221)) + await check_client_server("tcp://0.0.0.0", tcp_eq("0.0.0.0"), tcp_eq(EXTERNAL_IP4)) + await check_client_server( "tcp://0.0.0.0:3222", tcp_eq("0.0.0.0", 3222), tcp_eq(EXTERNAL_IP4, 3222) ) - yield check_client_server("tcp://", tcp_eq("0.0.0.0"), tcp_eq(EXTERNAL_IP4)) - yield check_client_server( + await check_client_server("tcp://", tcp_eq("0.0.0.0"), tcp_eq(EXTERNAL_IP4)) + await check_client_server( "tcp://:3223", tcp_eq("0.0.0.0", 3223), tcp_eq(EXTERNAL_IP4, 3223) ) @requires_ipv6 -@gen_test() -def test_tcp_client_server_ipv6(): - yield check_client_server("tcp://[::1]", tcp_eq("::1")) - yield check_client_server("tcp://[::1]:3231", tcp_eq("::1", 3231)) - yield check_client_server("tcp://[::]", tcp_eq("::"), tcp_eq(EXTERNAL_IP6)) - yield check_client_server( +@pytest.mark.asyncio +async def test_tcp_client_server_ipv6(): + await check_client_server("tcp://[::1]", tcp_eq("::1")) + await check_client_server("tcp://[::1]:3231", tcp_eq("::1", 3231)) + await check_client_server("tcp://[::]", tcp_eq("::"), tcp_eq(EXTERNAL_IP6)) + await check_client_server( "tcp://[::]:3232", tcp_eq("::", 3232), tcp_eq(EXTERNAL_IP6, 3232) ) -@gen_test() -def test_tls_client_server_ipv4(): - yield check_client_server("tls://127.0.0.1", tls_eq("127.0.0.1"), **tls_kwargs) - yield check_client_server( +@pytest.mark.asyncio +async def test_tls_client_server_ipv4(): + await check_client_server("tls://127.0.0.1", tls_eq("127.0.0.1"), **tls_kwargs) + await check_client_server( "tls://127.0.0.1:3221", tls_eq("127.0.0.1", 3221), **tls_kwargs ) - yield check_client_server( + await check_client_server( "tls://", tls_eq("0.0.0.0"), tls_eq(EXTERNAL_IP4), **tls_kwargs ) @requires_ipv6 -@gen_test() -def test_tls_client_server_ipv6(): - yield check_client_server("tls://[::1]", tls_eq("::1"), **tls_kwargs) +@pytest.mark.asyncio +async def test_tls_client_server_ipv6(): + await check_client_server("tls://[::1]", tls_eq("::1"), **tls_kwargs) -@gen_test() -def test_inproc_client_server(): - yield check_client_server("inproc://", inproc_check()) - yield check_client_server(inproc.new_address(), inproc_check()) +@pytest.mark.asyncio +async def test_inproc_client_server(): + await check_client_server("inproc://", inproc_check()) + await check_client_server(inproc.new_address(), inproc_check()) # @@ -640,8 +628,8 @@ def test_inproc_client_server(): # -@gen_test() -def test_tls_reject_certificate(): +@pytest.mark.asyncio +async def test_tls_reject_certificate(): cli_ctx = get_client_ssl_context() serv_ctx = get_server_ssl_context() @@ -650,23 +638,22 @@ def test_tls_reject_certificate(): bad_cli_ctx = get_client_ssl_context(*bad_cert_key) bad_serv_ctx = get_server_ssl_context(*bad_cert_key) - @gen.coroutine - def handle_comm(comm): + async def handle_comm(comm): scheme, loc = parse_address(comm.peer_address) assert scheme == "tls" - yield comm.close() + await comm.close() # Listener refuses a connector not signed by the CA listener = listen("tls://", handle_comm, connection_args={"ssl_context": serv_ctx}) - listener.start() + await listener.start() with pytest.raises(EnvironmentError) as excinfo: - comm = yield connect( + comm = await connect( listener.contact_address, timeout=0.5, connection_args={"ssl_context": bad_cli_ctx}, ) - yield comm.write({"x": "foo"}) # TODO: why is this necessary in Tornado 6 ? + await comm.write({"x": "foo"}) # TODO: why is this necessary in Tornado 6 ? # The wrong error is reported on Python 2, see https://github.com/tornadoweb/tornado/pull/2028 if sys.version_info >= (3,) and os.name != "nt": @@ -682,19 +669,19 @@ def handle_comm(comm): raise # Sanity check - comm = yield connect( + comm = await connect( listener.contact_address, timeout=2, connection_args={"ssl_context": cli_ctx} ) - yield comm.close() + await comm.close() # Connector refuses a listener not signed by the CA listener = listen( "tls://", handle_comm, connection_args={"ssl_context": bad_serv_ctx} ) - listener.start() + await listener.start() with pytest.raises(EnvironmentError) as excinfo: - yield connect( + await connect( listener.contact_address, timeout=2, connection_args={"ssl_context": cli_ctx}, @@ -709,130 +696,128 @@ def handle_comm(comm): # -@gen.coroutine -def check_comm_closed_implicit(addr, delay=None, listen_args=None, connect_args=None): - @gen.coroutine - def handle_comm(comm): - yield comm.close() +async def check_comm_closed_implicit( + addr, delay=None, listen_args=None, connect_args=None +): + async def handle_comm(comm): + await comm.close() listener = listen(addr, handle_comm, connection_args=listen_args) - listener.start() + await listener.start() contact_addr = listener.contact_address - comm = yield connect(contact_addr, connection_args=connect_args) + comm = await connect(contact_addr, connection_args=connect_args) with pytest.raises(CommClosedError): - yield comm.write({}) + await comm.write({}) - comm = yield connect(contact_addr, connection_args=connect_args) + comm = await connect(contact_addr, connection_args=connect_args) with pytest.raises(CommClosedError): - yield comm.read() + await comm.read() -@gen_test() -def test_tcp_comm_closed_implicit(): - yield check_comm_closed_implicit("tcp://127.0.0.1") +@pytest.mark.asyncio +async def test_tcp_comm_closed_implicit(): + await check_comm_closed_implicit("tcp://127.0.0.1") -@gen_test() -def test_tls_comm_closed_implicit(): - yield check_comm_closed_implicit("tls://127.0.0.1", **tls_kwargs) +@pytest.mark.asyncio +async def test_tls_comm_closed_implicit(): + await check_comm_closed_implicit("tls://127.0.0.1", **tls_kwargs) -@gen_test() -def test_inproc_comm_closed_implicit(): - yield check_comm_closed_implicit(inproc.new_address()) +@pytest.mark.asyncio +async def test_inproc_comm_closed_implicit(): + await check_comm_closed_implicit(inproc.new_address()) -@gen.coroutine -def check_comm_closed_explicit(addr, listen_args=None, connect_args=None): - a, b = yield get_comm_pair(addr, listen_args=listen_args, connect_args=connect_args) +async def check_comm_closed_explicit(addr, listen_args=None, connect_args=None): + a, b = await get_comm_pair(addr, listen_args=listen_args, connect_args=connect_args) a_read = a.read() b_read = b.read() - yield a.close() + await a.close() # In-flight reads should abort with CommClosedError with pytest.raises(CommClosedError): - yield a_read + await a_read with pytest.raises(CommClosedError): - yield b_read + await b_read # New reads as well with pytest.raises(CommClosedError): - yield a.read() + await a.read() with pytest.raises(CommClosedError): - yield b.read() + await b.read() # And writes with pytest.raises(CommClosedError): - yield a.write({}) + await a.write({}) with pytest.raises(CommClosedError): - yield b.write({}) - yield b.close() + await b.write({}) + await b.close() -@gen_test() -def test_tcp_comm_closed_explicit(): - yield check_comm_closed_explicit("tcp://127.0.0.1") +@pytest.mark.asyncio +async def test_tcp_comm_closed_explicit(): + await check_comm_closed_explicit("tcp://127.0.0.1") -@gen_test() -def test_tls_comm_closed_explicit(): - yield check_comm_closed_explicit("tls://127.0.0.1", **tls_kwargs) +@pytest.mark.asyncio +async def test_tls_comm_closed_explicit(): + await check_comm_closed_explicit("tls://127.0.0.1", **tls_kwargs) -@gen_test() -def test_inproc_comm_closed_explicit(): - yield check_comm_closed_explicit(inproc.new_address()) +@pytest.mark.asyncio +async def test_inproc_comm_closed_explicit(): + await check_comm_closed_explicit(inproc.new_address()) -@gen_test() -def test_inproc_comm_closed_explicit_2(): +@pytest.mark.asyncio +async def test_inproc_comm_closed_explicit_2(): listener_errors = [] - @gen.coroutine - def handle_comm(comm): + async def handle_comm(comm): # Wait try: - yield comm.read() + await comm.read() except CommClosedError: assert comm.closed() listener_errors.append(True) else: - yield comm.close() + await comm.close() listener = listen("inproc://", handle_comm) - listener.start() + await listener.start() contact_addr = listener.contact_address - comm = yield connect(contact_addr) - yield comm.close() + comm = await connect(contact_addr) + await comm.close() assert comm.closed() start = time() while len(listener_errors) < 1: assert time() < start + 1 - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert len(listener_errors) == 1 with pytest.raises(CommClosedError): - yield comm.read() + await comm.read() with pytest.raises(CommClosedError): - yield comm.write("foo") + await comm.write("foo") - comm = yield connect(contact_addr) - yield comm.write("foo") + comm = await connect(contact_addr) + await comm.write("foo") with pytest.raises(CommClosedError): - yield comm.read() + await comm.read() with pytest.raises(CommClosedError): - yield comm.write("foo") + await comm.write("foo") assert comm.closed() - comm = yield connect(contact_addr) - yield comm.write("foo") + comm = await connect(contact_addr) + await comm.write("foo") start = time() while not comm.closed(): - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 2 - yield comm.close() - yield comm.close() + await comm.close() + await comm.close() # @@ -840,28 +825,26 @@ def handle_comm(comm): # -@gen.coroutine -def check_connect_timeout(addr): +async def check_connect_timeout(addr): t1 = time() with pytest.raises(IOError): - yield connect(addr, timeout=0.15) + await connect(addr, timeout=0.15) dt = time() - t1 assert 1 >= dt >= 0.1 -@gen_test() -def test_tcp_connect_timeout(): - yield check_connect_timeout("tcp://127.0.0.1:44444") +@pytest.mark.asyncio +async def test_tcp_connect_timeout(): + await check_connect_timeout("tcp://127.0.0.1:44444") -@gen_test() -def test_inproc_connect_timeout(): - yield check_connect_timeout(inproc.new_address()) +@pytest.mark.asyncio +async def test_inproc_connect_timeout(): + await check_connect_timeout(inproc.new_address()) -def check_many_listeners(addr): - @gen.coroutine - def handle_comm(comm): +async def check_many_listeners(addr): + async def handle_comm(comm): pass listeners = [] @@ -869,7 +852,7 @@ def handle_comm(comm): for i in range(N): listener = listen(addr, handle_comm) - listener.start() + await listener.start() listeners.append(listener) assert len(set(l.listen_address for l in listeners)) == N @@ -879,16 +862,16 @@ def handle_comm(comm): listener.stop() -@gen_test() -def test_tcp_many_listeners(): - check_many_listeners("tcp://127.0.0.1") - check_many_listeners("tcp://0.0.0.0") - check_many_listeners("tcp://") +@pytest.mark.asyncio +async def test_tcp_many_listeners(): + await check_many_listeners("tcp://127.0.0.1") + await check_many_listeners("tcp://0.0.0.0") + await check_many_listeners("tcp://") -@gen_test() -def test_inproc_many_listeners(): - check_many_listeners("inproc://") +@pytest.mark.asyncio +async def test_inproc_many_listeners(): + await check_many_listeners("inproc://") # @@ -896,47 +879,42 @@ def test_inproc_many_listeners(): # -@gen.coroutine -def check_listener_deserialize(addr, deserialize, in_value, check_out): +async def check_listener_deserialize(addr, deserialize, in_value, check_out): q = queues.Queue() - @gen.coroutine - def handle_comm(comm): - msg = yield comm.read() + async def handle_comm(comm): + msg = await comm.read() q.put_nowait(msg) - yield comm.close() + await comm.close() - with listen(addr, handle_comm, deserialize=deserialize) as listener: - comm = yield connect(listener.contact_address) + async with listen(addr, handle_comm, deserialize=deserialize) as listener: + comm = await connect(listener.contact_address) - yield comm.write(in_value) + await comm.write(in_value) - out_value = yield q.get() + out_value = await q.get() check_out(out_value) - yield comm.close() + await comm.close() -@gen.coroutine -def check_connector_deserialize(addr, deserialize, in_value, check_out): +async def check_connector_deserialize(addr, deserialize, in_value, check_out): done = locks.Event() - @gen.coroutine - def handle_comm(comm): - yield comm.write(in_value) - yield done.wait() - yield comm.close() + async def handle_comm(comm): + await comm.write(in_value) + await done.wait() + await comm.close() - with listen(addr, handle_comm) as listener: - comm = yield connect(listener.contact_address, deserialize=deserialize) + async with listen(addr, handle_comm) as listener: + comm = await connect(listener.contact_address, deserialize=deserialize) - out_value = yield comm.read() + out_value = await comm.read() done.set() - yield comm.close() + await comm.close() check_out(out_value) -@gen.coroutine -def check_deserialize(addr): +async def check_deserialize(addr): """ Check the "deserialize" flag on connect() and listen(). """ @@ -979,11 +957,11 @@ def check_out_true(out_value): expected_msg["to_ser"] = [123] assert out_value == expected_msg - yield check_listener_deserialize(addr, False, msg, check_out_false) - yield check_connector_deserialize(addr, False, msg, check_out_false) + await check_listener_deserialize(addr, False, msg, check_out_false) + await check_connector_deserialize(addr, False, msg, check_out_false) - yield check_listener_deserialize(addr, True, msg, check_out_true) - yield check_connector_deserialize(addr, True, msg, check_out_true) + await check_listener_deserialize(addr, True, msg, check_out_true) + await check_connector_deserialize(addr, True, msg, check_out_true) # Test with long bytestrings, large enough to be transferred # as a separate payload @@ -1024,26 +1002,25 @@ def check_out(deserialize_flag, out_value): else: assert to_ser == to_serialize(_uncompressible) - yield check_listener_deserialize(addr, False, msg, partial(check_out, False)) - yield check_connector_deserialize(addr, False, msg, partial(check_out, False)) + await check_listener_deserialize(addr, False, msg, partial(check_out, False)) + await check_connector_deserialize(addr, False, msg, partial(check_out, False)) - yield check_listener_deserialize(addr, True, msg, partial(check_out, True)) - yield check_connector_deserialize(addr, True, msg, partial(check_out, True)) + await check_listener_deserialize(addr, True, msg, partial(check_out, True)) + await check_connector_deserialize(addr, True, msg, partial(check_out, True)) @pytest.mark.xfail(reason="intermittent failure on windows") -@gen_test() -def test_tcp_deserialize(): - yield check_deserialize("tcp://") +@pytest.mark.asyncio +async def test_tcp_deserialize(): + await check_deserialize("tcp://") -@gen_test() -def test_inproc_deserialize(): - yield check_deserialize("inproc://") +@pytest.mark.asyncio +async def test_inproc_deserialize(): + await check_deserialize("inproc://") -@gen.coroutine -def check_deserialize_roundtrip(addr): +async def check_deserialize_roundtrip(addr): """ Sanity check round-tripping with "deserialize" on and off. """ @@ -1059,11 +1036,11 @@ def check_deserialize_roundtrip(addr): } for should_deserialize in (True, False): - a, b = yield get_comm_pair(addr, deserialize=should_deserialize) - yield a.write(msg) - got = yield b.read() - yield b.write(got) - got = yield a.read() + a, b = await get_comm_pair(addr, deserialize=should_deserialize) + await a.write(msg) + got = await b.read() + await b.write(got) + got = await a.read() assert sorted(got) == sorted(msg) for k in ("op", "x"): @@ -1076,14 +1053,14 @@ def check_deserialize_roundtrip(addr): assert isinstance(got["ser"], Serialized) -@gen_test() -def test_inproc_deserialize_roundtrip(): - yield check_deserialize_roundtrip("inproc://") +@pytest.mark.asyncio +async def test_inproc_deserialize_roundtrip(): + await check_deserialize_roundtrip("inproc://") -@gen_test() -def test_tcp_deserialize_roundtrip(): - yield check_deserialize_roundtrip("tcp://") +@pytest.mark.asyncio +async def test_tcp_deserialize_roundtrip(): + await check_deserialize_roundtrip("tcp://") def _raise_eoferror(): @@ -1095,27 +1072,25 @@ def __reduce__(self): return _raise_eoferror, () -@gen.coroutine -def check_deserialize_eoferror(addr): +async def check_deserialize_eoferror(addr): """ EOFError when deserializing should close the comm. """ - @gen.coroutine - def handle_comm(comm): - yield comm.write({"data": to_serialize(_EOFRaising())}) + async def handle_comm(comm): + await comm.write({"data": to_serialize(_EOFRaising())}) with pytest.raises(CommClosedError): - yield comm.read() + await comm.read() - with listen(addr, handle_comm) as listener: - comm = yield connect(listener.contact_address, deserialize=deserialize) + async with listen(addr, handle_comm) as listener: + comm = await connect(listener.contact_address, deserialize=deserialize) with pytest.raises(CommClosedError): - yield comm.read() + await comm.read() -@gen_test() -def test_tcp_deserialize_eoferror(): - yield check_deserialize_eoferror("tcp://") +@pytest.mark.asyncio +async def test_tcp_deserialize_eoferror(): + await check_deserialize_eoferror("tcp://") # @@ -1123,61 +1098,59 @@ def test_tcp_deserialize_eoferror(): # -@gen.coroutine -def check_repr(a, b): +async def check_repr(a, b): assert "closed" not in repr(a) assert "closed" not in repr(b) - yield a.close() + await a.close() assert "closed" in repr(a) - yield b.close() + await b.close() assert "closed" in repr(b) -@gen_test() -def test_tcp_repr(): - a, b = yield get_tcp_comm_pair() +@pytest.mark.asyncio +async def test_tcp_repr(): + a, b = await get_tcp_comm_pair() assert a.local_address in repr(b) assert b.local_address in repr(a) - yield check_repr(a, b) + await check_repr(a, b) -@gen_test() -def test_tls_repr(): - a, b = yield get_tls_comm_pair() +@pytest.mark.asyncio +async def test_tls_repr(): + a, b = await get_tls_comm_pair() assert a.local_address in repr(b) assert b.local_address in repr(a) - yield check_repr(a, b) + await check_repr(a, b) -@gen_test() -def test_inproc_repr(): - a, b = yield get_inproc_comm_pair() +@pytest.mark.asyncio +async def test_inproc_repr(): + a, b = await get_inproc_comm_pair() assert a.local_address in repr(b) assert b.local_address in repr(a) - yield check_repr(a, b) + await check_repr(a, b) -@gen.coroutine -def check_addresses(a, b): +async def check_addresses(a, b): assert a.peer_address == b.local_address assert a.local_address == b.peer_address a.abort() b.abort() -@gen_test() -def test_tcp_adresses(): - a, b = yield get_tcp_comm_pair() - yield check_addresses(a, b) +@pytest.mark.asyncio +async def test_tcp_adresses(): + a, b = await get_tcp_comm_pair() + await check_addresses(a, b) -@gen_test() -def test_tls_adresses(): - a, b = yield get_tls_comm_pair() - yield check_addresses(a, b) +@pytest.mark.asyncio +async def test_tls_adresses(): + a, b = await get_tls_comm_pair() + await check_addresses(a, b) -@gen_test() -def test_inproc_adresses(): - a, b = yield get_inproc_comm_pair() - yield check_addresses(a, b) +@pytest.mark.asyncio +async def test_inproc_adresses(): + a, b = await get_inproc_comm_pair() + await check_addresses(a, b) diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index 2bdcff5e958..7c783b605a1 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -261,7 +261,7 @@ def port(self): def address(self): return "ucx://" + self.ip + ":" + str(self.port) - def start(self): + async def start(self): async def serve_forever(client_ep): ucx = UCX( client_ep, diff --git a/distributed/core.py b/distributed/core.py index 3e1a3b47cc3..ec6e9969fea 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -295,7 +295,7 @@ def port(self): def identity(self, comm=None): return {"type": type(self).__name__, "id": self.id} - def listen(self, port_or_addr=None, listen_args=None): + async def listen(self, port_or_addr=None, listen_args=None): if port_or_addr is None: port_or_addr = self.default_port if isinstance(port_or_addr, int): @@ -311,7 +311,7 @@ def listen(self, port_or_addr=None, listen_args=None): deserialize=self.deserialize, connection_args=listen_args, ) - self.listener.start() + await self.listener.start() async def handle_comm(self, comm, shutting_down=shutting_down): """ Dispatch new communications to coroutine-handlers diff --git a/distributed/nanny.py b/distributed/nanny.py index 11cf0157c10..dc2e8a3ea48 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -244,7 +244,7 @@ def local_dir(self): async def start(self): """ Start nanny, start local process, start watching """ - self.listen(self._start_address, listen_args=self.listen_args) + await self.listen(self._start_address, listen_args=self.listen_args) self.ip = get_address_host(self.address) logger.info(" Start Nanny at: %r", self.address) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 26b6bffa970..b77c36477c1 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1213,7 +1213,7 @@ async def start(self): c.cancel() if self.status != "running": - self.listen(self._start_address, listen_args=self.listen_args) + await self.listen(self._start_address, listen_args=self.listen_args) self.ip = get_address_host(self.listen_address) listen_ip = self.ip diff --git a/distributed/tests/test_batched.py b/distributed/tests/test_batched.py index af281aff8c3..a961157f948 100644 --- a/distributed/tests/test_batched.py +++ b/distributed/tests/test_batched.py @@ -1,4 +1,4 @@ -from contextlib import contextmanager +import asyncio from datetime import timedelta import random @@ -10,72 +10,67 @@ from distributed.core import listen, connect, CommClosedError from distributed.metrics import time from distributed.utils import All -from distributed.utils_test import gen_test, captured_logger +from distributed.utils_test import captured_logger from distributed.protocol import to_serialize class EchoServer(object): count = 0 - @gen.coroutine - def handle_comm(self, comm): + async def handle_comm(self, comm): while True: try: - msg = yield comm.read() + msg = await comm.read() self.count += 1 - yield comm.write(msg) + await comm.write(msg) except CommClosedError as e: return - def listen(self): + async def listen(self): listener = listen("", self.handle_comm) - listener.start() + await listener.start() self.address = listener.contact_address self.stop = listener.stop + async def __aenter__(self): + await self.listen() + return self -@contextmanager -def echo_server(): - server = EchoServer() - server.listen() + async def __aexit__(self, exc, typ, tb): + self.stop() - try: - yield server - finally: - server.stop() - -@gen_test() -def test_BatchedSend(): - with echo_server() as e: - comm = yield connect(e.address) +@pytest.mark.asyncio +async def test_BatchedSend(): + async with EchoServer() as e: + comm = await connect(e.address) b = BatchedSend(interval=10) assert str(len(b.buffer)) in str(b) assert str(len(b.buffer)) in repr(b) b.start(comm) - yield gen.sleep(0.020) + await asyncio.sleep(0.020) b.send("hello") b.send("hello") b.send("world") - yield gen.sleep(0.020) + await asyncio.sleep(0.020) b.send("HELLO") b.send("HELLO") - result = yield comm.read() + result = await comm.read() assert result == ("hello", "hello", "world") - result = yield comm.read() + result = await comm.read() assert result == ("HELLO", "HELLO") assert b.byte_count > 1 -@gen_test() -def test_send_before_start(): - with echo_server() as e: - comm = yield connect(e.address) +@pytest.mark.asyncio +async def test_send_before_start(): + async with EchoServer() as e: + comm = await connect(e.address) b = BatchedSend(interval=10) @@ -83,120 +78,117 @@ def test_send_before_start(): b.send("world") b.start(comm) - result = yield comm.read() + result = await comm.read() assert result == ("hello", "world") -@gen_test() -def test_send_after_stream_start(): - with echo_server() as e: - comm = yield connect(e.address) +@pytest.mark.asyncio +async def test_send_after_stream_start(): + async with EchoServer() as e: + comm = await connect(e.address) b = BatchedSend(interval=10) b.start(comm) b.send("hello") b.send("world") - result = yield comm.read() + result = await comm.read() if len(result) < 2: - result += yield comm.read() + result += await comm.read() assert result == ("hello", "world") -@gen_test() -def test_send_before_close(): - with echo_server() as e: - comm = yield connect(e.address) +@pytest.mark.asyncio +async def test_send_before_close(): + async with EchoServer() as e: + comm = await connect(e.address) b = BatchedSend(interval=10) b.start(comm) cnt = int(e.count) b.send("hello") - yield b.close() # close immediately after sending + await b.close() # close immediately after sending assert not b.buffer start = time() while e.count != cnt + 1: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 5 with pytest.raises(CommClosedError): b.send("123") -@gen_test() -def test_close_closed(): - with echo_server() as e: - comm = yield connect(e.address) +@pytest.mark.asyncio +async def test_close_closed(): + async with EchoServer() as e: + comm = await connect(e.address) b = BatchedSend(interval=10) b.start(comm) b.send(123) - yield comm.close() # external closing + await comm.close() # external closing - yield b.close() + await b.close() assert "closed" in repr(b) assert "closed" in str(b) -@gen_test() -def test_close_not_started(): +@pytest.mark.asyncio +async def test_close_not_started(): b = BatchedSend(interval=10) - yield b.close() + await b.close() -@gen_test() -def test_close_twice(): - with echo_server() as e: - comm = yield connect(e.address) +@pytest.mark.asyncio +async def test_close_twice(): + async with EchoServer() as e: + comm = await connect(e.address) b = BatchedSend(interval=10) b.start(comm) - yield b.close() - yield b.close() + await b.close() + await b.close() @pytest.mark.slow -@gen_test(timeout=50) -def test_stress(): - with echo_server() as e: - comm = yield connect(e.address) +@pytest.mark.asyncio +async def test_stress(): + async with EchoServer() as e: + comm = await connect(e.address) L = [] - @gen.coroutine - def send(): + async def send(): b = BatchedSend(interval=3) b.start(comm) for i in range(0, 10000, 2): b.send(i) b.send(i + 1) - yield gen.sleep(0.00001 * random.randint(1, 10)) + await asyncio.sleep(0.00001 * random.randint(1, 10)) - @gen.coroutine - def recv(): + async def recv(): while True: - result = yield gen.with_timeout(timedelta(seconds=1), comm.read()) + result = await gen.with_timeout(timedelta(seconds=1), comm.read()) L.extend(result) if result[-1] == 9999: break - yield All([send(), recv()]) + await All([send(), recv()]) assert L == list(range(0, 10000, 1)) - yield comm.close() + await comm.close() -@gen.coroutine -def run_traffic_jam(nsends, nbytes): +async def run_traffic_jam(nsends, nbytes): # This test eats `nsends * nbytes` bytes in RAM np = pytest.importorskip("numpy") from distributed.protocol import to_serialize data = bytes(np.random.randint(0, 255, size=(nbytes,)).astype("u1").data) - with echo_server() as e: - comm = yield connect(e.address) + async with EchoServer() as e: + comm = await connect(e.address) b = BatchedSend(interval=0.01) b.start(comm) @@ -205,7 +197,7 @@ def run_traffic_jam(nsends, nbytes): for i in range(nsends): b.send(assoc(msg, "i", i)) if np.random.random() > 0.5: - yield gen.sleep(0.001) + await asyncio.sleep(0.001) results = [] count = 0 @@ -213,7 +205,7 @@ def run_traffic_jam(nsends, nbytes): # If this times out then I think it's a backpressure issue # Somehow we're able to flood the socket so that the receiving end # loses some of our messages - L = yield gen.with_timeout(timedelta(seconds=5), comm.read()) + L = await gen.with_timeout(timedelta(seconds=5), comm.read()) count += 1 results.extend(r["i"] for r in L) @@ -222,45 +214,45 @@ def run_traffic_jam(nsends, nbytes): assert results == list(range(nsends)) - yield comm.close() # external closing - yield b.close() + await comm.close() # external closing + await b.close() -@gen_test() -def test_sending_traffic_jam(): - yield run_traffic_jam(50, 300000) +@pytest.mark.asyncio +async def test_sending_traffic_jam(): + await run_traffic_jam(50, 300000) @pytest.mark.slow -@gen_test() -def test_large_traffic_jam(): - yield run_traffic_jam(500, 1500000) +@pytest.mark.asyncio +async def test_large_traffic_jam(): + await run_traffic_jam(500, 1500000) -@gen_test() -def test_serializers(): - with echo_server() as e: - comm = yield connect(e.address) +@pytest.mark.asyncio +async def test_serializers(): + async with EchoServer() as e: + comm = await connect(e.address) b = BatchedSend(interval="10ms", serializers=["msgpack"]) b.start(comm) b.send({"x": to_serialize(123)}) b.send({"x": to_serialize("hello")}) - yield gen.sleep(0.100) + await asyncio.sleep(0.100) b.send({"x": to_serialize(lambda x: x + 1)}) with captured_logger("distributed.protocol") as sio: - yield gen.sleep(0.100) + await asyncio.sleep(0.100) value = sio.getvalue() assert "serialize" in value assert "type" in value assert "function" in value - msg = yield comm.read() + msg = await comm.read() assert list(msg) == [{"x": 123}, {"x": "hello"}] with pytest.raises(gen.TimeoutError): - msg = yield gen.with_timeout(timedelta(milliseconds=100), comm.read()) + msg = await gen.with_timeout(timedelta(milliseconds=100), comm.read()) diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index 78fbd1211d7..d3bcdaf8987 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -1,7 +1,7 @@ import asyncio -from contextlib import contextmanager import os import socket +import sys import threading import weakref @@ -23,7 +23,6 @@ from distributed.protocol import to_serialize from distributed.utils import get_ip, get_ipv6 from distributed.utils_test import ( - gen_test, gen_cluster, has_ipv6, assert_can_connect, @@ -86,7 +85,7 @@ async def f(): server = Server({"ping": pingpong}) with pytest.raises(ValueError): server.port - server.listen(8881) + await server.listen(8881) assert server.port == 8881 assert server.address == ("tcp://%s:8881" % get_ip()) @@ -114,7 +113,7 @@ async def f(): def test_server_raises_on_blocked_handlers(loop): async def f(): server = Server({"ping": pingpong}, blocked_handlers=["ping"]) - server.listen(8881) + await server.listen(8881) comm = await connect(server.address) await comm.write({"op": "ping"}) @@ -134,16 +133,21 @@ class MyServer(Server): default_port = 8756 -@gen_test() -def test_server_listen(): +@pytest.mark.skipif( + sys.version_info < (3, 7), + reason="asynccontextmanager not avaiable before Python 3.7", +) +@pytest.mark.asyncio +async def test_server_listen(): """ Test various Server.listen() arguments and their effect. """ + from contextlib import asynccontextmanager - @contextmanager - def listen_on(cls, *args, **kwargs): + @asynccontextmanager + async def listen_on(cls, *args, **kwargs): server = cls({}) - server.listen(*args, **kwargs) + await server.listen(*args, **kwargs) try: yield server finally: @@ -151,107 +155,107 @@ def listen_on(cls, *args, **kwargs): # Note server.address is the concrete, contactable address - with listen_on(Server, 7800) as server: + async with listen_on(Server, 7800) as server: assert server.port == 7800 assert server.address == "tcp://%s:%d" % (EXTERNAL_IP4, server.port) - yield assert_can_connect(server.address) - yield assert_can_connect_from_everywhere_4_6(server.port) + await assert_can_connect(server.address) + await assert_can_connect_from_everywhere_4_6(server.port) - with listen_on(Server) as server: + async with listen_on(Server) as server: assert server.port > 0 assert server.address == "tcp://%s:%d" % (EXTERNAL_IP4, server.port) - yield assert_can_connect(server.address) - yield assert_can_connect_from_everywhere_4_6(server.port) + await assert_can_connect(server.address) + await assert_can_connect_from_everywhere_4_6(server.port) - with listen_on(MyServer) as server: + async with listen_on(MyServer) as server: assert server.port == MyServer.default_port assert server.address == "tcp://%s:%d" % (EXTERNAL_IP4, server.port) - yield assert_can_connect(server.address) - yield assert_can_connect_from_everywhere_4_6(server.port) + await assert_can_connect(server.address) + await assert_can_connect_from_everywhere_4_6(server.port) - with listen_on(Server, ("", 7801)) as server: + async with listen_on(Server, ("", 7801)) as server: assert server.port == 7801 assert server.address == "tcp://%s:%d" % (EXTERNAL_IP4, server.port) - yield assert_can_connect(server.address) - yield assert_can_connect_from_everywhere_4_6(server.port) + await assert_can_connect(server.address) + await assert_can_connect_from_everywhere_4_6(server.port) - with listen_on(Server, "tcp://:7802") as server: + async with listen_on(Server, "tcp://:7802") as server: assert server.port == 7802 assert server.address == "tcp://%s:%d" % (EXTERNAL_IP4, server.port) - yield assert_can_connect(server.address) - yield assert_can_connect_from_everywhere_4_6(server.port) + await assert_can_connect(server.address) + await assert_can_connect_from_everywhere_4_6(server.port) # Only IPv4 - with listen_on(Server, ("0.0.0.0", 7810)) as server: + async with listen_on(Server, ("0.0.0.0", 7810)) as server: assert server.port == 7810 assert server.address == "tcp://%s:%d" % (EXTERNAL_IP4, server.port) - yield assert_can_connect(server.address) - yield assert_can_connect_from_everywhere_4(server.port) + await assert_can_connect(server.address) + await assert_can_connect_from_everywhere_4(server.port) - with listen_on(Server, ("127.0.0.1", 7811)) as server: + async with listen_on(Server, ("127.0.0.1", 7811)) as server: assert server.port == 7811 assert server.address == "tcp://127.0.0.1:%d" % server.port - yield assert_can_connect(server.address) - yield assert_can_connect_locally_4(server.port) + await assert_can_connect(server.address) + await assert_can_connect_locally_4(server.port) - with listen_on(Server, "tcp://127.0.0.1:7812") as server: + async with listen_on(Server, "tcp://127.0.0.1:7812") as server: assert server.port == 7812 assert server.address == "tcp://127.0.0.1:%d" % server.port - yield assert_can_connect(server.address) - yield assert_can_connect_locally_4(server.port) + await assert_can_connect(server.address) + await assert_can_connect_locally_4(server.port) # Only IPv6 if has_ipv6(): - with listen_on(Server, ("::", 7813)) as server: + async with listen_on(Server, ("::", 7813)) as server: assert server.port == 7813 assert server.address == "tcp://[%s]:%d" % (EXTERNAL_IP6, server.port) - yield assert_can_connect(server.address) - yield assert_can_connect_from_everywhere_6(server.port) + await assert_can_connect(server.address) + await assert_can_connect_from_everywhere_6(server.port) - with listen_on(Server, ("::1", 7814)) as server: + async with listen_on(Server, ("::1", 7814)) as server: assert server.port == 7814 assert server.address == "tcp://[::1]:%d" % server.port - yield assert_can_connect(server.address) - yield assert_can_connect_locally_6(server.port) + await assert_can_connect(server.address) + await assert_can_connect_locally_6(server.port) - with listen_on(Server, "tcp://[::1]:7815") as server: + async with listen_on(Server, "tcp://[::1]:7815") as server: assert server.port == 7815 assert server.address == "tcp://[::1]:%d" % server.port - yield assert_can_connect(server.address) - yield assert_can_connect_locally_6(server.port) + await assert_can_connect(server.address) + await assert_can_connect_locally_6(server.port) # TLS sec = tls_security() - with listen_on( + async with listen_on( Server, "tls://", listen_args=sec.get_listen_args("scheduler") ) as server: assert server.address.startswith("tls://") - yield assert_can_connect( + await assert_can_connect( server.address, connection_args=sec.get_connection_args("client") ) # InProc - with listen_on(Server, "inproc://") as server: + async with listen_on(Server, "inproc://") as server: inproc_addr1 = server.address assert inproc_addr1.startswith("inproc://%s/%d/" % (get_ip(), os.getpid())) - yield assert_can_connect(inproc_addr1) + await assert_can_connect(inproc_addr1) - with listen_on(Server, "inproc://") as server2: + async with listen_on(Server, "inproc://") as server2: inproc_addr2 = server2.address assert inproc_addr2.startswith("inproc://%s/%d/" % (get_ip(), os.getpid())) - yield assert_can_connect(inproc_addr2) + await assert_can_connect(inproc_addr2) - yield assert_can_connect(inproc_addr1) - yield assert_cannot_connect(inproc_addr2) + await assert_can_connect(inproc_addr1) + await assert_cannot_connect(inproc_addr2) async def check_rpc(listen_addr, rpc_addr=None, listen_args=None, connection_args=None): server = Server({"ping": pingpong}) - server.listen(listen_addr, listen_args=listen_args) + await server.listen(listen_addr, listen_args=listen_args) if rpc_addr is None: rpc_addr = server.address @@ -269,24 +273,25 @@ async def check_rpc(listen_addr, rpc_addr=None, listen_args=None, connection_arg assert remote.status == "closed" server.stop() + await asyncio.sleep(0) -@gen_test() -def test_rpc_default(): - yield check_rpc(8883, "127.0.0.1:8883") - yield check_rpc(8883) +@pytest.mark.asyncio +async def test_rpc_default(): + await check_rpc(8883, "127.0.0.1:8883") + await check_rpc(8883) -@gen_test() -def test_rpc_tcp(): - yield check_rpc("tcp://:8883", "tcp://127.0.0.1:8883") - yield check_rpc("tcp://") +@pytest.mark.asyncio +async def test_rpc_tcp(): + await check_rpc("tcp://:8883", "tcp://127.0.0.1:8883") + await check_rpc("tcp://") -@gen_test() -def test_rpc_tls(): +@pytest.mark.asyncio +async def test_rpc_tls(): sec = tls_security() - yield check_rpc( + await check_rpc( "tcp://", None, sec.get_listen_args("scheduler"), @@ -294,9 +299,9 @@ def test_rpc_tls(): ) -@gen_test() -def test_rpc_inproc(): - yield check_rpc("inproc://", None) +@pytest.mark.asyncio +async def test_rpc_inproc(): + await check_rpc("inproc://", None) @pytest.mark.asyncio @@ -313,7 +318,7 @@ async def check_rpc_message_lifetime(*listen_args): # Issue #956: rpc arguments and result shouldn't be kept alive longer # than necessary server = Server({"echo": echo_serialize}) - server.listen(*listen_args) + await server.listen(*listen_args) # Sanity check obj = CountedObject() @@ -343,19 +348,19 @@ async def check_rpc_message_lifetime(*listen_args): server.stop() -@gen_test() -def test_rpc_message_lifetime_default(): - yield check_rpc_message_lifetime() +@pytest.mark.asyncio +async def test_rpc_message_lifetime_default(): + await check_rpc_message_lifetime() -@gen_test() -def test_rpc_message_lifetime_tcp(): - yield check_rpc_message_lifetime("tcp://") +@pytest.mark.asyncio +async def test_rpc_message_lifetime_tcp(): + await check_rpc_message_lifetime("tcp://") -@gen_test() -def test_rpc_message_lifetime_inproc(): - yield check_rpc_message_lifetime("inproc://") +@pytest.mark.asyncio +async def test_rpc_message_lifetime_inproc(): + await check_rpc_message_lifetime("inproc://") async def check_rpc_with_many_connections(listen_arg): @@ -364,7 +369,7 @@ async def g(): await remote.ping() server = Server({"ping": pingpong}) - server.listen(listen_arg) + await server.listen(listen_arg) async with rpc(server.address) as remote: for i in range(10): @@ -376,20 +381,20 @@ async def g(): assert all(comm.closed() for comm in remote.comms) -@gen_test() -def test_rpc_with_many_connections_tcp(): - yield check_rpc_with_many_connections("tcp://") +@pytest.mark.asyncio +async def test_rpc_with_many_connections_tcp(): + await check_rpc_with_many_connections("tcp://") -@gen_test() -def test_rpc_with_many_connections_inproc(): - yield check_rpc_with_many_connections("inproc://") +@pytest.mark.asyncio +async def test_rpc_with_many_connections_inproc(): + await check_rpc_with_many_connections("inproc://") async def check_large_packets(listen_arg): """ tornado has a 100MB cap by default """ server = Server({"echo": echo}) - server.listen(listen_arg) + await server.listen(listen_arg) data = b"0" * int(200e6) # slightly more than 100MB async with rpc(server.address) as conn: @@ -404,19 +409,19 @@ async def check_large_packets(listen_arg): @pytest.mark.slow -@gen_test() -def test_large_packets_tcp(): - yield check_large_packets("tcp://") +@pytest.mark.asyncio +async def test_large_packets_tcp(): + await check_large_packets("tcp://") -@gen_test() -def test_large_packets_inproc(): - yield check_large_packets("inproc://") +@pytest.mark.asyncio +async def test_large_packets_inproc(): + await check_large_packets("inproc://") async def check_identity(listen_arg): server = Server({}) - server.listen(listen_arg) + await server.listen(listen_arg) async with rpc(server.address) as remote: a = await remote.identity() @@ -427,21 +432,22 @@ async def check_identity(listen_arg): server.stop() -@gen_test() -def test_identity_tcp(): - yield check_identity("tcp://") +@pytest.mark.asyncio +async def test_identity_tcp(): + await check_identity("tcp://") -@gen_test() -def test_identity_inproc(): - yield check_identity("inproc://") +@pytest.mark.asyncio +async def test_identity_inproc(): + await check_identity("inproc://") -def test_ports(loop): +@pytest.mark.asyncio +async def test_ports(loop): for port in range(9877, 9887): server = Server({}, io_loop=loop) try: - server.listen(port) + await server.listen(port) except OSError: # port already taken? pass else: @@ -453,13 +459,13 @@ def test_ports(loop): with pytest.raises((OSError, socket.error)): server2 = Server({}, io_loop=loop) - server2.listen(port) + await server2.listen(port) finally: server.stop() try: server3 = Server({}, io_loop=loop) - server3.listen(0) + await server3.listen(0) assert isinstance(server3.port, int) assert server3.port > 1024 finally: @@ -470,35 +476,35 @@ def stream_div(stream=None, x=None, y=None): return x / y -@gen_test() -def test_errors(): +@pytest.mark.asyncio +async def test_errors(): server = Server({"div": stream_div}) - server.listen(0) + await server.listen(0) with rpc(("127.0.0.1", server.port)) as r: with pytest.raises(ZeroDivisionError): - yield r.div(x=1, y=0) + await r.div(x=1, y=0) -@gen_test() -def test_connect_raises(): +@pytest.mark.asyncio +async def test_connect_raises(): with pytest.raises(IOError): - yield connect("127.0.0.1:58259", timeout=0.01) + await connect("127.0.0.1:58259", timeout=0.01) -@gen_test() -def test_send_recv_args(): +@pytest.mark.asyncio +async def test_send_recv_args(): server = Server({"echo": echo}) - server.listen(0) + await server.listen(0) - comm = yield connect(server.address) - result = yield send_recv(comm, op="echo", x=b"1") + comm = await connect(server.address) + result = await send_recv(comm, op="echo", x=b"1") assert result == b"1" assert not comm.closed() - result = yield send_recv(comm, op="echo", x=b"2", reply=False) + result = await send_recv(comm, op="echo", x=b"2", reply=False) assert result is None assert not comm.closed() - result = yield send_recv(comm, op="echo", x=b"3", close=True) + result = await send_recv(comm, op="echo", x=b"3", close=True) assert result == b"3" assert comm.closed() @@ -510,49 +516,57 @@ def test_coerce_to_address(): assert coerce_to_address(arg) == "tcp://127.0.0.1:8786" -@gen_test() -def test_connection_pool(): +@pytest.mark.asyncio +async def test_connection_pool(): async def ping(comm, delay=0.1): await asyncio.sleep(delay) return "pong" servers = [Server({"ping": ping}) for i in range(10)] for server in servers: - server.listen(0) + await server.listen(0) rpc = ConnectionPool(limit=5) # Reuse connections - yield [rpc(ip="127.0.0.1", port=s.port).ping() for s in servers[:5]] - yield [rpc(s.address).ping() for s in servers[:5]] - yield [rpc("127.0.0.1:%d" % s.port).ping() for s in servers[:5]] - yield [rpc(ip="127.0.0.1", port=s.port).ping() for s in servers[:5]] + await asyncio.gather( + *[rpc(ip="127.0.0.1", port=s.port).ping() for s in servers[:5]] + ) + await asyncio.gather(*[rpc(s.address).ping() for s in servers[:5]]) + await asyncio.gather(*[rpc("127.0.0.1:%d" % s.port).ping() for s in servers[:5]]) + await asyncio.gather( + *[rpc(ip="127.0.0.1", port=s.port).ping() for s in servers[:5]] + ) assert sum(map(len, rpc.available.values())) == 5 assert sum(map(len, rpc.occupied.values())) == 0 assert rpc.active == 0 assert rpc.open == 5 # Clear out connections to make room for more - yield [rpc(ip="127.0.0.1", port=s.port).ping() for s in servers[5:]] + await asyncio.gather( + *[rpc(ip="127.0.0.1", port=s.port).ping() for s in servers[5:]] + ) assert rpc.active == 0 assert rpc.open == 5 s = servers[0] - yield [rpc(ip="127.0.0.1", port=s.port).ping(delay=0.1) for i in range(3)] + await asyncio.gather( + *[rpc(ip="127.0.0.1", port=s.port).ping(delay=0.1) for i in range(3)] + ) assert len(rpc.available["tcp://127.0.0.1:%d" % s.port]) == 3 # Explicitly clear out connections rpc.collect() start = time() while any(rpc.available.values()): - yield asyncio.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 2 rpc.close() -@gen_test() -def test_connection_pool_respects_limit(): +@pytest.mark.asyncio +async def test_connection_pool_respects_limit(): limit = 5 @@ -567,15 +581,15 @@ async def do_ping(pool, port): servers = [Server({"ping": ping}) for i in range(10)] for server in servers: - server.listen(0) + await server.listen(0) pool = ConnectionPool(limit=limit) - yield [do_ping(pool, s.port) for s in servers] + await asyncio.gather(*[do_ping(pool, s.port) for s in servers]) -@gen_test() -def test_connection_pool_tls(): +@pytest.mark.asyncio +async def test_connection_pool_tls(): """ Make sure connection args are supported. """ @@ -589,33 +603,33 @@ async def ping(comm, delay=0.01): servers = [Server({"ping": ping}) for i in range(10)] for server in servers: - server.listen("tls://", listen_args=listen_args) + await server.listen("tls://", listen_args=listen_args) rpc = ConnectionPool(limit=5, connection_args=connection_args) - yield [rpc(s.address).ping() for s in servers[:5]] - yield [rpc(s.address).ping() for s in servers[::2]] - yield [rpc(s.address).ping() for s in servers] + await asyncio.gather(*[rpc(s.address).ping() for s in servers[:5]]) + await asyncio.gather(*[rpc(s.address).ping() for s in servers[::2]]) + await asyncio.gather(*[rpc(s.address).ping() for s in servers]) assert rpc.active == 0 rpc.close() -@gen_test() -def test_connection_pool_remove(): +@pytest.mark.asyncio +async def test_connection_pool_remove(): async def ping(comm, delay=0.01): await asyncio.sleep(delay) return "pong" servers = [Server({"ping": ping}) for i in range(5)] for server in servers: - server.listen(0) + await server.listen(0) rpc = ConnectionPool(limit=10) serv = servers.pop() - yield [rpc(s.address).ping() for s in servers] - yield [rpc(serv.address).ping() for i in range(3)] - yield rpc.connect(serv.address) + await asyncio.gather(*[rpc(s.address).ping() for s in servers]) + await asyncio.gather(*[rpc(serv.address).ping() for i in range(3)]) + await rpc.connect(serv.address) assert sum(map(len, rpc.available.values())) == 6 assert sum(map(len, rpc.occupied.values())) == 1 assert rpc.active == 1 @@ -633,39 +647,39 @@ async def ping(comm, delay=0.01): # this pattern of calls (esp. `reuse` after `remove`) # can happen in case of worker failures: - comm = yield rpc.connect(serv.address) + comm = await rpc.connect(serv.address) rpc.remove(serv.address) rpc.reuse(serv.address, comm) rpc.close() -@gen_test() -def test_counters(): +@pytest.mark.asyncio +async def test_counters(): server = Server({"div": stream_div}) - server.listen("tcp://") + await server.listen("tcp://") - with rpc(server.address) as r: + async with rpc(server.address) as r: for i in range(2): - yield r.identity() + await r.identity() with pytest.raises(ZeroDivisionError): - yield r.div(x=1, y=0) + await r.div(x=1, y=0) c = server.counters assert c["op"].components[0] == {"identity": 2, "div": 1} @gen_cluster() -def test_ticks(s, a, b): +async def test_ticks(s, a, b): pytest.importorskip("crick") - yield asyncio.sleep(0.1) + await asyncio.sleep(0.1) c = s.digests["tick-duration"] assert c.size() assert 0.01 < c.components[0].quantile(0.5) < 0.5 @gen_cluster() -def test_tick_logging(s, a, b): +async def test_tick_logging(s, a, b): pytest.importorskip("crick") from distributed import core @@ -673,7 +687,7 @@ def test_tick_logging(s, a, b): core.tick_maximum_delay = 0.001 try: with captured_logger("distributed.core") as sio: - yield asyncio.sleep(0.1) + await asyncio.sleep(0.1) text = sio.getvalue() assert "unresponsive" in text @@ -689,7 +703,7 @@ def test_compression(compression, serialize, loop): async def f(): server = Server({"echo": serialize}) - server.listen("tcp://") + await server.listen("tcp://") with rpc(server.address) as r: data = b"1" * 1000000 @@ -704,7 +718,7 @@ async def f(): def test_rpc_serialization(loop): async def f(): server = Server({"echo": echo_serialize}) - server.listen("tcp://") + await server.listen("tcp://") async with rpc(server.address, serializers=["msgpack"]) as r: with pytest.raises(TypeError): @@ -724,14 +738,14 @@ def test_thread_id(s, a, b): assert s.thread_id == a.thread_id == b.thread_id == threading.get_ident() -@gen_test() -def test_deserialize_error(): +@pytest.mark.asyncio +async def test_deserialize_error(): server = Server({"throws": throws}) - server.listen(0) + await server.listen(0) - comm = yield connect(server.address, deserialize=False) + comm = await connect(server.address, deserialize=False) with pytest.raises(Exception) as info: - yield send_recv(comm, op="throws") + await send_recv(comm, op="throws") assert type(info.value) == Exception for c in str(info.value): diff --git a/distributed/tests/test_security.py b/distributed/tests/test_security.py index 167abc762ae..8e9007f20f0 100644 --- a/distributed/tests/test_security.py +++ b/distributed/tests/test_security.py @@ -7,11 +7,10 @@ ssl = None import pytest -from tornado import gen from distributed.comm import connect, listen from distributed.security import Security -from distributed.utils_test import get_cert, gen_test +from distributed.utils_test import get_cert import dask @@ -256,18 +255,17 @@ def many_ciphers(ctx): assert len(tls_13_ciphers) == 3 -@gen_test() -def test_tls_listen_connect(): +@pytest.mark.asyncio +async def test_tls_listen_connect(): """ Functional test for TLS connection args. """ - @gen.coroutine - def handle_comm(comm): + async def handle_comm(comm): peer_addr = comm.peer_address assert peer_addr.startswith("tls://") - yield comm.write("hello") - yield comm.close() + await comm.write("hello") + await comm.close() c = { "distributed.comm.tls.ca-file": ca_file, @@ -282,25 +280,25 @@ def handle_comm(comm): with dask.config.set(c): forced_cipher_sec = Security() - with listen( + async with listen( "tls://", handle_comm, connection_args=sec.get_listen_args("scheduler") ) as listener: - comm = yield connect( + comm = await connect( listener.contact_address, connection_args=sec.get_connection_args("worker") ) - msg = yield comm.read() + msg = await comm.read() assert msg == "hello" comm.abort() # No SSL context for client with pytest.raises(TypeError): - yield connect( + await connect( listener.contact_address, connection_args=sec.get_connection_args("client"), ) # Check forced cipher - comm = yield connect( + comm = await connect( listener.contact_address, connection_args=forced_cipher_sec.get_connection_args("worker"), ) @@ -309,14 +307,13 @@ def handle_comm(comm): comm.abort() -@gen_test() -def test_require_encryption(): +@pytest.mark.asyncio +async def test_require_encryption(): """ Functional test for "require_encryption" setting. """ - @gen.coroutine - def handle_comm(comm): + async def handle_comm(comm): comm.abort() c = { @@ -333,19 +330,19 @@ def handle_comm(comm): sec2 = Security() for listen_addr in ["inproc://", "tls://"]: - with listen( + async with listen( listen_addr, handle_comm, connection_args=sec.get_listen_args("scheduler") ) as listener: - comm = yield connect( + comm = await connect( listener.contact_address, connection_args=sec2.get_connection_args("worker"), ) comm.abort() - with listen( + async with listen( listen_addr, handle_comm, connection_args=sec2.get_listen_args("scheduler") ) as listener: - comm = yield connect( + comm = await connect( listener.contact_address, connection_args=sec2.get_connection_args("worker"), ) @@ -358,17 +355,17 @@ def check_encryption_error(): assert "encryption required" in str(excinfo.value) for listen_addr in ["tcp://"]: - with listen( + async with listen( listen_addr, handle_comm, connection_args=sec.get_listen_args("scheduler") ) as listener: - comm = yield connect( + comm = await connect( listener.contact_address, connection_args=sec.get_connection_args("worker"), ) comm.abort() with pytest.raises(RuntimeError): - yield connect( + await connect( listener.contact_address, connection_args=sec2.get_connection_args("worker"), ) @@ -396,25 +393,24 @@ def test_temporary_credentials(): assert val not in sec_repr -@gen_test() -def test_tls_temporary_credentials_functional(): +@pytest.mark.asyncio +async def test_tls_temporary_credentials_functional(): pytest.importorskip("cryptography") - @gen.coroutine - def handle_comm(comm): + async def handle_comm(comm): peer_addr = comm.peer_address assert peer_addr.startswith("tls://") - yield comm.write("hello") - yield comm.close() + await comm.write("hello") + await comm.close() sec = Security.temporary() - with listen( + async with listen( "tls://", handle_comm, connection_args=sec.get_listen_args("scheduler") ) as listener: - comm = yield connect( + comm = await connect( listener.contact_address, connection_args=sec.get_connection_args("worker") ) - msg = yield comm.read() + msg = await comm.read() assert msg == "hello" comm.abort() diff --git a/distributed/worker.py b/distributed/worker.py index 1eb8252534d..b863e895f7d 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -986,7 +986,7 @@ async def start(self): enable_gc_diagnosis() thread_state.on_event_loop_thread = True - self.listen(self._start_address, listen_args=self.listen_args) + await self.listen(self._start_address, listen_args=self.listen_args) self.ip = get_address_host(self.address) if self.name is None: From 59ad42ed24fd1a0a010eb6c4620cf9c48c209f44 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Wed, 27 Nov 2019 20:04:00 -0600 Subject: [PATCH 0578/1550] Update function serialization caches with custom LRU class (#3260) --- distributed/tests/test_utils.py | 19 +++++++++++++++++++ distributed/utils.py | 22 +++++++++++++++++++++- distributed/worker.py | 33 ++++++++++++++++++--------------- 3 files changed, 58 insertions(+), 16 deletions(-) diff --git a/distributed/tests/test_utils.py b/distributed/tests/test_utils.py index ff733a1ad8c..ff2e42313ac 100644 --- a/distributed/tests/test_utils.py +++ b/distributed/tests/test_utils.py @@ -42,6 +42,7 @@ parse_timedelta, warn_on_duration, format_dashboard_link, + LRU, ) from distributed.utils_test import loop, loop_in_thread # noqa: F401 from distributed.utils_test import div, has_ipv6, inc, throws, gen_test, captured_logger @@ -598,3 +599,21 @@ def test_format_dashboard_link(): assert "hello" not in format_dashboard_link("host", 1234) finally: del os.environ["host"] + + +def test_lru(): + + l = LRU(maxsize=3) + l["a"] = 1 + l["b"] = 2 + l["c"] = 3 + assert list(l.keys()) == ["a", "b", "c"] + + # Use "a" and ensure it becomes the most recently used item + l["a"] + assert list(l.keys()) == ["b", "c", "a"] + + # Ensure maxsize is respected + l["d"] = 4 + assert len(l) == 3 + assert list(l.keys()) == ["c", "a", "d"] diff --git a/distributed/utils.py b/distributed/utils.py index 251e1110be8..978be4eae8a 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -1,6 +1,6 @@ import asyncio import atexit -from collections import deque +from collections import deque, OrderedDict, UserDict from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager from datetime import timedelta @@ -1399,3 +1399,23 @@ def deserialize_for_cli(data): The de-serialized data """ return json.loads(base64.urlsafe_b64decode(data.encode()).decode()) + + +class LRU(UserDict): + """ Limited size mapping, evicting the least recently looked-up key when full + """ + + def __init__(self, maxsize): + super().__init__() + self.data = OrderedDict() + self.maxsize = maxsize + + def __getitem__(self, key): + value = super().__getitem__(key) + self.data.move_to_end(key) + return value + + def __setitem__(self, key, value): + if len(self) >= self.maxsize: + self.data.popitem(last=False) + super().__setitem__(key, value) diff --git a/distributed/worker.py b/distributed/worker.py index b863e895f7d..90059002e9a 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -61,6 +61,7 @@ parse_timedelta, iscoroutinefunction, warn_on_duration, + LRU, ) from .utils_comm import pack_data, gather_from_workers from .utils_perf import ThrottledGC, enable_gc_diagnosis, disable_gc_diagnosis @@ -3176,18 +3177,26 @@ async def get_data_from_worker( job_counter = [0] -import functools +cache_loads = LRU(maxsize=100) -@functools.lru_cache(100) -def cached_function_deserialization(func): - return pickle.loads(func) + +def loads_function(bytes_object): + """ Load a function from bytes, cache bytes """ + if len(bytes_object) < 100000: + try: + result = cache_loads[bytes_object] + except KeyError: + result = pickle.loads(bytes_object) + cache_loads[bytes_object] = result + return result + return pickle.loads(bytes_object) def _deserialize(function=None, args=None, kwargs=None, task=no_value): """ Deserialize task inputs and regularize to func, args, kwargs """ if function is not None: - function = cached_function_deserialization(function) + function = loads_function(function) if args: args = pickle.loads(args) if kwargs: @@ -3219,24 +3228,18 @@ def execute_task(task): return task -try: - # a 10 MB cache of deserialized functions and their bytes - from zict import LRU - - cache = LRU(10000000, dict(), weight=lambda k, v: len(v)) -except ImportError: - cache = dict() +cache_dumps = LRU(maxsize=100) def dumps_function(func): """ Dump a function to bytes, cache functions """ try: - result = cache[func] + result = cache_dumps[func] except KeyError: result = pickle.dumps(func) if len(result) < 100000: - cache[func] = result - except TypeError: + cache_dumps[func] = result + except TypeError: # Unhashable function result = pickle.dumps(func) return result From e3731a6adbf70e92852cd9b71075d5607185cd9b Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 28 Nov 2019 13:42:30 -0800 Subject: [PATCH 0579/1550] Add performance_report context manager for static report generation (#3282) This generates a static HTML file with many of the same plots as the dashboard. Example ------- ```python from dask.distributed import Client client = Client() import dask.array as da x = da.random.random((30000, 30000), chunks=(1000, 1000)) from dask.distributed import Client, performance_report with performance_report(): x = x.persist() (x + x.T).sum().compute() ``` --- distributed/__init__.py | 1 + distributed/client.py | 41 ++++++++++++++++-- distributed/scheduler.py | 73 ++++++++++++++++++++++++++++++++ distributed/tests/test_client.py | 17 ++++++++ 4 files changed, 129 insertions(+), 3 deletions(-) diff --git a/distributed/__init__.py b/distributed/__init__.py index 1eadee32307..06136dd72a2 100644 --- a/distributed/__init__.py +++ b/distributed/__init__.py @@ -16,6 +16,7 @@ Future, futures_of, get_task_stream, + performance_report, ) from .lock import Lock from .nanny import Nanny diff --git a/distributed/client.py b/distributed/client.py index 3b2eb6e2863..5ca517ae219 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -3291,9 +3291,6 @@ def profile( >>> client.profile() # call on collections >>> client.profile(filename='dask-profile.html') # save to html file """ - if isinstance(workers, (str, Number)): - workers = [workers] - return self.sync( self._profile, key=key, @@ -4560,6 +4557,44 @@ async def __aexit__(self, typ, value, traceback): self.data.extend(L) +class performance_report: + """ Gather performance report + + This creates a static HTML file that includes many of the same plots of the + dashboard for later viewing. + + The resulting file uses JavaScript, and so must be viewed with a web + browser. Locally we recommend using ``python -m http.server`` or hosting + the file live online. + + Examples + -------- + >>> with performance_report(filename="myfile.html"): + ... x.compute() + + $ python -m http.server + $ open myfile.html + """ + + def __init__(self, filename="dask-report.html"): + self.filename = filename + + async def __aenter__(self): + self.start = time() + await get_client().get_task_stream(start=0, stop=0) # ensure plugin + + async def __aexit__(self, typ, value, traceback): + data = await get_client().scheduler.performance_report(start=self.start) + with open(self.filename, "w") as f: + f.write(data) + + def __enter__(self): + get_client().sync(self.__aenter__) + + def __exit__(self, typ, value, traceback): + get_client().sync(self.__aexit__, type, value, traceback) + + @contextmanager def temp_default_client(c): """ Set the default client for the duration of the context diff --git a/distributed/scheduler.py b/distributed/scheduler.py index b77c36477c1..bd627f608aa 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -57,6 +57,7 @@ parse_bytes, PeriodicCallback, shutting_down, + tmpfile, ) from .utils_comm import scatter_to_workers, gather_from_workers from .utils_perf import enable_gc_diagnosis, disable_gc_diagnosis @@ -1073,6 +1074,7 @@ def __init__( "processing": self.get_processing, "call_stack": self.get_call_stack, "profile": self.get_profile, + "performance_report": self.performance_report, "logs": self.get_logs, "worker_logs": self.get_worker_logs, "nbytes": self.get_nbytes, @@ -4699,6 +4701,77 @@ async def get_profile_metadata( return {"counts": counts, "keys": keys} + async def performance_report(self, comm=None, start=None): + # Profiles + compute, scheduler, workers = await asyncio.gather( + *[ + self.get_profile(start=start), + self.get_profile(scheduler=True, start=start), + self.get_profile(server=True, start=start), + ] + ) + from . import profile + + def profile_to_figure(state): + data = profile.plot_data(state) + figure, source = profile.plot_figure(data, sizing_mode="stretch_both") + return figure + + compute, scheduler, workers = map( + profile_to_figure, (compute, scheduler, workers) + ) + + # Task stream + task_stream = self.get_task_stream(start=start) + from .diagnostics.task_stream import rectangles + from .dashboard.components.scheduler import task_stream_figure + + rects = rectangles(task_stream) + source, task_stream = task_stream_figure(sizing_mode="stretch_both") + source.data.update(rects) + + from distributed.dashboard.components.scheduler import ( + BandwidthWorkers, + BandwidthTypes, + ) + + bandwidth_workers = BandwidthWorkers(self, sizing_mode="stretch_both") + bandwidth_workers.update() + bandwidth_types = BandwidthTypes(self, sizing_mode="stretch_both") + bandwidth_types.update() + + from bokeh.models import Panel, Tabs + + compute = Panel(child=compute, title="Worker Profile (compute)") + workers = Panel(child=workers, title="Worker Profile (administrative)") + scheduler = Panel(child=scheduler, title="Scheduler Profile (administrative)") + task_stream = Panel(child=task_stream, title="Task Stream") + bandwidth_workers = Panel( + child=bandwidth_workers.fig, title="Bandwidth (Workers)" + ) + bandwidth_types = Panel(child=bandwidth_types.fig, title="Bandwidth (Types)") + + tabs = Tabs( + tabs=[ + task_stream, + compute, + workers, + scheduler, + bandwidth_workers, + bandwidth_types, + ] + ) + + from bokeh.plotting import save + + with tmpfile(extension=".html") as fn: + save(tabs, filename=fn) + + with open(fn) as f: + data = f.read() + + return data + async def get_worker_logs(self, comm=None, n=None, workers=None, nanny=False): results = await self.broadcast( msg={"op": "get_logs", "n": n}, workers=workers, nanny=nanny diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index e97cb023a48..b8d5be0b449 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -36,6 +36,7 @@ get_worker, Executor, profile, + performance_report, TimeoutError, ) from distributed.comm import CommClosedError @@ -5703,5 +5704,21 @@ async def test_profile_server(c, s, a, b): assert "slowdec" in str(p) +@gen_cluster(client=True) +async def test_performance_report(c, s, a, b): + da = pytest.importorskip("dask.array") + x = da.random.random((1000, 1000), chunks=(100, 100)) + + with tmpfile(extension="html") as fn: + async with performance_report(filename=fn): + await c.compute((x + x.T).sum()) + + with open(fn) as f: + data = f.read() + + assert "bokeh" in data + assert "random" in data + + if sys.version_info >= (3, 5): from distributed.tests.py3_test_client import * # noqa F401 From 894910034230649b23547bddf10bc2d550552bae Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 28 Nov 2019 19:04:43 -0800 Subject: [PATCH 0580/1550] xfail test_workspace_concurrency for Python 3.6 (#3283) This only seems to fail for this version. We actually intended to do this before, but the condition was imprecise. --- distributed/tests/test_diskutils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/tests/test_diskutils.py b/distributed/tests/test_diskutils.py index 0057f96fb36..86b472e184a 100644 --- a/distributed/tests/test_diskutils.py +++ b/distributed/tests/test_diskutils.py @@ -275,7 +275,7 @@ def _test_workspace_concurrency(tmpdir, timeout, max_procs): def test_workspace_concurrency(tmpdir): if WINDOWS: raise pytest.xfail.Exception("TODO: unknown failure on windows") - if sys.version_info <= (3, 6): + if sys.version_info < (3, 7): raise pytest.xfail.Exception("TODO: unknown failure on Python 3.6") _test_workspace_concurrency(tmpdir, 2.0, 6) From 892c371d7ab06b4bd5f39afb0770659d87918b6e Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 29 Nov 2019 16:53:44 -0800 Subject: [PATCH 0581/1550] Move Python 3 syntax tests into main tests (#3281) Previously these were separated to support Python 2. Now we can include them in normal tests --- distributed/deploy/tests/py3_test_deploy.py | 13 -- distributed/deploy/tests/test_local.py | 13 +- distributed/tests/py3_test_client.py | 210 -------------------- distributed/tests/py3_test_pubsub.py | 38 ---- distributed/tests/py3_test_utils_tst.py | 17 -- distributed/tests/test_client.py | 190 +++++++++++++++++- distributed/tests/test_locks.py | 12 +- distributed/tests/test_pubsub.py | 40 +++- distributed/tests/test_utils_test.py | 14 +- 9 files changed, 251 insertions(+), 296 deletions(-) delete mode 100644 distributed/deploy/tests/py3_test_deploy.py delete mode 100644 distributed/tests/py3_test_client.py delete mode 100644 distributed/tests/py3_test_pubsub.py delete mode 100644 distributed/tests/py3_test_utils_tst.py diff --git a/distributed/deploy/tests/py3_test_deploy.py b/distributed/deploy/tests/py3_test_deploy.py deleted file mode 100644 index 7a66ecf942c..00000000000 --- a/distributed/deploy/tests/py3_test_deploy.py +++ /dev/null @@ -1,13 +0,0 @@ -from distributed import LocalCluster -from distributed.utils_test import loop # noqa: F401 - -import pytest - - -@pytest.mark.asyncio -async def test_async_with(): - async with LocalCluster(processes=False, asynchronous=True) as cluster: - w = cluster.workers - assert w - - assert not w diff --git a/distributed/deploy/tests/test_local.py b/distributed/deploy/tests/test_local.py index 9b39d8f81f5..0f8eb6d8901 100644 --- a/distributed/deploy/tests/test_local.py +++ b/distributed/deploy/tests/test_local.py @@ -898,10 +898,6 @@ class MyNanny(Nanny): assert all(isinstance(w, MyNanny) for w in cluster.workers.values()) -if sys.version_info >= (3, 5): - from distributed.deploy.tests.py3_test_deploy import * # noqa F401 - - def test_starts_up_sync(loop): cluster = LocalCluster( n_workers=2, @@ -1019,3 +1015,12 @@ async def test_no_danglng_asyncio_tasks(cleanup): tasks = asyncio.all_tasks() assert tasks == start + + +@pytest.mark.asyncio +async def test_async_with(): + async with LocalCluster(processes=False, asynchronous=True) as cluster: + w = cluster.workers + assert w + + assert not w diff --git a/distributed/tests/py3_test_client.py b/distributed/tests/py3_test_client.py deleted file mode 100644 index b5d10f8d553..00000000000 --- a/distributed/tests/py3_test_client.py +++ /dev/null @@ -1,210 +0,0 @@ -import gc -import sys -from time import sleep -import weakref - -import pytest -from tornado import gen - -from distributed.utils_test import div, gen_cluster, inc, loop, cluster # noqa F401 -from distributed import as_completed, Client, Lock -from distributed.metrics import time -from distributed.utils import sync - - -@gen_cluster(client=True) -def test_await_future(c, s, a, b): - future = c.submit(inc, 1) - - async def f(): # flake8: noqa - result = await future - assert result == 2 - - yield f() - - future = c.submit(div, 1, 0) - - async def f(): - with pytest.raises(ZeroDivisionError): - await future - - yield f() - - -@gen_cluster(client=True) -def test_as_completed_async_for(c, s, a, b): - futures = c.map(inc, range(10)) - ac = as_completed(futures) - results = [] - - async def f(): - async for future in ac: - result = await future - results.append(result) - - yield f() - - assert set(results) == set(range(1, 11)) - - -@gen_cluster(client=True) -def test_as_completed_async_for_results(c, s, a, b): - futures = c.map(inc, range(10)) - ac = as_completed(futures, with_results=True) - results = [] - - async def f(): - async for future, result in ac: - results.append(result) - - yield f() - - assert set(results) == set(range(1, 11)) - assert not s.counters["op"].components[0]["gather"] - - -@gen_cluster(client=True) -def test_as_completed_async_for_cancel(c, s, a, b): - x = c.submit(inc, 1) - y = c.submit(sleep, 0.3) - ac = as_completed([x, y]) - - async def _(): - await gen.sleep(0.1) - await y.cancel(asynchronous=True) - - c.loop.add_callback(_) - - L = [] - - async def f(): - async for future in ac: - L.append(future) - - yield f() - - assert L == [x, y] - - -def test_async_with(loop): - result = None - client = None - cluster = None - - async def f(): - async with Client(processes=False, asynchronous=True) as c: - nonlocal result, client, cluster - result = await c.submit(lambda x: x + 1, 10) - - client = c - cluster = c.cluster - - loop.run_sync(f) - - assert result == 11 - assert client.status == "closed" - assert cluster.status == "closed" - - -def test_locks(loop): - async def f(): - async with Client(processes=False, asynchronous=True) as c: - assert c.asynchronous - async with Lock("x"): - lock2 = Lock("x") - result = await lock2.acquire(timeout=0.1) - assert result is False - - loop.run_sync(f) - - -def test_client_sync_with_async_def(loop): - async def ff(): - await gen.sleep(0.01) - return 1 - - with cluster() as (s, [a, b]): - with Client(s["address"], loop=loop) as c: - assert sync(loop, ff) == 1 - assert c.sync(ff) == 1 - - -@pytest.mark.xfail(reason="known intermittent failure") -@gen_cluster(client=True) -async def test_dont_hold_on_to_large_messages(c, s, a, b): - np = pytest.importorskip("numpy") - da = pytest.importorskip("dask.array") - x = np.random.random(1000000) - xr = weakref.ref(x) - - d = da.from_array(x, chunks=(100000,)) - d = d.persist() - del x - - start = time() - while xr() is not None: - if time() > start + 5: - # Help diagnosing - from types import FrameType - - x = xr() - if x is not None: - del x - rc = sys.getrefcount(xr()) - refs = gc.get_referrers(xr()) - print("refs to x:", rc, refs, gc.isenabled()) - frames = [r for r in refs if isinstance(r, FrameType)] - for i, f in enumerate(frames): - print( - "frames #%d:" % i, - f.f_code.co_name, - f.f_code.co_filename, - sorted(f.f_locals), - ) - pytest.fail("array should have been destroyed") - - await gen.sleep(0.200) - - -@gen_cluster(client=True) -async def test_run_scheduler_async_def(c, s, a, b): - async def f(dask_scheduler): - await gen.sleep(0.01) - dask_scheduler.foo = "bar" - - await c.run_on_scheduler(f) - - assert s.foo == "bar" - - async def f(dask_worker): - await gen.sleep(0.01) - dask_worker.foo = "bar" - - await c.run(f) - assert a.foo == "bar" - assert b.foo == "bar" - - -@gen_cluster(client=True) -async def test_run_scheduler_async_def_wait(c, s, a, b): - async def f(dask_scheduler): - await gen.sleep(0.01) - dask_scheduler.foo = "bar" - - await c.run_on_scheduler(f, wait=False) - - while not hasattr(s, "foo"): - await gen.sleep(0.01) - assert s.foo == "bar" - - async def f(dask_worker): - await gen.sleep(0.01) - dask_worker.foo = "bar" - - await c.run(f, wait=False) - - while not hasattr(a, "foo") or not hasattr(b, "foo"): - await gen.sleep(0.01) - - assert a.foo == "bar" - assert b.foo == "bar" diff --git a/distributed/tests/py3_test_pubsub.py b/distributed/tests/py3_test_pubsub.py deleted file mode 100644 index 294ecfb90c8..00000000000 --- a/distributed/tests/py3_test_pubsub.py +++ /dev/null @@ -1,38 +0,0 @@ -from distributed import Pub, Sub -from distributed.utils_test import gen_cluster - -import asyncio -import toolz -from tornado import gen -import pytest - - -@pytest.mark.xfail(reason="out of order execution") -@gen_cluster(client=True) -def test_basic(c, s, a, b): - async def publish(): - pub = Pub("a") - - i = 0 - while True: - await gen.sleep(0.01) - pub._put(i) - i += 1 - - def f(_): - sub = Sub("a") - return list(toolz.take(5, sub)) - - asyncio.ensure_future(c.run(publish, workers=[a.address])) - - tasks = [c.submit(f, i) for i in range(4)] - results = yield c.gather(tasks) - - for r in results: - x = r[0] - # race conditions and unintended (but correct) messages - # can make this test not true - # assert r == [x, x + 1, x + 2, x + 3, x + 4] - - assert len(r) == 5 - assert all(r[i] < r[i + 1] for i in range(0, 4)), r diff --git a/distributed/tests/py3_test_utils_tst.py b/distributed/tests/py3_test_utils_tst.py deleted file mode 100644 index a4b1e242481..00000000000 --- a/distributed/tests/py3_test_utils_tst.py +++ /dev/null @@ -1,17 +0,0 @@ -from distributed.utils_test import gen_cluster, gen_test -from distributed import Client - -from tornado import gen - - -@gen_cluster() -async def test_gen_cluster_async(s, a, b): # flake8: noqa - async with Client(s.address, asynchronous=True) as c: - future = c.submit(lambda x: x + 1, 1) - result = await future - assert result == 2 - - -@gen_test() -async def test_gen_test_async(): # flake8: noqa - await gen.sleep(0.001) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index b8d5be0b449..02e2574a3ab 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -5704,6 +5704,192 @@ async def test_profile_server(c, s, a, b): assert "slowdec" in str(p) +@gen_cluster(client=True) +def test_await_future(c, s, a, b): + future = c.submit(inc, 1) + + async def f(): # flake8: noqa + result = await future + assert result == 2 + + yield f() + + future = c.submit(div, 1, 0) + + async def f(): + with pytest.raises(ZeroDivisionError): + await future + + yield f() + + +@gen_cluster(client=True) +def test_as_completed_async_for(c, s, a, b): + futures = c.map(inc, range(10)) + ac = as_completed(futures) + results = [] + + async def f(): + async for future in ac: + result = await future + results.append(result) + + yield f() + + assert set(results) == set(range(1, 11)) + + +@gen_cluster(client=True) +def test_as_completed_async_for_results(c, s, a, b): + futures = c.map(inc, range(10)) + ac = as_completed(futures, with_results=True) + results = [] + + async def f(): + async for future, result in ac: + results.append(result) + + yield f() + + assert set(results) == set(range(1, 11)) + assert not s.counters["op"].components[0]["gather"] + + +@gen_cluster(client=True) +def test_as_completed_async_for_cancel(c, s, a, b): + x = c.submit(inc, 1) + y = c.submit(sleep, 0.3) + ac = as_completed([x, y]) + + async def _(): + await gen.sleep(0.1) + await y.cancel(asynchronous=True) + + c.loop.add_callback(_) + + L = [] + + async def f(): + async for future in ac: + L.append(future) + + yield f() + + assert L == [x, y] + + +def test_async_with(loop): + result = None + client = None + cluster = None + + async def f(): + async with Client(processes=False, asynchronous=True) as c: + nonlocal result, client, cluster + result = await c.submit(lambda x: x + 1, 10) + + client = c + cluster = c.cluster + + loop.run_sync(f) + + assert result == 11 + assert client.status == "closed" + assert cluster.status == "closed" + + +def test_client_sync_with_async_def(loop): + async def ff(): + await gen.sleep(0.01) + return 1 + + with cluster() as (s, [a, b]): + with Client(s["address"], loop=loop) as c: + assert sync(loop, ff) == 1 + assert c.sync(ff) == 1 + + +@pytest.mark.xfail(reason="known intermittent failure") +@gen_cluster(client=True) +async def test_dont_hold_on_to_large_messages(c, s, a, b): + np = pytest.importorskip("numpy") + da = pytest.importorskip("dask.array") + x = np.random.random(1000000) + xr = weakref.ref(x) + + d = da.from_array(x, chunks=(100000,)) + d = d.persist() + del x + + start = time() + while xr() is not None: + if time() > start + 5: + # Help diagnosing + from types import FrameType + + x = xr() + if x is not None: + del x + rc = sys.getrefcount(xr()) + refs = gc.get_referrers(xr()) + print("refs to x:", rc, refs, gc.isenabled()) + frames = [r for r in refs if isinstance(r, FrameType)] + for i, f in enumerate(frames): + print( + "frames #%d:" % i, + f.f_code.co_name, + f.f_code.co_filename, + sorted(f.f_locals), + ) + pytest.fail("array should have been destroyed") + + await gen.sleep(0.200) + + +@gen_cluster(client=True) +async def test_run_scheduler_async_def(c, s, a, b): + async def f(dask_scheduler): + await gen.sleep(0.01) + dask_scheduler.foo = "bar" + + await c.run_on_scheduler(f) + + assert s.foo == "bar" + + async def f(dask_worker): + await gen.sleep(0.01) + dask_worker.foo = "bar" + + await c.run(f) + assert a.foo == "bar" + assert b.foo == "bar" + + +@gen_cluster(client=True) +async def test_run_scheduler_async_def_wait(c, s, a, b): + async def f(dask_scheduler): + await gen.sleep(0.01) + dask_scheduler.foo = "bar" + + await c.run_on_scheduler(f, wait=False) + + while not hasattr(s, "foo"): + await gen.sleep(0.01) + assert s.foo == "bar" + + async def f(dask_worker): + await gen.sleep(0.01) + dask_worker.foo = "bar" + + await c.run(f, wait=False) + + while not hasattr(a, "foo") or not hasattr(b, "foo"): + await gen.sleep(0.01) + + assert a.foo == "bar" + assert b.foo == "bar" + + @gen_cluster(client=True) async def test_performance_report(c, s, a, b): da = pytest.importorskip("dask.array") @@ -5718,7 +5904,3 @@ async def test_performance_report(c, s, a, b): assert "bokeh" in data assert "random" in data - - -if sys.version_info >= (3, 5): - from distributed.tests.py3_test_client import * # noqa F401 diff --git a/distributed/tests/test_locks.py b/distributed/tests/test_locks.py index 521a9b46114..4cf756ef178 100644 --- a/distributed/tests/test_locks.py +++ b/distributed/tests/test_locks.py @@ -3,7 +3,7 @@ import pytest -from distributed import Lock, get_client +from distributed import Lock, get_client, Client from distributed.metrics import time from distributed.utils_test import gen_cluster from distributed.utils_test import client, cluster_fixture, loop # noqa F401 @@ -128,3 +128,13 @@ def f(x, lock=None): lock2 = pickle.loads(pickle.dumps(lock)) assert lock2.name == lock.name assert lock2.client is lock.client + + +@pytest.mark.asyncio +async def test_locks(): + async with Client(processes=False, asynchronous=True) as c: + assert c.asynchronous + async with Lock("x"): + lock2 = Lock("x") + result = await lock2.acquire(timeout=0.1) + assert result is False diff --git a/distributed/tests/test_pubsub.py b/distributed/tests/test_pubsub.py index 555afb71a73..847b0b88bf0 100644 --- a/distributed/tests/test_pubsub.py +++ b/distributed/tests/test_pubsub.py @@ -1,13 +1,14 @@ -import sys +import asyncio from time import sleep +import pytest +from tornado import gen +import toolz + from distributed import Pub, Sub, wait, get_worker, TimeoutError from distributed.utils_test import gen_cluster from distributed.metrics import time -import pytest -from tornado import gen - @gen_cluster(client=True, timeout=None) def test_speed(c, s, a, b): @@ -135,5 +136,32 @@ async def test_repr(c, s, a, b): assert "Sub" in str(sub) -if sys.version_info >= (3, 5): - from distributed.tests.py3_test_pubsub import * # noqa: F401, F403 +@pytest.mark.xfail(reason="out of order execution") +@gen_cluster(client=True) +def test_basic(c, s, a, b): + async def publish(): + pub = Pub("a") + + i = 0 + while True: + await gen.sleep(0.01) + pub._put(i) + i += 1 + + def f(_): + sub = Sub("a") + return list(toolz.take(5, sub)) + + asyncio.ensure_future(c.run(publish, workers=[a.address])) + + tasks = [c.submit(f, i) for i in range(4)] + results = yield c.gather(tasks) + + for r in results: + x = r[0] + # race conditions and unintended (but correct) messages + # can make this test not true + # assert r == [x, x + 1, x + 2, x + 3, x + 4] + + assert len(r) == 5 + assert all(r[i] < r[i + 1] for i in range(0, 4)), r diff --git a/distributed/tests/test_utils_test.py b/distributed/tests/test_utils_test.py index 1c6802b5637..6a4a5ceaa5e 100755 --- a/distributed/tests/test_utils_test.py +++ b/distributed/tests/test_utils_test.py @@ -1,6 +1,5 @@ from contextlib import contextmanager import socket -import sys import threading from time import sleep @@ -180,5 +179,14 @@ async def test_tls_scheduler(security, cleanup): assert s.address.startswith("tls") -if sys.version_info >= (3, 5): - from distributed.tests.py3_test_utils_tst import * # noqa: F401, F403 +@gen_cluster() +async def test_gen_cluster_async(s, a, b): # flake8: noqa + async with Client(s.address, asynchronous=True) as c: + future = c.submit(lambda x: x + 1, 1) + result = await future + assert result == 2 + + +@gen_test() +async def test_gen_test_async(): # flake8: noqa + await gen.sleep(0.001) From f15abc58718fa4ebb457cbbe76f29765bb7b7bd9 Mon Sep 17 00:00:00 2001 From: Tom Rochette Date: Sat, 30 Nov 2019 10:59:21 -0500 Subject: [PATCH 0582/1550] Fix distributed.wait documentation (#3289) Properly display the return documentation. Add documentation for the return_when parameter. --- distributed/client.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/distributed/client.py b/distributed/client.py index 5ca517ae219..02337bb525f 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -4102,6 +4102,10 @@ def wait(fs, timeout=None, return_when=ALL_COMPLETED): fs: list of futures timeout: number, optional Time in seconds after which to raise a ``dask.distributed.TimeoutError`` + return_when: str, optional + One of `ALL_COMPLETED` or `FIRST_COMPLETED` + + Returns ------- Named tuple of completed, not completed """ From 67007ba9f8f4a19e5bd520291a7756e7c1681a37 Mon Sep 17 00:00:00 2001 From: Benjamin Zaitlen Date: Mon, 2 Dec 2019 11:32:50 -0500 Subject: [PATCH 0583/1550] xfail ucx empty object typed dataframe (#3279) --- distributed/comm/tests/test_ucx.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/distributed/comm/tests/test_ucx.py b/distributed/comm/tests/test_ucx.py index a9207e72e7a..42b5275f5b0 100644 --- a/distributed/comm/tests/test_ucx.py +++ b/distributed/comm/tests/test_ucx.py @@ -170,7 +170,10 @@ def test_ucx_deserialize(): lambda cudf: cudf.DataFrame([1]).head(0), lambda cudf: cudf.DataFrame([1.0]).head(0), lambda cudf: cudf.DataFrame({"a": []}), - lambda cudf: cudf.DataFrame({"a": ["a"]}).head(0), + pytest.param( + lambda cudf: cudf.DataFrame({"a": ["a"]}).head(0), + marks=pytest.mark.xfail(reason="0 length objects don't deseralize cleanly"), + ), lambda cudf: cudf.DataFrame({"a": [1.0]}).head(0), lambda cudf: cudf.DataFrame({"a": [1]}).head(0), lambda cudf: cudf.DataFrame({"a": [1, 2, None], "b": [1.0, 2.0, None]}), From a6b8356918bfa9f832ebf858edee2d58b03392a2 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Mon, 2 Dec 2019 15:29:07 -0600 Subject: [PATCH 0584/1550] Avoid repeatedly adding deps to already in memory stack (#3293) --- distributed/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index bd627f608aa..7e799715ed4 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1608,7 +1608,7 @@ def update_graph( else: child_deps = self.dependencies[dep] if all(d in done for d in child_deps): - if dep in self.tasks: + if dep in self.tasks and dep not in done: done.add(dep) stack.append(dep) From e0f075eeedf0b9245ecd66702af4a35bce167f12 Mon Sep 17 00:00:00 2001 From: Benjamin Zaitlen Date: Mon, 2 Dec 2019 18:40:52 -0500 Subject: [PATCH 0585/1550] Fix asynchronous listener in UCX (#3292) --- distributed/comm/tests/test_ucx.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/distributed/comm/tests/test_ucx.py b/distributed/comm/tests/test_ucx.py index 42b5275f5b0..6fc892176a9 100644 --- a/distributed/comm/tests/test_ucx.py +++ b/distributed/comm/tests/test_ucx.py @@ -36,12 +36,12 @@ async def handle_comm(comm): await q.put(comm) listener = listen(listen_addr, handle_comm, connection_args=listen_args, **kwargs) - with listener: + async with listener: comm = await connect( listener.contact_address, connection_args=connect_args, **kwargs ) - serv_com = await q.get() - return comm, serv_com + serv_comm = await q.get() + return (comm, serv_comm) @pytest.mark.asyncio @@ -97,7 +97,7 @@ async def handle_comm(comm): assert comm.closed listener = ucx.UCXListener(address, handle_comm) - listener.start() + await listener.start() host, port = listener.get_host_port() assert host.count(".") == 3 assert port > 0 @@ -174,10 +174,19 @@ def test_ucx_deserialize(): lambda cudf: cudf.DataFrame({"a": ["a"]}).head(0), marks=pytest.mark.xfail(reason="0 length objects don't deseralize cleanly"), ), - lambda cudf: cudf.DataFrame({"a": [1.0]}).head(0), - lambda cudf: cudf.DataFrame({"a": [1]}).head(0), + pytest.param( + lambda cudf: cudf.DataFrame({"a": [1.0]}).head(0), + marks=pytest.mark.xfail(reason="0 length objects don't deseralize cleanly"), + ), + pytest.param( + lambda cudf: cudf.DataFrame({"a": [1]}).head(0), + marks=pytest.mark.xfail(reason="0 length objects don't deseralize cleanly"), + ), lambda cudf: cudf.DataFrame({"a": [1, 2, None], "b": [1.0, 2.0, None]}), - lambda cudf: cudf.DataFrame({"a": ["Check", "str"], "b": ["Sup", "port"]}), + pytest.param( + lambda cudf: cudf.DataFrame({"a": ["Check", "str"], "b": ["Sup", "port"]}), + marks=pytest.mark.xfail(reason="0 length objects don't deseralize cleanly"), + ), ], ) async def test_ping_pong_cudf(g): @@ -269,7 +278,8 @@ async def test_ping_pong_numba(cleanup): @pytest.mark.asyncio async def test_ucx_localcluster(processes, cleanup): async with LocalCluster( - protocol="ucx", + protocol="ucx:://", + host=HOST, dashboard_address=None, n_workers=2, threads_per_worker=1, From 0259acc066c2f3210077b9e19f1d647f58173dd0 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Tue, 3 Dec 2019 16:37:22 +0100 Subject: [PATCH 0586/1550] worker.close() awaits batched_stream.close() (#3291) --- distributed/batched.py | 9 ++++++--- distributed/tests/test_scheduler.py | 6 ++++-- distributed/worker.py | 3 ++- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/distributed/batched.py b/distributed/batched.py index a3207b333ef..e066fcf7588 100644 --- a/distributed/batched.py +++ b/distributed/batched.py @@ -123,13 +123,16 @@ def send(self, msg): self.waker.set() @gen.coroutine - def close(self): - """ Flush existing messages and then close comm """ + def close(self, timeout=None): + """ Flush existing messages and then close comm + + If set, raises `tornado.util.TimeoutError` after a timeout. + """ if self.comm is None: return self.please_stop = True self.waker.set() - yield self.stopped.wait() + yield self.stopped.wait(timeout=timeout) if not self.comm.closed(): try: if self.buffer: diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 4e0e9a8710c..859f56fef42 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -1512,8 +1512,10 @@ def test_idle_timeout(c, s, a, b): yield gen.sleep(0.01) assert time() < start + 3 - assert a.status == "closed" - assert b.status == "closed" + start = time() + while not (a.status == "closed" and b.status == "closed"): + yield gen.sleep(0.01) + assert time() < start + 1 @gen_cluster(client=True, config={"distributed.scheduler.bandwidth": "100 GB"}) diff --git a/distributed/worker.py b/distributed/worker.py index 90059002e9a..ebddd042551 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1089,7 +1089,8 @@ async def close( self.batched_stream.send({"op": "close-stream"}) if self.batched_stream: - self.batched_stream.close() + with ignoring(gen.TimeoutError): + await self.batched_stream.close(timedelta(seconds=timeout)) self.actor_executor._work_queue.queue.clear() if isinstance(self.executor, ThreadPoolExecutor): From 54ce6785f82f9d1e180f303c4b2983f25c4db7ed Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Tue, 3 Dec 2019 12:58:29 -0600 Subject: [PATCH 0587/1550] Update SSHCluster docstring parameters (#3296) --- distributed/deploy/ssh.py | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/distributed/deploy/ssh.py b/distributed/deploy/ssh.py index 673cb7ba717..1f49f187a14 100644 --- a/distributed/deploy/ssh.py +++ b/distributed/deploy/ssh.py @@ -219,21 +219,16 @@ def SSHCluster( Parameters ---------- hosts: List[str] - List of hostnames or addresses on which to launch our cluster - The first will be used for the scheduler and the rest for workers - connect_options: - Keywords to pass through to asyncssh.connect - known_hosts: List[str] or None - The list of keys which will be used to validate the server host - key presented during the SSH handshake. If this is not specified, - the keys will be looked up in the file .ssh/known_hosts. If this - is explicitly set to None, server host key validation will be disabled. - worker_options: - Keywords to pass on to dask-worker - scheduler_options: - Keywords to pass on to dask-scheduler - worker_module: - Python module to call to start the worker + List of hostnames or addresses on which to launch our cluster. + The first will be used for the scheduler and the rest for workers. + connect_options: dict, optional + Keywords to pass through to ``asyncssh.connect``. + worker_options: dict, optional + Keywords to pass on to workers. + scheduler_options: dict, optional + Keywords to pass on to scheduler. + worker_module: str, optional + Python module to call to start the worker. Examples -------- From b60c4bfc5f46a9aa2f1aacbda21eac38abcc807e Mon Sep 17 00:00:00 2001 From: Benjamin Zaitlen Date: Tue, 3 Dec 2019 16:37:42 -0500 Subject: [PATCH 0588/1550] forgot to fix slow test (#3297) --- distributed/comm/tests/test_ucx.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/distributed/comm/tests/test_ucx.py b/distributed/comm/tests/test_ucx.py index 6fc892176a9..7725bfa2432 100644 --- a/distributed/comm/tests/test_ucx.py +++ b/distributed/comm/tests/test_ucx.py @@ -303,7 +303,11 @@ async def test_stress(cleanup): chunksize = "10 MB" async with LocalCluster( - protocol="ucx", dashboard_address=None, asynchronous=True, processes=False + protocol="ucx", + dashboard_address=None, + asynchronous=True, + processes=False, + host=HOST, ) as cluster: async with Client(cluster, asynchronous=True) as client: rs = da.random.RandomState() From 243fd0a8e665b5589982d647aed821578e413810 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 3 Dec 2019 15:51:29 -0800 Subject: [PATCH 0589/1550] Add title to performance_report (#3298) --- distributed/scheduler.py | 3 ++- distributed/tests/test_client.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 7e799715ed4..3ad82ece8c7 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4762,9 +4762,10 @@ def profile_to_figure(state): ] ) - from bokeh.plotting import save + from bokeh.plotting import save, output_file with tmpfile(extension=".html") as fn: + output_file(filename=fn, title="Dask Performance Report") save(tabs, filename=fn) with open(fn) as f: diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 02e2574a3ab..7e5442d01a5 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -5904,3 +5904,4 @@ async def test_performance_report(c, s, a, b): assert "bokeh" in data assert "random" in data + assert "Dask Performance Report" in data From 4a8a4f3bce378406e83a099e8a12fc9bc12ef25c Mon Sep 17 00:00:00 2001 From: jakirkham Date: Tue, 3 Dec 2019 22:05:56 -0500 Subject: [PATCH 0590/1550] Updates RMM comment to the correct release (#3299) This will land in RMM 0.11.0. So update this comment to be a bit more precise. --- distributed/comm/ucx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index 7c783b605a1..175d628a0f6 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -48,7 +48,7 @@ def init_once(): if hasattr(rmm, "DeviceBuffer"): cuda_array = lambda n: rmm.DeviceBuffer(size=n) - else: # pre-0.12.0 + else: # pre-0.11.0 cuda_array = lambda n: rmm.device_array(n, dtype=np.uint8) except ImportError: try: From b1e7ab4e599b8f3e93ffc3c4c95a471e923ac396 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 4 Dec 2019 17:16:42 -0800 Subject: [PATCH 0591/1550] Support multiple listeners in the scheduler (#3288) This allows the scheduler to listen on multiple different listeners at once. This can be useful when the client and workers are on different interfaces, or have different security concerns. --- distributed/comm/addressing.py | 52 ++++++++++++++++++++++++++++- distributed/core.py | 20 +++++++---- distributed/scheduler.py | 17 +++++----- distributed/tests/test_scheduler.py | 19 +++++++++++ 4 files changed, 93 insertions(+), 15 deletions(-) diff --git a/distributed/comm/addressing.py b/distributed/comm/addressing.py index f0c18b9fbda..2b1c4717407 100644 --- a/distributed/comm/addressing.py +++ b/distributed/comm/addressing.py @@ -1,3 +1,4 @@ +import itertools import dask from . import registry @@ -206,6 +207,54 @@ def uri_from_host_port(host_arg, port_arg, default_port): return addr +def addresses_from_user_args( + host=None, + port=None, + interface=None, + protocol=None, + peer=None, + security=None, + default_port=0, +) -> list: + """ Get a list of addresses if the inputs are lists + + This is like ``address_from_user_args`` except that it also accepts lists + for some of the arguments. If these arguments are lists then it will map + over them accordingly. + + Examples + -------- + >>> addresses_from_user_args(host="127.0.0.1", protocol=["inproc", "tcp"]) + ["inproc://127.0.0.1:", "tcp://127.0.0.1:"] + """ + + def listify(obj): + if isinstance(obj, (tuple, list)): + return obj + else: + return itertools.repeat(obj) + + if any(isinstance(x, (tuple, list)) for x in (host, port, interface, protocol)): + return [ + address_from_user_args( + host=h, + port=p, + interface=i, + protocol=pr, + peer=peer, + security=security, + default_port=default_port, + ) + for h, p, i, pr in zip(*map(listify, (host, port, interface, protocol))) + ] + else: + return [ + address_from_user_args( + host, port, interface, protocol, peer, security, default_port + ) + ] + + def address_from_user_args( host=None, port=None, @@ -214,8 +263,9 @@ def address_from_user_args( peer=None, security=None, default_port=0, -): +) -> str: """ Get an address to listen on from common user provided arguments """ + if security and security.require_encryption and not protocol: protocol = "tls" diff --git a/distributed/core.py b/distributed/core.py index ec6e9969fea..bf734070248 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -136,7 +136,7 @@ def __init__( self._ongoing_coroutines = weakref.WeakSet() self._event_finished = Event() - self.listener = None + self.listeners = [] self.io_loop = io_loop or IOLoop.current() self.loop = self.io_loop @@ -221,7 +221,7 @@ def start_pcs(): def stop(self): if not self.__stopped: self.__stopped = True - if self.listener is not None: + for listener in self.listeners: # Delay closing the server socket until the next IO loop tick. # Otherwise race conditions can appear if an event handler # for an accept() call is already scheduled by the IO loop, @@ -229,7 +229,14 @@ def stop(self): # The demonstrator for this is Worker.terminate(), which # closes the server socket in response to an incoming message. # See https://github.com/tornadoweb/tornado/issues/2069 - self.io_loop.add_callback(self.listener.stop) + self.io_loop.add_callback(listener.stop) + + @property + def listener(self): + if self.listeners: + return self.listeners[0] + else: + return None def _measure_tick(self): now = time() @@ -305,13 +312,14 @@ async def listen(self, port_or_addr=None, listen_args=None): else: addr = port_or_addr assert isinstance(addr, str) - self.listener = listen( + listener = listen( addr, self.handle_comm, deserialize=self.deserialize, connection_args=listen_args, ) - await self.listener.start() + await listener.start() + self.listeners.append(listener) async def handle_comm(self, comm, shutting_down=shutting_down): """ Dispatch new communications to coroutine-handlers @@ -487,7 +495,7 @@ async def handle_stream(self, comm, extra=None, every_cycle=[]): def close(self): for pc in self.periodic_callbacks.values(): pc.stop() - if self.listener: + for listener in self.listeners: self.listener.stop() for i in range(20): # let comms close naturally for a second if not self._comms: diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 3ad82ece8c7..ab78ef05370 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -36,7 +36,7 @@ get_address_host, unparse_host_port, ) -from .comm.addressing import address_from_user_args +from .comm.addressing import addresses_from_user_args from .core import rpc, connect, send_recv, clean_exception, CommClosedError from .diagnostics.plugin import SchedulerPlugin from . import profile @@ -1118,7 +1118,7 @@ def __init__( connection_limit = get_fileno_limit() / 2 - self._start_address = address_from_user_args( + self._start_address = addresses_from_user_args( host=host, port=port, interface=interface, @@ -1215,14 +1215,15 @@ async def start(self): c.cancel() if self.status != "running": - await self.listen(self._start_address, listen_args=self.listen_args) - self.ip = get_address_host(self.listen_address) - listen_ip = self.ip + for addr in self._start_address: + await self.listen(addr, listen_args=self.listen_args) + self.ip = get_address_host(self.listen_address) + listen_ip = self.ip - if listen_ip == "0.0.0.0": - listen_ip = "" + if listen_ip == "0.0.0.0": + listen_ip = "" - if self._start_address.startswith("inproc://"): + if self.address.startswith("inproc://"): listen_ip = "localhost" # Services listen on all addresses diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 859f56fef42..ded056d8f96 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -1849,3 +1849,22 @@ def reducer(x, y): if "reducer" in key and finish == "processing": finish_processing_transitions += 1 assert finish_processing_transitions == 1 + + +@pytest.mark.asyncio +async def test_multiple_listeners(cleanup): + async with Scheduler(port=0, protocol=["inproc", "tcp"]) as s: + async with Worker(s.listeners[0].contact_address) as a: + async with Worker(s.listeners[1].contact_address) as b: + assert a.address.startswith("inproc") + assert a.scheduler.address.startswith("inproc") + assert b.address.startswith("tcp") + assert b.scheduler.address.startswith("tcp") + + async with Client(s.address, asynchronous=True) as c: + futures = c.map(inc, range(20)) + await wait(futures) + + # Force inter-worker communication both ways + await c.submit(sum, futures, workers=[a.address]) + await c.submit(len, futures, workers=[b.address]) From 73b6bf989b3fbd5e164acb448bddf6334b7779e6 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Wed, 4 Dec 2019 20:46:47 -0600 Subject: [PATCH 0592/1550] Skip Security.temporary() tests if cryptography not installed (#3302) --- distributed/deploy/tests/test_local.py | 3 +++ distributed/tests/test_security.py | 2 ++ 2 files changed, 5 insertions(+) diff --git a/distributed/deploy/tests/test_local.py b/distributed/deploy/tests/test_local.py index 0f8eb6d8901..370423771a5 100644 --- a/distributed/deploy/tests/test_local.py +++ b/distributed/deploy/tests/test_local.py @@ -250,6 +250,7 @@ def test_Client_twice(loop): @pytest.mark.asyncio async def test_client_constructor_with_temporary_security(cleanup): + pytest.importorskip("cryptography") async with Client( security=True, silence_logs=False, dashboard_address=None, asynchronous=True ) as c: @@ -709,6 +710,7 @@ def test_adapt_then_manual(loop): @pytest.mark.parametrize("temporary", [True, False]) def test_local_tls(loop, temporary): if temporary: + pytest.importorskip("cryptography") security = True else: security = tls_only_security() @@ -989,6 +991,7 @@ async def test_repr(cleanup): @pytest.mark.parametrize("temporary", [True, False]) async def test_capture_security(cleanup, temporary): if temporary: + pytest.importorskip("cryptography") security = True else: security = tls_only_security() diff --git a/distributed/tests/test_security.py b/distributed/tests/test_security.py index 8e9007f20f0..002e63d2855 100644 --- a/distributed/tests/test_security.py +++ b/distributed/tests/test_security.py @@ -379,6 +379,8 @@ def check_encryption_error(): def test_temporary_credentials(): + pytest.importorskip("cryptography") + sec = Security.temporary() sec_repr = repr(sec) fields = ["tls_ca_file"] From 4e9eb461df3d7ac4a693d094f488fbd4b44676ac Mon Sep 17 00:00:00 2001 From: byjott Date: Thu, 5 Dec 2019 16:33:29 +0100 Subject: [PATCH 0593/1550] Retry operations on network issues (#3294) We operate distributed in the cloud and see tcp connection aborts. Unfortunately, distributed often does not recover cleanly from such situations, although a simple re-try would have helped in most cases. This PR proposes to add a more generic retry to some operations. Notes: - only some operations are re-tried, as for some operations, triggering it twice may have undesired effects. There are probably more operations that can / should be re-tried, so this is just a start for operations where it's "obviously" safe to retry. - parameters for the re-tries (maximum number of retry attempts, delay between re-tries) is configurable. The default is to not re-try at all to not change the current behavior (some might rely on/prefer seeing all connection failures, fast) --- dev-requirements.txt | 1 + distributed/client.py | 7 ++- distributed/distributed.yaml | 5 ++ distributed/scheduler.py | 17 ++++-- distributed/tests/test_scheduler.py | 44 ++++++++------- distributed/tests/test_utils_comm.py | 75 ++++++++++++++++++++++++- distributed/utils_comm.py | 83 ++++++++++++++++++++++++++++ distributed/worker.py | 38 ++++--------- 8 files changed, 214 insertions(+), 56 deletions(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index cd79b3e4317..a367f706e76 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -10,3 +10,4 @@ ipykernel >= 4.5.2 pytest >= 3.2 prometheus_client >= 0.6.0 jupyter-server-proxy >= 1.1.0 +pytest-asyncio diff --git a/distributed/client.py b/distributed/client.py index 02337bb525f..027e24afeee 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -53,6 +53,7 @@ pack_data, scatter_to_workers, gather_from_workers, + retry_operation, ) from .cfexecutor import ClientExecutor from .core import connect, rpc, clean_exception, CommClosedError, PooledRPCCall @@ -1794,19 +1795,19 @@ async def _gather_remote(self, direct, local_worker): try: if direct or local_worker: # gather directly from workers - who_has = await self.scheduler.who_has(keys=keys) + who_has = await retry_operation(self.scheduler.who_has, keys=keys) data2, missing_keys, missing_workers = await gather_from_workers( who_has, rpc=self.rpc, close=False ) response = {"status": "OK", "data": data2} if missing_keys: keys2 = [key for key in keys if key not in data2] - response = await self.scheduler.gather(keys=keys2) + response = await retry_operation(self.scheduler.gather, keys=keys2) if response["status"] == "OK": response["data"].update(data2) else: # ask scheduler to gather data for us - response = await self.scheduler.gather(keys=keys) + response = await retry_operation(self.scheduler.gather, keys=keys) finally: self._gather_semaphore.release() diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index ae42162bb2f..e6c6a49b484 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -72,6 +72,11 @@ distributed: lost-worker-timeout: 15s # Interval after which to hard-close a lost worker job comm: + retry: # some operations (such as gathering data) are subject to re-tries with the below parameters + count: 0 # the maximum retry attempts. 0 disables re-trying. + delay: + min: 1s # the first non-zero delay between re-tries + max: 20s # the maximum delay between re-tries compression: auto offload: 10MiB # Size after which we choose to offload serialization to another thread default-scheme: tcp diff --git a/distributed/scheduler.py b/distributed/scheduler.py index ab78ef05370..eeb7eb49732 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -59,7 +59,7 @@ shutting_down, tmpfile, ) -from .utils_comm import scatter_to_workers, gather_from_workers +from .utils_comm import scatter_to_workers, gather_from_workers, retry_operation from .utils_perf import enable_gc_diagnosis, disable_gc_diagnosis from .publish import PublishExtension @@ -2786,7 +2786,10 @@ async def rebalance(self, comm=None, keys=None, workers=None): to_senders[sender.address].append(ts.key) result = await asyncio.gather( - *(self.rpc(addr=r).gather(who_has=v) for r, v in to_recipients.items()) + *( + retry_operation(self.rpc(addr=r).gather, who_has=v) + for r, v in to_recipients.items() + ) ) for r, v in to_recipients.items(): self.log_event(r, {"action": "rebalance", "who_has": v}) @@ -2819,7 +2822,7 @@ async def rebalance(self, comm=None, keys=None, workers=None): await asyncio.gather( *( - self.rpc(addr=r).delete_data(keys=v, report=False) + retry_operation(self.rpc(addr=r).delete_data, keys=v, report=False) for r, v in to_senders.items() ) ) @@ -2887,8 +2890,10 @@ async def replicate( await asyncio.gather( *( - self.rpc(addr=ws.address).delete_data( - keys=[ts.key for ts in tasks], report=False + retry_operation( + self.rpc(addr=ws.address).delete_data, + keys=[ts.key for ts in tasks], + report=False, ) for ws, tasks in del_worker_tasks.items() ) @@ -2922,7 +2927,7 @@ async def replicate( results = await asyncio.gather( *( - self.rpc(addr=w).gather(who_has=who_has) + retry_operation(self.rpc(addr=w).gather, who_has=who_has) for w, who_has in gathers.items() ) ) diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index ded056d8f96..7bf3f456085 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -7,6 +7,7 @@ import operator import sys from time import sleep +from unittest import mock import logging import dask @@ -1749,7 +1750,8 @@ async def test_gather_failing_cnn_recover(c, s, a, b): x = await c.scatter({"x": 1}, workers=a.address) s.rpc = FlakyConnectionPool(failing_connections=1) - res = await s.gather(keys=["x"]) + with mock.patch("distributed.utils_comm.retry_count", 1): + res = await s.gather(keys=["x"]) assert res["status"] == "OK" @@ -1811,25 +1813,31 @@ def reducer(x, y): s.rpc = FlakyConnectionPool(failing_connections=4) - with captured_logger(logging.getLogger("distributed.scheduler")) as sched_logger: - with captured_logger(logging.getLogger("distributed.client")) as client_logger: - with captured_logger( - logging.getLogger("distributed.worker") - ) as worker_logger: - # Gather using the client (as an ordinary user would) - # Upon a missing key, the client will reschedule the computations - res = await c.gather(z) + with captured_logger( + logging.getLogger("distributed.scheduler") + ) as sched_logger, captured_logger( + logging.getLogger("distributed.client") + ) as client_logger, captured_logger( + logging.getLogger("distributed.utils_comm") + ) as utils_comm_logger, mock.patch( + "distributed.utils_comm.retry_count", 3 + ), mock.patch( + "distributed.utils_comm.retry_delay_min", 0.5 + ): + # Gather using the client (as an ordinary user would) + # Upon a missing key, the client will reschedule the computations + res = await c.gather(z) assert res == 5 sched_logger = sched_logger.getvalue() client_logger = client_logger.getvalue() - worker_logger = worker_logger.getvalue() + utils_comm_logger = utils_comm_logger.getvalue() # Ensure that the communication was done via the scheduler, i.e. we actually hit a bad connection assert s.rpc.cnn_count > 0 - assert "Encountered connection issue during data collection" in worker_logger + assert "Retrying get_data_from_worker after exception" in utils_comm_logger # The reducer task was actually not found upon first collection. The client will reschedule the graph assert "Couldn't gather 1 keys, rescheduling" in client_logger @@ -1841,14 +1849,12 @@ def reducer(x, y): # that the scheduler again knows about the result. # The final reduce step should then be used from the re-connected worker # instead of recomputing it. - - starts = [] - finish_processing_transitions = 0 - for transition in s.transition_log: - key, start, finish, recommendations, timestamp = transition - if "reducer" in key and finish == "processing": - finish_processing_transitions += 1 - assert finish_processing_transitions == 1 + transitions_to_processing = [ + (key, start, timestamp) + for key, start, finish, recommendations, timestamp in s.transition_log + if finish == "processing" and "reducer" in key + ] + assert len(transitions_to_processing) == 1 @pytest.mark.asyncio diff --git a/distributed/tests/test_utils_comm.py b/distributed/tests/test_utils_comm.py index f66d3ba62d5..3f26fae623a 100644 --- a/distributed/tests/test_utils_comm.py +++ b/distributed/tests/test_utils_comm.py @@ -1,7 +1,11 @@ from distributed.core import ConnectionPool from distributed.comm import Comm -from distributed.utils_test import gen_cluster -from distributed.utils_comm import pack_data, gather_from_workers +from distributed.utils_test import gen_cluster, loop # noqa: F401 +from distributed.utils_comm import pack_data, gather_from_workers, retry + +from unittest import mock + +import pytest def test_pack_data(): @@ -58,3 +62,70 @@ def test_gather_from_workers_permissive_flaky(c, s, a, b): assert missing == {"x": [a.address]} assert bad_workers == [a.address] + + +def test_retry_no_exception(loop): + n_calls = 0 + retval = object() + + async def coro(): + nonlocal n_calls + n_calls += 1 + return retval + + assert ( + loop.run_sync(lambda: retry(coro, count=0, delay_min=-1, delay_max=-1)) + is retval + ) + assert n_calls == 1 + + +def test_retry0_raises_immediately(loop): + # test that using max_reties=0 raises after 1 call + + n_calls = 0 + + async def coro(): + nonlocal n_calls + n_calls += 1 + raise RuntimeError(f"RT_ERROR {n_calls}") + + with pytest.raises(RuntimeError, match="RT_ERROR 1"): + loop.run_sync(lambda: retry(coro, count=0, delay_min=-1, delay_max=-1)) + + assert n_calls == 1 + + +def test_retry_does_retry_and_sleep(loop): + # test the retry and sleep pattern of `retry` + n_calls = 0 + + class MyEx(Exception): + pass + + async def coro(): + nonlocal n_calls + n_calls += 1 + raise MyEx(f"RT_ERROR {n_calls}") + + sleep_calls = [] + + async def my_sleep(amount): + sleep_calls.append(amount) + return + + with mock.patch("asyncio.sleep", my_sleep): + with pytest.raises(MyEx, match="RT_ERROR 6"): + loop.run_sync( + lambda: retry( + coro, + retry_on_exceptions=(MyEx,), + count=5, + delay_min=1.0, + delay_max=6.0, + jitter_fraction=0.0, + ) + ) + + assert n_calls == 6 + assert sleep_calls == [0.0, 1.0, 3.0, 6.0, 6.0] diff --git a/distributed/utils_comm.py b/distributed/utils_comm.py index e2072189be0..cb614602f7b 100644 --- a/distributed/utils_comm.py +++ b/distributed/utils_comm.py @@ -1,10 +1,13 @@ import asyncio from collections import defaultdict +from functools import partial from itertools import cycle import logging import random from dask.optimization import SubgraphCallable +import dask.config +from dask.utils import parse_timedelta from toolz import merge, concat, groupby, drop from .core import rpc @@ -275,3 +278,83 @@ def pack_data(o, d, key_types=object): return {k: pack_data(v, d, key_types=key_types) for k, v in o.items()} else: return o + + +retry_count = dask.config.get("distributed.comm.retry.count") +retry_delay_min = parse_timedelta( + dask.config.get("distributed.comm.retry.delay.min"), default="s" +) +retry_delay_max = parse_timedelta( + dask.config.get("distributed.comm.retry.delay.max"), default="s" +) + + +async def retry( + coro, + count, + delay_min, + delay_max, + jitter_fraction=0.1, + retry_on_exceptions=(EnvironmentError, IOError), + operation=None, +): + """ + Return the result of ``await coro()``, re-trying in case of exceptions + + The delay between attempts is ``delay_min * (2 ** i - 1)`` where ``i`` enumerates the attempt that just failed + (starting at 0), but never larger than ``delay_max``. + This yields no delay between the first and second attempt, then ``delay_min``, ``3 * delay_min``, etc. + (The reason to re-try with no delay is that in most cases this is sufficient and will thus recover faster + from a communication failure). + + Parameters + ---------- + coro + The coroutine function to call and await + count + The maximum number of re-tries before giving up. 0 means no re-try; must be >= 0. + delay_min + The base factor for the delay (in seconds); this is the first non-zero delay between re-tries. + delay_max + The maximum delay (in seconds) between consecutive re-tries (without jitter) + jitter_fraction + The maximum jitter to add to the delay, as fraction of the total delay. No jitter is added if this + value is <= 0. + Using a non-zero value here avoids "herd effects" of many operations re-tried at the same time + retry_on_exceptions + A tuple of exception classes to retry. Other exceptions are not caught and re-tried, but propagate immediately. + operation + A human-readable description of the operation attempted; used only for logging failures + + Returns + ------- + Any + Whatever `await `coro()` returned + """ + # this loop is a no-op in case max_retries<=0 + for i_try in range(count): + try: + return await coro() + except retry_on_exceptions as ex: + operation = operation or str(coro) + logger.info( + f"Retrying {operation} after exception in attempt {i_try}/{count}: {ex}" + ) + delay = min(delay_min * (2 ** i_try - 1), delay_max) + if jitter_fraction > 0: + delay *= 1 + random.random() * jitter_fraction + await asyncio.sleep(delay) + return await coro() + + +async def retry_operation(coro, *args, operation=None, **kwargs): + """ + Retry an operation using the configuration values for the retry parameters + """ + return await retry( + partial(coro, *args, **kwargs), + count=retry_count, + delay_min=retry_delay_min, + delay_max=retry_delay_max, + operation=operation, + ) diff --git a/distributed/worker.py b/distributed/worker.py index ebddd042551..0326cbd157a 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -63,7 +63,7 @@ warn_on_duration, LRU, ) -from .utils_comm import pack_data, gather_from_workers +from .utils_comm import pack_data, gather_from_workers, retry_operation from .utils_perf import ThrottledGC, enable_gc_diagnosis, disable_gc_diagnosis logger = logging.getLogger(__name__) @@ -868,7 +868,8 @@ async def heartbeat(self): logger.debug("Heartbeat: %s" % self.address) try: start = time() - response = await self.scheduler.heartbeat_worker( + response = await retry_operation( + self.scheduler.heartbeat_worker, address=self.contact_address, now=time(), metrics=await self.get_metrics(), @@ -1299,8 +1300,10 @@ async def set_resources(self, **resources): self.available_resources[r] = quantity self.total_resources[r] = quantity - await self.scheduler.set_resources( - resources=self.total_resources, worker=self.contact_address + await retry_operation( + self.scheduler.set_resources, + resources=self.total_resources, + worker=self.contact_address, ) ################### @@ -2047,7 +2050,7 @@ async def handle_missing_dep(self, *deps, **kwargs): self.suspicious_deps[dep], ) - who_has = await self.scheduler.who_has(keys=list(deps)) + who_has = await retry_operation(self.scheduler.who_has, keys=list(deps)) who_has = {k: v for k, v in who_has.items() if v} self.update_who_has(who_has) for dep in deps: @@ -2081,7 +2084,7 @@ async def handle_missing_dep(self, *deps, **kwargs): async def query_who_has(self, *deps): with log_errors(): - response = await self.scheduler.who_has(keys=deps) + response = await retry_operation(self.scheduler.who_has, keys=deps) self.update_who_has(response) return response @@ -3132,10 +3135,7 @@ async def get_data_from_worker( if deserializers is None: deserializers = rpc.deserializers - retry_count = 0 - max_retries = 3 - - while True: + async def _get_data(): comm = await rpc.connect(worker) comm.name = "Ephemeral Worker->Worker for gather" try: @@ -3155,25 +3155,11 @@ async def get_data_from_worker( else: if status == "OK": await comm.write("OK") - break - except (EnvironmentError, CommClosedError): - if retry_count < max_retries: - await asyncio.sleep(0.1 * (2 ** retry_count)) - retry_count += 1 - logger.info( - "Encountered connection issue during data collection of keys %s on worker %s. Retrying (%s / %s)", - keys, - worker, - retry_count, - max_retries, - ) - continue - else: - raise + return response finally: rpc.reuse(worker, comm) - return response + return await retry_operation(_get_data, operation="get_data_from_worker") job_counter = [0] From a70a080fec56acafc3d3a1510862265525633511 Mon Sep 17 00:00:00 2001 From: byjott Date: Fri, 6 Dec 2019 00:28:03 +0100 Subject: [PATCH 0594/1550] Connectionpool: don't hand out closed connections (#3301) Operating long-running dask clusters (sometimes, they run for many days without interruptions), we found that connection issues we observe are likely related to a behavior of the ConnectionPool of handing out connections that are already closed by the remote end (e.g. because a connection has been established some days ago, easily above connection timeouts). It also fixes another bug in the connecitonpool's bookkeeping of connections. --- distributed/comm/tcp.py | 14 ++++++++++++++ distributed/core.py | 14 ++++++++++++-- distributed/tests/test_client.py | 3 --- distributed/tests/test_core.py | 32 ++++++++++++++++++++++++++++++++ 4 files changed, 58 insertions(+), 5 deletions(-) diff --git a/distributed/comm/tcp.py b/distributed/comm/tcp.py index 40b6e8104b3..c2f3feeb704 100644 --- a/distributed/comm/tcp.py +++ b/distributed/comm/tcp.py @@ -132,6 +132,10 @@ def convert_stream_closed_error(obj, exc): raise CommClosedError("in %s: %s" % (obj, exc)) +def _do_nothing(): + pass + + class TCP(Comm): """ An established communication based on an underlying Tornado IOStream. @@ -154,6 +158,16 @@ def __init__(self, stream, local_addr, peer_addr, deserialize=True): stream.set_nodelay(True) set_tcp_timeout(stream) + # set a close callback, to make `self.stream.closed()` more reliable. + # Background: if `stream` is unused (e.g. because it's in `ConnectionPool.available`), + # the underlying fd is not watched for changes. In this case, even if the + # connection is actively closed by the remote end, `self.closed()` would still return `False`. + # Registering a closed callback will make tornado register the underlying fd + # for changes, and this would be reflected in `self.closed()` even without reading/writing. + # + # Use a global method (instead of a lambda) to avoid creating a reference + # to the local scope. + stream.set_close_callback(_do_nothing) self._read_extra() def _read_extra(self): diff --git a/distributed/core.py b/distributed/core.py index bf734070248..effc96831b2 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -840,6 +840,14 @@ def __init__( self._created = weakref.WeakSet() self._instances.add(self) + def _validate(self): + """ + Validate important invariants of this class + + Used only for testing / debugging + """ + assert self.semaphore._value == self.limit - self.open - self._n_connecting + @property def active(self): return sum(map(len, self.occupied.values())) @@ -868,9 +876,11 @@ async def connect(self, addr, timeout=None): """ available = self.available[addr] occupied = self.occupied[addr] - if available: + while available: comm = available.pop() - if not comm.closed(): + if comm.closed(): + self.semaphore.release() + else: occupied.add(comm) return comm diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 7e5442d01a5..2ec3a9f79d1 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -3532,9 +3532,6 @@ def test_reconnect(loop): assert time() < start + 5 sleep(0.01) - with pytest.raises(Exception): - c.nthreads() - assert x.status == "cancelled" with pytest.raises(CancelledError): x.result() diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index d3bcdaf8987..bda7bda2ad0 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -750,3 +750,35 @@ async def test_deserialize_error(): assert type(info.value) == Exception for c in str(info.value): assert c.isalpha() or c in "(',!)" # no crazy bytestrings + + +@pytest.mark.asyncio +async def test_connection_pool_detects_remote_close(): + server = Server({"ping": pingpong}) + await server.listen("tcp://") + + # open a connection, use it and give it back to the pool + p = ConnectionPool(limit=10) + conn = await p.connect(server.address) + await send_recv(conn, op="ping") + p.reuse(server.address, conn) + + # now close this connection on the *server* + assert len(server._comms) == 1 + server_conn = list(server._comms.keys())[0] + await server_conn.close() + + # give the ConnectionPool some time to realize that the connection is closed + await asyncio.sleep(0.1) + + # the connection pool should not hand out `conn` again + conn2 = await p.connect(server.address) + assert conn2 is not conn + p.reuse(server.address, conn2) + # check that `conn` has ben removed from the internal data structures + assert p.open == 1 and p.active == 0 + + # check connection pool invariants hold even after it detects a closed connection + # while creating conn2: + p._validate() + p.close() From f7a0d7a8e5a729e84f695b8efb593e590c0ec2f4 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 6 Dec 2019 09:53:17 -0800 Subject: [PATCH 0595/1550] Add dask-spec CLI tool (#3090) This launches a fully configured spec, rather than command line arguments. This is helpful for systems like SSH, that are trying to route Python arguments through the CLI. ## Example ```yaml # foo.yaml big: cls: dask.distributed.Worker opts: nanny: false, nthreads: 12 resources: {"FOO": 1} small: cls: dask.distributed.Worker opts: nanny: false, nthreads: 2 ``` ```bash $ dask-scheduler distributed.scheduler - INFO - Scheduler at: tcp://192.168.1.88:8786 $ python -m distributed.cli.dask_spec localhost:8786 --file foo.yaml distributed.worker - INFO - Start worker at: tcp://127.0.0.1:63400 distributed.worker - INFO - Listening to: tcp://127.0.0.1:63400 distributed.worker - INFO - Waiting to connect to: tcp://localhost:8786 distributed.worker - INFO - ------------------------------------------------- distributed.worker - INFO - Threads: 12 distributed.worker - INFO - Memory: 17.18 GB distributed.worker - INFO - Local Directory: /Users/mrocklin/workspace/distributed/dask-worker-space/worker-r762e6xl distributed.worker - INFO - ------------------------------------------------- distributed.worker - INFO - Start worker at: tcp://127.0.0.1:63401 distributed.worker - INFO - Listening to: tcp://127.0.0.1:63401 distributed.worker - INFO - Waiting to connect to: tcp://localhost:8786 distributed.worker - INFO - ------------------------------------------------- distributed.worker - INFO - Threads: 2 distributed.worker - INFO - Memory: 2.86 GB distributed.worker - INFO - Local Directory: /Users/mrocklin/workspace/distributed/dask-worker-space/worker-h8mjhnef distributed.worker - INFO - ------------------------------------------------- distributed.worker - INFO - Registered to: tcp://localhost:8786 distributed.worker - INFO - ------------------------------------------------- distributed.core - INFO - Starting established connection distributed.worker - INFO - Registered to: tcp://localhost:8786 distributed.worker - INFO - ------------------------------------------------- distributed.core - INFO - Starting established connection ``` As another example, here is starting a scheduler with two protocols ``` python -m distributed.cli.dask_spec --spec '{"cls": "dask.distributed.Scheduler", "opts": {"protocol": ["inproc", "tcp"]}}' ``` --- distributed/cli/dask_spec.py | 41 +++++++++ distributed/cli/tests/test_dask_spec.py | 89 +++++++++++++++++++ distributed/deploy/spec.py | 41 ++++++++- distributed/deploy/tests/test_spec_cluster.py | 20 ++++- distributed/utils.py | 17 ++++ 5 files changed, 202 insertions(+), 6 deletions(-) create mode 100644 distributed/cli/dask_spec.py create mode 100644 distributed/cli/tests/test_dask_spec.py diff --git a/distributed/cli/dask_spec.py b/distributed/cli/dask_spec.py new file mode 100644 index 00000000000..0a224e5b37c --- /dev/null +++ b/distributed/cli/dask_spec.py @@ -0,0 +1,41 @@ +import asyncio +import click +import json +import sys +import yaml + +from distributed.deploy.spec import run_spec + + +@click.command(context_settings=dict(ignore_unknown_options=True)) +@click.argument("args", nargs=-1) +@click.option("--spec", type=str, default="", help="") +@click.option("--spec-file", type=str, default=None, help="") +@click.version_option() +def main(args, spec: str, spec_file: str): + if spec and spec_file or not spec and not spec_file: + print("Must specify exactly one of --spec and --spec-file") + sys.exit(1) + _spec = {} + if spec_file: + with open(spec_file) as f: + _spec.update(yaml.safe_load(f)) + + if spec: + _spec.update(json.loads(spec)) + + if "cls" in _spec: # single worker spec + _spec = {_spec["opts"].get("name", 0): _spec} + + async def run(): + servers = await run_spec(_spec, *args) + try: + await asyncio.gather(*[w.finished() for w in servers.values()]) + except KeyboardInterrupt: + await asyncio.gather(*[w.close() for w in servers.values()]) + + asyncio.get_event_loop().run_until_complete(run()) + + +if __name__ == "__main__": + main() diff --git a/distributed/cli/tests/test_dask_spec.py b/distributed/cli/tests/test_dask_spec.py new file mode 100644 index 00000000000..a18b9fb383a --- /dev/null +++ b/distributed/cli/tests/test_dask_spec.py @@ -0,0 +1,89 @@ +import pytest +import sys +import yaml + +from distributed import Client +from distributed.utils_test import popen +from distributed.utils_test import cleanup # noqa: F401 + + +@pytest.mark.asyncio +async def test_text(cleanup): + with popen( + [ + sys.executable, + "-m", + "distributed.cli.dask_spec", + "--spec", + '{"cls": "dask.distributed.Scheduler", "opts": {"port": 9373}}', + ] + ) as sched: + with popen( + [ + sys.executable, + "-m", + "distributed.cli.dask_spec", + "tcp://localhost:9373", + "--spec", + '{"cls": "dask.distributed.Worker", "opts": {"nanny": false, "nthreads": 3, "name": "foo"}}', + ] + ) as w: + async with Client("tcp://localhost:9373", asynchronous=True) as client: + await client.wait_for_workers(1) + info = await client.scheduler.identity() + [w] = info["workers"].values() + assert w["name"] == "foo" + assert w["nthreads"] == 3 + + +@pytest.mark.asyncio +async def test_file(cleanup, tmp_path): + fn = str(tmp_path / "foo.yaml") + with open(fn, "w") as f: + yaml.dump( + { + "cls": "dask.distributed.Worker", + "opts": {"nanny": False, "nthreads": 3, "name": "foo"}, + }, + f, + ) + + with popen(["dask-scheduler", "--port", "9373", "--no-dashboard"]) as sched: + with popen( + [ + sys.executable, + "-m", + "distributed.cli.dask_spec", + "tcp://localhost:9373", + "--spec-file", + fn, + ] + ) as w: + async with Client("tcp://localhost:9373", asynchronous=True) as client: + await client.wait_for_workers(1) + info = await client.scheduler.identity() + [w] = info["workers"].values() + assert w["name"] == "foo" + assert w["nthreads"] == 3 + + +def test_errors(): + with popen( + [ + sys.executable, + "-m", + "distributed.cli.dask_spec", + "--spec", + '{"foo": "bar"}', + "--spec-file", + "foo.yaml", + ] + ) as proc: + line = proc.stdout.readline().decode() + assert "exactly one" in line + assert "--spec" in line and "--spec-file" in line + + with popen([sys.executable, "-m", "distributed.cli.dask_spec"]) as proc: + line = proc.stdout.readline().decode() + assert "exactly one" in line + assert "--spec" in line and "--spec-file" in line diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index 72cae01e85c..fb06057bb64 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -7,11 +7,19 @@ import dask from tornado import gen +from tornado.locks import Event from .adaptive import Adaptive from .cluster import Cluster from ..core import rpc, CommClosedError -from ..utils import LoopRunner, silence_logging, ignoring, parse_bytes, parse_timedelta +from ..utils import ( + LoopRunner, + silence_logging, + ignoring, + parse_bytes, + parse_timedelta, + import_term, +) from ..scheduler import Scheduler from ..security import Security @@ -33,6 +41,7 @@ def __init__(self, scheduler=None, name=None): self.external_address = None self.lock = asyncio.Lock() self.status = "created" + self._event_finished = Event() def __await__(self): async def _(): @@ -65,6 +74,11 @@ async def close(self): need to worry about shutting down gracefully """ self.status = "closed" + self._event_finished.set() + + async def finished(self): + """ Wait until the server has finished """ + await self._event_finished.wait() def __repr__(self): return "<%s: status=%s>" % (type(self).__name__, self.status) @@ -260,9 +274,11 @@ async def _start(self): else: services = {("dashboard", 8787): BokehScheduler} self.scheduler_spec = {"cls": Scheduler, "options": {"services": services}} - self.scheduler = self.scheduler_spec["cls"]( - **self.scheduler_spec.get("options", {}) - ) + + cls = self.scheduler_spec["cls"] + if isinstance(cls, str): + cls = import_term(cls) + self.scheduler = cls(**self.scheduler_spec.get("options", {})) self.status = "starting" self.scheduler = await self.scheduler @@ -307,6 +323,8 @@ async def _correct_state_internal(self): if "name" not in opts: opts = opts.copy() opts["name"] = name + if isinstance(cls, str): + cls = import_term(cls) worker = cls(self.scheduler.address, **opts) self._created.add(worker) workers.append(worker) @@ -566,6 +584,21 @@ def adapt( return super().adapt(*args, minimum=minimum, maximum=maximum, **kwargs) +async def run_spec(spec: dict, *args): + workers = {} + for k, d in spec.items(): + cls = d["cls"] + if isinstance(cls, str): + cls = import_term(cls) + workers[k] = cls(*args, **d.get("opts", {})) + + if workers: + await asyncio.gather(*workers.values()) + for w in workers.values(): + await w # for tornado gen.coroutine support + return workers + + @atexit.register def close_clusters(): for cluster in list(SpecCluster._instances): diff --git a/distributed/deploy/tests/test_spec_cluster.py b/distributed/deploy/tests/test_spec_cluster.py index 19e162ca67b..68642cda9d2 100644 --- a/distributed/deploy/tests/test_spec_cluster.py +++ b/distributed/deploy/tests/test_spec_cluster.py @@ -4,7 +4,7 @@ import dask from dask.distributed import SpecCluster, Worker, Client, Scheduler, Nanny -from distributed.deploy.spec import close_clusters, ProcessInterface +from distributed.deploy.spec import close_clusters, ProcessInterface, run_spec from distributed.metrics import time from distributed.utils_test import loop, cleanup # noqa: F401 from distributed.utils import is_valid_xml @@ -25,7 +25,7 @@ async def _(): worker_spec = { - 0: {"cls": Worker, "options": {"nthreads": 1}}, + 0: {"cls": "dask.distributed.Worker", "options": {"nthreads": 1}}, 1: {"cls": Worker, "options": {"nthreads": 2}}, "my-worker": {"cls": MyWorker, "options": {"nthreads": 3}}, } @@ -429,6 +429,8 @@ async def test_MultiWorker(cleanup): await cluster assert len(cluster.worker_spec) == 2 await client.wait_for_workers(4) + while len(cluster.scheduler_info["workers"]) < 4: + await asyncio.sleep(0.01) while "workers=4" not in repr(cluster): await asyncio.sleep(0.1) @@ -460,3 +462,17 @@ async def test_MultiWorker(cleanup): future = client.submit(lambda x: x + 1, 10) await future assert len(cluster.workers) == 1 + + +@pytest.mark.asyncio +async def test_run_spec(cleanup): + async with Scheduler(port=0) as s: + workers = await run_spec(worker_spec, s.address) + async with Client(s.address, asynchronous=True) as c: + await c.wait_for_workers(len(worker_spec)) + + await asyncio.gather(*[w.close() for w in workers.values()]) + + assert not s.workers + + await asyncio.gather(*[w.finished() for w in workers.values()]) diff --git a/distributed/utils.py b/distributed/utils.py index 978be4eae8a..22fae745d39 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -1366,6 +1366,23 @@ def is_valid_xml(text): weakref.finalize(_offload_executor, _offload_executor.shutdown) +def import_term(name: str): + """ Return the fully qualified term + + Examples + -------- + >>> import_term("math.sin") + + """ + try: + module_name, attr_name = name.rsplit(".", 1) + except ValueError: + return importlib.import_module(name) + + module = importlib.import_module(module_name) + return getattr(module, attr_name) + + async def offload(fn, *args, **kwargs): loop = asyncio.get_event_loop() return await loop.run_in_executor(_offload_executor, fn, *args, **kwargs) From 6c3bc6ef230354ffa7b7b65bb5b72e4bfc6f4f97 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Fri, 6 Dec 2019 16:07:23 -0600 Subject: [PATCH 0596/1550] bump version to 2.9.0 --- docs/source/changelog.rst | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index 21a57806533..3cff2b92ab5 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -1,6 +1,44 @@ Changelog ========= +2.9.0 - 2019-12-06 +------------------ + +- Add ``dask-spec`` CLI tool (:pr:`3090`) `Matthew Rocklin`_ +- Connectionpool: don't hand out closed connections (:pr:`3301`) `byjott`_ +- Retry operations on network issues (:pr:`3294`) `byjott`_ +- Skip ``Security.temporary()`` tests if cryptography not installed (:pr:`3302`) `James Bourbeau`_ +- Support multiple listeners in the scheduler (:pr:`3288`) `Matthew Rocklin`_ +- Updates RMM comment to the correct release (:pr:`3299`) `John Kirkham`_ +- Add title to ``performance_report`` (:pr:`3298`) `Matthew Rocklin`_ +- Forgot to fix slow test (:pr:`3297`) `Benjamin Zaitlen`_ +- Update ``SSHCluster`` docstring parameters (:pr:`3296`) `James Bourbeau`_ +- ``worker.close()`` awaits ``batched_stream.close()`` (:pr:`3291`) `Mads R. B. Kristensen`_ +- Fix asynchronous listener in UCX (:pr:`3292`) `Benjamin Zaitlen`_ +- Avoid repeatedly adding deps to already in memory stack (:pr:`3293`) `James Bourbeau`_ +- xfail ucx empty object typed dataframe (:pr:`3279`) `Benjamin Zaitlen`_ +- Fix ``distributed.wait`` documentation (:pr:`3289`) `Tom Rochette`_ +- Move Python 3 syntax tests into main tests (:pr:`3281`) `Matthew Rocklin`_ +- xfail ``test_workspace_concurrency`` for Python 3.6 (:pr:`3283`) `Matthew Rocklin`_ +- Add ``performance_report`` context manager for static report generation (:pr:`3282`) `Matthew Rocklin`_ +- Update function serialization caches with custom LRU class (:pr:`3260`) `James Bourbeau`_ +- Make ``Listener.start`` asynchronous (:pr:`3278`) `Matthew Rocklin`_ +- Remove ``dask-submit`` and ``dask-remote`` (:pr:`3280`) `Matthew Rocklin`_ +- Worker profile server (:pr:`3274`) `Matthew Rocklin`_ +- Improve bandwidth workers plot (:pr:`3273`) `Matthew Rocklin`_ +- Make profile coroutines consistent between ``Scheduler`` and ``Worker`` (:pr:`3277`) `Matthew Rocklin`_ +- Enable saving profile information from server threads (:pr:`3271`) `Matthew Rocklin`_ +- Remove memory use plot (:pr:`3269`) `Matthew Rocklin`_ +- Add offload size to configuration (:pr:`3270`) `Matthew Rocklin`_ +- Fix layout scaling on profile plots (:pr:`3268`) `Jacob Tomlinson`_ +- Set ``x_range`` in CPU plot based on the number of threads (:pr:`3266`) `Matthew Rocklin`_ +- Use base-2 values for byte-valued axes in dashboard (:pr:`3267`) `Matthew Rocklin`_ +- Robust gather in case of connection failures (:pr:`3246`) `fjetter`_ +- Use ``DeviceBuffer`` from newer RMM releases (:pr:`3261`) `John Kirkham`_ +- Fix dev requirements for pytest (:pr:`3264`) `Elliott Sales de Andrade`_ +- Add validate options to configuration (:pr:`3258`) `Matthew Rocklin`_ + + 2.8.1 - 2019-11-22 ------------------ @@ -1408,3 +1446,4 @@ significantly without many new features. .. _`Jed Brown`: https://github.com/jedbrown .. _`He Jia`: https://github.com/HerculesJack .. _`Jim Crist-Harif`: https://github.com/jcrist +.. _`fjetter`: https://github.com/fjetter From 07fba32f1f29136478edf676d608f9ef7e08bab2 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 6 Dec 2019 20:51:41 -0800 Subject: [PATCH 0597/1550] Make ConnectionPool.close asynchronous (#3304) Previously we would call a hard abort rather than waiting for comms to close more gracefully. --- distributed/client.py | 2 +- distributed/core.py | 17 ++++++++--------- distributed/nanny.py | 2 +- distributed/scheduler.py | 2 +- distributed/tests/test_core.py | 6 +++--- distributed/worker.py | 2 +- 6 files changed, 15 insertions(+), 16 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 027e24afeee..0cb9594f975 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -1278,7 +1278,7 @@ async def _close(self, fast=False): if self._start_arg is None: with ignoring(AttributeError): await self.cluster.close() - self.rpc.close() + await self.rpc.close() self.status = "closed" if _get_global_client() is self: _set_global_client(None) diff --git a/distributed/core.py b/distributed/core.py index effc96831b2..81cd7adf8e4 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -959,18 +959,17 @@ def remove(self, addr): IOLoop.current().add_callback(comm.close) self.semaphore.release() - def close(self): + async def close(self): """ - Close all communications abruptly. + Close all communications """ - for comms in self.available.values(): - for comm in comms: - comm.abort() - self.semaphore.release() - for comms in self.occupied.values(): - for comm in comms: + for d in [self.available, self.occupied]: + comms = [comm for comms in d.values() for comm in comms] + await asyncio.gather( + *[comm.close() for comm in comms], return_exceptions=True + ) + for _ in comms: self.semaphore.release() - comm.abort() for comm in self._created: IOLoop.current().add_callback(comm.abort) diff --git a/distributed/nanny.py b/distributed/nanny.py index dc2e8a3ea48..7cf3c2cbbaf 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -451,7 +451,7 @@ async def close(self, comm=None, timeout=5, report=None): except Exception: pass self.process = None - self.rpc.close() + await self.rpc.close() self.status = "closed" if comm: await comm.write("OK") diff --git a/distributed/scheduler.py b/distributed/scheduler.py index eeb7eb49732..b81b2545298 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1304,7 +1304,7 @@ async def close(self, comm=None, fast=False, close_workers=False): for comm in self.client_comms.values(): comm.abort() - self.rpc.close() + await self.rpc.close() self.status = "closed" self.stop() diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index bda7bda2ad0..d423c6ab6c3 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -562,7 +562,7 @@ async def ping(comm, delay=0.1): await asyncio.sleep(0.01) assert time() < start + 2 - rpc.close() + await rpc.close() @pytest.mark.asyncio @@ -612,7 +612,7 @@ async def ping(comm, delay=0.01): await asyncio.gather(*[rpc(s.address).ping() for s in servers]) assert rpc.active == 0 - rpc.close() + await rpc.close() @pytest.mark.asyncio @@ -651,7 +651,7 @@ async def ping(comm, delay=0.01): rpc.remove(serv.address) rpc.reuse(serv.address, comm) - rpc.close() + await rpc.close() @pytest.mark.asyncio diff --git a/distributed/worker.py b/distributed/worker.py index 0326cbd157a..1e36728ca78 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1102,7 +1102,7 @@ async def close( self.actor_executor.shutdown(wait=executor_wait, timeout=timeout) self.stop() - self.rpc.close() + await self.rpc.close() self.status = "closed" await ServerNode.close(self) From cf0767536b2b2e17afd6c849c687695c83200979 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sun, 8 Dec 2019 14:45:49 -0800 Subject: [PATCH 0598/1550] Log address for each of the Scheduler listerners (#3306) --- distributed/scheduler.py | 3 ++- distributed/tests/test_scheduler.py | 36 +++++++++++++++++------------ 2 files changed, 23 insertions(+), 16 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index b81b2545298..20d0c4ed239 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1230,7 +1230,8 @@ async def start(self): self.start_services(listen_ip) self.status = "running" - logger.info(" Scheduler at: %25s", self.address) + for listener in self.listeners: + logger.info(" Scheduler at: %25s", listener.contact_address) for k, v in self.services.items(): logger.info("%11s at: %25s", k, "%s:%d" % (listen_ip, v.port)) diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 7bf3f456085..f9087a3029a 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -5,6 +5,7 @@ from datetime import timedelta import json import operator +import re import sys from time import sleep from unittest import mock @@ -1859,18 +1860,23 @@ def reducer(x, y): @pytest.mark.asyncio async def test_multiple_listeners(cleanup): - async with Scheduler(port=0, protocol=["inproc", "tcp"]) as s: - async with Worker(s.listeners[0].contact_address) as a: - async with Worker(s.listeners[1].contact_address) as b: - assert a.address.startswith("inproc") - assert a.scheduler.address.startswith("inproc") - assert b.address.startswith("tcp") - assert b.scheduler.address.startswith("tcp") - - async with Client(s.address, asynchronous=True) as c: - futures = c.map(inc, range(20)) - await wait(futures) - - # Force inter-worker communication both ways - await c.submit(sum, futures, workers=[a.address]) - await c.submit(len, futures, workers=[b.address]) + with captured_logger(logging.getLogger("distributed.scheduler")) as log: + async with Scheduler(port=0, protocol=["inproc", "tcp"]) as s: + async with Worker(s.listeners[0].contact_address) as a: + async with Worker(s.listeners[1].contact_address) as b: + assert a.address.startswith("inproc") + assert a.scheduler.address.startswith("inproc") + assert b.address.startswith("tcp") + assert b.scheduler.address.startswith("tcp") + + async with Client(s.address, asynchronous=True) as c: + futures = c.map(inc, range(20)) + await wait(futures) + + # Force inter-worker communication both ways + await c.submit(sum, futures, workers=[a.address]) + await c.submit(len, futures, workers=[b.address]) + + log = log.getvalue() + assert re.search(r"Scheduler at:\s*tcp://", log) + assert re.search(r"Scheduler at:\s*inproc://", log) From e591f322cfef90663e3b45b7154e96f9cf5adf1e Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 9 Dec 2019 07:50:31 -0800 Subject: [PATCH 0599/1550] Add lock to scheduler for sensitive operations (#3259) Some operations like retiring workers or rebalancing data shouldn't happen concurrently. Here we add an asynchronous lock around these operations in order to protect them from each other. --- distributed/scheduler.py | 470 ++++++++++++++++++++------------------- distributed/utils.py | 17 ++ 2 files changed, 263 insertions(+), 224 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 20d0c4ed239..275058d5221 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -57,6 +57,7 @@ parse_bytes, PeriodicCallback, shutting_down, + empty_context, tmpfile, ) from .utils_comm import scatter_to_workers, gather_from_workers, retry_operation @@ -885,6 +886,7 @@ def __init__( else: self.idle_timeout = None self.time_started = time() + self._lock = asyncio.Lock() self.bandwidth = parse_bytes(dask.config.get("distributed.scheduler.bandwidth")) self.bandwidth_workers = defaultdict(float) self.bandwidth_types = defaultdict(float) @@ -2722,118 +2724,124 @@ async def rebalance(self, comm=None, keys=None, workers=None): average expected load. """ with log_errors(): - if keys: - tasks = {self.tasks[k] for k in keys} - missing_data = [ts.key for ts in tasks if not ts.who_has] - if missing_data: - return {"status": "missing-data", "keys": missing_data} - else: - tasks = set(self.tasks.values()) - - if workers: - workers = {self.workers[w] for w in workers} - workers_by_task = {ts: ts.who_has & workers for ts in tasks} - else: - workers = set(self.workers.values()) - workers_by_task = {ts: ts.who_has for ts in tasks} + async with self._lock: + if keys: + tasks = {self.tasks[k] for k in keys} + missing_data = [ts.key for ts in tasks if not ts.who_has] + if missing_data: + return {"status": "missing-data", "keys": missing_data} + else: + tasks = set(self.tasks.values()) - tasks_by_worker = {ws: set() for ws in workers} + if workers: + workers = {self.workers[w] for w in workers} + workers_by_task = {ts: ts.who_has & workers for ts in tasks} + else: + workers = set(self.workers.values()) + workers_by_task = {ts: ts.who_has for ts in tasks} - for k, v in workers_by_task.items(): - for vv in v: - tasks_by_worker[vv].add(k) + tasks_by_worker = {ws: set() for ws in workers} - worker_bytes = { - ws: sum(ts.get_nbytes() for ts in v) - for ws, v in tasks_by_worker.items() - } + for k, v in workers_by_task.items(): + for vv in v: + tasks_by_worker[vv].add(k) - avg = sum(worker_bytes.values()) / len(worker_bytes) + worker_bytes = { + ws: sum(ts.get_nbytes() for ts in v) + for ws, v in tasks_by_worker.items() + } - sorted_workers = list( - map(first, sorted(worker_bytes.items(), key=second, reverse=True)) - ) + avg = sum(worker_bytes.values()) / len(worker_bytes) - recipients = iter(reversed(sorted_workers)) - recipient = next(recipients) - msgs = [] # (sender, recipient, key) - for sender in sorted_workers[: len(workers) // 2]: - sender_keys = {ts: ts.get_nbytes() for ts in tasks_by_worker[sender]} - sender_keys = iter( - sorted(sender_keys.items(), key=second, reverse=True) + sorted_workers = list( + map(first, sorted(worker_bytes.items(), key=second, reverse=True)) ) - try: - while worker_bytes[sender] > avg: - while ( - worker_bytes[recipient] < avg and worker_bytes[sender] > avg - ): - ts, nb = next(sender_keys) - if ts not in tasks_by_worker[recipient]: - tasks_by_worker[recipient].add(ts) - # tasks_by_worker[sender].remove(ts) - msgs.append((sender, recipient, ts)) - worker_bytes[sender] -= nb - worker_bytes[recipient] += nb - if worker_bytes[sender] > avg: - recipient = next(recipients) - except StopIteration: - break + recipients = iter(reversed(sorted_workers)) + recipient = next(recipients) + msgs = [] # (sender, recipient, key) + for sender in sorted_workers[: len(workers) // 2]: + sender_keys = { + ts: ts.get_nbytes() for ts in tasks_by_worker[sender] + } + sender_keys = iter( + sorted(sender_keys.items(), key=second, reverse=True) + ) + + try: + while worker_bytes[sender] > avg: + while ( + worker_bytes[recipient] < avg + and worker_bytes[sender] > avg + ): + ts, nb = next(sender_keys) + if ts not in tasks_by_worker[recipient]: + tasks_by_worker[recipient].add(ts) + # tasks_by_worker[sender].remove(ts) + msgs.append((sender, recipient, ts)) + worker_bytes[sender] -= nb + worker_bytes[recipient] += nb + if worker_bytes[sender] > avg: + recipient = next(recipients) + except StopIteration: + break - to_recipients = defaultdict(lambda: defaultdict(list)) - to_senders = defaultdict(list) - for sender, recipient, ts in msgs: - to_recipients[recipient.address][ts.key].append(sender.address) - to_senders[sender.address].append(ts.key) + to_recipients = defaultdict(lambda: defaultdict(list)) + to_senders = defaultdict(list) + for sender, recipient, ts in msgs: + to_recipients[recipient.address][ts.key].append(sender.address) + to_senders[sender.address].append(ts.key) - result = await asyncio.gather( - *( - retry_operation(self.rpc(addr=r).gather, who_has=v) - for r, v in to_recipients.items() + result = await asyncio.gather( + *( + retry_operation(self.rpc(addr=r).gather, who_has=v) + for r, v in to_recipients.items() + ) ) - ) - for r, v in to_recipients.items(): - self.log_event(r, {"action": "rebalance", "who_has": v}) + for r, v in to_recipients.items(): + self.log_event(r, {"action": "rebalance", "who_has": v}) - self.log_event( - "all", - { - "action": "rebalance", - "total-keys": len(tasks), - "senders": valmap(len, to_senders), - "recipients": valmap(len, to_recipients), - "moved_keys": len(msgs), - }, - ) + self.log_event( + "all", + { + "action": "rebalance", + "total-keys": len(tasks), + "senders": valmap(len, to_senders), + "recipients": valmap(len, to_recipients), + "moved_keys": len(msgs), + }, + ) - if not all(r["status"] == "OK" for r in result): - return { - "status": "missing-data", - "keys": sum([r["keys"] for r in result if "keys" in r], []), - } + if not all(r["status"] == "OK" for r in result): + return { + "status": "missing-data", + "keys": sum([r["keys"] for r in result if "keys" in r], []), + } - for sender, recipient, ts in msgs: - assert ts.state == "memory" - ts.who_has.add(recipient) - recipient.has_what.add(ts) - recipient.nbytes += ts.get_nbytes() - self.log.append( - ("rebalance", ts.key, time(), sender.address, recipient.address) - ) + for sender, recipient, ts in msgs: + assert ts.state == "memory" + ts.who_has.add(recipient) + recipient.has_what.add(ts) + recipient.nbytes += ts.get_nbytes() + self.log.append( + ("rebalance", ts.key, time(), sender.address, recipient.address) + ) - await asyncio.gather( - *( - retry_operation(self.rpc(addr=r).delete_data, keys=v, report=False) - for r, v in to_senders.items() + await asyncio.gather( + *( + retry_operation( + self.rpc(addr=r).delete_data, keys=v, report=False + ) + for r, v in to_senders.items() + ) ) - ) - for sender, recipient, ts in msgs: - ts.who_has.remove(sender) - sender.has_what.remove(ts) - sender.nbytes -= ts.get_nbytes() + for sender, recipient, ts in msgs: + ts.who_has.remove(sender) + sender.has_what.remove(ts) + sender.nbytes -= ts.get_nbytes() - return {"status": "OK"} + return {"status": "OK"} async def replicate( self, @@ -2843,6 +2851,7 @@ async def replicate( workers=None, branching_factor=2, delete=True, + lock=True, ): """ Replicate data throughout cluster @@ -2866,89 +2875,96 @@ async def replicate( Scheduler.rebalance """ assert branching_factor > 0 + async with self._lock if lock else empty_context: + workers = {self.workers[w] for w in self.workers_list(workers)} + if n is None: + n = len(workers) + else: + n = min(n, len(workers)) + if n == 0: + raise ValueError("Can not use replicate to delete data") + + tasks = {self.tasks[k] for k in keys} + missing_data = [ts.key for ts in tasks if not ts.who_has] + if missing_data: + return {"status": "missing-data", "keys": missing_data} + + # Delete extraneous data + if delete: + del_worker_tasks = defaultdict(set) + for ts in tasks: + del_candidates = ts.who_has & workers + if len(del_candidates) > n: + for ws in random.sample( + del_candidates, len(del_candidates) - n + ): + del_worker_tasks[ws].add(ts) - workers = {self.workers[w] for w in self.workers_list(workers)} - if n is None: - n = len(workers) - else: - n = min(n, len(workers)) - if n == 0: - raise ValueError("Can not use replicate to delete data") - - tasks = {self.tasks[k] for k in keys} - missing_data = [ts.key for ts in tasks if not ts.who_has] - if missing_data: - return {"status": "missing-data", "keys": missing_data} - - # Delete extraneous data - if delete: - del_worker_tasks = defaultdict(set) - for ts in tasks: - del_candidates = ts.who_has & workers - if len(del_candidates) > n: - for ws in random.sample(del_candidates, len(del_candidates) - n): - del_worker_tasks[ws].add(ts) - - await asyncio.gather( - *( - retry_operation( - self.rpc(addr=ws.address).delete_data, - keys=[ts.key for ts in tasks], - report=False, + await asyncio.gather( + *( + retry_operation( + self.rpc(addr=ws.address).delete_data, + keys=[ts.key for ts in tasks], + report=False, + ) + for ws, tasks in del_worker_tasks.items() ) - for ws, tasks in del_worker_tasks.items() ) - ) - for ws, tasks in del_worker_tasks.items(): - ws.has_what -= tasks - for ts in tasks: - ts.who_has.remove(ws) - ws.nbytes -= ts.get_nbytes() - self.log_event( - ws.address, - {"action": "replicate-remove", "keys": [ts.key for ts in tasks]}, - ) + for ws, tasks in del_worker_tasks.items(): + ws.has_what -= tasks + for ts in tasks: + ts.who_has.remove(ws) + ws.nbytes -= ts.get_nbytes() + self.log_event( + ws.address, + { + "action": "replicate-remove", + "keys": [ts.key for ts in tasks], + }, + ) - # Copy not-yet-filled data - while tasks: - gathers = defaultdict(dict) - for ts in list(tasks): - n_missing = n - len(ts.who_has & workers) - if n_missing <= 0: - # Already replicated enough - tasks.remove(ts) - continue + # Copy not-yet-filled data + while tasks: + gathers = defaultdict(dict) + for ts in list(tasks): + n_missing = n - len(ts.who_has & workers) + if n_missing <= 0: + # Already replicated enough + tasks.remove(ts) + continue - count = min(n_missing, branching_factor * len(ts.who_has)) - assert count > 0 + count = min(n_missing, branching_factor * len(ts.who_has)) + assert count > 0 - for ws in random.sample(workers - ts.who_has, count): - gathers[ws.address][ts.key] = [wws.address for wws in ts.who_has] + for ws in random.sample(workers - ts.who_has, count): + gathers[ws.address][ts.key] = [ + wws.address for wws in ts.who_has + ] - results = await asyncio.gather( - *( - retry_operation(self.rpc(addr=w).gather, who_has=who_has) - for w, who_has in gathers.items() + results = await asyncio.gather( + *( + retry_operation(self.rpc(addr=w).gather, who_has=who_has) + for w, who_has in gathers.items() + ) ) - ) - for w, v in zip(gathers, results): - if v["status"] == "OK": - self.add_keys(worker=w, keys=list(gathers[w])) - else: - logger.warning("Communication failed during replication: %s", v) + for w, v in zip(gathers, results): + if v["status"] == "OK": + self.add_keys(worker=w, keys=list(gathers[w])) + else: + logger.warning("Communication failed during replication: %s", v) - self.log_event(w, {"action": "replicate-add", "keys": gathers[w]}) + self.log_event(w, {"action": "replicate-add", "keys": gathers[w]}) - self.log_event( - "all", - { - "action": "replicate", - "workers": list(workers), - "key-count": len(keys), - "branching-factor": branching_factor, - }, - ) + self.log_event( + "all", + { + "action": "replicate", + "workers": list(workers), + "key-count": len(keys), + "branching-factor": branching_factor, + }, + ) def workers_to_close( self, @@ -3090,6 +3106,7 @@ async def retire_workers( remove=True, close_workers=False, names=None, + lock=True, **kwargs ): """ Gracefully retire workers from cluster @@ -3122,68 +3139,73 @@ async def retire_workers( Scheduler.workers_to_close """ with log_errors(): - if names is not None: - if names: - logger.info("Retire worker names %s", names) - names = set(map(str, names)) - workers = [ - ws.address for ws in self.workers.values() if str(ws.name) in names - ] - if workers is None: - while True: - try: - workers = self.workers_to_close(**kwargs) - if workers: - workers = await self.retire_workers( - workers=workers, - remove=remove, - close_workers=close_workers, - ) - return workers - except KeyError: # keys left during replicate - pass - workers = {self.workers[w] for w in workers if w in self.workers} - if not workers: - return [] - logger.info("Retire workers %s", workers) - - # Keys orphaned by retiring those workers - keys = set.union(*[w.has_what for w in workers]) - keys = {ts.key for ts in keys if ts.who_has.issubset(workers)} - - other_workers = set(self.workers.values()) - workers - if keys: - if other_workers: - logger.info("Moving %d keys to other workers", len(keys)) - await self.replicate( - keys=keys, - workers=[ws.address for ws in other_workers], - n=1, - delete=False, - ) - else: + async with self._lock if lock else empty_context: + if names is not None: + if names: + logger.info("Retire worker names %s", names) + names = set(map(str, names)) + workers = [ + ws.address + for ws in self.workers.values() + if str(ws.name) in names + ] + if workers is None: + while True: + try: + workers = self.workers_to_close(**kwargs) + if workers: + workers = await self.retire_workers( + workers=workers, + remove=remove, + close_workers=close_workers, + lock=False, + ) + return workers + except KeyError: # keys left during replicate + pass + workers = {self.workers[w] for w in workers if w in self.workers} + if not workers: return [] + logger.info("Retire workers %s", workers) + + # Keys orphaned by retiring those workers + keys = set.union(*[w.has_what for w in workers]) + keys = {ts.key for ts in keys if ts.who_has.issubset(workers)} + + other_workers = set(self.workers.values()) - workers + if keys: + if other_workers: + logger.info("Moving %d keys to other workers", len(keys)) + await self.replicate( + keys=keys, + workers=[ws.address for ws in other_workers], + n=1, + delete=False, + lock=False, + ) + else: + return [] - worker_keys = {ws.address: ws.identity() for ws in workers} - if close_workers and worker_keys: - await asyncio.gather( - *[self.close_worker(worker=w, safe=True) for w in worker_keys] - ) - if remove: - for w in worker_keys: - self.remove_worker(address=w, safe=True) + worker_keys = {ws.address: ws.identity() for ws in workers} + if close_workers and worker_keys: + await asyncio.gather( + *[self.close_worker(worker=w, safe=True) for w in worker_keys] + ) + if remove: + for w in worker_keys: + self.remove_worker(address=w, safe=True) - self.log_event( - "all", - { - "action": "retire-workers", - "workers": worker_keys, - "moved-keys": len(keys), - }, - ) - self.log_event(list(worker_keys), {"action": "retired"}) + self.log_event( + "all", + { + "action": "retire-workers", + "workers": worker_keys, + "moved-keys": len(keys), + }, + ) + self.log_event(list(worker_keys), {"action": "retired"}) - return worker_keys + return worker_keys def add_keys(self, comm=None, worker=None, keys=()): """ diff --git a/distributed/utils.py b/distributed/utils.py index 22fae745d39..cfb0ee921d6 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -1418,6 +1418,23 @@ def deserialize_for_cli(data): return json.loads(base64.urlsafe_b64decode(data.encode()).decode()) +class EmptyContext: + def __enter__(self): + pass + + def __exit__(self, *args): + pass + + async def __aenter__(self): + pass + + async def __aexit__(self, *args): + pass + + +empty_context = EmptyContext() + + class LRU(UserDict): """ Limited size mapping, evicting the least recently looked-up key when full """ From 3116655cd551fe84b9759220a6f34c299a8c36a6 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Mon, 9 Dec 2019 16:14:17 -0600 Subject: [PATCH 0600/1550] Clean up flaky test_nanny_throttle (#3295) --- distributed/tests/test_nanny.py | 56 +++++++++++---------------------- 1 file changed, 18 insertions(+), 38 deletions(-) diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index cacd98477e0..68f207a51ce 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -305,54 +305,34 @@ def leak(): @gen_cluster( nthreads=[("127.0.0.1", 1)] * 8, client=True, - Worker=Nanny, - worker_kwargs={"memory_limit": 2e8}, - timeout=20, + Worker=Worker, clean_kwargs={"threads": False}, ) -async def test_nanny_throttle(c, s, *workers): - # Verify that get_data requests are throttled when the worker - # with the data is at high-memory by - # 1. Allocation some data on a worker - # 2. Pausing that worker - # 3. Requesting data from that worker from many other workers - a = workers[0] - proc = a.process.pid - size = 1000 - - def data(size): - return b"0" * size +async def test_throttle_outgoing_connections(c, s, a, *workers): + # But a bunch of small data on worker a + await c.run(lambda: logging.getLogger("distributed.worker").setLevel(logging.DEBUG)) + remote_data = c.map( + lambda x: b"0" * 10000, range(10), pure=False, workers=[a.address] + ) + await wait(remote_data) - def patch(dask_worker): + def pause(dask_worker): # Patch paused and memory_monitor on the one worker # This is is very fragile, since a refactor of memory_monitor to # remove _memory_monitoring will break this test. dask_worker._memory_monitoring = True dask_worker.paused = True + dask_worker.outgoing_current_count = 2 - def check(dask_worker): - return dask_worker.paused - - futures = [ - c.submit(data, size, workers=[a.worker_address], pure=False) for i in range(4) + await c.run(pause, workers=[a.address]) + requests = [ + await a.get_data(await w.rpc.connect(w.address), keys=[f.key], who=w.address) + for w in workers + for f in remote_data ] - await wait(futures) - await c.run(patch, workers=[a.worker_address]) - paused = await c.run(check, workers=[a.worker_address]) - assert paused[a.worker_address] - - await c.run(lambda: logging.getLogger("distributed.worker").setLevel(logging.DEBUG)) - # Cluster is in the correct state, now for the test. - n = len(workers) - result = c.map( - lambda x, i: x[i], - [futures[0]] * n, - range(n), - workers=[w.worker_address for w in workers[1:]], - ) - await result[0] - wlogs = await c.get_worker_logs(workers=[a.worker_address]) - wlogs = "\n".join(x[1] for x in wlogs[a.worker_address]) + await wait(requests) + wlogs = await c.get_worker_logs(workers=[a.address]) + wlogs = "\n".join(x[1] for x in wlogs[a.address]) assert "throttling" in wlogs.lower() From 3151f09e883e9e93f8e06dd0ed9db266172ba479 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 10 Dec 2019 08:27:22 -0800 Subject: [PATCH 0601/1550] Use hostname as default IP address rather than localhost (#3308) Previously if we couldn't connect to the scheduler we used localhost this made sense for testing, but probably doesn't make sense for operations. --- distributed/utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/distributed/utils.py b/distributed/utils.py index cfb0ee921d6..26c503205aa 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -118,7 +118,7 @@ def get_fileno_limit(): @toolz.memoize -def _get_ip(host, port, family, default): +def _get_ip(host, port, family): # By using a UDP socket, we don't actually try to connect but # simply select the local address through which *host* is reachable. sock = socket.socket(family, socket.SOCK_DGRAM) @@ -130,10 +130,10 @@ def _get_ip(host, port, family, default): # XXX Should first try getaddrinfo() on socket.gethostname() and getfqdn() warnings.warn( "Couldn't detect a suitable IP address for " - "reaching %r, defaulting to %r: %s" % (host, default, e), + "reaching %r, defaulting to hostname: %s" % (host, e), RuntimeWarning, ) - return default + return socket.gethostname() finally: sock.close() @@ -145,14 +145,14 @@ def get_ip(host="8.8.8.8", port=80): *host* defaults to a well-known Internet host (one of Google's public DNS servers). """ - return _get_ip(host, port, family=socket.AF_INET, default="127.0.0.1") + return _get_ip(host, port, family=socket.AF_INET) def get_ipv6(host="2001:4860:4860::8888", port=80): """ The same as get_ip(), but for IPv6. """ - return _get_ip(host, port, family=socket.AF_INET6, default="::1") + return _get_ip(host, port, family=socket.AF_INET6) def get_ip_interface(ifname): From b92782d0eb8fecedeaf650f188e65ea4ef40de32 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Tue, 10 Dec 2019 10:28:44 -0600 Subject: [PATCH 0602/1550] Update inlining Futures in task graph in Client._graph_to_futures (#3303) * Use pack_data to inline Futures * Add subs_mutliple * Check key mapping for keys to substitue * Avoid unnecessary hash attempts --- distributed/client.py | 7 ++++-- distributed/tests/test_utils_comm.py | 16 +++++++++++++- distributed/utils_comm.py | 32 ++++++++++++++++++++++++++++ 3 files changed, 52 insertions(+), 3 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 0cb9594f975..5ff715281ef 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -51,6 +51,7 @@ WrappedKey, unpack_remotedata, pack_data, + subs_multiple, scatter_to_workers, gather_from_workers, retry_operation, @@ -2435,10 +2436,12 @@ def _graph_to_futures( futures = {key: Future(key, self, inform=False) for key in keyset} values = { - k for k, v in dsk.items() if isinstance(v, Future) and k not in keyset + k: v + for k, v in dsk.items() + if isinstance(v, Future) and k not in keyset } if values: - dsk = dask.optimization.inline(dsk, keys=values) + dsk = subs_multiple(dsk, values) d = {k: unpack_remotedata(v, byte_keys=True) for k, v in dsk.items()} extra_futures = set.union(*[v[1] for v in d.values()]) if d else set() diff --git a/distributed/tests/test_utils_comm.py b/distributed/tests/test_utils_comm.py index 3f26fae623a..2d0159a2d3d 100644 --- a/distributed/tests/test_utils_comm.py +++ b/distributed/tests/test_utils_comm.py @@ -1,7 +1,7 @@ from distributed.core import ConnectionPool from distributed.comm import Comm from distributed.utils_test import gen_cluster, loop # noqa: F401 -from distributed.utils_comm import pack_data, gather_from_workers, retry +from distributed.utils_comm import pack_data, subs_multiple, gather_from_workers, retry from unittest import mock @@ -15,6 +15,20 @@ def test_pack_data(): assert pack_data({"a": ["x"], "b": "y"}, data) == {"a": [1], "b": "y"} +def test_subs_multiple(): + data = {"x": 1, "y": 2} + assert subs_multiple((sum, [0, "x", "y", "z"]), data) == (sum, [0, 1, 2, "z"]) + assert subs_multiple((sum, [0, ["x", "y", "z"]]), data) == (sum, [0, [1, 2, "z"]]) + + dsk = {"a": (sum, ["x", "y"])} + assert subs_multiple(dsk, data) == {"a": (sum, [1, 2])} + + # Tuple key + data = {"x": 1, ("y", 0): 2} + dsk = {"a": (sum, ["x", ("y", 0)])} + assert subs_multiple(dsk, data) == {"a": (sum, [1, 2])} + + @gen_cluster(client=True) def test_gather_from_workers_permissive(c, s, a, b): rpc = ConnectionPool() diff --git a/distributed/utils_comm.py b/distributed/utils_comm.py index cb614602f7b..792e73227a9 100644 --- a/distributed/utils_comm.py +++ b/distributed/utils_comm.py @@ -280,6 +280,38 @@ def pack_data(o, d, key_types=object): return o +def subs_multiple(o, d): + """ Perform substitutions on a tasks + + Parameters + ---------- + o: + Core data structures containing literals and keys + d: dict + Mapping of keys to values + + Examples + -------- + >>> dsk = {"a": (sum, ["x", 2])} + >>> data = {"x": 1} + >>> subs_multiple(dsk, data) # doctest: +SKIP + {'a': (sum, [1, 2])} + + """ + typ = type(o) + if typ is tuple and o and callable(o[0]): # istask(o) + return (o[0],) + tuple(subs_multiple(i, d) for i in o[1:]) + elif typ is list: + return [subs_multiple(i, d) for i in o] + elif typ is dict: + return {k: subs_multiple(v, d) for (k, v) in o.items()} + else: + try: + return d.get(o, o) + except TypeError: + return o + + retry_count = dask.config.get("distributed.comm.retry.count") retry_delay_min = parse_timedelta( dask.config.get("distributed.comm.retry.delay.min"), default="s" From 246eb9b4cf62430b3b9bc9cc1a64534bbd726730 Mon Sep 17 00:00:00 2001 From: fjetter Date: Tue, 10 Dec 2019 18:55:41 +0100 Subject: [PATCH 0603/1550] Update latencies with heartbeats (#3310) --- distributed/tests/test_worker.py | 12 ++++++++++++ distributed/worker.py | 9 ++++++++- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index f67701f671a..2cf316ccbca 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1586,3 +1586,15 @@ def bad_startup(w): w = await Worker(s.address, startup_information={"bad": bad_startup}) except Exception: pytest.fail("Startup exception was raised") + + +@pytest.mark.asyncio +async def test_update_latency(cleanup): + async with await Scheduler() as s: + async with await Worker(s.address) as w: + original = w.latency + await w.heartbeat() + assert original != w.latency + + if w.digests is not None: + assert w.digests["latency"].size() > 0 diff --git a/distributed/worker.py b/distributed/worker.py index 1e36728ca78..751fabce2b1 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -828,7 +828,7 @@ async def _register_with_scheduler(self): response = await future _end = time() middle = (_start + _end) / 2 - self.latency = (_end - start) * 0.05 + self.latency * 0.95 + self._update_latency(_end - start) self.scheduler_delay = response["time"] - middle self.status = "running" break @@ -862,6 +862,11 @@ async def _register_with_scheduler(self): self.periodic_callbacks["heartbeat"].start() self.loop.add_callback(self.handle_scheduler, comm) + def _update_latency(self, latency): + self.latency = latency * 0.05 + self.latency * 0.95 + if self.digests is not None: + self.digests["latency"].add(latency) + async def heartbeat(self): if not self.heartbeat_active: self.heartbeat_active = True @@ -877,6 +882,8 @@ async def heartbeat(self): end = time() middle = (start + end) / 2 + self._update_latency(end - start) + if response["status"] == "missing": await self._register_with_scheduler() return From 61238adb12d21985299fccd0f026bd42366d163f Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 11 Dec 2019 08:49:46 -0800 Subject: [PATCH 0604/1550] Add TaskGroup and TaskPrefix scheduler state (#3262) This aggregates task information into hierarchies. This should be helpful in improving both diagnostics and dashboards, particularly when the number of tasks increases. --- distributed/dashboard/components/scheduler.py | 33 ++- distributed/deploy/tests/test_adaptive.py | 126 ++++---- distributed/diagnostics/progress.py | 4 +- distributed/distributed.yaml | 4 +- distributed/scheduler.py | 276 +++++++++++++++--- distributed/tests/test_client.py | 24 +- distributed/tests/test_nanny.py | 17 +- distributed/tests/test_scheduler.py | 175 +++++++---- distributed/tests/test_steal.py | 40 ++- distributed/utils.py | 20 +- 10 files changed, 506 insertions(+), 213 deletions(-) diff --git a/distributed/dashboard/components/scheduler.py b/distributed/dashboard/components/scheduler.py index 2b27c708111..ad7cecea024 100644 --- a/distributed/dashboard/components/scheduler.py +++ b/distributed/dashboard/components/scheduler.py @@ -60,14 +60,13 @@ from distributed.metrics import time from distributed.utils import log_errors, format_time, parse_timedelta from distributed.diagnostics.progress_stream import color_of, progress_quads -from distributed.diagnostics.progress import AllProgress from distributed.diagnostics.graph_layout import GraphLayout from distributed.diagnostics.task_stream import TaskStreamPlugin try: - from cytoolz.curried import map, concat, groupby, valmap + from cytoolz.curried import map, concat, groupby except ImportError: - from toolz.curried import map, concat, groupby, valmap + from toolz.curried import map, concat, groupby if dask.config.get("distributed.dashboard.export-tool"): from distributed.dashboard.export_tool import ExportTool @@ -1283,11 +1282,6 @@ class TaskProgress(DashboardComponent): def __init__(self, scheduler, **kwargs): self.scheduler = scheduler - ps = [p for p in scheduler.plugins if isinstance(p, AllProgress)] - if ps: - self.plugin = ps[0] - else: - self.plugin = AllProgress(scheduler) data = progress_quads( dict(all={}, memory={}, erred={}, released={}, processing={}) @@ -1415,9 +1409,26 @@ def __init__(self, scheduler, **kwargs): @without_property_validation def update(self): with log_errors(): - state = {"all": valmap(len, self.plugin.all), "nbytes": self.plugin.nbytes} - for k in ["memory", "erred", "released", "processing", "waiting"]: - state[k] = valmap(len, self.plugin.state[k]) + state = { + "memory": {}, + "erred": {}, + "released": {}, + "processing": {}, + "waiting": {}, + } + + for tp in self.scheduler.task_prefixes.values(): + if any(tp.active_states.values()): + state["memory"][tp.name] = tp.active_states["memory"] + state["erred"][tp.name] = tp.active_states["erred"] + state["released"][tp.name] = tp.active_states["released"] + state["processing"][tp.name] = tp.active_states["processing"] + state["waiting"][tp.name] = tp.active_states["waiting"] + + state["all"] = { + k: sum(v[k] for v in state.values()) for k in state["memory"] + } + if not state["all"] and not len(self.source.data["all"]): return diff --git a/distributed/deploy/tests/test_adaptive.py b/distributed/deploy/tests/test_adaptive.py index af198747822..90f56c4bfde 100644 --- a/distributed/deploy/tests/test_adaptive.py +++ b/distributed/deploy/tests/test_adaptive.py @@ -1,8 +1,8 @@ from time import sleep +import dask import pytest from tornado import gen -from tornado.ioloop import IOLoop from distributed import Client, wait, Adaptive, LocalCluster, SpecCluster, Worker from distributed.utils_test import gen_test, slowinc, clean @@ -26,21 +26,25 @@ def scale_up(self, n, **kwargs): def scale_down(self, workers): assert False - async with TestCluster(n_workers=4, processes=False, asynchronous=True) as cluster: - async with Client(cluster, asynchronous=True) as c: - s = cluster.scheduler - s.task_duration["a"] = 4 - s.task_duration["b"] = 4 - s.task_duration["c"] = 1 + with dask.config.set( + {"distributed.scheduler.default-task-durations": {"a": 4, "b": 4, "c": 1}} + ): + async with TestCluster( + n_workers=4, processes=False, asynchronous=True + ) as cluster: + async with Client(cluster, asynchronous=True) as c: + s = cluster.scheduler - future = c.map(slowinc, [1, 1, 1], key=["a-4", "b-4", "c-1"]) + future = c.map(slowinc, [1, 1, 1], key=["a-4", "b-4", "c-1"]) - while len(s.rprocessing) < 3: - await gen.sleep(0.001) + while len(s.rprocessing) < 3: + await gen.sleep(0.001) - ta = cluster.adapt(interval="100 ms", scale_factor=2, Adaptive=TestAdaptive) + ta = cluster.adapt( + interval="100 ms", scale_factor=2, Adaptive=TestAdaptive + ) - await gen.sleep(0.3) + await gen.sleep(0.3) def test_adaptive_local_cluster(loop): @@ -298,32 +302,25 @@ def test_adapt_down(): @gen_test(timeout=30) -def test_no_more_workers_than_tasks(): - loop = IOLoop.current() - cluster = yield LocalCluster( - 0, - scheduler_port=0, - silence_logs=False, - processes=False, - dashboard_address=None, - loop=loop, - asynchronous=True, - ) - yield cluster._start() - try: - adapt = cluster.adapt(minimum=0, maximum=4, interval="10 ms") - client = yield Client(cluster, asynchronous=True, loop=loop) - cluster.scheduler.task_duration["slowinc"] = 1000 - - yield client.submit(slowinc, 1, delay=0.100) - - assert len(cluster.scheduler.workers) <= 1 - finally: - yield client.close() - yield cluster.close() - - -def test_basic_no_loop(): +async def test_no_more_workers_than_tasks(): + with dask.config.set( + {"distributed.scheduler.default-task-durations": {"slowinc": 1000}} + ): + async with LocalCluster( + 0, + scheduler_port=0, + silence_logs=False, + processes=False, + dashboard_address=None, + asynchronous=True, + ) as cluster: + adapt = cluster.adapt(minimum=0, maximum=4, interval="10 ms") + async with Client(cluster, asynchronous=True) as client: + await client.submit(slowinc, 1, delay=0.100) + assert len(cluster.scheduler.workers) <= 1 + + +def test_basic_no_loop(loop): with clean(threads=False): try: with LocalCluster( @@ -339,36 +336,31 @@ def test_basic_no_loop(): @gen_test(timeout=None) -def test_target_duration(): +async def test_target_duration(): """ Ensure that redefining adapt with a lower maximum removes workers """ - cluster = yield LocalCluster( - 0, - asynchronous=True, - processes=False, - scheduler_port=0, - silence_logs=False, - dashboard_address=None, - ) - client = yield Client(cluster, asynchronous=True) - adapt = cluster.adapt(interval="20ms", minimum=2, target_duration="5s") - - cluster.scheduler.task_duration["slowinc"] = 1 - - try: - while len(cluster.scheduler.workers) < 2: - yield gen.sleep(0.01) - - futures = client.map(slowinc, range(100), delay=0.3) - - while len(adapt.log) < 2: - yield gen.sleep(0.01) - - assert adapt.log[0][1] == {"status": "up", "n": 2} - assert adapt.log[1][1] == {"status": "up", "n": 20} - - finally: - yield client.close() - yield cluster.close() + with dask.config.set( + {"distributed.scheduler.default-task-durations": {"slowinc": 1}} + ): + async with LocalCluster( + 0, + asynchronous=True, + processes=False, + scheduler_port=0, + silence_logs=False, + dashboard_address=None, + ) as cluster: + adapt = cluster.adapt(interval="20ms", minimum=2, target_duration="5s") + async with Client(cluster, asynchronous=True) as client: + while len(cluster.scheduler.workers) < 2: + await gen.sleep(0.01) + + futures = client.map(slowinc, range(100), delay=0.3) + + while len(adapt.log) < 2: + await gen.sleep(0.01) + + assert adapt.log[0][1] == {"status": "up", "n": 2} + assert adapt.log[1][1] == {"status": "up", "n": 20} @pytest.mark.asyncio diff --git a/distributed/diagnostics/progress.py b/distributed/diagnostics/progress.py index 48f26570980..1dcab0dc9e9 100644 --- a/distributed/diagnostics/progress.py +++ b/distributed/diagnostics/progress.py @@ -246,7 +246,7 @@ def __init__(self, scheduler): for ts in self.scheduler.tasks.values(): key = ts.key - prefix = ts.prefix + prefix = ts.prefix.name self.all[prefix].add(key) self.state[ts.state][prefix].add(key) if ts.nbytes is not None: @@ -256,7 +256,7 @@ def __init__(self, scheduler): def transition(self, key, start, finish, *args, **kwargs): ts = self.scheduler.tasks[key] - prefix = ts.prefix + prefix = ts.prefix.name self.all[prefix].add(key) try: self.state[start][prefix].remove(key) diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index e6c6a49b484..ee38750f8ee 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -23,7 +23,9 @@ distributed: pickle: True # Is the scheduler allowed to deserialize arbitrary bytestrings preload: [] preload-argv: [] - default-task-durations: {} # How long we expect function names to run ("1h", "1s") (helps for long tasks) + default-task-durations: # How long we expect function names to run ("1h", "1s") (helps for long tasks) + rechunk-split: 1us + shuffle-split: 1us validate: False # Check scheduler state at every step for debugging dashboard: status: diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 275058d5221..98faba466e6 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -20,9 +20,9 @@ import sortedcontainers try: - from cytoolz import frequencies, merge, pluck, merge_sorted, first + from cytoolz import frequencies, merge, pluck, merge_sorted, first, merge_with except ImportError: - from toolz import frequencies, merge, pluck, merge_sorted, first + from toolz import frequencies, merge, pluck, merge_sorted, first, merge_with from toolz import valmap, second, compose, groupby from tornado import gen from tornado.ioloop import IOLoop @@ -57,6 +57,7 @@ parse_bytes, PeriodicCallback, shutting_down, + key_split_group, empty_context, tmpfile, ) @@ -333,11 +334,10 @@ class TaskState(object): from the name of the function, followed by a hash of the function and arguments, like ``'inc-ab31c010444977004d656610d2d421ec'``. - .. attribute:: prefix: str + .. attribute:: prefix: TaskPrefix - The key prefix, used in certain calculations to get an estimate - of the task's duration based on the duration of other tasks in the - same "family" (for example ``'inc'``). + The broad class of tasks to which this task belongs like "inc" or + "read_csv" .. attribute:: run_spec: object @@ -551,6 +551,10 @@ class TaskState(object): .. attribute: actor: bool Whether or not this task is an Actor. + + .. attribute: group: TaskGroup + +: The group of tasks to which this one belongs. """ __slots__ = ( @@ -573,7 +577,7 @@ class TaskState(object): "resource_restrictions", "loose_restrictions", # === Task state === - "state", + "_state", # Whether some dependencies were forgotten "has_lost_dependencies", # If in 'waiting' state, which tasks need to complete @@ -595,13 +599,14 @@ class TaskState(object): "retries", "nbytes", "type", + "group_key", + "group", ) def __init__(self, key, run_spec): self.key = key - self.prefix = key_split(key) self.run_spec = run_spec - self.state = None + self._state = None self.exception = self.traceback = self.exception_blame = None self.suspicious = self.retries = 0 self.nbytes = None @@ -620,14 +625,38 @@ def __init__(self, key, run_spec): self.loose_restrictions = False self.actor = None self.type = None + self.group_key = key_split_group(key) + self.group = None + + @property + def state(self) -> str: + return self._state - def get_nbytes(self): + @property + def prefix_key(self): + return self.prefix.name + + @state.setter + def state(self, value: str): + self.group.states[self._state] -= 1 + self.group.states[value] += 1 + self._state = value + + def add_dependency(self, other: "TaskState"): + """ Add another task as a dependency of this task """ + self.dependencies.add(other) + self.group.dependencies.add(other.group) + other.dependents.add(self) + + def get_nbytes(self) -> int: nbytes = self.nbytes return nbytes if nbytes is not None else DEFAULT_DATA_SIZE - def set_nbytes(self, nbytes): + def set_nbytes(self, nbytes: int): old_nbytes = self.nbytes diff = nbytes - (old_nbytes or 0) + self.group.nbytes_total += diff + self.group.nbytes_in_memory += diff for ws in self.who_has: ws.nbytes += diff self.nbytes = nbytes @@ -654,6 +683,161 @@ def validate(self): pdb.set_trace() +class TaskGroup(object): + """ Collection tracking all tasks within a group + + Keys often have a structure like ``("x-123", 0)`` + A group takes the first section, like ``"x-123"`` + + .. attribute:: name: str + + The name of a group of tasks. + For a task like ``("x-123", 0)`` this is the text ``"x-123"`` + + .. attribute:: states: Dict[str, int] + + The number of tasks in each state, + like ``{"memory": 10, "processing": 3, "released": 4, ...}`` + + .. attribute:: dependencies: Set[TaskGroup] + + The other TaskGroups on which this one depends + + .. attribute:: nbytes_total: int + + The total number of bytes that this task group has produced + + .. attribute:: nbytes_in_memory: int + + The number of bytes currently stored by this TaskGroup + + .. attribute:: duration: float + + The total amount of time spent on all tasks in this TaskGroup + + .. attribute:: types: Set[str] + + The result types of this TaskGroup + + See also + -------- + TaskPrefix + """ + + def __init__(self, name): + self.name = name + self.states = {state: 0 for state in ALL_TASK_STATES} + self.states["forgotten"] = 0 + self.dependencies = set() + self.nbytes_total = 0 + self.nbytes_in_memory = 0 + self.duration = 0 + self.types = set() + + def add(self, ts): + # self.tasks.add(ts) + self.states[ts.state] += 1 + ts.group = self + + def __repr__(self): + return ( + "<" + + (self.name or "no-group") + + ": " + + ", ".join( + "%s: %d" % (k, v) for (k, v) in sorted(self.states.items()) if v + ) + + ">" + ) + + def __len__(self): + return sum(self.states.values()) + + +class TaskPrefix(object): + """ Collection tracking all tasks within a group + + Keys often have a structure like ``("x-123", 0)`` + A group takes the first section, like ``"x"`` + + .. attribute:: name: str + + The name of a group of tasks. + For a task like ``("x-123", 0)`` this is the text ``"x"`` + + .. attribute:: states: Dict[str, int] + + The number of tasks in each state, + like ``{"memory": 10, "processing": 3, "released": 4, ...}`` + + .. attribute:: duration_average: float + + An exponentially weighted moving average duration of all tasks with this prefix + + See Also + -------- + TaskGroup + """ + + def __init__(self, name): + self.name = name + self.groups = [] + if self.name in dask.config.get("distributed.scheduler.default-task-durations"): + self.duration_average = parse_timedelta( + dask.config.get("distributed.scheduler.default-task-durations")[ + self.name + ] + ) + else: + self.duration_average = None + + @property + def states(self): + return merge_with(sum, [g.states for g in self.groups]) + + @property + def active(self): + return [ + g + for g in self.groups + if any(v != 0 for k, v in g.states.items() if k != "forgotten") + ] + + @property + def active_states(self): + return merge_with(sum, [g.states for g in self.active]) + + def __repr__(self): + return ( + "<" + + self.name + + ": " + + ", ".join( + "%s: %d" % (k, v) for (k, v) in sorted(self.states.items()) if v + ) + + ">" + ) + + @property + def nbytes_in_memory(self): + return sum(tg.nbytes_in_memory for tg in self.groups) + + @property + def nbytes_total(self): + return sum(tg.nbytes_total for tg in self.groups) + + def __len__(self): + return sum(map(len, self.groups)) + + @property + def duration(self): + return sum(tg.duration for tg in self.groups) + + @property + def types(self): + return set.union(*[tg.types for tg in self.groups]) + + class _StateLegacyMapping(Mapping): """ A mapping interface mimicking the former Scheduler state dictionaries. @@ -923,6 +1107,8 @@ def __init__( # Task state self.tasks = dict() + self.task_groups = dict() + self.task_prefixes = dict() for old_attr, new_attr, wrap in [ ("priority", "priority", None), ("dependencies", "dependencies", _legacy_task_key_set), @@ -972,11 +1158,6 @@ def __init__( self.datasets = dict() # Prefix-keyed containers - self.task_duration = {prefix: 0.00001 for prefix in fast_tasks} - for k, v in dask.config.get( - "distributed.scheduler.default-task-durations", {} - ).items(): - self.task_duration[k] = parse_timedelta(v) self.unknown_durations = defaultdict(set) # Client state @@ -1631,8 +1812,7 @@ def update_graph( # XXX Have a method get_task_state(self, k) ? ts = self.tasks.get(k) if ts is None: - ts = self.tasks[k] = TaskState(k, tasks.get(k)) - ts.state = "released" + ts = self.new_task(k, tasks.get(k), "released") elif not ts.run_spec: ts.run_spec = tasks.get(k) @@ -1649,8 +1829,7 @@ def update_graph( continue for dep in deps: dts = self.tasks[dep] - ts.dependencies.add(dts) - dts.dependents.add(ts) + ts.add_dependency(dts) # Compute priorities if isinstance(user_priority, Number): @@ -1775,6 +1954,27 @@ def update_graph( # TODO: balance workers + def new_task(self, key, spec, state): + """ Create a new task, and associated states """ + ts = TaskState(key, spec) + ts._state = state + try: + tg = self.task_groups[ts.group_key] + except KeyError: + tg = self.task_groups[ts.group_key] = TaskGroup(ts.group_key) + tg.add(ts) + prefix_key = key_split(key) + try: + tp = self.task_prefixes[prefix_key] + except KeyError: + tp = TaskPrefix(prefix_key) + tp.groups.append(tg) + self.task_prefixes[prefix_key] = tp + ts.prefix = tp + tg.prefix = tp + self.tasks[key] = ts + return ts + def stimulus_task_finished(self, key=None, worker=None, **kwargs): """ Mark that a task has finished execution on a particular worker """ logger.debug("Stimulus task finished %s, %s", key, worker) @@ -2037,8 +2237,7 @@ def client_desires_keys(self, keys=None, client=None): ts = self.tasks.get(k) if ts is None: # For publish, queues etc. - ts = self.tasks[k] = TaskState(k, None) - ts.state = "released" + ts = self.new_task(k, None, "released") ts.who_wants.add(cs) cs.wants_what.add(ts) @@ -2416,15 +2615,14 @@ def handle_long_running(self, key=None, worker=None, compute_duration=None): return if compute_duration: - prefix = ts.prefix - old_duration = self.task_duration.get(prefix, 0) + old_duration = ts.prefix.duration_average or 0 new_duration = compute_duration if not old_duration: avg_duration = new_duration else: avg_duration = 0.5 * old_duration + 0.5 * new_duration - self.task_duration[prefix] = avg_duration + ts.prefix.duration_average = avg_duration ws.occupancy -= ws.processing[ts] self.total_occupancy -= ws.processing[ts] @@ -3250,7 +3448,7 @@ def update_data( for key, workers in who_has.items(): ts = self.tasks.get(key) if ts is None: - ts = self.tasks[key] = TaskState(key, None) + ts = self.new_task(key, None, "memory") ts.state = "memory" if key in nbytes: ts.set_nbytes(nbytes[key]) @@ -3446,13 +3644,13 @@ def get_task_duration(self, ts, default=0.5): Get the estimated computation cost of the given task (not including any communication cost). """ - prefix = ts.prefix - try: - return self.task_duration[prefix] - except KeyError: - self.unknown_durations[prefix].add(ts) + duration = ts.prefix.duration_average + if duration is None: + self.unknown_durations[ts.prefix.name].add(ts) return default + return duration + def run_function(self, stream, function, args=(), kwargs={}, wait=True): """ Run a function within this process @@ -3574,6 +3772,7 @@ def _add_to_memory( ts.state = "memory" ts.type = typename + ts.group.types.add(typename) cs = self.clients["fire-and-forget"] if ts in cs.wants_what: @@ -3847,17 +4046,17 @@ def transition_processing_memory( ############################# if compute_start and ws.processing.get(ts, True): # Update average task duration for worker - prefix = ts.prefix - old_duration = self.task_duration.get(prefix, 0) + old_duration = ts.prefix.duration_average or 0 new_duration = compute_stop - compute_start if not old_duration: avg_duration = new_duration else: avg_duration = 0.5 * old_duration + 0.5 * new_duration - self.task_duration[prefix] = avg_duration + ts.prefix.duration_average = avg_duration + ts.group.duration += new_duration - for tts in self.unknown_durations.pop(prefix, ()): + for tts in self.unknown_durations.pop(ts.prefix.name, ()): if tts.processing_on: wws = tts.processing_on old = wws.processing[tts] @@ -3921,6 +4120,7 @@ def transition_memory_released(self, key, safe=False): for ws in ts.who_has: ws.has_what.remove(ts) ws.nbytes -= ts.get_nbytes() + ts.group.nbytes_in_memory -= ts.get_nbytes() self.worker_send( ws.address, {"op": "delete-data", "keys": [key], "report": False} ) @@ -4236,6 +4436,9 @@ def _propagate_forgotten(self, ts, recommendations): ts.dependencies.clear() ts.waiting_on.clear() + if ts.who_has: + ts.group.nbytes_in_memory -= ts.get_nbytes() + for ws in ts.who_has: ws.has_what.remove(ts) ws.nbytes -= ts.get_nbytes() @@ -5142,9 +5345,6 @@ def validate_state(tasks, workers, clients): _round_robin = [0] -fast_tasks = {"rechunk-split", "shuffle-split"} - - def heartbeat_interval(n): """ Interval in seconds that we desire heartbeats based on number of workers diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 2ec3a9f79d1..7aba29040be 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -1262,7 +1262,9 @@ def test_if_intermediates_clear_on_error(c, s, a, b): assert not any(ts.who_has for ts in s.tasks.values()) -@gen_cluster(client=True) +@gen_cluster( + client=True, config={"distributed.scheduler.default-task-durations": {"f": "1ms"}} +) def test_pragmatic_move_small_data_to_large_data(c, s, a, b): np = pytest.importorskip("numpy") lists = c.map(np.ones, [10000] * 10, pure=False) @@ -1272,7 +1274,6 @@ def test_pragmatic_move_small_data_to_large_data(c, s, a, b): def f(x, y): return None - s.task_duration["f"] = 0.001 results = c.map(f, lists, [total] * 10) yield wait([total]) @@ -3102,12 +3103,12 @@ def test_client_replicate_sync(c): def test_task_load_adapts_quickly(c, s, a): future = c.submit(slowinc, 1, delay=0.2) # slow yield wait(future) - assert 0.15 < s.task_duration["slowinc"] < 0.4 + assert 0.15 < s.task_prefixes["slowinc"].duration_average < 0.4 futures = c.map(slowinc, range(10), delay=0) # very fast yield wait(futures) - assert 0 < s.task_duration["slowinc"] < 0.1 + assert 0 < s.task_prefixes["slowinc"].duration_average < 0.1 @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2) @@ -4013,10 +4014,13 @@ def test_retire_many_workers(c, s, *workers): assert 15 < len(keys) < 50 -@gen_cluster(client=True, nthreads=[("127.0.0.1", 3)] * 2) +@gen_cluster( + client=True, + nthreads=[("127.0.0.1", 3)] * 2, + config={"distributed.scheduler.default-task-durations": {"f": "10ms"}}, +) def test_weight_occupancy_against_data_movement(c, s, a, b): s.extensions["stealing"]._pc.callback_time = 1000000 - s.task_duration["f"] = 0.01 def f(x, y=0, z=0): sleep(0.01) @@ -4033,9 +4037,12 @@ def f(x, y=0, z=0): assert sum(f.key in b.data for f in futures) >= 1 -@gen_cluster(client=True, nthreads=[("127.0.0.1", 1), ("127.0.0.1", 10)]) +@gen_cluster( + client=True, + nthreads=[("127.0.0.1", 1), ("127.0.0.1", 10)], + config={"distributed.scheduler.default-task-durations": {"f": "10ms"}}, +) def test_distribute_tasks_by_nthreads(c, s, a, b): - s.task_duration["f"] = 0.01 s.extensions["stealing"]._pc.callback_time = 1000000 def f(x, y=0): @@ -4748,7 +4755,6 @@ def f(x): yield gen.sleep(0.01) assert threading.active_count() < count + 50 - # assert 0.005 < s.task_duration['f'] < 0.1 assert len(a.log) < 2 * len(b.log) assert len(b.log) < 2 * len(a.log) diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index 68f207a51ce..88910c87069 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -35,23 +35,26 @@ async def test_nanny(s): async with Nanny(s.address, nthreads=2, loop=s.loop) as n: async with rpc(n.address) as nn: assert n.is_alive() - assert s.nthreads[n.worker_address] == 2 - assert s.workers[n.worker_address].nanny == n.address + [ws] = s.workers.values() + assert ws.nthreads == 2 + assert ws.nanny == n.address await nn.kill() assert not n.is_alive() - assert n.worker_address not in s.nthreads - assert n.worker_address not in s.workers + start = time() + while n.worker_address in s.workers: + assert time() < start + 1 + await asyncio.sleep(0.01) await nn.kill() assert not n.is_alive() - assert n.worker_address not in s.nthreads assert n.worker_address not in s.workers await nn.instantiate() assert n.is_alive() - assert s.nthreads[n.worker_address] == 2 - assert s.workers[n.worker_address].nanny == n.address + [ws] = s.workers.values() + assert ws.nthreads == 2 + assert ws.nanny == n.address await nn.terminate() assert not n.is_alive() diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index f9087a3029a..077c5530260 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -21,7 +21,7 @@ from distributed import Nanny, Worker, Client, wait, fire_and_forget from distributed.comm import Comm from distributed.core import connect, rpc, ConnectionPool -from distributed.scheduler import Scheduler, TaskState +from distributed.scheduler import Scheduler from distributed.client import wait from distributed.metrics import time from distributed.protocol.pickle import dumps @@ -710,18 +710,17 @@ def test_retire_workers_n(c, s, a, b): @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 4) -def test_workers_to_close(cl, s, *workers): - s.task_duration["a"] = 4 - s.task_duration["b"] = 4 - s.task_duration["c"] = 1 - - futures = cl.map(slowinc, [1, 1, 1], key=["a-4", "b-4", "c-1"]) - while sum(len(w.processing) for w in s.workers.values()) < 3: - yield gen.sleep(0.001) +async def test_workers_to_close(cl, s, *workers): + with dask.config.set( + {"distributed.scheduler.default-task-durations": {"a": 4, "b": 4, "c": 1}} + ): + futures = cl.map(slowinc, [1, 1, 1], key=["a-4", "b-4", "c-1"]) + while sum(len(w.processing) for w in s.workers.values()) < 3: + await gen.sleep(0.001) - wtc = s.workers_to_close() - assert all(not s.workers[w].processing for w in wtc) - assert len(wtc) == 1 + wtc = s.workers_to_close() + assert all(not s.workers[w].processing for w in wtc) + assert len(wtc) == 1 @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 4) @@ -897,19 +896,19 @@ def test_learn_occupancy_multiple_workers(c, s, a, b): @gen_cluster(client=True) -def test_include_communication_in_occupancy(c, s, a, b): - s.task_duration["slowadd"] = 0.001 +async def test_include_communication_in_occupancy(c, s, a, b): + await c.submit(slowadd, 1, 2, delay=0) x = c.submit(operator.mul, b"0", int(s.bandwidth), workers=a.address) y = c.submit(operator.mul, b"1", int(s.bandwidth * 1.5), workers=b.address) z = c.submit(slowadd, x, y, delay=1) while z.key not in s.tasks or not s.tasks[z.key].processing_on: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) ts = s.tasks[z.key] assert ts.processing_on == s.workers[b.address] assert s.workers[b.address].processing[ts] > 1 - yield wait(z) + await wait(z) del z @@ -1602,27 +1601,28 @@ def test_dashboard_address(): @gen_cluster(client=True) async def test_adaptive_target(c, s, a, b): - assert s.adaptive_target() == 0 - x = c.submit(inc, 1) - await x - assert s.adaptive_target() == 1 - - # Long task - s.task_duration["slowinc"] = 10 - x = c.submit(slowinc, 1, delay=0.5) - while x.key not in s.tasks: - await gen.sleep(0.01) - assert s.adaptive_target(target_duration=".1s") == 1 # still one - - s.task_duration["slowinc"] = 10 - L = c.map(slowinc, range(100), delay=0.5) - while len(s.tasks) < 100: - await gen.sleep(0.01) - assert 10 < s.adaptive_target(target_duration=".1s") <= 100 - del x, L - while s.tasks: - await gen.sleep(0.01) - assert s.adaptive_target(target_duration=".1s") == 0 + with dask.config.set( + {"distributed.scheduler.default-task-durations": {"slowinc": 10}} + ): + assert s.adaptive_target() == 0 + x = c.submit(inc, 1) + await x + assert s.adaptive_target() == 1 + + # Long task + x = c.submit(slowinc, 1, delay=0.5) + while x.key not in s.tasks: + await gen.sleep(0.01) + assert s.adaptive_target(target_duration=".1s") == 1 # still one + + L = c.map(slowinc, range(100), delay=0.5) + while len(s.tasks) < 100: + await gen.sleep(0.01) + assert 10 < s.adaptive_target(target_duration=".1s") <= 100 + del x, L + while s.tasks: + await gen.sleep(0.01) + assert s.adaptive_target(target_duration=".1s") == 0 @pytest.mark.asyncio @@ -1673,27 +1673,28 @@ async def test_retire_names_str(cleanup): assert len(b.data) == 10 -def test_get_task_duration(): +@gen_cluster(client=True) +async def test_get_task_duration(c, s, a, b): with dask.config.set( - {"distributed.scheduler.default-task-durations": {"prefix_1": 100}} + {"distributed.scheduler.default-task-durations": {"inc": 100}} ): - s = Scheduler(port=0) - assert "prefix_1" in s.task_duration - assert s.task_duration["prefix_1"] == 100 + future = c.submit(inc, 1) + await future + assert 10 < s.task_prefixes["inc"].duration_average < 100 - ts_pref1 = TaskState("prefix_1-abcdefab", None) - assert s.get_task_duration(ts_pref1) == 100 + ts_pref1 = s.new_task("inc-abcdefab", None, "released") + assert 10 < s.get_task_duration(ts_pref1) < 100 # make sure get_task_duration adds TaskStates to unknown dict assert len(s.unknown_durations) == 0 - ts_pref2 = TaskState("prefix_2-abcdefab", None) - assert s.get_task_duration(ts_pref2) == 0.5 # default - assert len(s.unknown_durations) == 1 - assert len(s.unknown_durations["prefix_2"]) == 1 - ts_pref2_2 = TaskState("prefix_2-accdefab", None) - assert s.get_task_duration(ts_pref2_2) == 0.5 # default + x = c.submit(slowinc, 1, delay=0.5) + while len(s.tasks) < 3: + await asyncio.sleep(0.01) + + ts = s.tasks[x.key] + assert s.get_task_duration(ts) == 0.5 # default assert len(s.unknown_durations) == 1 - assert len(s.unknown_durations["prefix_2"]) == 2 + assert len(s.unknown_durations["slowinc"]) == 1 @pytest.mark.asyncio @@ -1711,6 +1712,56 @@ async def test_no_danglng_asyncio_tasks(cleanup): assert tasks == start +@gen_cluster(client=True) +async def test_task_groups(c, s, a, b): + da = pytest.importorskip("dask.array") + x = da.arange(100, chunks=(20,)) + y = (x + 1).persist(optimize_graph=False) + y = await y + + tg = s.task_groups[x.name] + tp = s.task_prefixes["arange"] + repr(tg) + repr(tp) + assert tg.states["memory"] == 0 + assert tg.states["released"] == 5 + assert tp.states["memory"] == 0 + assert tp.states["released"] == 5 + assert tg.prefix is tp + assert tg in tp.groups + assert tg.duration == tp.duration + assert tg.nbytes_in_memory == tp.nbytes_in_memory + assert tg.nbytes_total == tp.nbytes_total + + tg = s.task_groups[y.name] + assert tg.states["memory"] == 5 + + assert s.task_groups[y.name].dependencies == {s.task_groups[x.name]} + + await c.replicate(y) + assert tg.nbytes_in_memory == y.nbytes + + del y + + while s.tasks: + await asyncio.sleep(0.01) + + assert tg.nbytes_in_memory == 0 + assert tg.states["forgotten"] == 5 + assert "array" in str(tg.types) + assert "array" in str(tp.types) + + +@gen_cluster(client=True) +async def test_task_prefix(c, s, a, b): + da = pytest.importorskip("dask.array") + x = da.arange(100, chunks=(20,)) + y = (x + 1).sum().persist() + y = await y + + assert s.task_prefixes["sum-aggregate"].states["memory"] == 1 + + class BrokenComm(Comm): peer_address = None local_address = None @@ -1857,6 +1908,28 @@ def reducer(x, y): ] assert len(transitions_to_processing) == 1 + starts = [] + finish_processing_transitions = 0 + for transition in s.transition_log: + key, start, finish, recommendations, timestamp = transition + if "reducer" in key and finish == "processing": + finish_processing_transitions += 1 + assert finish_processing_transitions == 1 + + +@gen_cluster(client=True) +async def test_too_many_groups(c, s, a, b): + x = dask.delayed(inc)(1) + y = dask.delayed(dec)(2) + z = dask.delayed(operator.add)(x, y) + + await c.compute(z) + + while s.tasks: + await asyncio.sleep(0.01) + + assert len(s.task_groups) < 3 + @pytest.mark.asyncio async def test_multiple_listeners(cleanup): diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 6d98e662034..de63e542807 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -317,12 +317,15 @@ def test_dont_steal_executing_tasks(c, s, a, b): assert len(b.data) == 0 -@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 10) +@gen_cluster( + client=True, + nthreads=[("127.0.0.1", 1)] * 10, + config={"distributed.scheduler.default-task-durations": {"slowidentity": 0.2}}, +) def test_dont_steal_few_saturated_tasks_many_workers(c, s, a, *rest): s.extensions["stealing"]._pc.callback_time = 20 x = c.submit(mul, b"0", 100000000, workers=a.address) # 100 MB yield wait(x) - s.task_duration["slowidentity"] = 0.2 futures = [c.submit(slowidentity, x, pure=False, delay=0.2) for i in range(2)] @@ -336,12 +339,12 @@ def test_dont_steal_few_saturated_tasks_many_workers(c, s, a, *rest): client=True, nthreads=[("127.0.0.1", 1)] * 10, worker_kwargs={"memory_limit": MEMORY_LIMIT}, + config={"distributed.scheduler.default-task-durations": {"slowidentity": 0.2}}, ) def test_steal_when_more_tasks(c, s, a, *rest): s.extensions["stealing"]._pc.callback_time = 20 x = c.submit(mul, b"0", 50000000, workers=a.address) # 50 MB yield wait(x) - s.task_duration["slowidentity"] = 0.2 futures = [c.submit(slowidentity, x, pure=False, delay=0.2) for i in range(20)] @@ -351,7 +354,16 @@ def test_steal_when_more_tasks(c, s, a, *rest): assert time() < start + 1 -@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 10) +@gen_cluster( + client=True, + nthreads=[("127.0.0.1", 1)] * 10, + config={ + "distributed.scheduler.default-task-durations": { + "slowidentity": 0.2, + "slow2": 1, + } + }, +) def test_steal_more_attractive_tasks(c, s, a, *rest): def slow2(x): sleep(1) @@ -361,9 +373,6 @@ def slow2(x): x = c.submit(mul, b"0", 100000000, workers=a.address) # 100 MB yield wait(x) - s.task_duration["slowidentity"] = 0.2 - s.task_duration["slow2"] = 1 - futures = [c.submit(slowidentity, x, pure=False, delay=0.2) for i in range(10)] future = c.submit(slow2, x, priority=-1) @@ -399,7 +408,6 @@ def assert_balanced(inp, expected, c, s, *workers): ws.nbytes += ts.nbytes - old_nbytes else: dat = 123 - s.task_duration[str(int(t))] = 1 i = next(counter) f = c.submit( func, @@ -473,7 +481,15 @@ def assert_balanced(inp, expected, c, s, *workers): ) def test_balance(inp, expected): test = lambda *args, **kwargs: assert_balanced(inp, expected, *args, **kwargs) - test = gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * len(inp))(test) + test = gen_cluster( + client=True, + nthreads=[("127.0.0.1", 1)] * len(inp), + config={ + "distributed.scheduler.default-task-durations": { + str(i): 1 for i in range(10) + } + }, + )(test) test() @@ -495,10 +511,12 @@ def test_restart(c, s, a, b): assert not any(x for L in steal.stealable.values() for x in L) -@gen_cluster(client=True) +@gen_cluster( + client=True, + config={"distributed.scheduler.default-task-durations": {"slowadd": 0.001}}, +) def test_steal_communication_heavy_tasks(c, s, a, b): steal = s.extensions["stealing"] - s.task_duration["slowadd"] = 0.001 x = c.submit(mul, b"0", int(s.bandwidth), workers=a.address) y = c.submit(mul, b"1", int(s.bandwidth), workers=b.address) diff --git a/distributed/utils.py b/distributed/utils.py index 26c503205aa..f7009ece83b 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -614,28 +614,16 @@ def key_split(s): def key_split_group(x): """A more fine-grained version of key_split - >>> key_split_group('x') - 'x' - >>> key_split_group('x-1') - 'x-1' - >>> key_split_group('x-1-2-3') - 'x-1-2-3' >>> key_split_group(('x-2', 1)) 'x-2' >>> key_split_group("('x-2', 1)") 'x-2' - >>> key_split_group('hello-world-1') - 'hello-world-1' - >>> key_split_group(b'hello-world-1') - 'hello-world-1' >>> key_split_group('ae05086432ca935f6eba409a8ecd4896') 'data' >>> key_split_group('>> key_split_group(None) - 'Other' - >>> key_split_group('x-abcdefab') # ignores hex - 'x-abcdefab' + >>> key_split_group('x') + >>> key_split_group('x-1') """ typ = type(x) if typ is tuple: @@ -648,11 +636,11 @@ def key_split_group(x): elif x[0] == "<": return x.strip("<>").split()[0].split(".")[-1] else: - return x + return "" elif typ is bytes: return key_split_group(x.decode()) else: - return "Other" + return "" @contextmanager From 74b6e1aa4980df76d72441ed72145e662b0211fd Mon Sep 17 00:00:00 2001 From: Stephan Erb Date: Thu, 12 Dec 2019 02:11:56 +0100 Subject: [PATCH 0605/1550] Use worker name in logs (#3309) --- distributed/scheduler.py | 14 ++++++-------- distributed/utils.py | 5 ++++- distributed/worker.py | 3 ++- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 98faba466e6..46851c2c059 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -292,15 +292,13 @@ def clean(self): return ws def __repr__(self): - return "" % ( + return "" % ( self.address, + self.name, len(self.has_what), len(self.processing), ) - def __str__(self): - return self.address - def identity(self): return { "type": "Worker", @@ -1616,7 +1614,7 @@ async def add_worker( ws = self.workers.get(address) if ws is not None: - raise ValueError("Worker already exists %s" % address) + raise ValueError("Worker already exists %s" % ws) self.workers[address] = ws = WorkerState( address=address, @@ -1699,7 +1697,7 @@ async def add_worker( self.log_event(address, {"action": "add-worker"}) self.log_event("all", {"action": "add-worker", "worker": address}) - logger.info("Register %s", str(address)) + logger.info("Register worker %s", ws) if comm: await comm.write( @@ -2120,7 +2118,7 @@ def remove_worker(self, comm=None, address=None, safe=False, close=True): "processing-tasks": dict(ws.processing), }, ) - logger.info("Remove worker %s", address) + logger.info("Remove worker %s", ws) if close: with ignoring(AttributeError, CommClosedError): self.stream_comms[address].send({"op": "close", "report": False}) @@ -2191,7 +2189,7 @@ def remove_worker_from_events(): dask.config.get("distributed.scheduler.events-cleanup-delay") ) self.loop.call_later(cleanup_delay, remove_worker_from_events) - logger.debug("Removed worker %s", address) + logger.debug("Removed worker %s", ws) return "OK" diff --git a/distributed/utils.py b/distributed/utils.py index f7009ece83b..d268e0b54b0 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -6,6 +6,7 @@ from datetime import timedelta import functools from hashlib import md5 +import html import inspect import json import logging @@ -1284,7 +1285,9 @@ class Log(str): """ A container for logs """ def _repr_html_(self): - return "
          \n{log}\n
          ".format(log=self.rstrip()) + return "
          \n{log}\n
          ".format( + log=html.escape(self.rstrip()) + ) class Logs(dict): diff --git a/distributed/worker.py b/distributed/worker.py index 751fabce2b1..639299fd477 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -709,10 +709,11 @@ def __init__( def __repr__(self): return ( - "<%s: %s, %s, stored: %d, running: %d/%d, ready: %d, comm: %d, waiting: %d>" + "<%s: %r, %s, %s, stored: %d, running: %d/%d, ready: %d, comm: %d, waiting: %d>" % ( self.__class__.__name__, self.address, + self.name, self.status, len(self.data), len(self.executing), From 5c1520a94b15c992044d5e70e2cfb91b22506f6f Mon Sep 17 00:00:00 2001 From: fjetter Date: Fri, 13 Dec 2019 16:36:08 +0100 Subject: [PATCH 0606/1550] All scheduler task states in prometheus (#3307) --- distributed/dashboard/scheduler.py | 62 +++++++----- .../tests/test_scheduler_bokeh_html.py | 94 +++++++++++++++---- 2 files changed, 117 insertions(+), 39 deletions(-) diff --git a/distributed/dashboard/scheduler.py b/distributed/dashboard/scheduler.py index 2c1ec38c4f3..cb96389344e 100644 --- a/distributed/dashboard/scheduler.py +++ b/distributed/dashboard/scheduler.py @@ -4,8 +4,12 @@ import dask from dask.utils import format_bytes -import toolz -from toolz import merge + +try: + from cytoolz import merge, merge_with +except ImportError: + from toolz import merge, merge_with + from tornado import escape try: @@ -42,6 +46,7 @@ from .proxy import GlobalProxyHandler from .utils import RequestHandler, redirect from ..utils import log_errors, format_time +from ..scheduler import ALL_TASK_STATES ns = { @@ -65,7 +70,7 @@ def get(self): "workers.html", title="Workers", scheduler=self.server, - **toolz.merge(self.server.__dict__, ns, self.extra, rel_path_statics), + **merge(self.server.__dict__, ns, self.extra, rel_path_statics), ) @@ -81,7 +86,7 @@ def get(self, worker): title="Worker: " + worker, scheduler=self.server, Worker=worker, - **toolz.merge(self.server.__dict__, ns, self.extra, rel_path_statics), + **merge(self.server.__dict__, ns, self.extra, rel_path_statics), ) @@ -97,7 +102,7 @@ def get(self, task): title="Task: " + task, Task=task, scheduler=self.server, - **toolz.merge(self.server.__dict__, ns, self.extra, rel_path_statics), + **merge(self.server.__dict__, ns, self.extra, rel_path_statics), ) @@ -109,7 +114,7 @@ def get(self): "logs.html", title="Logs", logs=logs, - **toolz.merge(self.extra, rel_path_statics), + **merge(self.extra, rel_path_statics), ) @@ -123,7 +128,7 @@ async def get(self, worker): "logs.html", title="Logs: " + worker, logs=logs, - **toolz.merge(self.extra, rel_path_statics), + **merge(self.extra, rel_path_statics), ) @@ -137,7 +142,7 @@ async def get(self, worker): "call-stack.html", title="Call Stacks: " + worker, call_stack=call_stack, - **toolz.merge(self.extra, rel_path_statics), + **merge(self.extra, rel_path_statics), ) @@ -156,7 +161,7 @@ async def get(self, key): "call-stack.html", title="Call Stack: " + key, call_stack=call_stack, - **toolz.merge(self.extra, rel_path_statics), + **merge(self.extra, rel_path_statics), ) @@ -239,7 +244,7 @@ def __init__(self, server): self.server = server def collect(self): - from prometheus_client.core import GaugeMetricFamily + from prometheus_client.core import GaugeMetricFamily, CounterMetricFamily yield GaugeMetricFamily( "dask_scheduler_clients", @@ -253,40 +258,53 @@ def collect(self): value=self.server.adaptive_target(), ) - tasks = GaugeMetricFamily( + worker_states = GaugeMetricFamily( "dask_scheduler_workers", "Number of workers known by scheduler.", labels=["state"], ) - tasks.add_metric(["connected"], len(self.server.workers)) - tasks.add_metric(["saturated"], len(self.server.saturated)) - tasks.add_metric(["idle"], len(self.server.idle)) - yield tasks + worker_states.add_metric(["connected"], len(self.server.workers)) + worker_states.add_metric(["saturated"], len(self.server.saturated)) + worker_states.add_metric(["idle"], len(self.server.idle)) + yield worker_states tasks = GaugeMetricFamily( "dask_scheduler_tasks", "Number of tasks known by scheduler.", labels=["state"], ) - tasks.add_metric(["received"], len(self.server.tasks)) - tasks.add_metric(["unrunnable"], len(self.server.unrunnable)) + + task_counter = merge_with( + sum, (tp.states for tp in self.server.task_prefixes.values()) + ) + + yield CounterMetricFamily( + "dask_scheduler_tasks_forgotten", + "Total number of processed tasks no longer in memory and already removed from the scheduler job queue.", + value=task_counter.get("forgotten", 0.0), + ) + + for state in ALL_TASK_STATES: + tasks.add_metric([state], task_counter.get(state, 0.0)) yield tasks class PrometheusHandler(RequestHandler): - _initialized = False + _collector = None def __init__(self, *args, **kwargs): import prometheus_client super(PrometheusHandler, self).__init__(*args, **kwargs) - if PrometheusHandler._initialized: + if PrometheusHandler._collector: + # Especially during testing, multiple schedulers are started + # sequentially in the same python process + PrometheusHandler._collector.server = self.server return - prometheus_client.REGISTRY.register(_PrometheusCollector(self.server)) - - PrometheusHandler._initialized = True + PrometheusHandler._collector = _PrometheusCollector(self.server) + prometheus_client.REGISTRY.register(PrometheusHandler._collector) def get(self): import prometheus_client diff --git a/distributed/dashboard/tests/test_scheduler_bokeh_html.py b/distributed/dashboard/tests/test_scheduler_bokeh_html.py index 55e4b797b4e..39da4730a28 100644 --- a/distributed/dashboard/tests/test_scheduler_bokeh_html.py +++ b/distributed/dashboard/tests/test_scheduler_bokeh_html.py @@ -20,10 +20,10 @@ scheduler_kwargs={"services": {("dashboard", 0): BokehScheduler}}, worker_kwargs={"services": {"dashboard": BokehWorker}}, ) -def test_connect(c, s, a, b): +async def test_connect(c, s, a, b): future = c.submit(lambda x: x + 1, 1) x = c.submit(slowinc, 1, delay=1, retries=5) - yield future + await future http_client = AsyncHTTPClient() for suffix in [ "info/main/workers.html", @@ -38,7 +38,7 @@ def test_connect(c, s, a, b): "json/index.html", "individual-plots.json", ]: - response = yield http_client.fetch( + response = await http_client.fetch( "http://localhost:%d/%s" % (s.services["dashboard"].port, suffix) ) assert response.code == 200 @@ -55,15 +55,15 @@ def test_connect(c, s, a, b): nthreads=[], scheduler_kwargs={"services": {("dashboard", 0): BokehScheduler}}, ) -def test_worker_404(c, s): +async def test_worker_404(c, s): http_client = AsyncHTTPClient() with pytest.raises(HTTPClientError) as err: - yield http_client.fetch( + await http_client.fetch( "http://localhost:%d/info/worker/unknown" % s.services["dashboard"].port ) assert err.value.code == 404 with pytest.raises(HTTPClientError) as err: - yield http_client.fetch( + await http_client.fetch( "http://localhost:%d/info/task/unknown" % s.services["dashboard"].port ) assert err.value.code == 404 @@ -75,10 +75,10 @@ def test_worker_404(c, s): "services": {("dashboard", 0): (BokehScheduler, {"prefix": "/foo"})} }, ) -def test_prefix(c, s, a, b): +async def test_prefix(c, s, a, b): http_client = AsyncHTTPClient() for suffix in ["foo/info/main/workers.html", "foo/json/index.html", "foo/system"]: - response = yield http_client.fetch( + response = await http_client.fetch( "http://localhost:%d/%s" % (s.services["dashboard"].port, suffix) ) assert response.code == 200 @@ -94,7 +94,7 @@ def test_prefix(c, s, a, b): clean_kwargs={"threads": False}, scheduler_kwargs={"services": {("dashboard", 0): BokehScheduler}}, ) -def test_prometheus(c, s, a, b): +async def test_prometheus(c, s, a, b): pytest.importorskip("prometheus_client") from prometheus_client.parser import text_string_to_metric_families @@ -103,7 +103,7 @@ def test_prometheus(c, s, a, b): # request data twice since there once was a case where metrics got registered multiple times resulting in # prometheus_client errors for _ in range(2): - response = yield http_client.fetch( + response = await http_client.fetch( "http://localhost:%d/metrics" % s.services["dashboard"].port ) assert response.code == 200 @@ -119,10 +119,70 @@ def test_prometheus(c, s, a, b): clean_kwargs={"threads": False}, scheduler_kwargs={"services": {("dashboard", 0): BokehScheduler}}, ) -def test_health(c, s, a, b): +async def test_prometheus_collect_task_states(c, s, a, b): + pytest.importorskip("prometheus_client") + from prometheus_client.parser import text_string_to_metric_families + + http_client = AsyncHTTPClient() + + async def fetch_metrics(): + bokeh_scheduler = s.services["dashboard"] + assert s.services["dashboard"].scheduler is s + response = await http_client.fetch( + f"http://{bokeh_scheduler.server.address}:{bokeh_scheduler.port}/metrics" + ) + txt = response.body.decode("utf8") + families = { + family.name: family for family in text_string_to_metric_families(txt) + } + + active_metrics = { + sample.labels["state"]: sample.value + for sample in families["dask_scheduler_tasks"].samples + } + forgotten_tasks = [ + sample.value + for sample in families["dask_scheduler_tasks_forgotten"].samples + ] + return active_metrics, forgotten_tasks + + expected = {"memory", "released", "processing", "waiting", "no-worker", "erred"} + + # Ensure that we get full zero metrics for all states even though the + # scheduler did nothing, yet + assert not s.tasks + active_metrics, forgotten_tasks = await fetch_metrics() + assert active_metrics.keys() == expected + assert sum(active_metrics.values()) == 0.0 + assert sum(forgotten_tasks) == 0.0 + + # submit a task which should show up in the prometheus scraping + future = c.submit(slowinc, 1, delay=0.5) + + active_metrics, forgotten_tasks = await fetch_metrics() + assert active_metrics.keys() == expected + assert sum(active_metrics.values()) == 1.0 + assert sum(forgotten_tasks) == 0.0 + + res = await c.gather(future) + assert res == 2 + + del future + active_metrics, forgotten_tasks = await fetch_metrics() + assert active_metrics.keys() == expected + assert sum(active_metrics.values()) == 0.0 + assert sum(forgotten_tasks) == 1.0 + + +@gen_cluster( + client=True, + clean_kwargs={"threads": False}, + scheduler_kwargs={"services": {("dashboard", 0): BokehScheduler}}, +) +async def test_health(c, s, a, b): http_client = AsyncHTTPClient() - response = yield http_client.fetch( + response = await http_client.fetch( "http://localhost:%d/health" % s.services["dashboard"].port ) assert response.code == 200 @@ -135,14 +195,14 @@ def test_health(c, s, a, b): @gen_cluster( client=True, scheduler_kwargs={"services": {("dashboard", 0): BokehScheduler}} ) -def test_task_page(c, s, a, b): +async def test_task_page(c, s, a, b): future = c.submit(lambda x: x + 1, 1, workers=a.address) x = c.submit(inc, 1) - yield future + await future http_client = AsyncHTTPClient() "info/task/" + url_escape(future.key) + ".html", - response = yield http_client.fetch( + response = await http_client.fetch( "http://localhost:%d/info/task/" % s.services["dashboard"].port + url_escape(future.key) + ".html" @@ -167,13 +227,13 @@ def test_task_page(c, s, a, b): } }, ) -def test_allow_websocket_origin(c, s, a, b): +async def test_allow_websocket_origin(c, s, a, b): url = ( "ws://localhost:%d/status/ws?bokeh-protocol-version=1.0&bokeh-session-id=1" % s.services["dashboard"].port ) with pytest.raises(HTTPClientError) as err: - yield websocket_connect( + await websocket_connect( HTTPRequest(url, headers={"Origin": "http://evil.invalid"}) ) assert err.value.code == 403 From 06c0fc27f54b43f97ed89c5af48d8baeeb9175f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Patrick=20Sodr=C3=A9?= Date: Fri, 13 Dec 2019 12:07:53 -0500 Subject: [PATCH 0607/1550] Add plugin entry point for out-of-tree comms library (#3305) * Add distributed.comm.backends plugin entry-point - Allows end-users to create their own comms plugins and have them discovered at run-time without having to preload the library. - Existing behavior remains unchanged, i.e. entry-points has lower precedence than distributed.comm.registry * Add tests * Only import pkg_resources if backend is not in default registry. * Update backends entry-point documentation. * Cache the Backend classes found through package metadata. * Indent the code-block. * fix typo --- distributed/comm/registry.py | 22 ++++++++++++++++++-- distributed/comm/tests/test_comms.py | 30 +++++++++++++++++++++++++++- docs/source/communications.rst | 16 +++++++++++++++ 3 files changed, 65 insertions(+), 3 deletions(-) diff --git a/distributed/comm/registry.py b/distributed/comm/registry.py index 369f2415c35..8fb7a6026f8 100644 --- a/distributed/comm/registry.py +++ b/distributed/comm/registry.py @@ -60,10 +60,28 @@ def get_local_address_for(self, loc): def get_backend(scheme): """ Get the Backend instance for the given *scheme*. + It looks for matching scheme in dask's internal cache, and falls-back to + package metadata for the group name ``distributed.comm.backends`` """ + backend = backends.get(scheme) if backend is None: - raise ValueError( - "unknown address scheme %r (known schemes: %s)" % (scheme, sorted(backends)) + import pkg_resources + + backend = next( + iter( + backend_class_ep.load()() + for backend_class_ep in pkg_resources.iter_entry_points( + "distributed.comm.backends", scheme + ) + ), + None, ) + if backend is None: + raise ValueError( + "unknown address scheme %r (known schemes: %s)" + % (scheme, sorted(backends)) + ) + else: + backends[scheme] = backend return backend diff --git a/distributed/comm/tests/test_comms.py b/distributed/comm/tests/test_comms.py index 465be11c7a4..470c667b989 100644 --- a/distributed/comm/tests/test_comms.py +++ b/distributed/comm/tests/test_comms.py @@ -1,15 +1,18 @@ import asyncio +import types from functools import partial import os import sys import threading import warnings +import pkg_resources import pytest from tornado import ioloop, locks, queues from tornado.concurrent import Future +import distributed from distributed.metrics import time from distributed.utils import get_ip, get_ipv6 from distributed.utils_test import ( @@ -23,7 +26,7 @@ from distributed.protocol import to_serialize, Serialized, serialize, deserialize -from distributed.comm.registry import backends +from distributed.comm.registry import backends, get_backend from distributed.comm import ( tcp, inproc, @@ -1154,3 +1157,28 @@ async def test_tls_adresses(): async def test_inproc_adresses(): a, b = await get_inproc_comm_pair() await check_addresses(a, b) + + +def test_register_backend_entrypoint(): + # Code adapted from pandas backend entry point testing + # https://github.com/pandas-dev/pandas/blob/2470690b9f0826a8feb426927694fa3500c3e8d2/pandas/tests/plotting/test_backend.py#L50-L76 + + dist = pkg_resources.get_distribution("distributed") + if dist.module_path not in distributed.__file__: + # We are running from a non-installed distributed, and this test is invalid + pytest.skip("Testing a non-installed distributed") + + mod = types.ModuleType("dask_udp") + mod.UDPBackend = lambda: 1 + sys.modules[mod.__name__] = mod + + entry_point_name = "distributed.comm.backends" + backends_entry_map = pkg_resources.get_entry_map("distributed") + if entry_point_name not in backends_entry_map: + backends_entry_map[entry_point_name] = dict() + backends_entry_map[entry_point_name]["udp"] = pkg_resources.EntryPoint( + "udp", mod.__name__, attrs=["UDPBackend"], dist=dist + ) + + result = get_backend("udp") + assert result == 1 diff --git a/docs/source/communications.rst b/docs/source/communications.rst index 2869012ed48..b9406e6b3c5 100644 --- a/docs/source/communications.rst +++ b/docs/source/communications.rst @@ -97,6 +97,22 @@ Each transport is represented by a URI scheme (such as ``tcp``) and backed by a dedicated :class:`Backend` implementation, which provides entry points into all transport-specific routines. +Out-of-tree backends can be registered under the group ``distributed.comm.backends`` +in setuptools `entry_points`_. For example, a hypothetical ``dask_udp`` package +would register its UDP backend class by including the following in its ``setup.py`` file: + +.. code-block:: python + + setup(name="dask_udp", + entry_points={ + "distributed.comm.backends": [ + "udp=dask_udp.backend:UDPBackend", + ] + }, + ... + ) .. autoclass:: distributed.comm.registry.Backend :members: + +.. _entry_points: https://packaging.python.org/guides/creating-and-discovering-plugins/#using-package-metadata From d747f639d73c5f9b8961f2f80e407868de7da4cb Mon Sep 17 00:00:00 2001 From: Jacob Tomlinson Date: Fri, 13 Dec 2019 17:46:55 +0000 Subject: [PATCH 0608/1550] Switch startstops to dicts and add worker name to transfer (#3319) * Switch startstops to dicts and add worker name to transfer * Fix task stream * Rename `worker` to `source` --- distributed/diagnostics/progress_stream.py | 12 ++++----- distributed/diagnostics/task_stream.py | 20 ++++++++------- distributed/scheduler.py | 6 ++++- distributed/tests/test_worker.py | 2 +- distributed/worker.py | 29 ++++++++++++++-------- 5 files changed, 42 insertions(+), 27 deletions(-) diff --git a/distributed/diagnostics/progress_stream.py b/distributed/diagnostics/progress_stream.py index d127ecfeb7e..e417ee8e35b 100644 --- a/distributed/diagnostics/progress_stream.py +++ b/distributed/diagnostics/progress_stream.py @@ -156,17 +156,17 @@ def task_stream_append(lists, msg, workers): name = key_split(key) startstops = msg.get("startstops", []) - for action, start, stop in startstops: - color = colors[action] + for startstop in startstops: + color = colors[startstop["action"]] if type(color) is not str: color = color(msg) - lists["start"].append((start + stop) / 2 * 1000) - lists["duration"].append(1000 * (stop - start)) + lists["start"].append((startstop["start"] + startstop["stop"]) / 2 * 1000) + lists["duration"].append(1000 * (startstop["stop"] - startstop["start"])) lists["key"].append(key) - lists["name"].append(prefix[action] + name) + lists["name"].append(prefix[startstop["action"]] + name) lists["color"].append(color) - lists["alpha"].append(alphas[action]) + lists["alpha"].append(alphas[startstop["action"]]) lists["worker"].append(msg["worker"]) worker_thread = "%s-%d" % (msg["worker"], msg["thread"]) diff --git a/distributed/diagnostics/task_stream.py b/distributed/diagnostics/task_stream.py index 2491c8a89c0..c319ca73d69 100644 --- a/distributed/diagnostics/task_stream.py +++ b/distributed/diagnostics/task_stream.py @@ -42,7 +42,9 @@ def bisect(target, left, right): return left mid = (left + right) // 2 - value = max(stop for _, start, stop in self.buffer[mid]["startstops"]) + value = max( + startstop["stop"] for startstop in self.buffer[mid]["startstops"] + ) if value < target: return bisect(target, mid + 1, right) @@ -119,20 +121,20 @@ def rectangles(msgs, workers=None, start_boundary=0): if worker_thread not in workers: workers[worker_thread] = len(workers) / 2 - for action, start, stop in startstops: - if start < start_boundary: + for startstop in startstops: + if startstop["start"] < start_boundary: continue - color = colors[action] + color = colors[startstop["action"]] if type(color) is not str: color = color(msg) - L_start.append((start + stop) / 2 * 1000) - L_duration.append(1000 * (stop - start)) - L_duration_text.append(format_time(stop - start)) + L_start.append((startstop["start"] + startstop["stop"]) / 2 * 1000) + L_duration.append(1000 * (startstop["stop"] - startstop["start"])) + L_duration_text.append(format_time(startstop["stop"] - startstop["start"])) L_key.append(key) - L_name.append(prefix[action] + name) + L_name.append(prefix[startstop["action"]] + name) L_color.append(color) - L_alpha.append(alphas[action]) + L_alpha.append(alphas[startstop["action"]]) L_worker.append(msg["worker"]) L_worker_thread.append(worker_thread) L_y.append(workers[worker_thread]) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 46851c2c059..c6b12e14e93 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4031,7 +4031,11 @@ def transition_processing_memory( return {} if startstops: - L = [(b, c) for a, b, c in startstops if a == "compute"] + L = [ + (startstop["start"], startstop["stop"]) + for startstop in startstops + if startstop["action"] == "compute" + ] if L: compute_start, compute_stop = L[0] else: # This is very rare diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 2cf316ccbca..9cf32c1eec7 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -668,7 +668,7 @@ def test_multiple_transfers(c, s, w1, w2, w3): yield wait(z) r = w3.startstops[z.key] - transfers = [t for t in r if t[0] == "transfer"] + transfers = [t for t in r if t["action"] == "transfer"] assert len(transfers) == 2 diff --git a/distributed/worker.py b/distributed/worker.py index 639299fd477..07695122f91 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -215,7 +215,7 @@ class Worker(ServerNode): The exception caused by running a task if it erred * **tracebacks**: ``{key: traceback}`` The exception caused by running a task if it erred - * **startstops**: ``{key: [(str, float, float)]}`` + * **startstops**: ``{key: [{startstop}]}`` Log of transfer, load, and compute times for a task * **priorities**: ``{key: tuple}`` @@ -1866,7 +1866,9 @@ def put_key_in_memory(self, key, value, transition=True): self.data[key] = value stop = time() if stop - start > 0.020: - self.startstops[key].append(("disk-write", start, stop)) + self.startstops[key].append( + {"action": "disk-write", "start": start, "stop": stop} + ) if key not in self.nbytes: self.nbytes[key] = sizeof(value) @@ -1933,11 +1935,12 @@ async def gather_dep(self, worker, dep, deps, total_nbytes, cause=None): if cause: self.startstops[cause].append( - ( - "transfer", - start + self.scheduler_delay, - stop + self.scheduler_delay, - ) + { + "action": "transfer", + "start": start + self.scheduler_delay, + "stop": stop + self.scheduler_delay, + "source": worker, + } ) total_bytes = sum(self.nbytes.get(dep, 0) for dep in response["data"]) @@ -2383,7 +2386,9 @@ def _maybe_deserialize_task(self, key): stop = time() if stop - start > 0.010: - self.startstops[key].append(("deserialize", start, stop)) + self.startstops[key].append( + {"action": "deserialize", "start": start, "stop": stop} + ) return function, args, kwargs except Exception as e: logger.warning("Could not deserialize task", exc_info=True) @@ -2456,7 +2461,9 @@ async def execute(self, key, report=False): kwargs2 = pack_data(kwargs, data, key_types=(bytes, str)) stop = time() if stop - start > 0.005: - self.startstops[key].append(("disk-read", start, stop)) + self.startstops[key].append( + {"action": "disk-read", "start": start, "stop": stop} + ) if self.digests is not None: self.digests["disk-load-duration"].add(stop - start) @@ -2487,7 +2494,9 @@ async def execute(self, key, report=False): result["key"] = key value = result.pop("result", None) - self.startstops[key].append(("compute", result["start"], result["stop"])) + self.startstops[key].append( + {"action": "compute", "start": result["start"], "stop": result["stop"]} + ) self.threads[key] = result["thread"] if result["op"] == "task-finished": From 5b1a1a99feeb934534d66621430e239f9564ee5a Mon Sep 17 00:00:00 2001 From: Benjamin Zaitlen Date: Mon, 16 Dec 2019 11:57:24 -0500 Subject: [PATCH 0609/1550] Add performance report docs and color definitions to docs (#3325) --- docs/source/diagnosing-performance.rst | 29 ++++++++++++++++++++++++++ docs/source/web.rst | 9 ++++++++ 2 files changed, 38 insertions(+) diff --git a/docs/source/diagnosing-performance.rst b/docs/source/diagnosing-performance.rst index 773a5d2316b..76ecfed944f 100644 --- a/docs/source/diagnosing-performance.rst +++ b/docs/source/diagnosing-performance.rst @@ -115,6 +115,35 @@ command on the workers: client.run(lambda dask_worker: dask_worker.incoming_transfer_log) +Performance Reports +------------------- + +Often when benchmarking and/or profiling, users may want to record a +particular computation or even a full workflow. Dask can save the bokeh +dashboards as static HTML plots including the task stream, worker profiles, +bandwidths, etc. This is done wrapping a computation with the ``performance_report`` context manager: + +.. code-block:: python + + from dask.distributed import performance_report + + with performance_report(filename="dask-report.html): + ## some dask computation + +The following video demonstrates the ``performance_report`` context manager in greater +detail: + +.. raw:: html + + + + A note about times ------------------ diff --git a/docs/source/web.rst b/docs/source/web.rst index cfef4902c35..6a5b58fac5e 100644 --- a/docs/source/web.rst +++ b/docs/source/web.rst @@ -126,6 +126,15 @@ accordingly. .. image:: https://raw.githubusercontent.com/dask/dask-org/master/images/bokeh-task-stream.gif :alt: Task stream plot of Dask web interface +The colors signifying the following: + +1. Serialization (gray) +2. Communication between workers (red) +3. Disk I/O (orange) +4. Error (black) +5. Execution times (colored by task: purple, green, yellow, etc) + + If data transfer occurs between workers a *red* bar appears preceding the task bar showing the duration of the transfer. If an error occurs than a *black* bar replaces the normal color. This plot show the last 1000 tasks. From 94bf2ce3d28ee2e286086ef7c101d2c6c5d3cd89 Mon Sep 17 00:00:00 2001 From: jakirkham Date: Tue, 17 Dec 2019 13:03:39 -0600 Subject: [PATCH 0610/1550] Add missing `"` in performance report example (#3329) --- docs/source/diagnosing-performance.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/diagnosing-performance.rst b/docs/source/diagnosing-performance.rst index 76ecfed944f..330194076d3 100644 --- a/docs/source/diagnosing-performance.rst +++ b/docs/source/diagnosing-performance.rst @@ -127,7 +127,7 @@ bandwidths, etc. This is done wrapping a computation with the ``performance_repo from dask.distributed import performance_report - with performance_report(filename="dask-report.html): + with performance_report(filename="dask-report.html"): ## some dask computation The following video demonstrates the ``performance_report`` context manager in greater From 0a8ccd8d720f001c03b51ef5b9be7c549b341aae Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 18 Dec 2019 10:54:30 -0800 Subject: [PATCH 0611/1550] Use TaskPrefix.name in Graph layout (#3328) Fixes #3327 --- distributed/dashboard/components/scheduler.py | 2 +- distributed/dashboard/tests/test_scheduler_bokeh.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/distributed/dashboard/components/scheduler.py b/distributed/dashboard/components/scheduler.py index ad7cecea024..9319937346c 100644 --- a/distributed/dashboard/components/scheduler.py +++ b/distributed/dashboard/components/scheduler.py @@ -1219,7 +1219,7 @@ def add_new_nodes_edges(self, new, new_edges, update=False): node_x.append(xx) node_y.append(yy) node_state.append(task.state) - node_name.append(task.prefix) + node_name.append(task.prefix.name) for a, b in new_edges: try: diff --git a/distributed/dashboard/tests/test_scheduler_bokeh.py b/distributed/dashboard/tests/test_scheduler_bokeh.py index 0f262ec5809..39d2ce84156 100644 --- a/distributed/dashboard/tests/test_scheduler_bokeh.py +++ b/distributed/dashboard/tests/test_scheduler_bokeh.py @@ -442,6 +442,8 @@ def test_TaskGraph(c, s, a, b): gp.update() assert set(map(len, gp.node_source.data.values())) == {6} assert set(map(len, gp.edge_source.data.values())) == {5} + json.dumps(gp.edge_source.data) + json.dumps(gp.node_source.data) da = pytest.importorskip("dask.array") x = da.random.random((20, 20), chunks=(10, 10)).persist() From 54efd79fa535eaf2926f52eb056c4a5a6c4e0d3b Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Wed, 18 Dec 2019 14:28:56 -0600 Subject: [PATCH 0612/1550] Add setuptools to dependencies (#3320) --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index b17e4620be6..67ab27f4edf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,3 +9,4 @@ toolz >= 0.7.4 tornado >= 5 zict >= 0.1.3 pyyaml +setuptools From 94f0219b26721bddb3af5539120add0a2a901298 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sat, 21 Dec 2019 11:25:15 -0800 Subject: [PATCH 0613/1550] Add lock around dumps_function cache (#3337) Fixes #2727 --- distributed/worker.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/distributed/worker.py b/distributed/worker.py index 07695122f91..716806a67d6 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -3234,15 +3234,19 @@ def execute_task(task): cache_dumps = LRU(maxsize=100) +_cache_lock = threading.Lock() + def dumps_function(func): """ Dump a function to bytes, cache functions """ try: - result = cache_dumps[func] + with _cache_lock: + result = cache_dumps[func] except KeyError: result = pickle.dumps(func) if len(result) < 100000: - cache_dumps[func] = result + with _cache_lock: + cache_dumps[func] = result except TypeError: # Unhashable function result = pickle.dumps(func) return result From 45eb9bf12c322be429a13979967b5854ceca31f1 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 27 Dec 2019 10:53:20 -0800 Subject: [PATCH 0614/1550] bump version to 2.9.1 --- docs/source/changelog.rst | 24 ++++++++++++++++++++++++ requirements.txt | 2 +- 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index 3cff2b92ab5..3e4a05fcc79 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -1,6 +1,28 @@ Changelog ========= +2.9.1 - 2019-12-27 +------------------ + +- Add lock around dumps_function cache (:pr:`3337`) `Matthew Rocklin`_ +- Add setuptools to dependencies (:pr:`3320`) `James Bourbeau`_ +- Use TaskPrefix.name in Graph layout (:pr:`3328`) `Matthew Rocklin`_ +- Add missing `"` in performance report example (:pr:`3329`) `John Kirkham`_ +- Add performance report docs and color definitions to docs (:pr:`3325`) `Benjamin Zaitlen`_ +- Switch startstops to dicts and add worker name to transfer (:pr:`3319`) `Jacob Tomlinson`_ +- Add plugin entry point for out-of-tree comms library (:pr:`3305`) `Patrick Sodré`_ +- All scheduler task states in prometheus (:pr:`3307`) `fjetter`_ +- Use worker name in logs (:pr:`3309`) `Stephan Erb`_ +- Add TaskGroup and TaskPrefix scheduler state (:pr:`3262`) `Matthew Rocklin`_ +- Update latencies with heartbeats (:pr:`3310`) `fjetter`_ +- Update inlining Futures in task graph in Client._graph_to_futures (:pr:`3303`) `James Bourbeau`_ +- Use hostname as default IP address rather than localhost (:pr:`3308`) `Matthew Rocklin`_ +- Clean up flaky test_nanny_throttle (:pr:`3295`) `Tom Augspurger`_ +- Add lock to scheduler for sensitive operations (:pr:`3259`) `Matthew Rocklin`_ +- Log address for each of the Scheduler listerners (:pr:`3306`) `Matthew Rocklin`_ +- Make ConnectionPool.close asynchronous (:pr:`3304`) `Matthew Rocklin`_ + + 2.9.0 - 2019-12-06 ------------------ @@ -1447,3 +1469,5 @@ significantly without many new features. .. _`He Jia`: https://github.com/HerculesJack .. _`Jim Crist-Harif`: https://github.com/jcrist .. _`fjetter`: https://github.com/fjetter +.. _`Patrick Sodré`: https://github.com/sodre +.. _`Stephan Erb`: https://github.com/StephanErb diff --git a/requirements.txt b/requirements.txt index 67ab27f4edf..545eba40c4d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ click >= 6.6 cloudpickle >= 0.2.2 -dask >= 2.7.0 +dask >= 2.9.0 msgpack psutil >= 5.0 sortedcontainers !=2.0.0, !=2.0.1 From b6fb54154e8805ffe8fbec80bcb5ee7f2d24f328 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 27 Dec 2019 16:42:14 -0800 Subject: [PATCH 0615/1550] Relax intermittent failing test_profile_server (#3346) --- distributed/tests/test_client.py | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 7aba29040be..079a673cbcf 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -5693,18 +5693,27 @@ async def test_futures_of_sorted(c, s, a, b): @gen_cluster(client=True, worker_kwargs={"profile_cycle_interval": "10ms"}) async def test_profile_server(c, s, a, b): - x = c.map(slowinc, range(10), delay=0.01, workers=a.address) - await wait(x) + for i in range(5): + try: + x = c.map(slowinc, range(10), delay=0.01, workers=a.address, pure=False) + await wait(x) - await asyncio.gather( - c.run(slowinc, 1, delay=0.5), c.run_on_scheduler(slowdec, 1, delay=0.5) - ) + await asyncio.gather( + c.run(slowinc, 1, delay=0.5), c.run_on_scheduler(slowdec, 1, delay=0.5) + ) - p = await c.profile(server=True) # All worker servers - assert "slowinc" in str(p) + p = await c.profile(server=True) # All worker servers + assert "slowinc" in str(p) - p = await c.profile(scheduler=True) # Scheduler - assert "slowdec" in str(p) + p = await c.profile(scheduler=True) # Scheduler + assert "slowdec" in str(p) + except AssertionError: + if i == 4: + raise + else: + pass + else: + break @gen_cluster(client=True) From 1e634e8da244db325007fb9e101b62bf2fb634cc Mon Sep 17 00:00:00 2001 From: Mana Borwornpadungkitti Date: Tue, 31 Dec 2019 23:53:02 +0700 Subject: [PATCH 0616/1550] Avoid setting event loop policy if within IPython kernel and no running event loop (#3336) Setting asyncio event loop policy at these two places could cause problems. 1. When policy is set in Jupyter notebook server extension. This causes the notebook server to hang. This is fixed in https://github.com/dask/distributed/pull/2343. 2. When policy is set in iPython startup config (`~/.ipython/profile_default/startup/whatever.py`) or by setting `get_config().InteractiveShellApp.exec_lines` in `~/.ipython/profile_default/ipython_config.py`. This causes the kernel to hang. This can be reproduced by running either `jupyter console` or `jupyter notebook`. If running In Jupyter notebook, it will struck at "Kernel starting, please wait...". Note that manually setting the policy in notebook cell, after the kernel has started, is fine. In both cases, running `asyncio.get_running_loop()` just before setting the policy will raise `RuntimeError`, meaning there is no running event loop yet. See https://github.com/dask/distributed/issues/3202. --- distributed/utils.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/distributed/utils.py b/distributed/utils.py index d268e0b54b0..e4884e4d16c 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -1177,25 +1177,33 @@ def reset_logger_locks(): handler.createLock() -# Only bother if asyncio has been loaded by Tornado -if "asyncio" in sys.modules and tornado.version_info[0] >= 5: +if tornado.version_info[0] >= 5: - jupyter_event_loop_initialized = False + is_server_extension = False if "notebook" in sys.modules: import traitlets from notebook.notebookapp import NotebookApp - jupyter_event_loop_initialized = traitlets.config.Application.initialized() and isinstance( + is_server_extension = traitlets.config.Application.initialized() and isinstance( traitlets.config.Application.instance(), NotebookApp ) - if not jupyter_event_loop_initialized: - import tornado.platform.asyncio + if not is_server_extension: + is_kernel_and_no_running_loop = False - asyncio.set_event_loop_policy( - tornado.platform.asyncio.AnyThreadEventLoopPolicy() - ) + if is_kernel(): + try: + asyncio.get_running_loop() + except RuntimeError: + is_kernel_and_no_running_loop = True + + if not is_kernel_and_no_running_loop: + import tornado.platform.asyncio + + asyncio.set_event_loop_policy( + tornado.platform.asyncio.AnyThreadEventLoopPolicy() + ) @functools.lru_cache(1000) From fd451210617787a3bd77c370a2e113ad22fdab24 Mon Sep 17 00:00:00 2001 From: Markus Mohrhard Date: Thu, 2 Jan 2020 00:23:02 +0800 Subject: [PATCH 0617/1550] Avoid calling nbytes multiple times when sending data (#3349) --- distributed/comm/tcp.py | 20 ++++---------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/distributed/comm/tcp.py b/distributed/comm/tcp.py index c2f3feeb704..77876a04fbc 100644 --- a/distributed/comm/tcp.py +++ b/distributed/comm/tcp.py @@ -12,7 +12,6 @@ ssl = None import dask -import tornado from tornado import netutil from tornado.iostream import StreamClosedError, IOStream from tornado.tcpclient import TCPClient @@ -20,15 +19,7 @@ from ..system import MEMORY_LIMIT from ..threadpoolexecutor import ThreadPoolExecutor -from ..utils import ( - ensure_bytes, - ensure_ip, - get_ip, - get_ipv6, - nbytes, - parse_timedelta, - shutting_down, -) +from ..utils import ensure_ip, get_ip, get_ipv6, nbytes, parse_timedelta, shutting_down from .registry import Backend, backends from .addressing import parse_host_port, unparse_host_port @@ -141,7 +132,6 @@ class TCP(Comm): An established communication based on an underlying Tornado IOStream. """ - _iostream_allows_memoryview = tornado.version_info >= (4, 5) # IOStream.read_into() currently proposed in # https://github.com/tornadoweb/tornado/pull/2193 _iostream_has_read_into = hasattr(IOStream, "read_into") @@ -251,14 +241,12 @@ async def write(self, msg, serializers=None, on_error="message"): else: stream.write(b"".join(length_bytes)) # avoid large memcpy, send in many - for frame in frames: + for frame, frame_bytes in zip(frames, lengths): # Can't wait for the write() Future as it may be lost # ("If write is called again before that Future has resolved, # the previous future will be orphaned and will never resolve") - if not self._iostream_allows_memoryview: - frame = ensure_bytes(frame) future = stream.write(frame) - bytes_since_last_yield += nbytes(frame) + bytes_since_last_yield += frame_bytes if bytes_since_last_yield > 32e6: await future bytes_since_last_yield = 0 @@ -271,7 +259,7 @@ async def write(self, msg, serializers=None, on_error="message"): else: raise - return sum(map(nbytes, frames)) + return sum(lengths) @gen.coroutine def close(self): From 71a8f4551f362a901d1911628c30197f3afccab1 Mon Sep 17 00:00:00 2001 From: Benedikt Reinartz Date: Mon, 6 Jan 2020 18:30:14 +0100 Subject: [PATCH 0618/1550] Fix failures on mixed integer/string worker names (#3352) --- distributed/dashboard/components/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/dashboard/components/scheduler.py b/distributed/dashboard/components/scheduler.py index 9319937346c..8eec6b8b772 100644 --- a/distributed/dashboard/components/scheduler.py +++ b/distributed/dashboard/components/scheduler.py @@ -1625,7 +1625,7 @@ def __init__(self, scheduler, width=800, **kwargs): def update(self): data = {name: [] for name in self.names + self.extra_names} for i, (addr, ws) in enumerate( - sorted(self.scheduler.workers.items(), key=lambda kv: kv[1].name) + sorted(self.scheduler.workers.items(), key=lambda kv: str(kv[1].name)) ): for name in self.names + self.extra_names: data[name].append(ws.metrics.get(name, None)) From 465c9975d922efd38c4c06e886bda51654ddfe04 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Mon, 6 Jan 2020 16:26:20 -0600 Subject: [PATCH 0619/1550] Return task in dask-worker on_signal function (#3354) --- distributed/cli/dask_worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/cli/dask_worker.py b/distributed/cli/dask_worker.py index 9070024c430..a252fe1a232 100755 --- a/distributed/cli/dask_worker.py +++ b/distributed/cli/dask_worker.py @@ -398,7 +398,7 @@ def on_signal(signum): signal_fired = True if signum != signal.SIGINT: logger.info("Exiting on signal %d", signum) - asyncio.ensure_future(close_all()) + return asyncio.ensure_future(close_all()) async def run(): await asyncio.gather(*nannies) From f68119a49b6090b56c8fd5bebbc7def4090d463e Mon Sep 17 00:00:00 2001 From: Jacob Tomlinson Date: Tue, 7 Jan 2020 16:54:26 +0000 Subject: [PATCH 0620/1550] Add websocket scheduler plugin (#3335) --- distributed/dashboard/scheduler.py | 32 +++++++++++++++++ distributed/diagnostics/websocket.py | 52 ++++++++++++++++++++++++++++ 2 files changed, 84 insertions(+) create mode 100644 distributed/diagnostics/websocket.py diff --git a/distributed/dashboard/scheduler.py b/distributed/dashboard/scheduler.py index cb96389344e..6a52063b9f1 100644 --- a/distributed/dashboard/scheduler.py +++ b/distributed/dashboard/scheduler.py @@ -1,5 +1,6 @@ from datetime import datetime from functools import partial +import json import logging import dask @@ -11,6 +12,7 @@ from toolz import merge, merge_with from tornado import escape +from tornado.websocket import WebSocketHandler try: import numpy as np @@ -45,6 +47,7 @@ from .worker import counters_doc from .proxy import GlobalProxyHandler from .utils import RequestHandler, redirect +from ..diagnostics.websocket import WebsocketPlugin from ..utils import log_errors, format_time from ..scheduler import ALL_TASK_STATES @@ -319,6 +322,34 @@ def get(self): self.set_header("Content-Type", "text/plain") +class EventstreamHandler(WebSocketHandler): + def initialize(self, server=None, extra=None): + self.server = server + self.extra = extra or {} + self.plugin = WebsocketPlugin(self, server) + self.server.add_plugin(self.plugin) + + def send(self, name, data): + data["name"] = name + for k in list(data): + # Drop bytes objects for now + if isinstance(data[k], bytes): + del data[k] + self.write_message(data) + + def open(self): + for worker in self.server.workers: + self.plugin.add_worker(self.server, worker) + + def on_message(self, message): + message = json.loads(message) + if message["name"] == "ping": + self.send("pong", {"timestamp": str(datetime.now())}) + + def on_close(self): + self.server.remove_plugin(self.plugin) + + routes = [ (r"info", redirect("info/main/workers.html")), (r"info/main/workers.html", Workers), @@ -334,6 +365,7 @@ def get(self): (r"individual-plots.json", IndividualPlots), (r"metrics", PrometheusHandler), (r"health", HealthHandler), + (r"eventstream", EventstreamHandler), (r"proxy/(\d+)/(.*?)/(.*)", GlobalProxyHandler), ] diff --git a/distributed/diagnostics/websocket.py b/distributed/diagnostics/websocket.py new file mode 100644 index 00000000000..6682dd6a739 --- /dev/null +++ b/distributed/diagnostics/websocket.py @@ -0,0 +1,52 @@ +from .plugin import SchedulerPlugin +from ..utils import key_split +from .task_stream import colors + + +class WebsocketPlugin(SchedulerPlugin): + def __init__(self, socket, scheduler): + self.socket = socket + self.scheduler = scheduler + + def restart(self, scheduler, **kwargs): + """ Run when the scheduler restarts itself """ + self.socket.send("restart", {}) + + def add_worker(self, scheduler=None, worker=None, **kwargs): + """ Run when a new worker enters the cluster """ + self.socket.send("add_worker", {"worker": worker}) + + def remove_worker(self, scheduler=None, worker=None, **kwargs): + """ Run when a worker leaves the cluster""" + self.socket.send("remove_worker", {"worker": worker}) + + def transition(self, key, start, finish, *args, **kwargs): + """ Run whenever a task changes state + + Parameters + ---------- + key: string + start: string + Start state of the transition. + One of released, waiting, processing, memory, error. + finish: string + Final state of the transition. + *args, **kwargs: More options passed when transitioning + This may include worker ID, compute time, etc. + """ + if key not in self.scheduler.tasks: + return + kwargs["key"] = key + startstops = kwargs.get("startstops", []) + for startstop in startstops: + color = colors[startstop["action"]] + if type(color) is not str: + color = color(kwargs) + data = { + "key": key, + "name": key_split(key), + "color": color, + **kwargs, + **startstop, + } + self.socket.send("transition", data) From 32cb96effe9287fdc42dd73be36989be13abe99d Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 8 Jan 2020 18:38:21 -0800 Subject: [PATCH 0621/1550] Rework version checking (#2627) This creates automatic version checks whenever a worker or client joins the cluster and raises informative errors letting users know about the mismatched versions. --- distributed/client.py | 55 ++++++------- distributed/deploy/tests/test_adaptive.py | 30 +++---- distributed/scheduler.py | 97 ++++++++++++++--------- distributed/tests/test_client.py | 38 ++++----- distributed/tests/test_collections.py | 12 +-- distributed/tests/test_failed_workers.py | 12 +-- distributed/tests/test_scheduler.py | 41 ++++++---- distributed/tests/test_worker.py | 5 +- distributed/utils.py | 10 +++ distributed/versions.py | 58 +++++++++++--- distributed/worker.py | 21 ++++- 11 files changed, 233 insertions(+), 146 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 5ff715281ef..451a6628e73 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -77,7 +77,6 @@ log_errors, str_graph, key_split, - asciitable, thread_state, no_default, PeriodicCallback, @@ -88,7 +87,7 @@ has_keyword, format_dashboard_link, ) -from .versions import get_versions +from . import versions as version_module logger = logging.getLogger(__name__) @@ -1050,7 +1049,12 @@ async def _ensure_connected(self, timeout=None): else: await self._update_scheduler_info() await comm.write( - {"op": "register-client", "client": self.id, "reply": False} + { + "op": "register-client", + "client": self.id, + "reply": False, + "versions": version_module.get_versions(), + } ) except Exception as e: if self.status == "closed": @@ -1066,6 +1070,9 @@ async def _ensure_connected(self, timeout=None): assert len(msg) == 1 assert msg[0]["op"] == "stream-start" + if msg[0].get("warning"): + warnings.warn(version_module.VersionMismatchWarning(msg[0]["warning"])) + bcomm = BatchedSend(interval="10ms", loop=self.loop) bcomm.start(comm) self.scheduler_comm = bcomm @@ -1249,7 +1256,7 @@ async def _close(self, fast=False): if self.get == dask.config.get("get", None): del dask.config.config["get"] if self.status == "closed": - raise gen.Return() + return if ( self.scheduler_comm @@ -1274,15 +1281,21 @@ async def _close(self, fast=False): and not self.scheduler_comm.comm.closed() ): await self.scheduler_comm.close() + for key in list(self.futures): self._release_key(key=key) + if self._start_arg is None: with ignoring(AttributeError): await self.cluster.close() + await self.rpc.close() + self.status = "closed" + if _get_global_client() is self: _set_global_client(None) + coroutines = set(self.coroutines) for f in self.coroutines: # cancel() works on asyncio futures (Tornado 5) @@ -1292,11 +1305,14 @@ async def _close(self, fast=False): if f.cancelled(): coroutines.remove(f) del self.coroutines[:] + if not fast: with ignoring(TimeoutError): await gen.with_timeout(timedelta(seconds=2), list(coroutines)) + with ignoring(AttributeError): await self.scheduler.close_rpc() + self.scheduler = None self.status = "closed" @@ -3551,7 +3567,7 @@ def get_versions(self, check=False, packages=[]): return self.sync(self._get_versions, check=check, packages=packages) async def _get_versions(self, check=False, packages=[]): - client = get_versions(packages=packages) + client = version_module.get_versions(packages=packages) try: scheduler = await self.scheduler.versions(packages=packages) except KeyError: @@ -3565,32 +3581,9 @@ async def _get_versions(self, check=False, packages=[]): result = {"scheduler": scheduler, "workers": workers, "client": client} if check: - # we care about the required & optional packages matching - def to_packages(d): - L = list(d["packages"].values()) - return dict(sum(L, type(L[0])())) - - client_versions = to_packages(result["client"]) - versions = [("scheduler", to_packages(result["scheduler"]))] - versions.extend((w, to_packages(d)) for w, d in sorted(workers.items())) - - mismatched = defaultdict(list) - for name, vers in versions: - for pkg, cv in client_versions.items(): - v = vers.get(pkg, "MISSING") - if cv != v: - mismatched[pkg].append((name, v)) - - if mismatched: - errs = [] - for pkg, versions in sorted(mismatched.items()): - rows = [("client", client_versions[pkg])] - rows.extend(versions) - errs.append("%s\n%s" % (pkg, asciitable(["", "version"], rows))) - - raise ValueError( - "Mismatched versions found\n\n%s" % ("\n\n".join(errs)) - ) + msg = version_module.error_message(scheduler, workers, client) + if msg: + raise ValueError(msg) return result diff --git a/distributed/deploy/tests/test_adaptive.py b/distributed/deploy/tests/test_adaptive.py index 90f56c4bfde..2eddeeceff8 100644 --- a/distributed/deploy/tests/test_adaptive.py +++ b/distributed/deploy/tests/test_adaptive.py @@ -272,33 +272,29 @@ def test_adapt_quickly(): @gen_test(timeout=None) -def test_adapt_down(): +async def test_adapt_down(): """ Ensure that redefining adapt with a lower maximum removes workers """ - cluster = yield LocalCluster( + async with LocalCluster( 0, asynchronous=True, processes=False, scheduler_port=0, silence_logs=False, dashboard_address=None, - ) - client = yield Client(cluster, asynchronous=True) - cluster.adapt(interval="20ms", maximum=5) + ) as cluster: + async with Client(cluster, asynchronous=True) as client: + cluster.adapt(interval="20ms", maximum=5) - try: - futures = client.map(slowinc, range(1000), delay=0.1) - while len(cluster.scheduler.workers) < 5: - yield gen.sleep(0.1) + futures = client.map(slowinc, range(1000), delay=0.1) + while len(cluster.scheduler.workers) < 5: + await gen.sleep(0.1) - cluster.adapt(maximum=2) + cluster.adapt(maximum=2) - start = time() - while len(cluster.scheduler.workers) != 2: - yield gen.sleep(0.1) - assert time() < start + 1 - finally: - yield client.close() - yield cluster.close() + start = time() + while len(cluster.scheduler.workers) != 2: + await gen.sleep(0.1) + assert time() < start + 1 @gen_test(timeout=30) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index c6b12e14e93..66827935b19 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -63,6 +63,7 @@ ) from .utils_comm import scatter_to_workers, gather_from_workers, retry_operation from .utils_perf import enable_gc_diagnosis, disable_gc_diagnosis +from . import versions as version_module from .publish import PublishExtension from .queues import QueueExtension @@ -115,12 +116,13 @@ class ClientState(object): """ - __slots__ = ("client_key", "wants_what", "last_seen") + __slots__ = ("client_key", "wants_what", "last_seen", "versions") - def __init__(self, client): + def __init__(self, client, versions=None): self.client_key = client self.wants_what = set() self.last_seen = time() + self.versions = versions or {} def __repr__(self): return "" % (self.client_key,) @@ -232,6 +234,7 @@ class WorkerState(object): "status", "time_delay", "used_resources", + "versions", ) def __init__( @@ -243,6 +246,7 @@ def __init__( memory_limit=0, local_directory=None, services=None, + versions=None, nanny=None, extra=None, ): @@ -253,6 +257,7 @@ def __init__( self.memory_limit = memory_limit self.local_directory = local_directory self.services = services or {} + self.versions = versions or {} self.nanny = nanny self.status = "running" @@ -1533,33 +1538,30 @@ def heartbeat_worker( host = get_address_host(address) local_now = time() now = now or time() - metrics = metrics or {} + assert metrics host_info = host_info or {} self.host_info[host]["last-seen"] = local_now frac = 1 / len(self.workers) - try: - self.bandwidth = ( - self.bandwidth * (1 - frac) + metrics["bandwidth"]["total"] * frac - ) - for other, (bw, count) in metrics["bandwidth"]["workers"].items(): - if (address, other) not in self.bandwidth_workers: - self.bandwidth_workers[address, other] = bw / count - else: - alpha = (1 - frac) ** count - self.bandwidth_workers[address, other] = self.bandwidth_workers[ - address, other - ] * alpha + bw * (1 - alpha) - for typ, (bw, count) in metrics["bandwidth"]["types"].items(): - if typ not in self.bandwidth_types: - self.bandwidth_types[typ] = bw / count - else: - alpha = (1 - frac) ** count - self.bandwidth_types[typ] = self.bandwidth_types[ - typ - ] * alpha + bw * (1 - alpha) - except KeyError: - pass + self.bandwidth = ( + self.bandwidth * (1 - frac) + metrics["bandwidth"]["total"] * frac + ) + for other, (bw, count) in metrics["bandwidth"]["workers"].items(): + if (address, other) not in self.bandwidth_workers: + self.bandwidth_workers[address, other] = bw / count + else: + alpha = (1 - frac) ** count + self.bandwidth_workers[address, other] = self.bandwidth_workers[ + address, other + ] * alpha + bw * (1 - alpha) + for typ, (bw, count) in metrics["bandwidth"]["types"].items(): + if typ not in self.bandwidth_types: + self.bandwidth_types[typ] = bw / count + else: + alpha = (1 - frac) ** count + self.bandwidth_types[typ] = self.bandwidth_types[typ] * alpha + bw * ( + 1 - alpha + ) ws = self.workers[address] @@ -1603,6 +1605,7 @@ async def add_worker( pid=0, services=None, local_directory=None, + versions=None, nanny=None, extra=None, ): @@ -1624,6 +1627,7 @@ async def add_worker( name=name, local_directory=local_directory, services=services, + versions=versions, nanny=nanny, extra=extra, ) @@ -1699,15 +1703,27 @@ async def add_worker( self.log_event("all", {"action": "add-worker", "worker": address}) logger.info("Register worker %s", ws) + msg = { + "status": "OK", + "time": time(), + "heartbeat-interval": heartbeat_interval(len(self.workers)), + "worker-plugins": self.worker_plugins, + } + + version_warning = version_module.error_message( + version_module.get_versions(), + merge( + {w: ws.versions for w, ws in self.workers.items()}, + {c: cs.versions for c, cs in self.clients.items() if cs.versions}, + ), + versions, + client_name="This Worker", + ) + if version_warning: + msg["warning"] = version_warning + if comm: - await comm.write( - { - "status": "OK", - "time": time(), - "heartbeat-interval": heartbeat_interval(len(self.workers)), - "worker-plugins": self.worker_plugins, - } - ) + await comm.write(msg) await self.handle_worker(comm=comm, worker=address) def update_graph( @@ -2438,7 +2454,7 @@ def report(self, msg, ts=None, client=None): if self.status == "running": logger.critical("Tried writing to closed comm: %s", msg) - async def add_client(self, comm, client=None): + async def add_client(self, comm, client=None, versions=None): """ Add client to network We listen to all future messages from this Comm. @@ -2447,12 +2463,21 @@ async def add_client(self, comm, client=None): comm.name = "Scheduler->Client" logger.info("Receive client connection: %s", client) self.log_event(["all", client], {"action": "add-client", "client": client}) - self.clients[client] = ClientState(client) + self.clients[client] = ClientState(client, versions=versions) + try: bcomm = BatchedSend(interval="2ms", loop=self.loop) bcomm.start(comm) self.client_comms[client] = bcomm - bcomm.send({"op": "stream-start"}) + msg = {"op": "stream-start"} + version_warning = version_module.error_message( + version_module.get_versions(), + {w: ws.versions for w, ws in self.workers.items()}, + versions, + ) + if version_warning: + msg["warning"] = version_warning + bcomm.send(msg) try: await self.handle_stream(comm=comm, extra={"client": client}) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 079a673cbcf..9450f08fd75 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -3724,7 +3724,7 @@ def test_get_versions(c): # that this does not raise v = c.get_versions(packages=["requests"]) - assert dict(v["client"]["packages"]["optional"])["requests"] == requests.__version__ + assert v["client"]["packages"]["requests"] == requests.__version__ @gen_cluster(client=True) @@ -4007,7 +4007,11 @@ def test_retire_many_workers(c, s, *workers): results = yield c.gather(futures) assert results == list(range(100)) + while len(s.workers) != 3: + yield gen.sleep(0.01) + assert len(s.has_what) == len(s.nthreads) == 3 + assert all(future.done() for future in futures) assert all(s.tasks[future.key].state == "memory" for future in futures) for w, keys in s.has_what.items(): @@ -5284,40 +5288,39 @@ def test_client_active_bad_port(): @pytest.mark.parametrize("direct", [True, False]) def test_turn_off_pickle(direct): @gen_cluster() - def test(s, a, b): + async def test(s, a, b): import numpy as np - c = yield Client(s.address, asynchronous=True, serializers=["dask", "msgpack"]) - try: - assert (yield c.submit(inc, 1)) == 2 - yield c.submit(np.ones, 5) - yield c.scatter(1) + async with Client( + s.address, asynchronous=True, serializers=["dask", "msgpack"] + ) as c: + assert (await c.submit(inc, 1)) == 2 + await c.submit(np.ones, 5) + await c.scatter(1) # Can't send complex data with pytest.raises(TypeError): - future = yield c.scatter(inc) + future = await c.scatter(inc) # can send complex tasks (this uses pickle regardless) future = c.submit(lambda x: x, inc) - yield wait(future) + await wait(future) # but can't receive complex results with pytest.raises(TypeError): - yield c.gather(future, direct=direct) + await c.gather(future, direct=direct) # Run works - result = yield c.run(lambda: 1) + result = await c.run(lambda: 1) assert list(result.values()) == [1, 1] - result = yield c.run_on_scheduler(lambda: 1) + result = await c.run_on_scheduler(lambda: 1) assert result == 1 # But not with complex return values with pytest.raises(TypeError): - yield c.run(lambda: inc) + await c.run(lambda: inc) with pytest.raises(TypeError): - yield c.run_on_scheduler(lambda: inc) - finally: - yield c.close() + await c.run_on_scheduler(lambda: inc) test() @@ -5697,7 +5700,6 @@ async def test_profile_server(c, s, a, b): try: x = c.map(slowinc, range(10), delay=0.01, workers=a.address, pure=False) await wait(x) - await asyncio.gather( c.run(slowinc, 1, delay=0.5), c.run_on_scheduler(slowdec, 1, delay=0.5) ) @@ -5821,7 +5823,7 @@ async def ff(): assert c.sync(ff) == 1 -@pytest.mark.xfail(reason="known intermittent failure") +@pytest.mark.skip(reason="known intermittent failure") @gen_cluster(client=True) async def test_dont_hold_on_to_large_messages(c, s, a, b): np = pytest.importorskip("numpy") diff --git a/distributed/tests/test_collections.py b/distributed/tests/test_collections.py index 7fe8467b14b..0843d711761 100644 --- a/distributed/tests/test_collections.py +++ b/distributed/tests/test_collections.py @@ -109,19 +109,19 @@ def test_bag_groupby_tasks_default(c, s, a, b): @pytest.mark.parametrize("wait", [wait, lambda x: None]) def test_dataframe_set_index_sync(wait, client): - df = dd.demo.make_timeseries( - "2000", - "2001", - {"value": float, "name": str, "id": int}, + df = dask.datasets.timeseries( + start="2000", + end="2001", + dtypes={"value": float, "name": str, "id": int}, freq="2H", partition_freq="1M", seed=1, ) - df = client.persist(df) + df = df.persist() wait(df) df2 = df.set_index("name", shuffle="tasks") - df2 = client.persist(df2) + df2 = df2.persist() assert len(df2) diff --git a/distributed/tests/test_failed_workers.py b/distributed/tests/test_failed_workers.py index 27bce439da4..8f790edf20e 100644 --- a/distributed/tests/test_failed_workers.py +++ b/distributed/tests/test_failed_workers.py @@ -415,8 +415,9 @@ def test_restart_timeout_on_long_running_task(c, s, a): assert "timeout" not in text.lower() -@gen_cluster(client=True, scheduler_kwargs={"worker_ttl": "100ms"}) +@gen_cluster(client=True, scheduler_kwargs={"worker_ttl": "500ms"}) def test_worker_time_to_live(c, s, a, b): + assert set(s.workers) == {a.address, b.address} a.periodic_callbacks["heartbeat"].stop() yield gen.sleep(0.010) assert set(s.workers) == {a.address, b.address} @@ -424,13 +425,6 @@ def test_worker_time_to_live(c, s, a, b): start = time() while set(s.workers) == {a.address, b.address}: yield gen.sleep(0.050) - assert time() < start + 1 + assert time() < start + 2 set(s.workers) == {b.address} - - start = time() - while b.status == "running": - yield gen.sleep(0.050) - assert time() < start + 1 - - assert b.status in ("closed", "closing") diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 077c5530260..a8b07ea499e 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -408,8 +408,8 @@ def test_delete(c, s, a): def test_filtered_communication(s, a, b): c = yield connect(s.address) f = yield connect(s.address) - yield c.write({"op": "register-client", "client": "c"}) - yield f.write({"op": "register-client", "client": "f"}) + yield c.write({"op": "register-client", "client": "c", "versions": {}}) + yield f.write({"op": "register-client", "client": "f", "versions": {}}) yield c.read() yield f.read() @@ -942,10 +942,11 @@ def test_worker_arrives_with_processing_data(c, s, a, b): yield w.close() +@pytest.mark.slow @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)]) def test_worker_breaks_and_returns(c, s, a): future = c.submit(slowinc, 1, delay=0.1) - for i in range(10): + for i in range(20): future = c.submit(slowinc, future, delay=0.1) yield wait(future) @@ -957,10 +958,10 @@ def test_worker_breaks_and_returns(c, s, a): yield wait(future, timeout=10) end = time() - assert end - start < 1 + assert end - start < 2 states = frequencies(ts.state for ts in s.tasks.values()) - assert states == {"memory": 1, "released": 10} + assert states == {"memory": 1, "released": 20} @gen_cluster(client=True, nthreads=[]) @@ -1211,14 +1212,15 @@ def test_profile_metadata(c, s, a, b): @gen_cluster(client=True, worker_kwargs={"profile_cycle_interval": 100}) def test_profile_metadata_keys(c, s, a, b): - start = time() - 1 x = c.map(slowinc, range(10), delay=0.05) y = c.map(slowdec, range(10), delay=0.05) yield wait(x + y) meta = yield s.get_profile_metadata(profile_cycle_interval=0.100) assert set(meta["keys"]) == {"slowinc", "slowdec"} - assert len(meta["counts"]) == len(meta["keys"]["slowinc"]) + assert ( + len(meta["counts"]) - 3 <= len(meta["keys"]["slowinc"]) <= len(meta["counts"]) + ) @gen_cluster(client=True) @@ -1501,21 +1503,32 @@ def qux(x): yield f -@gen_cluster(client=True, config={"distributed.scheduler.idle-timeout": "200ms"}) -def test_idle_timeout(c, s, a, b): +@gen_cluster(client=True) +def test_collect_versions(c, s, a, b): + cs = s.clients[c.id] + (w1, w2) = s.workers.values() + assert cs.versions + assert w1.versions + assert w2.versions + assert "dask" in str(cs.versions) + assert cs.versions == w1.versions == w2.versions + + +@gen_cluster(client=True, config={"distributed.scheduler.idle-timeout": "500ms"}) +async def test_idle_timeout(c, s, a, b): future = c.submit(slowinc, 1) - yield future + await future assert s.status != "closed" start = time() while s.status != "closed": - yield gen.sleep(0.01) - assert time() < start + 3 + await gen.sleep(0.01) + assert time() < start + 3 start = time() while not (a.status == "closed" and b.status == "closed"): - yield gen.sleep(0.01) + await gen.sleep(0.01) assert time() < start + 1 @@ -1851,7 +1864,7 @@ def inc_slow(x): # need to sleep for at least 0.5 seconds to give the worker a chance to # reconnect (Heartbeat timing) if x in ALREADY_CALCULATED: - time.sleep(0.5) + time.sleep(1) ALREADY_CALCULATED.append(x) return x + 1 diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 9cf32c1eec7..df886b9431a 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1168,9 +1168,10 @@ def test_statistical_profiling_cycle(c, s, a, b): x = yield a.get_profile(start=time() + 10, stop=time() + 20) assert not x["count"] - x = yield a.get_profile(start=0, stop=time()) + x = yield a.get_profile(start=0, stop=time() + 10) + recent = a.profile_recent["count"] actual = sum(p["count"] for _, p in a.profile_history) + a.profile_recent["count"] - x2 = yield a.get_profile(start=0, stop=time()) + x2 = yield a.get_profile(start=0, stop=time() + 10) assert x["count"] <= actual <= x2["count"] y = yield a.get_profile(start=end - 0.300, stop=time()) diff --git a/distributed/utils.py b/distributed/utils.py index e4884e4d16c..df601291f06 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -81,6 +81,16 @@ def _initialize_mp_context(): preload = ["distributed"] if "pkg_resources" in sys.modules: preload.append("pkg_resources") + + from .versions import required_packages, optional_packages + + for pkg, _ in required_packages + optional_packages: + try: + importlib.import_module(pkg) + except ImportError: + pass + else: + preload.append(pkg) ctx.set_forkserver_preload(preload) return ctx diff --git a/distributed/versions.py b/distributed/versions.py index a769c9ab032..e3a5e5d0f17 100644 --- a/distributed/versions.py +++ b/distributed/versions.py @@ -1,5 +1,8 @@ """ utilities for package version introspection """ +from __future__ import print_function, division, absolute_import + +from collections import defaultdict import platform import struct import os @@ -7,8 +10,6 @@ import locale import importlib -from .utils import ignoring - required_packages = [ ("dask", lambda p: p.__version__), @@ -21,10 +22,7 @@ optional_packages = [ ("numpy", lambda p: p.__version__), - ("pandas", lambda p: p.__version__), - ("bokeh", lambda p: p.__version__), ("lz4", lambda p: p.__version__), - ("dask_ml", lambda p: p.__version__), ("blosc", lambda p: p.__version__), ] @@ -38,11 +36,11 @@ def get_versions(packages=None): d = { "host": get_system_info(), - "packages": { - "required": get_package_info(required_packages), - "optional": get_package_info(optional_packages + list(packages)), - }, + "packages": get_package_info( + required_packages + optional_packages + list(packages) + ), } + return d @@ -66,6 +64,8 @@ def get_system_info(): def version_of_package(pkg): """ Try a variety of common ways to get the version of a package """ + from .utils import ignoring + with ignoring(AttributeError): return pkg.__version__ with ignoring(AttributeError): @@ -96,4 +96,42 @@ def get_package_info(pkgs): except Exception: pversions.append((modname, None)) - return pversions + return dict(pversions) + + +def error_message(scheduler, workers, client, client_name="client"): + # we care about the required & optional packages matching + try: + client_versions = client["packages"] + versions = [("scheduler", scheduler["packages"])] + versions.extend((w, d["packages"]) for w, d in sorted(workers.items())) + except KeyError: + return ( + "Version mismatch for dask.distributed. " + "The scheduler has version >= 1.28.0 " + "but some other component is less than this" + ) + + mismatched = defaultdict(list) + for name, vers in versions: + for pkg, cv in client_versions.items(): + v = vers.get(pkg, "MISSING") + if cv != v: + mismatched[pkg].append((name, v)) + + if mismatched: + from .utils import asciitable + + errs = [] + for pkg, versions in sorted(mismatched.items()): + rows = [(client_name, client_versions[pkg])] + rows.extend(versions) + errs.append("%s\n%s" % (pkg, asciitable(["", "version"], rows))) + + return "Mismatched versions found\n" "\n" "%s" % ("\n\n".join(errs)) + else: + return "" + + +class VersionMismatchWarning(Warning): + """Indicates version mismatch between nodes""" diff --git a/distributed/worker.py b/distributed/worker.py index 716806a67d6..407c381c83f 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -65,6 +65,7 @@ ) from .utils_comm import pack_data, gather_from_workers, retry_operation from .utils_perf import ThrottledGC, enable_gc_diagnosis, disable_gc_diagnosis +from .versions import get_versions logger = logging.getLogger(__name__) @@ -724,6 +725,10 @@ def __repr__(self): ) ) + @property + def logs(self): + return self._deque_handler.deque + @property def worker_address(self): """ For API compatibility with Nanny """ @@ -797,7 +802,6 @@ async def _register_with_scheduler(self): while True: try: _start = time() - types = {k: typename(v) for k, v in self.data.items()} comm = await connect( self.scheduler.address, connection_args=self.connection_args ) @@ -812,7 +816,7 @@ async def _register_with_scheduler(self): nthreads=self.nthreads, name=self.name, nbytes=self.nbytes, - types=types, + types={k: typename(v) for k, v in self.data.items()}, now=time(), resources=self.total_resources, memory_limit=self.memory_limit, @@ -820,13 +824,18 @@ async def _register_with_scheduler(self): services=self.service_ports, nanny=self.nanny, pid=os.getpid(), + versions=get_versions(), metrics=await self.get_metrics(), extra=await self.get_startup_information(), ), serializers=["msgpack"], ) future = comm.read(deserializers=["msgpack"]) + response = await future + if response.get("warning"): + logger.warning(response["warning"]) + _end = time() middle = (_start + _end) / 2 self._update_latency(_end - start) @@ -886,7 +895,13 @@ async def heartbeat(self): self._update_latency(end - start) if response["status"] == "missing": - await self._register_with_scheduler() + for i in range(10): + if self.status != "running": + break + else: + await asyncio.sleep(0.05) + else: + await self._register_with_scheduler() return self.scheduler_delay = response["time"] - middle self.periodic_callbacks["heartbeat"].callback_time = ( From baf5903c6994a286e02c36f9df48c87776b062d7 Mon Sep 17 00:00:00 2001 From: Jacob Tomlinson Date: Fri, 10 Jan 2020 17:04:43 +0000 Subject: [PATCH 0622/1550] Remove locale check that fails on OS X (#3360) --- distributed/versions.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/distributed/versions.py b/distributed/versions.py index e3a5e5d0f17..0b97a6f7ac0 100644 --- a/distributed/versions.py +++ b/distributed/versions.py @@ -7,7 +7,6 @@ import struct import os import sys -import locale import importlib @@ -56,7 +55,6 @@ def get_system_info(): ("byteorder", "%s" % sys.byteorder), ("LC_ALL", "%s" % os.environ.get("LC_ALL", "None")), ("LANG", "%s" % os.environ.get("LANG", "None")), - ("LOCALE", "%s.%s" % locale.getlocale()), ] return host From e7ac25a0e9655cf68f81ee5e30e423b6ab375f4b Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Sat, 11 Jan 2020 11:41:07 -0600 Subject: [PATCH 0623/1550] Add --worker-class option to dask-worker CLI (#3364) --- distributed/cli/dask_worker.py | 18 ++++++++-- distributed/cli/tests/test_dask_worker.py | 43 +++++++++++++++++++++++ 2 files changed, 58 insertions(+), 3 deletions(-) diff --git a/distributed/cli/dask_worker.py b/distributed/cli/dask_worker.py index a252fe1a232..e76bed2e9bc 100755 --- a/distributed/cli/dask_worker.py +++ b/distributed/cli/dask_worker.py @@ -11,7 +11,7 @@ import dask from dask.utils import ignoring from dask.system import CPU_COUNT -from distributed import Nanny, Worker +from distributed import Nanny from distributed.security import Security from distributed.cli.utils import check_python_3, install_signal_handlers from distributed.comm import get_address_host_port @@ -20,7 +20,7 @@ enable_proctitle_on_children, enable_proctitle_on_current, ) -from distributed.utils import deserialize_for_cli +from distributed.utils import deserialize_for_cli, import_term from toolz import valmap from tornado.ioloop import IOLoop, TimeoutError @@ -190,6 +190,13 @@ show_default=True, help="Random amount by which to stagger lifetime values", ) +@click.option( + "--worker-class", + type=str, + default="dask.distributed.Worker", + show_default=True, + help="Worker class used to instantiate workers from.", +) @click.option( "--lifetime-restart/--no-lifetime-restart", "lifetime_restart", @@ -233,6 +240,7 @@ def main( tls_cert, tls_key, dashboard_address, + worker_class, **kwargs ): g0, g1, g2 = gc.get_threshold() # https://github.com/dask/distributed/issues/1653 @@ -339,13 +347,17 @@ def del_pid_file(): loop = IOLoop.current() + worker_class = import_term(worker_class) + if nanny: + kwargs["worker_class"] = worker_class + if nanny: kwargs.update({"worker_port": worker_port, "listen_address": listen_address}) t = Nanny else: if nanny_port: kwargs["service_ports"] = {"nanny": nanny_port} - t = Worker + t = worker_class if ( not scheduler diff --git a/distributed/cli/tests/test_dask_worker.py b/distributed/cli/tests/test_dask_worker.py index c509772d113..767613f2a26 100644 --- a/distributed/cli/tests/test_dask_worker.py +++ b/distributed/cli/tests/test_dask_worker.py @@ -340,3 +340,46 @@ async def test_integer_names(cleanup): await asyncio.sleep(0.01) [ws] = s.workers.values() assert ws.name == 123 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("nanny", ["--nanny", "--no-nanny"]) +async def test_worker_class(cleanup, tmp_path, nanny): + # Create module with custom worker class + WORKER_CLASS_TEXT = """ +from distributed.worker import Worker + +class MyWorker(Worker): + pass +""" + tmpdir = str(tmp_path) + tmpfile = str(tmp_path / "myworker.py") + with open(tmpfile, "w") as f: + f.write(WORKER_CLASS_TEXT) + + # Put module on PYTHONPATH + env = os.environ.copy() + if "PYTHONPATH" in env: + env["PYTHONPATH"] = tmpdir + ":" + env["PYTHONPATH"] + else: + env["PYTHONPATH"] = tmpdir + + async with Scheduler(port=0) as s: + async with Client(s.address, asynchronous=True) as c: + with popen( + [ + "dask-worker", + s.address, + nanny, + "--worker-class", + "myworker.MyWorker", + ], + env=env, + ) as worker: + await c.wait_for_workers(1) + + def worker_type(dask_worker): + return type(dask_worker).__name__ + + worker_types = await c.run(worker_type) + assert all(name == "MyWorker" for name in worker_types.values()) From 44d24098b503d463ae5abfe1683934627973c14c Mon Sep 17 00:00:00 2001 From: byjott Date: Mon, 13 Jan 2020 17:38:21 +0100 Subject: [PATCH 0624/1550] Fix scheduler state in case of worker name collision (#3366) --- distributed/scheduler.py | 24 +++++++++++++----------- distributed/tests/test_scheduler.py | 24 +++++++++++++++++++++++- distributed/worker.py | 2 +- 3 files changed, 37 insertions(+), 13 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 66827935b19..d16026b1c24 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1619,6 +1619,16 @@ async def add_worker( if ws is not None: raise ValueError("Worker already exists %s" % ws) + if name in self.aliases: + msg = { + "status": "error", + "message": "name taken, %s" % name, + "time": time(), + } + if comm: + await comm.write(msg) + return + self.workers[address] = ws = WorkerState( address=address, pid=pid, @@ -1632,16 +1642,6 @@ async def add_worker( extra=extra, ) - if name in self.aliases: - msg = { - "status": "error", - "message": "name taken, %s" % name, - "time": time(), - } - if comm: - await comm.write(msg) - return - if "addresses" not in self.host_info[host]: self.host_info[host].update({"addresses": set(), "nthreads": 0}) @@ -2118,10 +2118,12 @@ def remove_worker(self, comm=None, address=None, safe=False, close=True): with log_errors(): if self.status == "closed": return + + address = self.coerce_address(address) + if address not in self.workers: return "already-removed" - address = self.coerce_address(address) host = get_address_host(address) ws = self.workers[address] diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index a8b07ea499e..c493c41dd21 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -217,6 +217,15 @@ def test_remove_worker_from_scheduler(s, a, b): s.validate_state() +@gen_cluster() +def test_remove_worker_by_name_from_scheduler(s, a, b): + assert a.address in s.stream_comms + assert s.remove_worker(address=a.name) == "OK" + assert a.address not in s.nthreads + assert s.remove_worker(address=a.address) == "already-removed" + s.validate_state() + + @gen_cluster(config={"distributed.scheduler.events-cleanup-delay": "10 ms"}) def test_clear_events_worker_removal(s, a, b): assert a.address in s.events @@ -566,7 +575,7 @@ def test_coerce_address(): "tcp://127.0.0.1:8000", "tcp://[::1]:8000", ) - assert s.coerce_address(u"localhost:8000") in ( + assert s.coerce_address("localhost:8000") in ( "tcp://127.0.0.1:8000", "tcp://[::1]:8000", ) @@ -1966,3 +1975,16 @@ async def test_multiple_listeners(cleanup): log = log.getvalue() assert re.search(r"Scheduler at:\s*tcp://", log) assert re.search(r"Scheduler at:\s*inproc://", log) + + +@gen_cluster(nthreads=[("127.0.0.1", 1)]) +async def test_worker_name_collision(s, a): + # test that a name collision for workers produces the expected respsone + # and leaves the data structures of Scheduler in a good state + # is not updated by the second worker + with pytest.raises(ValueError, match=f"name taken, {a.name!r}"): + await Worker(s.address, name=a.name, loop=s.loop, host="127.0.0.1") + + s.validate_state() + assert set(s.workers) == {a.address} + assert s.aliases == {a.name: a.address} diff --git a/distributed/worker.py b/distributed/worker.py index 407c381c83f..f4a662ce44f 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1096,7 +1096,7 @@ async def close( for pc in self.periodic_callbacks.values(): pc.stop() with ignoring(EnvironmentError, gen.TimeoutError): - if report: + if report and self.contact_address is not None: await gen.with_timeout( timedelta(seconds=timeout), self.scheduler.unregister( From d5864834af508016fa272db882829a4cbdd74dd0 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Mon, 13 Jan 2020 12:48:08 -0600 Subject: [PATCH 0625/1550] Close connection comm on retry (#3365) --- distributed/comm/core.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/distributed/comm/core.py b/distributed/comm/core.py index 11f74a1aba8..42c95e3579e 100644 --- a/distributed/comm/core.py +++ b/distributed/comm/core.py @@ -1,11 +1,9 @@ from abc import ABC, abstractmethod, abstractproperty import asyncio -from datetime import timedelta import logging import weakref import dask -from tornado import gen from ..metrics import time from ..utils import parse_timedelta, ignoring @@ -211,11 +209,9 @@ def _raise(error): future = connector.connect( loc, deserialize=deserialize, **(connection_args or {}) ) - with ignoring(gen.TimeoutError): - comm = await gen.with_timeout( - timedelta(seconds=min(deadline - time(), 1)), - future, - quiet_exceptions=EnvironmentError, + with ignoring(asyncio.TimeoutError): + comm = await asyncio.wait_for( + future, timeout=min(deadline - time(), 1) ) break if not comm: From 8472a0371ef9ddab5a49e089ef88d2fd16036448 Mon Sep 17 00:00:00 2001 From: Jacob Tomlinson Date: Tue, 14 Jan 2020 11:23:31 +0000 Subject: [PATCH 0626/1550] Add cluster map dashboard (#3361) * New diagram style dashboard * Minor improvements * Switch greensock for anime, refactor and apply prettier * Move sending to the socket handler * Add websocket scheduler plugin * Rename eventstream endpoint and remove incorrect type hint * Switch to event stream * Revert conftest * Simplify using anime.js timelines * Tidy up code, fox reset bugs and reconnect websocket * Add title and meta * Check status code instead of looking for body content --- distributed/dashboard/scheduler.py | 11 +- .../static/css/individual-cluster-map.css | 54 +++ .../static/individual-cluster-map.html | 27 ++ distributed/dashboard/static/js/anime.min.js | 8 + .../static/js/individual-cluster-map.js | 367 ++++++++++++++++++ .../static/js/reconnecting-websocket.min.js | 8 + .../dashboard/tests/test_scheduler_bokeh.py | 2 +- 7 files changed, 475 insertions(+), 2 deletions(-) create mode 100644 distributed/dashboard/static/css/individual-cluster-map.css create mode 100644 distributed/dashboard/static/individual-cluster-map.html create mode 100644 distributed/dashboard/static/js/anime.min.js create mode 100644 distributed/dashboard/static/js/individual-cluster-map.js create mode 100644 distributed/dashboard/static/js/reconnecting-websocket.min.js diff --git a/distributed/dashboard/scheduler.py b/distributed/dashboard/scheduler.py index 6a52063b9f1..2c0520161b3 100644 --- a/distributed/dashboard/scheduler.py +++ b/distributed/dashboard/scheduler.py @@ -1,5 +1,7 @@ from datetime import datetime from functools import partial +import os +import os.path import json import logging @@ -234,11 +236,18 @@ def get(self): class IndividualPlots(RequestHandler): def get(self): bokeh_server = self.server.services["dashboard"] - result = { + individual_bokeh = { uri.strip("/").replace("-", " ").title(): uri for uri in bokeh_server.apps if uri.lstrip("/").startswith("individual-") and not uri.endswith(".json") } + individual_static = { + uri.strip("/").replace(".html", "").replace("-", " ").title(): "/statics/" + + uri + for uri in os.listdir(os.path.join(os.path.dirname(__file__), "static")) + if uri.lstrip("/").startswith("individual-") and uri.endswith(".html") + } + result = {**individual_bokeh, **individual_static} self.write(result) diff --git a/distributed/dashboard/static/css/individual-cluster-map.css b/distributed/dashboard/static/css/individual-cluster-map.css new file mode 100644 index 00000000000..e45b735dd31 --- /dev/null +++ b/distributed/dashboard/static/css/individual-cluster-map.css @@ -0,0 +1,54 @@ +#vis { + height: 100%; +} + +.node { + stroke: black; + stroke-width: 3px; + stroke-linejoin: round; + filter: drop-shadow(-1px 1px 2px rgba(0, 0, 0, .4)); +} + +.worker { + fill: #ECB172; +} + +.scheduler { + fill: #c5c1ff +} + +.projectile { + stroke-width: 4; + stroke-linecap: round; + fill: transparent; + stroke-dasharray: 20 1000; + stroke-dashoffset: 0; + animation: shoot 0.5s linear infinite; +} + +.transfer { + stroke-width: 4; + stroke-linecap: round; + fill: transparent; + stroke-dasharray: 10; + stroke-dashoffset: 10; + animation: transfer 1s linear infinite; +} + +@keyframes transfer { + from { + stroke-dashoffset: -100; + } + to { + stroke-dashoffset: 100; + } +} + +@keyframes shoot { + from { + stroke-dashoffset: 0; + } + to { + stroke-dashoffset: 1000; + } +} \ No newline at end of file diff --git a/distributed/dashboard/static/individual-cluster-map.html b/distributed/dashboard/static/individual-cluster-map.html new file mode 100644 index 00000000000..20ab9d53f5b --- /dev/null +++ b/distributed/dashboard/static/individual-cluster-map.html @@ -0,0 +1,27 @@ + + + Dask: Cluster Map + + + + + + + + + + + + + + diff --git a/distributed/dashboard/static/js/anime.min.js b/distributed/dashboard/static/js/anime.min.js new file mode 100644 index 00000000000..99b263aaebc --- /dev/null +++ b/distributed/dashboard/static/js/anime.min.js @@ -0,0 +1,8 @@ +/* + * anime.js v3.1.0 + * (c) 2019 Julian Garnier + * Released under the MIT license + * animejs.com + */ + +!function(n,e){"object"==typeof exports&&"undefined"!=typeof module?module.exports=e():"function"==typeof define&&define.amd?define(e):n.anime=e()}(this,function(){"use strict";var n={update:null,begin:null,loopBegin:null,changeBegin:null,change:null,changeComplete:null,loopComplete:null,complete:null,loop:1,direction:"normal",autoplay:!0,timelineOffset:0},e={duration:1e3,delay:0,endDelay:0,easing:"easeOutElastic(1, .5)",round:0},r=["translateX","translateY","translateZ","rotate","rotateX","rotateY","rotateZ","scale","scaleX","scaleY","scaleZ","skew","skewX","skewY","perspective"],t={CSS:{},springs:{}};function a(n,e,r){return Math.min(Math.max(n,e),r)}function o(n,e){return n.indexOf(e)>-1}function u(n,e){return n.apply(null,e)}var i={arr:function(n){return Array.isArray(n)},obj:function(n){return o(Object.prototype.toString.call(n),"Object")},pth:function(n){return i.obj(n)&&n.hasOwnProperty("totalLength")},svg:function(n){return n instanceof SVGElement},inp:function(n){return n instanceof HTMLInputElement},dom:function(n){return n.nodeType||i.svg(n)},str:function(n){return"string"==typeof n},fnc:function(n){return"function"==typeof n},und:function(n){return void 0===n},hex:function(n){return/(^#[0-9A-F]{6}$)|(^#[0-9A-F]{3}$)/i.test(n)},rgb:function(n){return/^rgb/.test(n)},hsl:function(n){return/^hsl/.test(n)},col:function(n){return i.hex(n)||i.rgb(n)||i.hsl(n)},key:function(r){return!n.hasOwnProperty(r)&&!e.hasOwnProperty(r)&&"targets"!==r&&"keyframes"!==r}};function c(n){var e=/\(([^)]+)\)/.exec(n);return e?e[1].split(",").map(function(n){return parseFloat(n)}):[]}function s(n,e){var r=c(n),o=a(i.und(r[0])?1:r[0],.1,100),u=a(i.und(r[1])?100:r[1],.1,100),s=a(i.und(r[2])?10:r[2],.1,100),f=a(i.und(r[3])?0:r[3],.1,100),l=Math.sqrt(u/o),d=s/(2*Math.sqrt(u*o)),p=d<1?l*Math.sqrt(1-d*d):0,h=1,v=d<1?(d*l-f)/p:-f+l;function g(n){var r=e?e*n/1e3:n;return r=d<1?Math.exp(-r*d*l)*(h*Math.cos(p*r)+v*Math.sin(p*r)):(h+v*r)*Math.exp(-r*l),0===n||1===n?n:1-r}return e?g:function(){var e=t.springs[n];if(e)return e;for(var r=0,a=0;;)if(1===g(r+=1/6)){if(++a>=16)break}else a=0;var o=r*(1/6)*1e3;return t.springs[n]=o,o}}function f(n){return void 0===n&&(n=10),function(e){return Math.round(e*n)*(1/n)}}var l,d,p=function(){var n=11,e=1/(n-1);function r(n,e){return 1-3*e+3*n}function t(n,e){return 3*e-6*n}function a(n){return 3*n}function o(n,e,o){return((r(e,o)*n+t(e,o))*n+a(e))*n}function u(n,e,o){return 3*r(e,o)*n*n+2*t(e,o)*n+a(e)}return function(r,t,a,i){if(0<=r&&r<=1&&0<=a&&a<=1){var c=new Float32Array(n);if(r!==t||a!==i)for(var s=0;s=.001?function(n,e,r,t){for(var a=0;a<4;++a){var i=u(e,r,t);if(0===i)return e;e-=(o(e,r,t)-n)/i}return e}(t,l,r,a):0===d?l:function(n,e,r,t,a){for(var u,i,c=0;(u=o(i=e+(r-e)/2,t,a)-n)>0?r=i:e=i,Math.abs(u)>1e-7&&++c<10;);return i}(t,i,i+e,r,a)}}}(),h=(l={linear:function(){return function(n){return n}}},d={Sine:function(){return function(n){return 1-Math.cos(n*Math.PI/2)}},Circ:function(){return function(n){return 1-Math.sqrt(1-n*n)}},Back:function(){return function(n){return n*n*(3*n-2)}},Bounce:function(){return function(n){for(var e,r=4;n<((e=Math.pow(2,--r))-1)/11;);return 1/Math.pow(4,3-r)-7.5625*Math.pow((3*e-2)/22-n,2)}},Elastic:function(n,e){void 0===n&&(n=1),void 0===e&&(e=.5);var r=a(n,1,10),t=a(e,.1,2);return function(n){return 0===n||1===n?n:-r*Math.pow(2,10*(n-1))*Math.sin((n-1-t/(2*Math.PI)*Math.asin(1/r))*(2*Math.PI)/t)}}},["Quad","Cubic","Quart","Quint","Expo"].forEach(function(n,e){d[n]=function(){return function(n){return Math.pow(n,e+2)}}}),Object.keys(d).forEach(function(n){var e=d[n];l["easeIn"+n]=e,l["easeOut"+n]=function(n,r){return function(t){return 1-e(n,r)(1-t)}},l["easeInOut"+n]=function(n,r){return function(t){return t<.5?e(n,r)(2*t)/2:1-e(n,r)(-2*t+2)/2}}}),l);function v(n,e){if(i.fnc(n))return n;var r=n.split("(")[0],t=h[r],a=c(n);switch(r){case"spring":return s(n,e);case"cubicBezier":return u(p,a);case"steps":return u(f,a);default:return u(t,a)}}function g(n){try{return document.querySelectorAll(n)}catch(n){return}}function m(n,e){for(var r=n.length,t=arguments.length>=2?arguments[1]:void 0,a=[],o=0;o1&&(r-=1),r<1/6?n+6*(e-n)*r:r<.5?e:r<2/3?n+(e-n)*(2/3-r)*6:n}if(0==u)e=r=t=i;else{var f=i<.5?i*(1+u):i+u-i*u,l=2*i-f;e=s(l,f,o+1/3),r=s(l,f,o),t=s(l,f,o-1/3)}return"rgba("+255*e+","+255*r+","+255*t+","+c+")"}(n):void 0;var e,r,t,a}function C(n){var e=/[+-]?\d*\.?\d+(?:\.\d+)?(?:[eE][+-]?\d+)?(%|px|pt|em|rem|in|cm|mm|ex|ch|pc|vw|vh|vmin|vmax|deg|rad|turn)?$/.exec(n);if(e)return e[1]}function B(n,e){return i.fnc(n)?n(e.target,e.id,e.total):n}function P(n,e){return n.getAttribute(e)}function I(n,e,r){if(M([r,"deg","rad","turn"],C(e)))return e;var a=t.CSS[e+r];if(!i.und(a))return a;var o=document.createElement(n.tagName),u=n.parentNode&&n.parentNode!==document?n.parentNode:document.body;u.appendChild(o),o.style.position="absolute",o.style.width=100+r;var c=100/o.offsetWidth;u.removeChild(o);var s=c*parseFloat(e);return t.CSS[e+r]=s,s}function T(n,e,r){if(e in n.style){var t=e.replace(/([a-z])([A-Z])/g,"$1-$2").toLowerCase(),a=n.style[e]||getComputedStyle(n).getPropertyValue(t)||"0";return r?I(n,a,r):a}}function D(n,e){return i.dom(n)&&!i.inp(n)&&(P(n,e)||i.svg(n)&&n[e])?"attribute":i.dom(n)&&M(r,e)?"transform":i.dom(n)&&"transform"!==e&&T(n,e)?"css":null!=n[e]?"object":void 0}function E(n){if(i.dom(n)){for(var e,r=n.style.transform||"",t=/(\w+)\(([^)]*)\)/g,a=new Map;e=t.exec(r);)a.set(e[1],e[2]);return a}}function F(n,e,r,t){var a,u=o(e,"scale")?1:0+(o(a=e,"translate")||"perspective"===a?"px":o(a,"rotate")||o(a,"skew")?"deg":void 0),i=E(n).get(e)||u;return r&&(r.transforms.list.set(e,i),r.transforms.last=e),t?I(n,i,t):i}function N(n,e,r,t){switch(D(n,e)){case"transform":return F(n,e,t,r);case"css":return T(n,e,r);case"attribute":return P(n,e);default:return n[e]||0}}function A(n,e){var r=/^(\*=|\+=|-=)/.exec(n);if(!r)return n;var t=C(n)||0,a=parseFloat(e),o=parseFloat(n.replace(r[0],""));switch(r[0][0]){case"+":return a+o+t;case"-":return a-o+t;case"*":return a*o+t}}function L(n,e){if(i.col(n))return O(n);if(/\s/g.test(n))return n;var r=C(n),t=r?n.substr(0,n.length-r.length):n;return e?t+e:t}function j(n,e){return Math.sqrt(Math.pow(e.x-n.x,2)+Math.pow(e.y-n.y,2))}function S(n){for(var e,r=n.points,t=0,a=0;a0&&(t+=j(e,o)),e=o}return t}function q(n){if(n.getTotalLength)return n.getTotalLength();switch(n.tagName.toLowerCase()){case"circle":return o=n,2*Math.PI*P(o,"r");case"rect":return 2*P(a=n,"width")+2*P(a,"height");case"line":return j({x:P(t=n,"x1"),y:P(t,"y1")},{x:P(t,"x2"),y:P(t,"y2")});case"polyline":return S(n);case"polygon":return r=(e=n).points,S(e)+j(r.getItem(r.numberOfItems-1),r.getItem(0))}var e,r,t,a,o}function $(n,e){var r=e||{},t=r.el||function(n){for(var e=n.parentNode;i.svg(e)&&i.svg(e.parentNode);)e=e.parentNode;return e}(n),a=t.getBoundingClientRect(),o=P(t,"viewBox"),u=a.width,c=a.height,s=r.viewBox||(o?o.split(" "):[0,0,u,c]);return{el:t,viewBox:s,x:s[0]/1,y:s[1]/1,w:u/s[2],h:c/s[3]}}function X(n,e){function r(r){void 0===r&&(r=0);var t=e+r>=1?e+r:0;return n.el.getPointAtLength(t)}var t=$(n.el,n.svg),a=r(),o=r(-1),u=r(1);switch(n.property){case"x":return(a.x-t.x)*t.w;case"y":return(a.y-t.y)*t.h;case"angle":return 180*Math.atan2(u.y-o.y,u.x-o.x)/Math.PI}}function Y(n,e){var r=/[+-]?\d*\.?\d+(?:\.\d+)?(?:[eE][+-]?\d+)?/g,t=L(i.pth(n)?n.totalLength:n,e)+"";return{original:t,numbers:t.match(r)?t.match(r).map(Number):[0],strings:i.str(n)||e?t.split(r):[]}}function Z(n){return m(n?y(i.arr(n)?n.map(b):b(n)):[],function(n,e,r){return r.indexOf(n)===e})}function Q(n){var e=Z(n);return e.map(function(n,r){return{target:n,id:r,total:e.length,transforms:{list:E(n)}}})}function V(n,e){var r=x(e);if(/^spring/.test(r.easing)&&(r.duration=s(r.easing)),i.arr(n)){var t=n.length;2===t&&!i.obj(n[0])?n={value:n}:i.fnc(e.duration)||(r.duration=e.duration/t)}var a=i.arr(n)?n:[n];return a.map(function(n,r){var t=i.obj(n)&&!i.pth(n)?n:{value:n};return i.und(t.delay)&&(t.delay=r?0:e.delay),i.und(t.endDelay)&&(t.endDelay=r===a.length-1?e.endDelay:0),t}).map(function(n){return k(n,r)})}function z(n,e){var r=[],t=e.keyframes;for(var a in t&&(e=k(function(n){for(var e=m(y(n.map(function(n){return Object.keys(n)})),function(n){return i.key(n)}).reduce(function(n,e){return n.indexOf(e)<0&&n.push(e),n},[]),r={},t=function(t){var a=e[t];r[a]=n.map(function(n){var e={};for(var r in n)i.key(r)?r==a&&(e.value=n[r]):e[r]=n[r];return e})},a=0;a-1&&(_.splice(o,1),r=_.length)}else a.tick(e);t++}n()}else U=cancelAnimationFrame(U)}return n}();function rn(r){void 0===r&&(r={});var t,o=0,u=0,i=0,c=0,s=null;function f(n){var e=window.Promise&&new Promise(function(n){return s=n});return n.finished=e,e}var l,d,p,h,v,g,y,b,M=(d=w(n,l=r),p=w(e,l),h=z(p,l),v=Q(l.targets),g=W(v,h),y=J(g,p),b=K,K++,k(d,{id:b,children:[],animatables:v,animations:g,duration:y.duration,delay:y.delay,endDelay:y.endDelay}));f(M);function x(){var n=M.direction;"alternate"!==n&&(M.direction="normal"!==n?"normal":"reverse"),M.reversed=!M.reversed,t.forEach(function(n){return n.reversed=M.reversed})}function O(n){return M.reversed?M.duration-n:n}function C(){o=0,u=O(M.currentTime)*(1/rn.speed)}function B(n,e){e&&e.seek(n-e.timelineOffset)}function P(n){for(var e=0,r=M.animations,t=r.length;e2||(b=Math.round(b*p)/p)),h.push(b)}var k=d.length;if(k){g=d[0];for(var O=0;O0&&(M.began=!0,I("begin")),!M.loopBegan&&M.currentTime>0&&(M.loopBegan=!0,I("loopBegin")),d<=r&&0!==M.currentTime&&P(0),(d>=l&&M.currentTime!==e||!e)&&P(e),d>r&&d=e&&(u=0,M.remaining&&!0!==M.remaining&&M.remaining--,M.remaining?(o=i,I("loopComplete"),M.loopBegan=!1,"alternate"===M.direction&&x()):(M.paused=!0,M.completed||(M.completed=!0,I("loopComplete"),I("complete"),!M.passThrough&&"Promise"in window&&(s(),f(M)))))}return M.reset=function(){var n=M.direction;M.passThrough=!1,M.currentTime=0,M.progress=0,M.paused=!0,M.began=!1,M.loopBegan=!1,M.changeBegan=!1,M.completed=!1,M.changeCompleted=!1,M.reversePlayback=!1,M.reversed="reverse"===n,M.remaining=M.loop,t=M.children;for(var e=c=t.length;e--;)M.children[e].reset();(M.reversed&&!0!==M.loop||"alternate"===n&&1===M.loop)&&M.remaining++,P(M.reversed?M.duration:0)},M.set=function(n,e){return R(n,e),M},M.tick=function(n){i=n,o||(o=i),T((i+(u-o))*rn.speed)},M.seek=function(n){T(O(n))},M.pause=function(){M.paused=!0,C()},M.play=function(){M.paused&&(M.completed&&M.reset(),M.paused=!1,_.push(M),C(),U||en())},M.reverse=function(){x(),C()},M.restart=function(){M.reset(),M.play()},M.reset(),M.autoplay&&M.play(),M}function tn(n,e){for(var r=e.length;r--;)M(n,e[r].animatable.target)&&e.splice(r,1)}return"undefined"!=typeof document&&document.addEventListener("visibilitychange",function(){document.hidden?(_.forEach(function(n){return n.pause()}),nn=_.slice(0),rn.running=_=[]):nn.forEach(function(n){return n.play()})}),rn.version="3.1.0",rn.speed=1,rn.running=_,rn.remove=function(n){for(var e=Z(n),r=_.length;r--;){var t=_[r],a=t.animations,o=t.children;tn(e,a);for(var u=o.length;u--;){var i=o[u],c=i.animations;tn(e,c),c.length||i.children.length||o.splice(u,1)}a.length||o.length||t.pause()}},rn.get=N,rn.set=R,rn.convertPx=I,rn.path=function(n,e){var r=i.str(n)?g(n)[0]:n,t=e||100;return function(n){return{property:n,el:r,svg:$(r),totalLength:q(r)*(t/100)}}},rn.setDashoffset=function(n){var e=q(n);return n.setAttribute("stroke-dasharray",e),e},rn.stagger=function(n,e){void 0===e&&(e={});var r=e.direction||"normal",t=e.easing?v(e.easing):null,a=e.grid,o=e.axis,u=e.from||0,c="first"===u,s="center"===u,f="last"===u,l=i.arr(n),d=l?parseFloat(n[0]):parseFloat(n),p=l?parseFloat(n[1]):0,h=C(l?n[1]:n)||0,g=e.start||0+(l?d:0),m=[],y=0;return function(n,e,i){if(c&&(u=0),s&&(u=(i-1)/2),f&&(u=i-1),!m.length){for(var v=0;v-1&&_.splice(o,1);for(var s=0;s this.dashboard.removeChild(worker) + }); + + // Remove worker from list of workers + let index = this.workers.indexOf(id); + if (index > -1) { + this.workers.splice(index, 1); + } + + // Reposition other workers to fill in the gap + this.update_worker_positions(); + } + + connected() { + anime({ + targets: "#scheduler", + opacity: 1, + duration: 1000 + }); + } + + disconnected() { + while (this.workers.length > 0) { + this.remove_worker(this.workers[0]); + } + anime({ + targets: "#scheduler", + opacity: 0.3, + duration: 1000 + }); + } + + update_worker_positions() { + // Calculate a circle around the scheduler and position our workers equally around it + for (var i = 0; i < this.workers.length; i++) { + let θ = (2 * Math.PI * i) / this.workers.length; + let r = 40; + let h = 50; + let k = 50; + let x = h + r * Math.cos(θ); + let y = k + r * Math.sin(θ); + anime({ + targets: "#" + this.workers[i], + r: workerIdleSize, + cx: x + "%", + cy: y + "%", + easing: "easeInOutQuint", + duration: 500 + }); + } + } + + run_task(worker_id, task_name, duration, color) { + let worker = document.getElementById(worker_id); + let scheduler = document.getElementById("scheduler"); + let arc = this.draw_arc(scheduler, worker, color, "projectile"); + + anime + .timeline({ + targets: "#" + worker_id + }) + .add({ + begin: () => this.dashboard.insertBefore(arc, this.schedulerNode) + }) + .add( + { + fill: color, + r: workerBusySize, + begin: () => this.dashboard.removeChild(arc) + }, + 500 + ) + .add({ fill: workerColor, r: workerIdleSize }, "+=" + duration); + } + + run_transfer(start_worker, end_worker, duration) { + start_worker = document.getElementById(start_worker); + end_worker = document.getElementById(end_worker); + duration = Math.max(250, duration / 1000); + let color = "rgba(255, 0, 0, .6)"; + let arc = this.draw_arc(start_worker, end_worker, color, "transfer"); + + anime + .timeline({ + targets: ["#" + start_worker, "#" + end_worker], + duration: 250 + }) + .add({ + fill: color, + r: workerBusySize, + begin: () => this.dashboard.insertBefore(arc, this.schedulerNode) + }) + .add( + { + fill: workerColor, + r: workerIdleSize, + begin: () => this.dashboard.removeChild(arc) + }, + "+=" + duration + ); + } + + run_swap(worker, duration) { + anime + .timeline({ + targets: "#" + worker, + duration: 250 + }) + .add({ + fill: "#D67548", + r: workerBusySize + }) + .add( + { + fill: workerColor, + r: workerIdleSize + }, + "+=" + duration + ); + } + + run_deserialize(worker, duration) { + anime + .timeline({ + targets: "#" + worker, + duration: 250 + }) + .add({ + fill: "gray", + r: workerBusySize + }) + .add( + { + fill: workerColor, + r: workerIdleSize + }, + "+=" + duration + ); + } + + kill_worker(worker) { + anime({ + targets: "#" + worker, + fill: "rgba(0, 0, 0, 1)", + duration: 250 + }); + } + + reset() { + for (var i = 0; i < this.workers.length; i++) { + anime({ + targets: "#" + this.workers[i], + fill: workerColor, + r: workerIdleSize, + duration: 250 + }); + } + } + + calculate_arc(start_x, start_y, end_x, end_y) { + // mid-point of line: + let mpx = (start_x + end_x) * 0.5; + let mpy = (start_y + end_y) * 0.5; + + // angle of perpendicular to line: + let theta = Math.atan2(start_y - end_y, start_x - end_x) - Math.PI / 2; + + // distance of control point from mid-point of line: + let offset = Math.random() * 50; + if (Math.random() >= 0.5) { + offset = -offset; + } + + // location of control point: + let c1x = mpx + offset * Math.cos(theta); + let c1y = mpy + offset * Math.sin(theta); + + // construct the command to draw a quadratic curve + return ( + "M" + + end_x + + " " + + end_y + + " Q " + + c1x + + " " + + c1y + + " " + + start_x + + " " + + start_y + ); + } + + draw_arc(start_element, end_element, color, class_name) { + let curve = this.calculate_arc( + this.getAbsoluteXY(start_element)[0], + this.getAbsoluteXY(start_element)[1], + this.getAbsoluteXY(end_element)[0], + this.getAbsoluteXY(end_element)[1] + ); + + let arc = document.createElementNS("http://www.w3.org/2000/svg", "path"); + arc.setAttributeNS(null, "id", class_name); + arc.setAttributeNS(null, "class", class_name); + arc.setAttributeNS(null, "stroke", color); + arc.setAttribute("d", curve); + return arc; + } + + getAbsoluteXY(element) { + var box = element.getBoundingClientRect(); + var x = box.left + box.width / 4; + var y = box.top + box.height / 4; + return [x, y]; + } +} + +function get_websocket_url(endpoint) { + var l = window.location; + return ( + (l.protocol === "https:" ? "wss://" : "ws://") + + l.hostname + + (l.port != 80 && l.port != 443 ? ":" + l.port : "") + + endpoint + ); +} + +function main() { + dashboard = new Dashboard(); + + var ws = new ReconnectingWebSocket(get_websocket_url("/eventstream")); + ws.onopen = function() { + dashboard.connected(); + }; + ws.onmessage = function(event) { + dashboard.handle_event(JSON.parse(event.data)); + }; + ws.onclose = function() { + dashboard.disconnected(); + }; +} + +window.addEventListener("load", main); diff --git a/distributed/dashboard/static/js/reconnecting-websocket.min.js b/distributed/dashboard/static/js/reconnecting-websocket.min.js new file mode 100644 index 00000000000..b2c6e624ce0 --- /dev/null +++ b/distributed/dashboard/static/js/reconnecting-websocket.min.js @@ -0,0 +1,8 @@ +/* + * reconnecting-websocket.min.js v1.0.0 + * Copyright (c) 2010-2012, Joe Walnes + * Released under the MIT license + * https://github.com/joewalnes/reconnecting-websocket + */ + +!function(a,b){"function"==typeof define&&define.amd?define([],b):"undefined"!=typeof module&&module.exports?module.exports=b():a.ReconnectingWebSocket=b()}(this,function(){function a(b,c,d){function l(a,b){var c=document.createEvent("CustomEvent");return c.initCustomEvent(a,!1,!1,b),c}var e={debug:!1,automaticOpen:!0,reconnectInterval:1e3,maxReconnectInterval:3e4,reconnectDecay:1.5,timeoutInterval:2e3};d||(d={});for(var f in e)this[f]="undefined"!=typeof d[f]?d[f]:e[f];this.url=b,this.reconnectAttempts=0,this.readyState=WebSocket.CONNECTING,this.protocol=null;var h,g=this,i=!1,j=!1,k=document.createElement("div");k.addEventListener("open",function(a){g.onopen(a)}),k.addEventListener("close",function(a){g.onclose(a)}),k.addEventListener("connecting",function(a){g.onconnecting(a)}),k.addEventListener("message",function(a){g.onmessage(a)}),k.addEventListener("error",function(a){g.onerror(a)}),this.addEventListener=k.addEventListener.bind(k),this.removeEventListener=k.removeEventListener.bind(k),this.dispatchEvent=k.dispatchEvent.bind(k),this.open=function(b){h=new WebSocket(g.url,c||[]),b||k.dispatchEvent(l("connecting")),(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","attempt-connect",g.url);var d=h,e=setTimeout(function(){(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","connection-timeout",g.url),j=!0,d.close(),j=!1},g.timeoutInterval);h.onopen=function(){clearTimeout(e),(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","onopen",g.url),g.protocol=h.protocol,g.readyState=WebSocket.OPEN,g.reconnectAttempts=0;var d=l("open");d.isReconnect=b,b=!1,k.dispatchEvent(d)},h.onclose=function(c){if(clearTimeout(e),h=null,i)g.readyState=WebSocket.CLOSED,k.dispatchEvent(l("close"));else{g.readyState=WebSocket.CONNECTING;var d=l("connecting");d.code=c.code,d.reason=c.reason,d.wasClean=c.wasClean,k.dispatchEvent(d),b||j||((g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","onclose",g.url),k.dispatchEvent(l("close")));var e=g.reconnectInterval*Math.pow(g.reconnectDecay,g.reconnectAttempts);setTimeout(function(){g.reconnectAttempts++,g.open(!0)},e>g.maxReconnectInterval?g.maxReconnectInterval:e)}},h.onmessage=function(b){(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","onmessage",g.url,b.data);var c=l("message");c.data=b.data,k.dispatchEvent(c)},h.onerror=function(b){(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","onerror",g.url,b),k.dispatchEvent(l("error"))}},1==this.automaticOpen&&this.open(!1),this.send=function(b){if(h)return(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","send",g.url,b),h.send(b);throw"INVALID_STATE_ERR : Pausing to reconnect websocket"},this.close=function(a,b){"undefined"==typeof a&&(a=1e3),i=!0,h&&h.close(a,b)},this.refresh=function(){h&&h.close()}}return a.prototype.onopen=function(){},a.prototype.onclose=function(){},a.prototype.onconnecting=function(){},a.prototype.onmessage=function(){},a.prototype.onerror=function(){},a.debugAll=!1,a.CONNECTING=WebSocket.CONNECTING,a.OPEN=WebSocket.OPEN,a.CLOSING=WebSocket.CLOSING,a.CLOSED=WebSocket.CLOSED,a}); diff --git a/distributed/dashboard/tests/test_scheduler_bokeh.py b/distributed/dashboard/tests/test_scheduler_bokeh.py index 39d2ce84156..6594ce2142f 100644 --- a/distributed/dashboard/tests/test_scheduler_bokeh.py +++ b/distributed/dashboard/tests/test_scheduler_bokeh.py @@ -679,8 +679,8 @@ def test_https_support(c, s, a, b): url="https://localhost:%d/%s" % (port, suffix), ssl_options=ctx ) response = yield http_client.fetch(req) + assert response.code < 300 body = response.body.decode() - assert "bokeh" in body.lower() assert not re.search("href=./", body) # no absolute links From 7d2ed43c794cf2a00a97f8d0f83b57f028f6be42 Mon Sep 17 00:00:00 2001 From: Jacob Tomlinson Date: Tue, 14 Jan 2020 15:58:30 +0000 Subject: [PATCH 0627/1550] Add client join and leave hooks (#3371) --- distributed/diagnostics/plugin.py | 8 +++++++- distributed/diagnostics/websocket.py | 12 ++++++++++++ distributed/scheduler.py | 12 ++++++++++++ 3 files changed, 31 insertions(+), 1 deletion(-) diff --git a/distributed/diagnostics/plugin.py b/distributed/diagnostics/plugin.py index 8d56679e9a9..4d94f7c8859 100644 --- a/distributed/diagnostics/plugin.py +++ b/distributed/diagnostics/plugin.py @@ -62,7 +62,13 @@ def add_worker(self, scheduler=None, worker=None, **kwargs): """ Run when a new worker enters the cluster """ def remove_worker(self, scheduler=None, worker=None, **kwargs): - """ Run when a worker leaves the cluster""" + """ Run when a worker leaves the cluster """ + + def add_client(self, scheduler=None, client=None, **kwargs): + """ Run when a new client connects """ + + def remove_client(self, scheduler=None, client=None, **kwargs): + """ Run when a client disconnects """ class WorkerPlugin(object): diff --git a/distributed/diagnostics/websocket.py b/distributed/diagnostics/websocket.py index 6682dd6a739..641730faf54 100644 --- a/distributed/diagnostics/websocket.py +++ b/distributed/diagnostics/websocket.py @@ -20,6 +20,18 @@ def remove_worker(self, scheduler=None, worker=None, **kwargs): """ Run when a worker leaves the cluster""" self.socket.send("remove_worker", {"worker": worker}) + def add_client(self, scheduler=None, client=None, **kwargs): + """ Run when a new client connects """ + self.socket.send("add_client", {"client": client}) + + def remove_client(self, scheduler=None, client=None, **kwargs): + """ Run when a client disconnects """ + self.socket.send("remove_client", {"client": client}) + + def update_graph(self, scheduler, client=None, **kwargs): + """ Run when a new graph / tasks enter the scheduler """ + self.socket.send("update_graph", {"client": client}) + def transition(self, key, start, finish, *args, **kwargs): """ Run whenever a task changes state diff --git a/distributed/scheduler.py b/distributed/scheduler.py index d16026b1c24..478eb413422 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -2467,6 +2467,12 @@ async def add_client(self, comm, client=None, versions=None): self.log_event(["all", client], {"action": "add-client", "client": client}) self.clients[client] = ClientState(client, versions=versions) + for plugin in self.plugins[:]: + try: + plugin.add_client(scheduler=self, client=client) + except Exception as e: + logger.exception(e) + try: bcomm = BatchedSend(interval="2ms", loop=self.loop) bcomm.start(comm) @@ -2514,6 +2520,12 @@ def remove_client(self, client=None): ) del self.clients[client] + for plugin in self.plugins[:]: + try: + plugin.remove_client(scheduler=self, client=client) + except Exception as e: + logger.exception(e) + def remove_client_from_events(): # If the client isn't registered anymore after the delay, remove from events if client not in self.clients and client in self.events: From 6a76ca7ac6db9f6243869057b6805f14c23e3811 Mon Sep 17 00:00:00 2001 From: Jacob Tomlinson Date: Thu, 16 Jan 2020 08:39:36 +0000 Subject: [PATCH 0628/1550] Move Windows CI to GitHub Actions (#3373) * Move Windows CI to GitHub Actions * Switch to CMD and conda * Remove indirection and add conda environment file * Fix bad indentation * Remove Python version * Name steps, remove unnecessary conda activations, remove JUnit * Align pytest options with Travis config * Run single test for faster iterating * Hmm pytest is fine. Trying two runners to improve performance * Remove bad -n flag * Add multicore support as GitHub Actions has dual-core runners * Removing multicore again as that had some unexpected results * Try bash instead of powershell --- .github/workflows/ci-windows.yaml | 33 +++++++ appveyor.yml | 37 -------- continuous_integration/build.cmd | 6 -- continuous_integration/environment.yml | 36 ++++++++ continuous_integration/run_tests.cmd | 9 -- continuous_integration/run_with_env.cmd | 90 ------------------- .../setup_conda_environment.cmd | 61 ------------- 7 files changed, 69 insertions(+), 203 deletions(-) create mode 100644 .github/workflows/ci-windows.yaml delete mode 100644 appveyor.yml delete mode 100644 continuous_integration/build.cmd create mode 100644 continuous_integration/environment.yml delete mode 100644 continuous_integration/run_tests.cmd delete mode 100644 continuous_integration/run_with_env.cmd delete mode 100644 continuous_integration/setup_conda_environment.cmd diff --git a/.github/workflows/ci-windows.yaml b/.github/workflows/ci-windows.yaml new file mode 100644 index 00000000000..ecbf29f5a3d --- /dev/null +++ b/.github/workflows/ci-windows.yaml @@ -0,0 +1,33 @@ +name: Windows CI + +on: [push, pull_request] + +jobs: + build: + runs-on: windows-latest + strategy: + matrix: + python-version: ["3.6", "3.7"] + + steps: + - name: Checkout source + uses: actions/checkout@v1 + + - name: Setup Conda Environment + uses: goanpeca/setup-miniconda@v1 + with: + miniconda-version: "latest" + python-version: ${{ matrix.python-version }} + environment-file: continuous_integration/environment.yml + activate-environment: testenv + auto-activate-base: false + + - name: Install distributed from source + shell: bash -l {0} + run: pip install -q --no-deps -e . + + - name: Run tests + shell: bash -l {0} + env: + PYTHONFAULTHANDLER: 1 + run: py.test -m "not avoid_travis" distributed --verbose -r s --timeout-method=thread --timeout=300 --durations=20 diff --git a/appveyor.yml b/appveyor.yml deleted file mode 100644 index e32c48f105a..00000000000 --- a/appveyor.yml +++ /dev/null @@ -1,37 +0,0 @@ -# Environment loosely based on https://github.com/conda/conda/blob/master/appveyor.yml - -environment: - - global: - # SDK v7.0 MSVC Express 2008's SetEnv.cmd script will fail if the - # /E:ON and /V:ON options are not enabled in the batch script intepreter - # See: http://stackoverflow.com/a/13751649/163740 - CMD_IN_ENV: "cmd /E:ON /V:ON /C .\\continuous_integration\\run_with_env.cmd" - JUNIT_OUT: junit-results.xml - - matrix: - # Since appveyor is quite slow, we only use a single configuration - - PYTHON: "3.6" - ARCH: "64" - CONDA_ENV: testenv - -init: - # Use AppVeyor's provided Miniconda: https://www.appveyor.com/docs/installed-software#python - - if "%ARCH%" == "64" set MINICONDA=C:\Miniconda36-x64 - - if "%ARCH%" == "32" set MINICONDA=C:\Miniconda36 - - set PATH=%MINICONDA%;%MINICONDA%/Scripts;%MINICONDA%/Library/bin;%PATH% - -install: - - continuous_integration\\setup_conda_environment.cmd - -build_script: - - continuous_integration\\build.cmd - -test_script: - # %CMD_IN_ENV% is needed for distutils/setuptools-based tests - # on certain build configurations. - - "%CMD_IN_ENV% continuous_integration\\run_tests.cmd" - -on_finish: - - ps: $wc = New-Object 'System.Net.WebClient' - - ps: $wc.UploadFile("https://ci.appveyor.com/api/testresults/junit/$($env:APPVEYOR_JOB_ID)", (Resolve-Path "$($env:JUNIT_OUT)")) diff --git a/continuous_integration/build.cmd b/continuous_integration/build.cmd deleted file mode 100644 index c29c3eafe82..00000000000 --- a/continuous_integration/build.cmd +++ /dev/null @@ -1,6 +0,0 @@ -call activate %CONDA_ENV% - -@echo on - -@rem Install Distributed -%PIP_INSTALL% --no-deps -e . diff --git a/continuous_integration/environment.yml b/continuous_integration/environment.yml new file mode 100644 index 00000000000..f6651254af2 --- /dev/null +++ b/continuous_integration/environment.yml @@ -0,0 +1,36 @@ +name: testenv +channels: + - defaults + - conda-forge +dependencies: + - zstandard + - bokeh + - click + - cloudpickle + - dask + - dill + - lz4 + - ipykernel + - ipywidgets + - joblib + - jupyter_client + - msgpack-python + - prometheus_client + - psutil + - pytest + - requests + - toolz + - tblib + - tornado=5 + - zict + - fsspec + - pip + - pip: + - pytest-repeat + - pytest-timeout + - pytest-faulthandler + - sortedcollections + - pytest-asyncio + - git+https://github.com/dask/dask + - git+https://github.com/joblib/joblib.git + - git+https://github.com/dask/zict diff --git a/continuous_integration/run_tests.cmd b/continuous_integration/run_tests.cmd deleted file mode 100644 index f5ba5680dc2..00000000000 --- a/continuous_integration/run_tests.cmd +++ /dev/null @@ -1,9 +0,0 @@ -call activate %CONDA_ENV% - -@echo on - -set PYTHONFAULTHANDLER=1 - -set PYTEST=py.test --tb=native --timeout=120 -r s - -%PYTEST% -v -m "not avoid_travis" --junit-xml="%JUNIT_OUT%" distributed diff --git a/continuous_integration/run_with_env.cmd b/continuous_integration/run_with_env.cmd deleted file mode 100644 index 3a56e3e840e..00000000000 --- a/continuous_integration/run_with_env.cmd +++ /dev/null @@ -1,90 +0,0 @@ -:: From https://github.com/ogrisel/python-appveyor-demo -:: -:: To build extensions for 64 bit Python 3, we need to configure environment -:: variables to use the MSVC 2010 C++ compilers from GRMSDKX_EN_DVD.iso of: -:: MS Windows SDK for Windows 7 and .NET Framework 4 (SDK v7.1) -:: -:: To build extensions for 64 bit Python 2, we need to configure environment -:: variables to use the MSVC 2008 C++ compilers from GRMSDKX_EN_DVD.iso of: -:: MS Windows SDK for Windows 7 and .NET Framework 3.5 (SDK v7.0) -:: -:: 32 bit builds, and 64-bit builds for 3.5 and beyond, do not require specific -:: environment configurations. -:: -:: Note: this script needs to be run with the /E:ON and /V:ON flags for the -:: cmd interpreter, at least for (SDK v7.0) -:: -:: More details at: -:: https://github.com/cython/cython/wiki/64BitCythonExtensionsOnWindows -:: http://stackoverflow.com/a/13751649/163740 -:: -:: Author: Olivier Grisel -:: License: CC0 1.0 Universal: http://creativecommons.org/publicdomain/zero/1.0/ -:: -:: Notes about batch files for Python people: -:: -:: Quotes in values are literally part of the values: -:: SET FOO="bar" -:: FOO is now five characters long: " b a r " -:: If you don't want quotes, don't include them on the right-hand side. -:: -:: The CALL lines at the end of this file look redundant, but if you move them -:: outside of the IF clauses, they do not run properly in the SET_SDK_64==Y -:: case, I don't know why. -@ECHO OFF - -SET COMMAND_TO_RUN=%* -SET WIN_SDK_ROOT=C:\Program Files\Microsoft SDKs\Windows -SET WIN_WDK=c:\Program Files (x86)\Windows Kits\10\Include\wdf - -:: Extract the major and minor versions, and allow for the minor version to be -:: more than 9. This requires the version number to have two dots in it. -SET MAJOR_PYTHON_VERSION=%PYTHON:~0,1% -IF "%PYTHON:~3,1%" == "." ( - SET MINOR_PYTHON_VERSION=%PYTHON:~2,1% -) ELSE ( - SET MINOR_PYTHON_VERSION=%PYTHON:~2,2% -) - -:: Based on the Python version, determine what SDK version to use, and whether -:: to set the SDK for 64-bit. -IF %MAJOR_PYTHON_VERSION% == 2 ( - SET WINDOWS_SDK_VERSION="v7.0" - SET SET_SDK_64=Y -) ELSE ( - IF %MAJOR_PYTHON_VERSION% == 3 ( - SET WINDOWS_SDK_VERSION="v7.1" - IF %MINOR_PYTHON_VERSION% LEQ 4 ( - SET SET_SDK_64=Y - ) ELSE ( - SET SET_SDK_64=N - IF EXIST "%WIN_WDK%" ( - :: See: https://connect.microsoft.com/VisualStudio/feedback/details/1610302/ - REN "%WIN_WDK%" 0wdf - ) - ) - ) ELSE ( - ECHO Unsupported Python version: "%MAJOR_PYTHON_VERSION%" - EXIT 1 - ) -) - -IF %ARCH% == 64 ( - IF %SET_SDK_64% == Y ( - ECHO Configuring Windows SDK %WINDOWS_SDK_VERSION% for Python %MAJOR_PYTHON_VERSION% on a 64 bit architecture - SET DISTUTILS_USE_SDK=1 - SET MSSdk=1 - "%WIN_SDK_ROOT%\%WINDOWS_SDK_VERSION%\Setup\WindowsSdkVer.exe" -q -version:%WINDOWS_SDK_VERSION% - "%WIN_SDK_ROOT%\%WINDOWS_SDK_VERSION%\Bin\SetEnv.cmd" /x64 /release - ECHO Executing: %COMMAND_TO_RUN% - call %COMMAND_TO_RUN% || EXIT 1 - ) ELSE ( - ECHO Using default MSVC build environment for 64 bit architecture - ECHO Executing: %COMMAND_TO_RUN% - call %COMMAND_TO_RUN% || EXIT 1 - ) -) ELSE ( - ECHO Using default MSVC build environment for 32 bit architecture - ECHO Executing: %COMMAND_TO_RUN% - call %COMMAND_TO_RUN% || EXIT 1 -) diff --git a/continuous_integration/setup_conda_environment.cmd b/continuous_integration/setup_conda_environment.cmd deleted file mode 100644 index 5efc7358dbe..00000000000 --- a/continuous_integration/setup_conda_environment.cmd +++ /dev/null @@ -1,61 +0,0 @@ -@rem The cmd /C hack circumvents a regression where conda installs a conda.bat -@rem script in non-root environments. -set CONDA=cmd /C conda -set CONDA_INSTALL=%CONDA% install -q -y -set PIP_INSTALL=pip install -q - -@echo on - -@rem Deactivate any environment -call deactivate -@rem Update conda -%CONDA% update -q -y conda -@rem Display root environment (for debugging) -%CONDA% list -@rem Clean up any left-over from a previous build -%CONDA% remove --all -q -y -n %CONDA_ENV% - -@rem Create test environment -@rem (note: no cytoolz as it seems to prevent faulthandler tracebacks on crash) -%CONDA% create -n %CONDA_ENV% -q -y ^ - zstandard ^ - bokeh ^ - click ^ - cloudpickle ^ - dask ^ - dill ^ - lz4 ^ - ipykernel ^ - ipywidgets ^ - joblib ^ - jupyter_client ^ - msgpack-python ^ - prometheus_client ^ - psutil ^ - pytest ^ - python=%PYTHON% ^ - requests ^ - toolz ^ - tblib ^ - tornado=5 ^ - zict ^ - fsspec ^ - -c conda-forge - -call activate %CONDA_ENV% - -%CONDA% uninstall -q -y --force dask joblib zict -%PIP_INSTALL% pip --upgrade -%PIP_INSTALL% git+https://github.com/dask/dask --upgrade -%PIP_INSTALL% git+https://github.com/joblib/joblib.git --upgrade -%PIP_INSTALL% git+https://github.com/dask/zict --upgrade - -%PIP_INSTALL% "pytest>=4" pytest-repeat pytest-timeout pytest-faulthandler sortedcollections pytest-asyncio - -@rem Display final environment (for reproducing) -%CONDA% list -%CONDA% list --explicit -where python -where pip -pip list -python -m site From 2999ae08d3f2ce975f05ca66bf7b84604b9d281c Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Thu, 16 Jan 2020 18:08:31 -0600 Subject: [PATCH 0629/1550] bump version to 2.9.2 --- docs/source/changelog.rst | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index 3e4a05fcc79..6410c2aa0b1 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -1,6 +1,25 @@ Changelog ========= +2.9.2 - 2020-01-16 +------------------ + +- Move Windows CI to GitHub Actions (:pr:`3373`) `Jacob Tomlinson`_ +- Add client join and leave hooks (:pr:`3371`) `Jacob Tomlinson`_ +- Add cluster map dashboard (:pr:`3361`) `Jacob Tomlinson`_ +- Close connection comm on retry (:pr:`3365`) `James Bourbeau`_ +- Fix scheduler state in case of worker name collision (:pr:`3366`) `byjott`_ +- Add ``--worker-class`` option to ``dask-worker`` CLI (:pr:`3364`) `James Bourbeau`_ +- Remove ``locale`` check that fails on OS X (:pr:`3360`) `Jacob Tomlinson`_ +- Rework version checking (:pr:`2627`) `Matthew Rocklin`_ +- Add websocket scheduler plugin (:pr:`3335`) `Jacob Tomlinson`_ +- Return task in ``dask-worker`` ``on_signal`` function (:pr:`3354`) `James Bourbeau`_ +- Fix failures on mixed integer/string worker names (:pr:`3352`) `Benedikt Reinartz`_ +- Avoid calling ``nbytes`` multiple times when sending data (:pr:`3349`) `Markus Mohrhard`_ +- Avoid setting event loop policy if within IPython kernel and no running event loop (:pr:`3336`) `Mana Borwornpadungkitti`_ +- Relax intermittent failing ``test_profile_server`` (:pr:`3346`) `Matthew Rocklin`_ + + 2.9.1 - 2019-12-27 ------------------ @@ -1471,3 +1490,6 @@ significantly without many new features. .. _`fjetter`: https://github.com/fjetter .. _`Patrick Sodré`: https://github.com/sodre .. _`Stephan Erb`: https://github.com/StephanErb +.. _`Benedikt Reinartz`: https://github.com/filmor +.. _`Markus Mohrhard`: https://github.com/mmohrhard +.. _`Mana Borwornpadungkitti`: https://github.com/potpath From d88118d25e9751c54c07dd90efa2fe1008a7d6b8 Mon Sep 17 00:00:00 2001 From: Jacob Tomlinson Date: Fri, 17 Jan 2020 14:03:08 +0000 Subject: [PATCH 0630/1550] Get JavaScript document location instead of window and handle proxied url (#3382) --- distributed/dashboard/static/js/individual-cluster-map.js | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/distributed/dashboard/static/js/individual-cluster-map.js b/distributed/dashboard/static/js/individual-cluster-map.js index 20aed45b7b1..57b5d210a65 100644 --- a/distributed/dashboard/static/js/individual-cluster-map.js +++ b/distributed/dashboard/static/js/individual-cluster-map.js @@ -340,12 +340,12 @@ class Dashboard { } function get_websocket_url(endpoint) { - var l = window.location; + var l = document.location; return ( (l.protocol === "https:" ? "wss://" : "ws://") + l.hostname + (l.port != 80 && l.port != 443 ? ":" + l.port : "") + - endpoint + l.pathname.replace("/statics/individual-cluster-map.html", endpoint) ); } From db24547945ffd8ee126f54dd196e486edb3ea66f Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Fri, 17 Jan 2020 10:13:27 -0600 Subject: [PATCH 0631/1550] Fix get_running_loop import (#3383) --- distributed/compatibility.py | 5 +++++ distributed/utils.py | 4 ++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/distributed/compatibility.py b/distributed/compatibility.py index 186e66e485c..33e50e429b8 100644 --- a/distributed/compatibility.py +++ b/distributed/compatibility.py @@ -7,3 +7,8 @@ PYPY = platform.python_implementation().lower() == "pypy" WINDOWS = sys.platform.startswith("win") + +if sys.version_info[:2] >= (3, 7): + from asyncio import get_running_loop +else: + from asyncio import _get_running_loop as get_running_loop # noqa: F401 diff --git a/distributed/utils.py b/distributed/utils.py index df601291f06..15824262ab7 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -56,7 +56,7 @@ except ImportError: PollIOLoop = None # dropped in tornado 6.0 -from .compatibility import PYPY, WINDOWS +from .compatibility import PYPY, WINDOWS, get_running_loop from .metrics import time @@ -1204,7 +1204,7 @@ def reset_logger_locks(): if is_kernel(): try: - asyncio.get_running_loop() + get_running_loop() except RuntimeError: is_kernel_and_no_running_loop = True From 7b3c6e9427082b8f2a87c1e19e055da6f2fb69ff Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Fri, 17 Jan 2020 11:11:15 -0600 Subject: [PATCH 0632/1550] Raise RuntimeError if no running loop (#3385) --- distributed/compatibility.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/distributed/compatibility.py b/distributed/compatibility.py index 33e50e429b8..0dca141e0e9 100644 --- a/distributed/compatibility.py +++ b/distributed/compatibility.py @@ -11,4 +11,11 @@ if sys.version_info[:2] >= (3, 7): from asyncio import get_running_loop else: - from asyncio import _get_running_loop as get_running_loop # noqa: F401 + + def get_running_loop(): + from asyncio import _get_running_loop + + loop = _get_running_loop() + if loop is None: + raise RuntimeError("no running event loop") + return loop From 26007f244911307f2dcc495357f3398e494d1237 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Fri, 17 Jan 2020 11:14:15 -0600 Subject: [PATCH 0633/1550] bump version to 2.9.3 --- docs/source/changelog.rst | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index 6410c2aa0b1..1bb19e70330 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -1,6 +1,14 @@ Changelog ========= +2.9.3 - 2020-01-17 +------------------ + +- Raise ``RuntimeError`` if no running loop (:pr:`3385`) `James Bourbeau`_ +- Fix ``get_running_loop`` import (:pr:`3383`) `James Bourbeau`_ +- Get JavaScript document location instead of window and handle proxied url (:pr:`3382`) `Jacob Tomlinson`_ + + 2.9.2 - 2020-01-16 ------------------ From b464ae3e3abdea6b9577ed3dca5afe9f42efde60 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Fri, 17 Jan 2020 15:25:23 -0600 Subject: [PATCH 0634/1550] Replace gen.with_timeout with asyncio.wait_for (#3372) --- distributed/client.py | 23 +++----- distributed/deploy/cluster.py | 4 +- distributed/deploy/tests/test_local.py | 2 +- distributed/lock.py | 7 +-- distributed/nanny.py | 20 +++---- distributed/node.py | 2 +- distributed/process.py | 4 +- distributed/scheduler.py | 2 +- distributed/tests/test_batched.py | 10 ++-- distributed/tests/test_client.py | 8 +-- distributed/tests/test_client_executor.py | 67 +++++++++++++---------- distributed/tests/test_failed_workers.py | 4 +- distributed/tests/test_nanny.py | 2 +- distributed/tests/test_publish.py | 4 +- distributed/tests/test_queues.py | 6 +- distributed/tests/test_scheduler.py | 9 ++- distributed/tests/test_stress.py | 7 +-- distributed/tests/test_worker.py | 9 ++- distributed/utils.py | 3 +- distributed/utils_test.py | 29 ++++++---- distributed/worker.py | 6 +- 21 files changed, 113 insertions(+), 115 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 451a6628e73..0a4080cd05d 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -5,7 +5,6 @@ from concurrent.futures._base import DoneAndNotDoneFutures from contextlib import contextmanager import copy -from datetime import timedelta import errno from functools import partial import html @@ -761,7 +760,7 @@ def sync(self, func, *args, asynchronous=None, callback_timeout=None, **kwargs): ): future = func(*args, **kwargs) if callback_timeout is not None: - future = gen.with_timeout(timedelta(seconds=callback_timeout), future) + future = asyncio.wait_for(future, callback_timeout) return future else: return sync( @@ -1043,9 +1042,7 @@ async def _ensure_connected(self, timeout=None): ) comm.name = "Client->Scheduler" if timeout is not None: - await gen.with_timeout( - timedelta(seconds=timeout), self._update_scheduler_info() - ) + await asyncio.wait_for(self._update_scheduler_info(), timeout) else: await self._update_scheduler_info() await comm.write( @@ -1064,7 +1061,7 @@ async def _ensure_connected(self, timeout=None): finally: self._connecting_to_scheduler = False if timeout is not None: - msg = await gen.with_timeout(timedelta(seconds=timeout), comm.read()) + msg = await asyncio.wait_for(comm.read(), timeout) else: msg = await comm.read() assert len(msg) == 1 @@ -1268,11 +1265,9 @@ async def _close(self, fast=False): # Give the scheduler 'stream-closed' message 100ms to come through # This makes the shutdown slightly smoother and quieter - with ignoring(AttributeError, gen.TimeoutError): - await gen.with_timeout( - timedelta(milliseconds=100), - self._handle_scheduler_coroutine, - quiet_exceptions=(CancelledError,), + with ignoring(AttributeError, CancelledError, asyncio.TimeoutError): + await asyncio.wait_for( + asyncio.shield(self._handle_scheduler_coroutine), 0.1 ) if ( @@ -1308,7 +1303,7 @@ async def _close(self, fast=False): if not fast: with ignoring(TimeoutError): - await gen.with_timeout(timedelta(seconds=2), list(coroutines)) + await asyncio.wait_for(asyncio.gather(*coroutines), 2) with ignoring(AttributeError): await self.scheduler.close_rpc() @@ -1344,7 +1339,7 @@ def close(self, timeout=no_default): if self.asynchronous: future = self._close() if timeout: - future = gen.with_timeout(timedelta(seconds=timeout), future) + future = asyncio.wait_for(future, timeout) return future if self._start_arg is None: @@ -4077,7 +4072,7 @@ async def _wait(fs, timeout=None, return_when=ALL_COMPLETED): future = wait_for({f._state.wait() for f in fs}) if timeout is not None: - future = gen.with_timeout(timedelta(seconds=timeout), future) + future = asyncio.wait_for(future, timeout) await future done, not_done = ( diff --git a/distributed/deploy/cluster.py b/distributed/deploy/cluster.py index 2631fb502df..1b304b0a53e 100644 --- a/distributed/deploy/cluster.py +++ b/distributed/deploy/cluster.py @@ -1,10 +1,8 @@ import asyncio -from datetime import timedelta import logging import threading from dask.utils import format_bytes -from tornado import gen from .adaptive import Adaptive @@ -156,7 +154,7 @@ def sync(self, func, *args, asynchronous=None, callback_timeout=None, **kwargs): if asynchronous: future = func(*args, **kwargs) if callback_timeout is not None: - future = gen.with_timeout(timedelta(seconds=callback_timeout), future) + future = asyncio.wait_for(future, callback_timeout) return future else: return sync(self.loop, func, *args, **kwargs) diff --git a/distributed/deploy/tests/test_local.py b/distributed/deploy/tests/test_local.py index 370423771a5..4687d0c476f 100644 --- a/distributed/deploy/tests/test_local.py +++ b/distributed/deploy/tests/test_local.py @@ -523,7 +523,7 @@ def test_memory_nanny(loop, n_workers): def test_death_timeout_raises(loop): - with pytest.raises(gen.TimeoutError): + with pytest.raises(asyncio.TimeoutError): with LocalCluster( scheduler_port=0, silence_logs=False, diff --git a/distributed/lock.py b/distributed/lock.py index ed3eb4313f2..c581bb5d552 100644 --- a/distributed/lock.py +++ b/distributed/lock.py @@ -1,9 +1,8 @@ from collections import defaultdict, deque -from datetime import timedelta import logging import uuid +import asyncio -from tornado import gen import tornado.locks from .client import _get_global_client @@ -45,10 +44,10 @@ async def acquire(self, stream=None, name=None, id=None, timeout=None): self.events[name].append(event) future = event.wait() if timeout is not None: - future = gen.with_timeout(timedelta(seconds=timeout), future) + future = asyncio.wait_for(future, timeout) try: await future - except gen.TimeoutError: + except asyncio.TimeoutError: result = False break else: diff --git a/distributed/nanny.py b/distributed/nanny.py index 7cf3c2cbbaf..19d48328f47 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -1,5 +1,4 @@ import asyncio -from datetime import timedelta import logging from multiprocessing.queues import Empty import os @@ -31,6 +30,7 @@ json_load_robust, PeriodicCallback, parse_timedelta, + ignoring, ) from .worker import run, parse_memory_limit, Worker @@ -219,14 +219,10 @@ async def _unregister(self, timeout=10): EnvironmentError, RPCClosed, ) - try: - await gen.with_timeout( - timedelta(seconds=timeout), - self.scheduler.unregister(address=self.worker_address), - quiet_exceptions=allowed_errors, + with ignoring(allowed_errors): + await asyncio.wait_for( + self.scheduler.unregister(address=self.worker_address), timeout ) - except allowed_errors: - pass @property def worker_address(self): @@ -318,8 +314,8 @@ async def instantiate(self, comm=None): self.auto_restart = True if self.death_timeout: try: - result = await gen.with_timeout( - timedelta(seconds=self.death_timeout), self.process.start() + result = await asyncio.wait_for( + self.process.start(), self.death_timeout ) except gen.TimeoutError: await self.close(timeout=self.death_timeout) @@ -343,8 +339,8 @@ async def _(): await self.instantiate() try: - await gen.with_timeout(timedelta(seconds=timeout), _()) - except gen.TimeoutError: + await asyncio.wait_for(_(), timeout) + except asyncio.TimeoutError: logger.error("Restart timed out, returning before finished") return "timed out" else: diff --git a/distributed/node.py b/distributed/node.py index 2d7447b1a06..6cf30f997fe 100644 --- a/distributed/node.py +++ b/distributed/node.py @@ -173,7 +173,7 @@ async def wait_for(future, timeout=None): await asyncio.wait_for(future, timeout=timeout) except Exception: await self.close(timeout=1) - raise gen.TimeoutError( + raise asyncio.TimeoutError( "{} failed to start in {} seconds".format( type(self).__name__, timeout ) diff --git a/distributed/process.py b/distributed/process.py index 38527ecd9ab..4ad86e2bb08 100644 --- a/distributed/process.py +++ b/distributed/process.py @@ -1,11 +1,11 @@ import atexit -from datetime import timedelta import logging import os from queue import Queue as PyQueue import re import threading import weakref +import asyncio import dask from .utils import mp_context @@ -282,7 +282,7 @@ def join(self, timeout=None): yield self._exit_future else: try: - yield gen.with_timeout(timedelta(seconds=timeout), self._exit_future) + yield asyncio.wait_for(self._exit_future, timeout) except gen.TimeoutError: pass diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 478eb413422..152df2d705f 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -2878,7 +2878,7 @@ async def restart(self, client=None, timeout=3): for nanny in nannies ] ) - resps = await gen.with_timeout(timedelta(seconds=timeout), resps) + resps = await asyncio.wait_for(resps, timeout) if not all(resp == "OK" for resp in resps): logger.error( "Not all workers responded positively: %s", resps, exc_info=True diff --git a/distributed/tests/test_batched.py b/distributed/tests/test_batched.py index a961157f948..3174f3f3022 100644 --- a/distributed/tests/test_batched.py +++ b/distributed/tests/test_batched.py @@ -1,10 +1,8 @@ import asyncio -from datetime import timedelta import random import pytest from toolz import assoc -from tornado import gen from distributed.batched import BatchedSend from distributed.core import listen, connect, CommClosedError @@ -170,7 +168,7 @@ async def send(): async def recv(): while True: - result = await gen.with_timeout(timedelta(seconds=1), comm.read()) + result = await asyncio.wait_for(comm.read(), 1) L.extend(result) if result[-1] == 9999: break @@ -205,7 +203,7 @@ async def run_traffic_jam(nsends, nbytes): # If this times out then I think it's a backpressure issue # Somehow we're able to flood the socket so that the receiving end # loses some of our messages - L = await gen.with_timeout(timedelta(seconds=5), comm.read()) + L = await asyncio.wait_for(comm.read(), 5) count += 1 results.extend(r["i"] for r in L) @@ -254,5 +252,5 @@ async def test_serializers(): msg = await comm.read() assert list(msg) == [{"x": 123}, {"x": "hello"}] - with pytest.raises(gen.TimeoutError): - msg = await gen.with_timeout(timedelta(milliseconds=100), comm.read()) + with pytest.raises(asyncio.TimeoutError): + msg = await asyncio.wait_for(comm.read(), 0.1) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 9450f08fd75..da23b85df0e 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -496,7 +496,7 @@ def test_thread(c): assert x.result() == 2 x = c.submit(slowinc, 1, delay=0.3) - with pytest.raises(gen.TimeoutError): + with pytest.raises((gen.TimeoutError, asyncio.TimeoutError)): x.result(timeout=0.01) assert x.result() == 2 @@ -681,7 +681,7 @@ def test_wait_first_completed(c, s, a, b): @gen_cluster(client=True, timeout=2) def test_wait_timeout(c, s, a, b): future = c.submit(sleep, 0.3) - with pytest.raises(gen.TimeoutError): + with pytest.raises(asyncio.TimeoutError): yield wait(future, timeout=0.01) @@ -695,7 +695,7 @@ def test_wait_sync(c): assert x.status == y.status == "finished" future = c.submit(sleep, 0.3) - with pytest.raises(gen.TimeoutError): + with pytest.raises(asyncio.TimeoutError): wait(future, timeout=0.01) @@ -5279,7 +5279,7 @@ def test_client_active_bad_port(): http_server.listen(8080) with dask.config.set({"distributed.comm.timeouts.connect": "10ms"}): c = Client("127.0.0.1:8080", asynchronous=True) - with pytest.raises((TimeoutError, IOError)): + with pytest.raises((asyncio.TimeoutError, IOError)): yield c yield c._close(fast=True) http_server.stop() diff --git a/distributed/tests/test_client_executor.py b/distributed/tests/test_client_executor.py index 7d08a63c5b2..40639998852 100644 --- a/distributed/tests/test_client_executor.py +++ b/distributed/tests/test_client_executor.py @@ -14,7 +14,16 @@ import pytest from toolz import take -from distributed.utils_test import slowinc, slowadd, slowdec, inc, throws, varying +from distributed import Client +from distributed.utils_test import ( + slowinc, + slowadd, + slowdec, + inc, + throws, + varying, + cluster, +) from distributed.utils_test import client, cluster_fixture, loop, s, a, b # noqa: F401 @@ -218,30 +227,32 @@ def test_retries(client): exc_info.match("one") -def test_shutdown(client): - # shutdown(wait=True) waits for pending tasks to finish - e = client.get_executor() - fut = e.submit(time.sleep, 1.0) - t1 = time.time() - e.shutdown() - dt = time.time() - t1 - assert 0.5 <= dt <= 2.0 - time.sleep(0.1) # wait for future outcome to propagate - assert fut.done() - fut.result() # doesn't raise - - with pytest.raises(RuntimeError): - e.submit(time.sleep, 1.0) - - # shutdown(wait=False) cancels pending tasks - e = client.get_executor() - fut = e.submit(time.sleep, 2.0) - t1 = time.time() - e.shutdown(wait=False) - dt = time.time() - t1 - assert dt < 0.5 - time.sleep(0.1) # wait for future outcome to propagate - assert fut.cancelled() - - with pytest.raises(RuntimeError): - e.submit(time.sleep, 1.0) +def test_shutdown(loop): + with cluster(disconnect_timeout=10) as (s, [a, b]): + with Client(s["address"], loop=loop) as client: + # shutdown(wait=True) waits for pending tasks to finish + e = client.get_executor() + fut = e.submit(time.sleep, 1.0) + t1 = time.time() + e.shutdown() + dt = time.time() - t1 + assert 0.5 <= dt <= 2.0 + time.sleep(0.1) # wait for future outcome to propagate + assert fut.done() + fut.result() # doesn't raise + + with pytest.raises(RuntimeError): + e.submit(time.sleep, 1.0) + + # shutdown(wait=False) cancels pending tasks + e = client.get_executor() + fut = e.submit(time.sleep, 2.0) + t1 = time.time() + e.shutdown(wait=False) + dt = time.time() - t1 + assert dt < 0.5 + time.sleep(0.1) # wait for future outcome to propagate + assert fut.cancelled() + + with pytest.raises(RuntimeError): + e.submit(time.sleep, 1.0) diff --git a/distributed/tests/test_failed_workers.py b/distributed/tests/test_failed_workers.py index 8f790edf20e..3cc055b5246 100644 --- a/distributed/tests/test_failed_workers.py +++ b/distributed/tests/test_failed_workers.py @@ -26,7 +26,7 @@ def test_submit_after_failed_worker_sync(loop): - with cluster(active_rpc_timeout=10) as (s, [a, b]): + with cluster(active_rpc_timeout=10, disconnect_timeout=10) as (s, [a, b]): with Client(s["address"], loop=loop) as c: L = c.map(inc, range(10)) wait(L) @@ -64,7 +64,7 @@ def test_submit_after_failed_worker(c, s, a, b): def test_gather_after_failed_worker(loop): - with cluster(active_rpc_timeout=10) as (s, [a, b]): + with cluster(active_rpc_timeout=10, disconnect_timeout=10) as (s, [a, b]): with Client(s["address"], loop=loop) as c: L = c.map(inc, range(10)) wait(L) diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index 88910c87069..b5631f0a47f 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -184,7 +184,7 @@ def test_nanny_alt_worker_class(c, s, w1, w2): def test_nanny_death_timeout(s): yield s.close() w = Nanny(s.address, death_timeout=1) - with pytest.raises(gen.TimeoutError): + with pytest.raises(asyncio.TimeoutError): yield w assert w.status == "closed" diff --git a/distributed/tests/test_publish.py b/distributed/tests/test_publish.py index 32b2974a738..dde10b11cf1 100644 --- a/distributed/tests/test_publish.py +++ b/distributed/tests/test_publish.py @@ -54,8 +54,8 @@ def test_publish_non_string_key(s, a, b): assert name in datasets finally: - c.close() - f.close() + yield c.close() + yield f.close() @gen_cluster(client=False) diff --git a/distributed/tests/test_queues.py b/distributed/tests/test_queues.py index 817bfcbcea5..80ce977e9f1 100644 --- a/distributed/tests/test_queues.py +++ b/distributed/tests/test_queues.py @@ -1,5 +1,5 @@ -from datetime import timedelta from time import sleep +import asyncio import pytest from tornado import gen @@ -181,8 +181,8 @@ def test_get_many(c, s, a, b): data = yield xx.get(batch=2) assert data == [1, 2] - with pytest.raises(gen.TimeoutError): - data = yield gen.with_timeout(timedelta(seconds=0.100), xx.get(batch=2)) + with pytest.raises(asyncio.TimeoutError): + data = yield asyncio.wait_for(xx.get(batch=2), 0.1) @gen_cluster(client=True) diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index c493c41dd21..df13f7a1fc1 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -2,7 +2,6 @@ import cloudpickle import pickle from collections import defaultdict -from datetime import timedelta import json import operator import re @@ -146,8 +145,8 @@ def test_no_valid_workers(client, s, a, b, c): assert s.tasks[x.key] in s.unrunnable - with pytest.raises(gen.TimeoutError): - yield gen.with_timeout(timedelta(milliseconds=50), x) + with pytest.raises(asyncio.TimeoutError): + yield asyncio.wait_for(x, 0.05) @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3) @@ -166,8 +165,8 @@ def test_no_workers(client, s): assert s.tasks[x.key] in s.unrunnable - with pytest.raises(gen.TimeoutError): - yield gen.with_timeout(timedelta(milliseconds=50), x) + with pytest.raises(asyncio.TimeoutError): + yield asyncio.wait_for(x, 0.05) @gen_cluster(nthreads=[]) diff --git a/distributed/tests/test_stress.py b/distributed/tests/test_stress.py index db91ec0c004..5275bc47fd8 100644 --- a/distributed/tests/test_stress.py +++ b/distributed/tests/test_stress.py @@ -1,9 +1,9 @@ from concurrent.futures import CancelledError -from datetime import timedelta from operator import add import random import sys from time import sleep +import asyncio from dask import delayed import pytest @@ -111,9 +111,8 @@ def create_and_destroy_worker(delay): yield n.close() print("Killed nanny") - yield gen.with_timeout( - timedelta(minutes=1), - All([create_and_destroy_worker(0.1 * i) for i in range(20)]), + yield asyncio.wait_for( + All([create_and_destroy_worker(0.1 * i) for i in range(20)]), 60 ) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index df886b9431a..a57cbaf536c 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1,5 +1,4 @@ from concurrent.futures import ThreadPoolExecutor -from datetime import timedelta import importlib import logging from numbers import Number @@ -10,6 +9,7 @@ import sys from time import sleep import traceback +import asyncio import dask from dask import delayed @@ -19,7 +19,6 @@ from toolz import pluck, sliding_window, first import tornado from tornado import gen -from tornado.ioloop import TimeoutError from distributed import ( Client, @@ -326,8 +325,8 @@ def test_worker_waits_for_scheduler(loop): def f(): w = Worker("127.0.0.1", 8007) try: - yield gen.with_timeout(timedelta(seconds=3), w) - except TimeoutError: + yield asyncio.wait_for(w, 3) + except asyncio.TimeoutError: pass else: assert False @@ -762,7 +761,7 @@ def test_worker_death_timeout(s): yield s.close() w = Worker(s.address, death_timeout=1) - with pytest.raises(gen.TimeoutError) as info: + with pytest.raises(asyncio.TimeoutError) as info: yield w assert "Worker" in str(info.value) diff --git a/distributed/utils.py b/distributed/utils.py index 15824262ab7..39bc973cdb7 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -3,7 +3,6 @@ from collections import deque, OrderedDict, UserDict from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager -from datetime import timedelta import functools from hashlib import md5 import html @@ -325,7 +324,7 @@ def f(): thread_state.asynchronous = True future = func(*args, **kwargs) if callback_timeout is not None: - future = gen.with_timeout(timedelta(seconds=callback_timeout), future) + future = asyncio.wait_for(future, callback_timeout) result[0] = yield future except Exception as exc: error[0] = sys.exc_info() diff --git a/distributed/utils_test.py b/distributed/utils_test.py index aef7bde8eee..2fba3c74bfc 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -2,7 +2,6 @@ import collections from contextlib import contextmanager import copy -from datetime import timedelta import functools from glob import glob import io @@ -606,7 +605,12 @@ def security(): @contextmanager def cluster( - nworkers=2, nanny=False, worker_kwargs={}, active_rpc_timeout=1, scheduler_kwargs={} + nworkers=2, + nanny=False, + worker_kwargs={}, + active_rpc_timeout=1, + disconnect_timeout=3, + scheduler_kwargs={}, ): ws = weakref.WeakSet() enable_proctitle_on_children() @@ -689,10 +693,16 @@ def cluster( loop.run_sync( lambda: disconnect_all( - [w["address"] for w in workers], timeout=0.5, rpc_kwargs=rpc_kwargs + [w["address"] for w in workers], + timeout=disconnect_timeout, + rpc_kwargs=rpc_kwargs, + ) + ) + loop.run_sync( + lambda: disconnect( + saddr, timeout=disconnect_timeout, rpc_kwargs=rpc_kwargs ) ) - loop.run_sync(lambda: disconnect(saddr, timeout=0.5, rpc_kwargs=rpc_kwargs)) scheduler.terminate() scheduler_q.close() @@ -740,8 +750,7 @@ async def do_disconnect(): with rpc(addr, **rpc_kwargs) as w: await w.terminate(close=True) - with ignoring(TimeoutError): - await gen.with_timeout(timedelta(seconds=timeout), do_disconnect()) + await asyncio.wait_for(do_disconnect(), timeout=timeout) async def disconnect_all(addresses, timeout=3, rpc_kwargs=None): @@ -914,9 +923,7 @@ async def coro(): try: future = func(*args) if timeout: - future = gen.with_timeout( - timedelta(seconds=timeout), future - ) + future = asyncio.wait_for(future, timeout) result = await future if s.validate: s.validate_state() @@ -924,9 +931,7 @@ async def coro(): if client and c.status not in ("closing", "closed"): await c._close(fast=s.status == "closed") await end_cluster(s, workers) - await gen.with_timeout( - timedelta(seconds=1), cleanup_global_workers() - ) + await asyncio.wait_for(cleanup_global_workers(), 1) try: c = await default_client() diff --git a/distributed/worker.py b/distributed/worker.py index f4a662ce44f..9bb30dfd39d 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1095,13 +1095,13 @@ async def close( for pc in self.periodic_callbacks.values(): pc.stop() - with ignoring(EnvironmentError, gen.TimeoutError): + with ignoring(EnvironmentError, asyncio.TimeoutError): if report and self.contact_address is not None: - await gen.with_timeout( - timedelta(seconds=timeout), + await asyncio.wait_for( self.scheduler.unregister( address=self.contact_address, safe=safe ), + timeout, ) await self.scheduler.close_rpc() self._workdir.release() From 726f65438815317bd6c430b983463cfdbe34712b Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Sun, 19 Jan 2020 18:27:19 -0600 Subject: [PATCH 0635/1550] Use latest release of black (#3388) --- .pre-commit-config.yaml | 2 +- .travis.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2c72a38ce93..2b64eddd06a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/ambv/black - rev: cad4138050b86d1c8570b926883e32f7465c2880 + rev: stable hooks: - id: black language_version: python3.7 diff --git a/.travis.yml b/.travis.yml index 5d3cbf0ec0b..56c2588ff5f 100644 --- a/.travis.yml +++ b/.travis.yml @@ -23,7 +23,7 @@ install: script: - if [[ $TESTS == true ]]; then source continuous_integration/travis/run_tests.sh ; fi - if [[ $LINT == true ]]; then pip install flake8 ; flake8 distributed ; fi - - if [[ $LINT == true ]]; then pip install git+https://github.com/psf/black@cad4138050b86d1c8570b926883e32f7465c2880; black distributed --check; fi + - if [[ $LINT == true ]]; then pip install black ; black distributed --check; fi after_success: - if [[ $COVERAGE == true ]]; then coverage report; pip install -q coveralls ; coveralls ; fi From 84f220a838cd4578671652bd1fae3a46574c5e92 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 20 Jan 2020 13:22:20 -0800 Subject: [PATCH 0636/1550] Add lifecycle hooks to SchedulerPlugin (#3391) * Add lifecycle hooks to SchedulerPlugin This adds start and close async functions to Scheduler Plugins --- distributed/diagnostics/plugin.py | 15 +++++++++++ .../tests/test_scheduler_plugin.py | 26 +++++++++++++++++-- distributed/scheduler.py | 7 ++++- distributed/tests/test_utils.py | 3 ++- 4 files changed, 47 insertions(+), 4 deletions(-) diff --git a/distributed/diagnostics/plugin.py b/distributed/diagnostics/plugin.py index 4d94f7c8859..1d218fe5ac8 100644 --- a/distributed/diagnostics/plugin.py +++ b/distributed/diagnostics/plugin.py @@ -37,6 +37,21 @@ class SchedulerPlugin(object): >>> scheduler.add_plugin(plugin) # doctest: +SKIP """ + async def start(self, scheduler): + """ Run when the scheduler starts up + + This runs at the end of the Scheduler startup process + """ + pass + + async def close(self): + """ Run when the scheduler closes down + + This runs at the beginning of the Scheduler shutdown process, but after + workers have been asked to shut down gracefully + """ + pass + def update_graph(self, scheduler, dsk=None, keys=None, restrictions=None, **kwargs): """ Run when a new graph / tasks enter the scheduler """ diff --git a/distributed/diagnostics/tests/test_scheduler_plugin.py b/distributed/diagnostics/tests/test_scheduler_plugin.py index 2903214ba32..6fc9e22f3df 100644 --- a/distributed/diagnostics/tests/test_scheduler_plugin.py +++ b/distributed/diagnostics/tests/test_scheduler_plugin.py @@ -1,5 +1,6 @@ -from distributed import Worker, SchedulerPlugin -from distributed.utils_test import inc, gen_cluster +import pytest +from distributed import Scheduler, Worker, SchedulerPlugin +from distributed.utils_test import inc, gen_cluster, cleanup # noqa: F401 @gen_cluster(client=True) @@ -67,3 +68,24 @@ def remove_worker(self, worker, scheduler): a = yield Worker(s.address) yield a.close() assert events == [] + + +@pytest.mark.asyncio +async def test_lifecycle(cleanup): + class LifeCycle(SchedulerPlugin): + def __init__(self): + self.history = [] + + async def start(self, scheduler): + self.scheduler = scheduler + self.history.append("started") + + async def close(self): + self.history.append("closed") + + plugin = LifeCycle() + async with Scheduler(plugins=[plugin]) as s: + pass + + assert plugin.history == ["started", "closed"] + assert plugin.scheduler is s diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 152df2d705f..31ce2596e7e 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1041,6 +1041,7 @@ def __init__( dashboard_address=None, preload=None, preload_argv=(), + plugins=(), **kwargs ): self._setup_logging(logger) @@ -1210,7 +1211,7 @@ def __init__( ] self.extensions = {} - self.plugins = [] + self.plugins = list(plugins) self.transition_log = deque( maxlen=dask.config.get("distributed.scheduler.transition-log-length") ) @@ -1437,6 +1438,8 @@ def del_scheduler_file(): preload_modules(self.preload, parameter=self, argv=self.preload_argv) + await asyncio.gather(*[plugin.start(self) for plugin in self.plugins]) + self.start_periodic_callbacks() setproctitle("dask-scheduler [%s]" % (self.address,)) @@ -1467,6 +1470,8 @@ async def close(self, comm=None, fast=False, close_workers=False): else: break + await asyncio.gather(*[plugin.close() for plugin in self.plugins]) + for pc in self.periodic_callbacks.values(): pc.stop() self.periodic_callbacks.clear() diff --git a/distributed/tests/test_utils.py b/distributed/tests/test_utils.py index ff2e42313ac..cf15985eb7a 100644 --- a/distributed/tests/test_utils.py +++ b/distributed/tests/test_utils.py @@ -1,3 +1,4 @@ +import asyncio import array import datetime from functools import partial @@ -109,7 +110,7 @@ def function2(x): def test_sync_timeout(loop_in_thread): loop = loop_in_thread - with pytest.raises(gen.TimeoutError): + with pytest.raises((asyncio.TimeoutError, gen.TimeoutError)): sync(loop_in_thread, gen.sleep, 0.5, callback_timeout=0.05) From 3a2f8a84121d27decf432cd4f379724907ea2b3d Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 21 Jan 2020 14:10:08 -0800 Subject: [PATCH 0637/1550] Replace gen.TimeoutError with utils.TimeoutError (#3394) * Replace gen.TimeoutError with utils.TimeoutError Previously we had a mix of gen.TimeoutError and asyncio.TimeoutError within the code. This import asyncio.TimeoutError within utils.py (and our top level imports) and then uses that consistently throughout the codebase * Translate gen.TimeoutErrors into asyncio.TimeoutErrors Sometimes when we use Tornado objects like events and queues we experience gen.TimeoutErrors. In these cases we raise instead asyncio.TimeoutErrors --- distributed/__init__.py | 4 +--- distributed/cfexecutor.py | 4 ++-- distributed/client.py | 6 ++--- distributed/comm/core.py | 4 ++-- distributed/deploy/spec.py | 5 +++-- distributed/deploy/tests/test_local.py | 4 ++-- distributed/lock.py | 4 ++-- distributed/nanny.py | 18 ++++++--------- distributed/node.py | 4 ++-- distributed/process.py | 4 ++-- distributed/pubsub.py | 9 +++++--- distributed/queues.py | 22 +++++++++++++----- distributed/scheduler.py | 31 +++++++++++++------------- distributed/tests/test_batched.py | 4 ++-- distributed/tests/test_client.py | 16 ++++++------- distributed/tests/test_nanny.py | 4 ++-- distributed/tests/test_queues.py | 12 +++++----- distributed/tests/test_scheduler.py | 10 ++++----- distributed/tests/test_utils.py | 4 ++-- distributed/tests/test_variable.py | 6 ++--- distributed/tests/test_worker.py | 6 ++--- distributed/utils.py | 3 ++- distributed/utils_test.py | 4 ++-- distributed/variable.py | 11 +++++---- distributed/worker.py | 7 +++--- 25 files changed, 110 insertions(+), 96 deletions(-) diff --git a/distributed/__init__.py b/distributed/__init__.py index 06136dd72a2..9238d57ccc9 100644 --- a/distributed/__init__.py +++ b/distributed/__init__.py @@ -24,13 +24,11 @@ from .queues import Queue from .scheduler import Scheduler from .threadpoolexecutor import rejoin -from .utils import sync +from .utils import sync, TimeoutError from .variable import Variable from .worker import Worker, get_worker, get_client, secede, Reschedule from .worker_client import local_client, worker_client -from tornado.gen import TimeoutError - from ._version import get_versions versions = get_versions() diff --git a/distributed/cfexecutor.py b/distributed/cfexecutor.py index 373a3c4eb28..985a407bdb9 100644 --- a/distributed/cfexecutor.py +++ b/distributed/cfexecutor.py @@ -6,7 +6,7 @@ from tornado import gen from .metrics import time -from .utils import sync +from .utils import sync, TimeoutError @gen.coroutine @@ -135,7 +135,7 @@ def result_iterator(): if timeout is not None: try: yield future.result(end_time - time()) - except gen.TimeoutError: + except TimeoutError: raise cf.TimeoutError else: yield future.result() diff --git a/distributed/client.py b/distributed/client.py index 0a4080cd05d..7a39ec4b235 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -37,7 +37,6 @@ except ImportError: single_key = first from tornado import gen -from tornado.gen import TimeoutError from tornado.locks import Event, Condition, Semaphore from tornado.ioloop import IOLoop from tornado.queues import Queue @@ -85,6 +84,7 @@ Any, has_keyword, format_dashboard_link, + TimeoutError, ) from . import versions as version_module @@ -1265,7 +1265,7 @@ async def _close(self, fast=False): # Give the scheduler 'stream-closed' message 100ms to come through # This makes the shutdown slightly smoother and quieter - with ignoring(AttributeError, CancelledError, asyncio.TimeoutError): + with ignoring(AttributeError, CancelledError, TimeoutError): await asyncio.wait_for( asyncio.shield(self._handle_scheduler_coroutine), 0.1 ) @@ -1957,7 +1957,7 @@ async def _scatter( if nthreads is not None: await asyncio.sleep(0.1) if time() > start + timeout: - raise gen.TimeoutError("No valid workers found") + raise TimeoutError("No valid workers found") nthreads = await self.scheduler.ncores(workers=workers) if not nthreads: raise ValueError("No valid workers") diff --git a/distributed/comm/core.py b/distributed/comm/core.py index 42c95e3579e..e801242bb40 100644 --- a/distributed/comm/core.py +++ b/distributed/comm/core.py @@ -6,7 +6,7 @@ import dask from ..metrics import time -from ..utils import parse_timedelta, ignoring +from ..utils import parse_timedelta, ignoring, TimeoutError from . import registry from .addressing import parse_address @@ -209,7 +209,7 @@ def _raise(error): future = connector.connect( loc, deserialize=deserialize, **(connection_args or {}) ) - with ignoring(asyncio.TimeoutError): + with ignoring(TimeoutError): comm = await asyncio.wait_for( future, timeout=min(deadline - time(), 1) ) diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index fb06057bb64..96279d15323 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -6,8 +6,8 @@ import weakref import dask -from tornado import gen from tornado.locks import Event +from tornado import gen from .adaptive import Adaptive from .cluster import Cluster @@ -19,6 +19,7 @@ parse_bytes, parse_timedelta, import_term, + TimeoutError, ) from ..scheduler import Scheduler from ..security import Security @@ -602,6 +603,6 @@ async def run_spec(spec: dict, *args): @atexit.register def close_clusters(): for cluster in list(SpecCluster._instances): - with ignoring(gen.TimeoutError): + with ignoring((gen.TimeoutError, TimeoutError)): if cluster.status != "closed": cluster.close(timeout=10) diff --git a/distributed/deploy/tests/test_local.py b/distributed/deploy/tests/test_local.py index 4687d0c476f..98d04c78d17 100644 --- a/distributed/deploy/tests/test_local.py +++ b/distributed/deploy/tests/test_local.py @@ -31,7 +31,7 @@ tls_only_security, ) from distributed.utils_test import loop # noqa: F401 -from distributed.utils import sync +from distributed.utils import sync, TimeoutError from distributed.deploy.utils_test import ClusterTest @@ -523,7 +523,7 @@ def test_memory_nanny(loop, n_workers): def test_death_timeout_raises(loop): - with pytest.raises(asyncio.TimeoutError): + with pytest.raises(TimeoutError): with LocalCluster( scheduler_port=0, silence_logs=False, diff --git a/distributed/lock.py b/distributed/lock.py index c581bb5d552..c230a8e861c 100644 --- a/distributed/lock.py +++ b/distributed/lock.py @@ -6,7 +6,7 @@ import tornado.locks from .client import _get_global_client -from .utils import log_errors +from .utils import log_errors, TimeoutError from .worker import get_worker logger = logging.getLogger(__name__) @@ -47,7 +47,7 @@ async def acquire(self, stream=None, name=None, id=None, timeout=None): future = asyncio.wait_for(future, timeout) try: await future - except asyncio.TimeoutError: + except TimeoutError: result = False break else: diff --git a/distributed/nanny.py b/distributed/nanny.py index 19d48328f47..945f33041d3 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -11,9 +11,9 @@ import dask from dask.system import CPU_COUNT -from tornado import gen -from tornado.ioloop import IOLoop, TimeoutError +from tornado.ioloop import IOLoop from tornado.locks import Event +from tornado import gen from .comm import get_address_host, unparse_host_port from .comm.addressing import address_from_user_args @@ -31,6 +31,7 @@ PeriodicCallback, parse_timedelta, ignoring, + TimeoutError, ) from .worker import run, parse_memory_limit, Worker @@ -213,12 +214,7 @@ async def _unregister(self, timeout=10): if worker_address is None: return - allowed_errors = ( - gen.TimeoutError, - CommClosedError, - EnvironmentError, - RPCClosed, - ) + allowed_errors = (TimeoutError, CommClosedError, EnvironmentError, RPCClosed) with ignoring(allowed_errors): await asyncio.wait_for( self.scheduler.unregister(address=self.worker_address), timeout @@ -317,7 +313,7 @@ async def instantiate(self, comm=None): result = await asyncio.wait_for( self.process.start(), self.death_timeout ) - except gen.TimeoutError: + except TimeoutError: await self.close(timeout=self.death_timeout) logger.error( "Timed out connecting Nanny '%s' to scheduler '%s'", @@ -340,7 +336,7 @@ async def _(): try: await asyncio.wait_for(_(), timeout) - except asyncio.TimeoutError: + except TimeoutError: logger.error("Restart timed out, returning before finished") return "timed out" else: @@ -729,7 +725,7 @@ async def run(): try: loop.run_sync(run) - except TimeoutError: + except (TimeoutError, gen.TimeoutError): # Loop was stopped before wait_until_closed() returned, ignore pass except KeyboardInterrupt: diff --git a/distributed/node.py b/distributed/node.py index 6cf30f997fe..edee3e2dd7b 100644 --- a/distributed/node.py +++ b/distributed/node.py @@ -9,7 +9,7 @@ from .core import Server, ConnectionPool from .versions import get_versions -from .utils import DequeHandler +from .utils import DequeHandler, TimeoutError class Node(object): @@ -173,7 +173,7 @@ async def wait_for(future, timeout=None): await asyncio.wait_for(future, timeout=timeout) except Exception: await self.close(timeout=1) - raise asyncio.TimeoutError( + raise TimeoutError( "{} failed to start in {} seconds".format( type(self).__name__, timeout ) diff --git a/distributed/process.py b/distributed/process.py index 4ad86e2bb08..5899c853385 100644 --- a/distributed/process.py +++ b/distributed/process.py @@ -8,7 +8,7 @@ import asyncio import dask -from .utils import mp_context +from .utils import mp_context, TimeoutError from tornado import gen from tornado.concurrent import Future @@ -283,7 +283,7 @@ def join(self, timeout=None): else: try: yield asyncio.wait_for(self._exit_future, timeout) - except gen.TimeoutError: + except TimeoutError: pass def close(self): diff --git a/distributed/pubsub.py b/distributed/pubsub.py index 0a5d82897fd..9de133ddb47 100644 --- a/distributed/pubsub.py +++ b/distributed/pubsub.py @@ -8,7 +8,7 @@ from tornado import gen from .core import CommClosedError -from .utils import sync +from .utils import sync, TimeoutError from .protocol.serialize import to_serialize logger = logging.getLogger(__name__) @@ -400,10 +400,13 @@ async def _get(self, timeout=None): if timeout is not None: timeout2 = timeout - (datetime.datetime.now() - start) if timeout2.total_seconds() < 0: - raise gen.TimeoutError() + raise TimeoutError() else: timeout2 = None - await self.condition.wait(timeout=timeout2) + try: + await self.condition.wait(timeout=timeout2) + except gen.TimeoutError: + raise TimeoutError("Timed out waiting on Sub") return self.buffer.popleft() diff --git a/distributed/queues.py b/distributed/queues.py index 1d0c2c0bdd3..6d1fc76571b 100644 --- a/distributed/queues.py +++ b/distributed/queues.py @@ -5,9 +5,10 @@ import tornado.queues from tornado.locks import Event +from tornado import gen from .client import Future, _get_global_client, Client -from .utils import tokey, sync, thread_state +from .utils import tokey, sync, thread_state, TimeoutError from .worker import get_client logger = logging.getLogger(__name__) @@ -78,7 +79,10 @@ async def put( record = {"type": "msgpack", "value": data} if timeout is not None: timeout = datetime.timedelta(seconds=timeout) - await self.queues[name].put(record, timeout=timeout) + try: + await self.queues[name].put(record, timeout=timeout) + except gen.TimeoutError: + raise TimeoutError("Timed out waiting for Queue") def future_release(self, name=None, key=None, client=None): self.future_refcount[name, key] -= 1 @@ -124,7 +128,10 @@ def process(record): else: if timeout is not None: timeout = datetime.timedelta(seconds=timeout) - record = await self.queues[name].get(timeout=timeout) + try: + record = await self.queues[name].get(timeout=timeout) + except gen.TimeoutError: + raise TimeoutError("Timed out waiting for Queue") record = process(record) return record @@ -225,9 +232,12 @@ def qsize(self, **kwargs): return self.client.sync(self._qsize, **kwargs) async def _get(self, timeout=None, batch=False): - resp = await self.client.scheduler.queue_get( - timeout=timeout, name=self.name, batch=batch - ) + try: + resp = await self.client.scheduler.queue_get( + timeout=timeout, name=self.name, batch=batch + ) + except gen.TimeoutError: + raise TimeoutError("Timed out waiting for Queue") def process(d): if d["type"] == "Future": diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 31ce2596e7e..8080a1186da 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -24,7 +24,6 @@ except ImportError: from toolz import frequencies, merge, pluck, merge_sorted, first, merge_with from toolz import valmap, second, compose, groupby -from tornado import gen from tornado.ioloop import IOLoop import dask @@ -60,6 +59,7 @@ key_split_group, empty_context, tmpfile, + TimeoutError, ) from .utils_comm import scatter_to_workers, gather_from_workers, retry_operation from .utils_perf import enable_gc_diagnosis, disable_gc_diagnosis @@ -2744,7 +2744,7 @@ async def scatter( while not self.workers: await asyncio.sleep(0.2) if time() > start + timeout: - raise gen.TimeoutError("No workers found") + raise TimeoutError("No workers found") if workers is None: nthreads = {w: ws.nthreads for w, ws in self.workers.items()} @@ -2874,25 +2874,26 @@ async def restart(self, client=None, timeout=3): if nanny_address is not None ] + resps = All( + [ + nanny.restart( + close=True, timeout=timeout * 0.8, executor_wait=False + ) + for nanny in nannies + ] + ) try: - resps = All( - [ - nanny.restart( - close=True, timeout=timeout * 0.8, executor_wait=False - ) - for nanny in nannies - ] - ) resps = await asyncio.wait_for(resps, timeout) - if not all(resp == "OK" for resp in resps): - logger.error( - "Not all workers responded positively: %s", resps, exc_info=True - ) - except gen.TimeoutError: + except TimeoutError: logger.error( "Nannies didn't report back restarted within " "timeout. Continuuing with restart process" ) + else: + if not all(resp == "OK" for resp in resps): + logger.error( + "Not all workers responded positively: %s", resps, exc_info=True + ) finally: await asyncio.gather(*[nanny.close_rpc() for nanny in nannies]) diff --git a/distributed/tests/test_batched.py b/distributed/tests/test_batched.py index 3174f3f3022..07dd32f4c68 100644 --- a/distributed/tests/test_batched.py +++ b/distributed/tests/test_batched.py @@ -7,7 +7,7 @@ from distributed.batched import BatchedSend from distributed.core import listen, connect, CommClosedError from distributed.metrics import time -from distributed.utils import All +from distributed.utils import All, TimeoutError from distributed.utils_test import captured_logger from distributed.protocol import to_serialize @@ -252,5 +252,5 @@ async def test_serializers(): msg = await comm.read() assert list(msg) == [{"x": 123}, {"x": "hello"}] - with pytest.raises(asyncio.TimeoutError): + with pytest.raises(TimeoutError): msg = await asyncio.wait_for(comm.read(), 0.1) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index da23b85df0e..77180eefa35 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -496,7 +496,7 @@ def test_thread(c): assert x.result() == 2 x = c.submit(slowinc, 1, delay=0.3) - with pytest.raises((gen.TimeoutError, asyncio.TimeoutError)): + with pytest.raises(TimeoutError): x.result(timeout=0.01) assert x.result() == 2 @@ -681,7 +681,7 @@ def test_wait_first_completed(c, s, a, b): @gen_cluster(client=True, timeout=2) def test_wait_timeout(c, s, a, b): future = c.submit(sleep, 0.3) - with pytest.raises(asyncio.TimeoutError): + with pytest.raises(TimeoutError): yield wait(future, timeout=0.01) @@ -695,7 +695,7 @@ def test_wait_sync(c): assert x.status == y.status == "finished" future = c.submit(sleep, 0.3) - with pytest.raises(asyncio.TimeoutError): + with pytest.raises(TimeoutError): wait(future, timeout=0.01) @@ -1394,7 +1394,7 @@ def test_scatter_direct_broadcast_target(c, s, *workers): @gen_cluster(client=True, nthreads=[]) def test_scatter_direct_empty(c, s): - with pytest.raises((ValueError, gen.TimeoutError)): + with pytest.raises((ValueError, TimeoutError)): yield c.scatter(123, direct=True, timeout=0.1) @@ -1801,12 +1801,12 @@ def test_allow_restrictions(c, s, a, b): def test_bad_address(): try: Client("123.123.123.123:1234", timeout=0.1) - except (IOError, gen.TimeoutError) as e: + except (IOError, TimeoutError) as e: assert "connect" in str(e).lower() try: Client("127.0.0.1:1234", timeout=0.1) - except (IOError, gen.TimeoutError) as e: + except (IOError, TimeoutError) as e: assert "connect" in str(e).lower() @@ -3501,7 +3501,7 @@ def test_persist_optimize_graph(c, s, a, b): @gen_cluster(client=True, nthreads=[]) def test_scatter_raises_if_no_workers(c, s): - with pytest.raises(gen.TimeoutError): + with pytest.raises(TimeoutError): yield c.scatter(1, timeout=0.5) @@ -5279,7 +5279,7 @@ def test_client_active_bad_port(): http_server.listen(8080) with dask.config.set({"distributed.comm.timeouts.connect": "10ms"}): c = Client("127.0.0.1:8080", asynchronous=True) - with pytest.raises((asyncio.TimeoutError, IOError)): + with pytest.raises((TimeoutError, IOError)): yield c yield c._close(fast=True) http_server.stop() diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index b5631f0a47f..2ddc3b7e5db 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -20,7 +20,7 @@ from distributed.core import CommClosedError from distributed.metrics import time from distributed.protocol.pickle import dumps -from distributed.utils import ignoring, tmpfile +from distributed.utils import ignoring, tmpfile, TimeoutError from distributed.utils_test import ( # noqa: F401 gen_cluster, gen_test, @@ -184,7 +184,7 @@ def test_nanny_alt_worker_class(c, s, w1, w2): def test_nanny_death_timeout(s): yield s.close() w = Nanny(s.address, death_timeout=1) - with pytest.raises(asyncio.TimeoutError): + with pytest.raises(TimeoutError): yield w assert w.status == "closed" diff --git a/distributed/tests/test_queues.py b/distributed/tests/test_queues.py index 80ce977e9f1..d797433d6b4 100644 --- a/distributed/tests/test_queues.py +++ b/distributed/tests/test_queues.py @@ -4,7 +4,7 @@ import pytest from tornado import gen -from distributed import Client, Queue, Nanny, worker_client, wait +from distributed import Client, Queue, Nanny, worker_client, wait, TimeoutError from distributed.metrics import time from distributed.utils_test import gen_cluster, inc, div from distributed.utils_test import client, cluster_fixture, loop # noqa: F401 @@ -24,7 +24,7 @@ def test_queue(c, s, a, b): future2 = yield xx.get() assert future.key == future2.key - with pytest.raises(gen.TimeoutError): + with pytest.raises(TimeoutError): yield x.get(timeout=0.1) del future, future2 @@ -50,7 +50,7 @@ def test_queue_with_data(c, s, a, b): assert data == (1, "hello") - with pytest.raises(gen.TimeoutError): + with pytest.raises(TimeoutError): yield x.get(timeout=0.1) @@ -181,7 +181,7 @@ def test_get_many(c, s, a, b): data = yield xx.get(batch=2) assert data == [1, 2] - with pytest.raises(asyncio.TimeoutError): + with pytest.raises(TimeoutError): data = yield asyncio.wait_for(xx.get(batch=2), 0.1) @@ -248,7 +248,7 @@ def test_timeout(c, s, a, b): q = yield Queue("v", maxsize=1) start = time() - with pytest.raises(gen.TimeoutError): + with pytest.raises(TimeoutError): yield q.get(timeout=0.3) stop = time() assert 0.2 < stop - start < 2.0 @@ -256,7 +256,7 @@ def test_timeout(c, s, a, b): yield q.put(1) start = time() - with pytest.raises(gen.TimeoutError): + with pytest.raises(TimeoutError): yield q.put(2, timeout=0.3) stop = time() assert 0.1 < stop - start < 2.0 diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index df13f7a1fc1..fd0775ad003 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -25,7 +25,7 @@ from distributed.metrics import time from distributed.protocol.pickle import dumps from distributed.worker import dumps_function, dumps_task -from distributed.utils import tmpfile, typename +from distributed.utils import tmpfile, typename, TimeoutError from distributed.utils_test import ( # noqa: F401 captured_logger, cleanup, @@ -145,7 +145,7 @@ def test_no_valid_workers(client, s, a, b, c): assert s.tasks[x.key] in s.unrunnable - with pytest.raises(asyncio.TimeoutError): + with pytest.raises(TimeoutError): yield asyncio.wait_for(x, 0.05) @@ -165,7 +165,7 @@ def test_no_workers(client, s): assert s.tasks[x.key] in s.unrunnable - with pytest.raises(asyncio.TimeoutError): + with pytest.raises(TimeoutError): yield asyncio.wait_for(x, 0.05) @@ -656,11 +656,11 @@ def test_story(c, s, a, b): @gen_cluster(nthreads=[], client=True) def test_scatter_no_workers(c, s): - with pytest.raises(gen.TimeoutError): + with pytest.raises(TimeoutError): yield s.scatter(data={"x": 1}, client="alice", timeout=0.1) start = time() - with pytest.raises(gen.TimeoutError): + with pytest.raises(TimeoutError): yield c.scatter(123, timeout=0.1) assert time() < start + 1.5 diff --git a/distributed/tests/test_utils.py b/distributed/tests/test_utils.py index cf15985eb7a..93b843358a8 100644 --- a/distributed/tests/test_utils.py +++ b/distributed/tests/test_utils.py @@ -1,4 +1,3 @@ -import asyncio import array import datetime from functools import partial @@ -44,6 +43,7 @@ warn_on_duration, format_dashboard_link, LRU, + TimeoutError, ) from distributed.utils_test import loop, loop_in_thread # noqa: F401 from distributed.utils_test import div, has_ipv6, inc, throws, gen_test, captured_logger @@ -110,7 +110,7 @@ def function2(x): def test_sync_timeout(loop_in_thread): loop = loop_in_thread - with pytest.raises((asyncio.TimeoutError, gen.TimeoutError)): + with pytest.raises(TimeoutError): sync(loop_in_thread, gen.sleep, 0.5, callback_timeout=0.05) diff --git a/distributed/tests/test_variable.py b/distributed/tests/test_variable.py index 6dcca9c9cf4..962b7a40e42 100644 --- a/distributed/tests/test_variable.py +++ b/distributed/tests/test_variable.py @@ -6,7 +6,7 @@ import pytest from tornado import gen -from distributed import Client, Variable, worker_client, Nanny, wait +from distributed import Client, Variable, worker_client, Nanny, wait, TimeoutError from distributed.metrics import time from distributed.utils_test import gen_cluster, inc, div from distributed.utils_test import client, cluster_fixture, loop # noqa: F401 @@ -84,7 +84,7 @@ def test_timeout(c, s, a, b): v = Variable("v") start = time() - with pytest.raises(gen.TimeoutError): + with pytest.raises(TimeoutError): yield v.get(timeout=0.1) stop = time() assert 0.1 < stop - start < 2.0 @@ -93,7 +93,7 @@ def test_timeout(c, s, a, b): def test_timeout_sync(client): v = Variable("v") start = time() - with pytest.raises(gen.TimeoutError): + with pytest.raises(TimeoutError): v.get(timeout=0.1) stop = time() assert 0.1 < stop - start < 2.0 diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index a57cbaf536c..11ab461ae25 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -35,7 +35,7 @@ from distributed.scheduler import Scheduler from distributed.metrics import time from distributed.worker import Worker, error_message, logger, parse_memory_limit -from distributed.utils import tmpfile +from distributed.utils import tmpfile, TimeoutError from distributed.utils_test import ( # noqa: F401 cleanup, inc, @@ -326,7 +326,7 @@ def f(): w = Worker("127.0.0.1", 8007) try: yield asyncio.wait_for(w, 3) - except asyncio.TimeoutError: + except TimeoutError: pass else: assert False @@ -761,7 +761,7 @@ def test_worker_death_timeout(s): yield s.close() w = Worker(s.address, death_timeout=1) - with pytest.raises(asyncio.TimeoutError) as info: + with pytest.raises(TimeoutError) as info: yield w assert "Worker" in str(info.value) diff --git a/distributed/utils.py b/distributed/utils.py index 39bc973cdb7..ee9d7948a11 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -1,4 +1,5 @@ import asyncio +from asyncio import TimeoutError import atexit from collections import deque, OrderedDict, UserDict from concurrent.futures import ThreadPoolExecutor @@ -335,7 +336,7 @@ def f(): loop.add_callback(f) if callback_timeout is not None: if not e.wait(callback_timeout): - raise gen.TimeoutError("timed out after %s s." % (callback_timeout,)) + raise TimeoutError("timed out after %s s." % (callback_timeout,)) else: while not e.is_set(): e.wait(10) diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 2fba3c74bfc..1650e0426ef 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -34,7 +34,6 @@ import dask from toolz import merge, memoize, assoc from tornado import gen, queues -from tornado.gen import TimeoutError from tornado.ioloop import IOLoop from . import system @@ -60,6 +59,7 @@ iscoroutinefunction, thread_state, _offload_executor, + TimeoutError, ) from .worker import Worker from .nanny import Nanny @@ -141,7 +141,7 @@ def start(): except RuntimeError as e: if not re.match("IOLoop is clos(ed|ing)", str(e)): raise - except gen.TimeoutError: + except TimeoutError: pass else: is_stopped.wait() diff --git a/distributed/variable.py b/distributed/variable.py index 2169c287f61..677e2997b32 100644 --- a/distributed/variable.py +++ b/distributed/variable.py @@ -3,8 +3,8 @@ import logging import uuid -from tornado import gen import tornado.locks +from tornado import gen try: from cytoolz import merge @@ -13,7 +13,7 @@ from .client import Future, _get_global_client, Client from .metrics import time -from .utils import tokey, log_errors +from .utils import tokey, log_errors, TimeoutError from .worker import get_client logger = logging.getLogger(__name__) @@ -82,8 +82,11 @@ async def get(self, stream=None, name=None, client=None, timeout=None): else: left = None if left and left < 0: - raise gen.TimeoutError() - await self.started.wait(timeout=left) + raise TimeoutError() + try: + await self.started.wait(timeout=left) + except gen.TimeoutError: + raise TimeoutError("Timed out waiting for Variable.get") record = self.variables[name] if record["type"] == "Future": key = record["value"] diff --git a/distributed/worker.py b/distributed/worker.py index 9bb30dfd39d..7062aba5b87 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -62,6 +62,7 @@ iscoroutinefunction, warn_on_duration, LRU, + TimeoutError, ) from .utils_comm import pack_data, gather_from_workers, retry_operation from .utils_perf import ThrottledGC, enable_gc_diagnosis, disable_gc_diagnosis @@ -845,7 +846,7 @@ async def _register_with_scheduler(self): except EnvironmentError: logger.info("Waiting to connect to: %26s", self.scheduler.address) await asyncio.sleep(0.1) - except gen.TimeoutError: + except TimeoutError: logger.info("Timed out when connecting to scheduler") if response["status"] != "OK": raise ValueError("Unexpected response from register: %r" % (response,)) @@ -1095,7 +1096,7 @@ async def close( for pc in self.periodic_callbacks.values(): pc.stop() - with ignoring(EnvironmentError, asyncio.TimeoutError): + with ignoring(EnvironmentError, TimeoutError): if report and self.contact_address is not None: await asyncio.wait_for( self.scheduler.unregister( @@ -1113,7 +1114,7 @@ async def close( self.batched_stream.send({"op": "close-stream"}) if self.batched_stream: - with ignoring(gen.TimeoutError): + with ignoring(TimeoutError): await self.batched_stream.close(timedelta(seconds=timeout)) self.actor_executor._work_queue.queue.clear() From a891a8316a2f28eeea77145a125e56570e51bc39 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 21 Jan 2020 18:40:22 -0800 Subject: [PATCH 0638/1550] Support args and kwargs in offload (#3392) --- distributed/tests/test_utils.py | 7 +++++++ distributed/utils.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/distributed/tests/test_utils.py b/distributed/tests/test_utils.py index 93b843358a8..e162b9fc2e1 100644 --- a/distributed/tests/test_utils.py +++ b/distributed/tests/test_utils.py @@ -43,6 +43,7 @@ warn_on_duration, format_dashboard_link, LRU, + offload, TimeoutError, ) from distributed.utils_test import loop, loop_in_thread # noqa: F401 @@ -618,3 +619,9 @@ def test_lru(): l["d"] = 4 assert len(l) == 3 assert list(l.keys()) == ["c", "a", "d"] + + +@pytest.mark.asyncio +async def test_offload(): + assert (await offload(inc, 1)) == 2 + assert (await offload(lambda x, y: x + y, 1, y=2)) == 3 diff --git a/distributed/utils.py b/distributed/utils.py index ee9d7948a11..c2e32d849c4 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -1394,7 +1394,7 @@ def import_term(name: str): async def offload(fn, *args, **kwargs): loop = asyncio.get_event_loop() - return await loop.run_in_executor(_offload_executor, fn, *args, **kwargs) + return await loop.run_in_executor(_offload_executor, lambda: fn(*args, **kwargs)) def serialize_for_cli(data): From 84ee20548892c8d6dd3e9d0f44712392d9a66e04 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Thu, 23 Jan 2020 10:01:39 -0600 Subject: [PATCH 0639/1550] Add GitHub actions badge to README for windows build (#3403) --- README.rst | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/README.rst b/README.rst index 3d9c02915fc..eb4c601bc8b 100644 --- a/README.rst +++ b/README.rst @@ -1,13 +1,15 @@ Distributed =========== -|Build Status| |Doc Status| |Gitter| |Version Status| |NumFOCUS| +|Linux Build Status| |Windows Build Status| |Doc Status| |Gitter| |Version Status| |NumFOCUS| A library for distributed computation. See documentation_ for more details. .. _documentation: https://distributed.dask.org -.. |Build Status| image:: https://travis-ci.org/dask/distributed.svg?branch=master +.. |Linux Build Status| image:: https://travis-ci.org/dask/distributed.svg?branch=master :target: https://travis-ci.org/dask/distributed +.. |Windows Build Status| image:: https://github.com/dask/distributed/workflows/Windows%20CI/badge.svg?branch=master + :target: https://github.com/dask/distributed/actions?query=workflow%3A%22Windows+CI%22 .. |Doc Status| image:: https://readthedocs.org/projects/distributed/badge/?version=latest :target: https://distributed.dask.org :alt: Documentation Status From 036bcba0c1fa1e682f3b0d8a7b8c4dfccb7cf3e7 Mon Sep 17 00:00:00 2001 From: Darren Weber Date: Fri, 24 Jan 2020 11:29:24 -0800 Subject: [PATCH 0640/1550] Revise develop-docs: functional test example (#3398) --- docs/source/develop.rst | 35 ++++++++++++++++++++--------------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/docs/source/develop.rst b/docs/source/develop.rst index 9b3b70afb15..f7a10c64471 100644 --- a/docs/source/develop.rst +++ b/docs/source/develop.rst @@ -78,28 +78,33 @@ The test suite contains three kinds of tests These are rare and mostly for testing the command line interface. If you are comfortable with the Tornado interface then you will be happiest -using the ``@gen_cluster`` style of test +using the ``@gen_cluster`` style of test, e.g. .. code-block:: python - from distributed.utils_test import gen_cluster + # tests/test_submit.py - @gen_cluster(client=True) - def test_submit(c, s, a, b): - assert isinstance(c, Client) - assert isinstance(s, Scheduler) - assert isinstance(a, Worker) - assert isinstance(b, Worker) + from distributed.utils_test import gen_cluster, inc + from distributed import Client, Future, Scheduler, Worker - future = c.submit(inc, 1) - assert future.key in c.futures + @gen_cluster(client=True) + def test_submit(c, s, a, b): + assert isinstance(c, Client) + assert isinstance(s, Scheduler) + assert isinstance(a, Worker) + assert isinstance(b, Worker) - # result = future.result() # This synchronous API call would block - result = yield future - assert result == 2 + future = c.submit(inc, 1) + assert isinstance(future, Future) + assert future.key in c.futures + + # result = future.result() # This synchronous API call would block + result = yield future + assert result == 2 + + assert future.key in s.tasks + assert future.key in a.data or future.key in b.data - assert future.key in s.tasks - assert future.key in a.data or future.key in b.data The ``@gen_cluster`` decorator sets up a scheduler, client, and workers for you and cleans them up after the test. It also allows you to directly inspect From fe20eaaee60cfe5e4a8ad66dd733927a751985aa Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Fri, 24 Jan 2020 16:58:16 -0600 Subject: [PATCH 0641/1550] Use instance-level client instead of class-level (#3408) --- distributed/comm/tcp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/comm/tcp.py b/distributed/comm/tcp.py index 77876a04fbc..ce6b7fa0b44 100644 --- a/distributed/comm/tcp.py +++ b/distributed/comm/tcp.py @@ -345,7 +345,7 @@ async def connect(self, address, deserialize=True, **connection_args): kwargs = self._get_connect_args(**connection_args) try: - stream = await BaseTCPConnector.client.connect( + stream = await self.client.connect( ip, port, max_buffer_size=MAX_BUFFER_SIZE, **kwargs ) From 4081a0e09febd9099f04a1db9772f6cd95203b2e Mon Sep 17 00:00:00 2001 From: jakirkham Date: Fri, 24 Jan 2020 21:58:21 -0800 Subject: [PATCH 0642/1550] Enable WorkStealing case-by-case (#3410) * Enable work-stealing dynamically * Include test for work-stealing config * Rewrite test to use `@pytest.mark.asyncio` --- distributed/scheduler.py | 7 +++---- distributed/tests/test_scheduler.py | 13 +++++++++++++ 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 8080a1186da..ab2e615cef2 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -89,9 +89,6 @@ PubSubSchedulerExtension, ] -if dask.config.get("distributed.scheduler.work-stealing"): - DEFAULT_EXTENSIONS.append(WorkStealing) - ALL_TASK_STATES = {"released", "waiting", "no-worker", "processing", "erred", "memory"} @@ -1333,7 +1330,9 @@ def __init__( self.periodic_callbacks["idle-timeout"] = pc if extensions is None: - extensions = DEFAULT_EXTENSIONS + extensions = list(DEFAULT_EXTENSIONS) + if dask.config.get("distributed.scheduler.work-stealing"): + extensions.append(WorkStealing) for ext in extensions: ext(self) diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index fd0775ad003..3ce681ea546 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -596,6 +596,19 @@ def test_coerce_address(): yield [w.close() for w in [a, b, c]] +@pytest.mark.asyncio +async def test_config_stealing(cleanup): + # Regression test for https://github.com/dask/distributed/issues/3409 + + with dask.config.set({"distributed.scheduler.work-stealing": True}): + async with Scheduler(port=0) as s: + assert "stealing" in s.extensions + + with dask.config.set({"distributed.scheduler.work-stealing": False}): + async with Scheduler(port=0) as s: + assert "stealing" not in s.extensions + + @pytest.mark.skipif( sys.platform.startswith("win"), reason="file descriptors not really a thing" ) From 4f61b3d341c5098dfb9267a2a8fc4a6e74483dd4 Mon Sep 17 00:00:00 2001 From: Chrysostomos Nanakos Date: Mon, 27 Jan 2020 18:57:05 +0200 Subject: [PATCH 0643/1550] Respect dashboard prefix when redirecting root (#3387) When --dashboard-prefix is used root location is always redirected to /status without adding the prefix. Fixes https://github.com/dask/distributed/issues/3405 Signed-off-by: Chrysostomos Nanakos --- distributed/dashboard/core.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/distributed/dashboard/core.py b/distributed/dashboard/core.py index 41e7c289c17..82cfe92da17 100644 --- a/distributed/dashboard/core.py +++ b/distributed/dashboard/core.py @@ -5,6 +5,7 @@ import bokeh from bokeh.server.server import Server from tornado import web +from urllib.parse import urljoin if LooseVersion(bokeh.__version__) < LooseVersion("0.13.0"): @@ -34,7 +35,13 @@ def listen(self, addr): check_unused_sessions_milliseconds=500, allow_websocket_origin=["*"], use_index=False, - extra_patterns=[(r"/", web.RedirectHandler, {"url": "/status"})], + extra_patterns=[ + ( + r"/", + web.RedirectHandler, + {"url": urljoin(self.prefix.rstrip("/") + "/", r"status")}, + ) + ], ) server_kwargs.update(self.server_kwargs) self.server = Server(self.apps, **server_kwargs) From 77ffa7b78a1096ce67a0366abffeaed071ed02fa Mon Sep 17 00:00:00 2001 From: jakirkham Date: Mon, 27 Jan 2020 13:32:36 -0800 Subject: [PATCH 0644/1550] Drop custom cuDF serialization (#3404) * Drop custom cuDF serialization This is since handled in cuDF 0.9.0, which was released a while ago. So go ahead and drop this from Distributed. * Restore cuDF import of serializers --- distributed/protocol/__init__.py | 7 +------ distributed/protocol/cudf.py | 23 ----------------------- 2 files changed, 1 insertion(+), 29 deletions(-) delete mode 100644 distributed/protocol/cudf.py diff --git a/distributed/protocol/__init__.py b/distributed/protocol/__init__.py index ef8b5564bbb..30ae3935498 100644 --- a/distributed/protocol/__init__.py +++ b/distributed/protocol/__init__.py @@ -83,9 +83,4 @@ def _register_numba(): @cuda_serialize.register_lazy("cudf") @cuda_deserialize.register_lazy("cudf") def _register_cudf(): - import cudf - - if LooseVersion(cudf.__version__) > "0.9": - from cudf.comm import serialize - else: - from . import cudf + from cudf.comm import serialize diff --git a/distributed/protocol/cudf.py b/distributed/protocol/cudf.py deleted file mode 100644 index f236a6c1f0c..00000000000 --- a/distributed/protocol/cudf.py +++ /dev/null @@ -1,23 +0,0 @@ -import pickle -import cudf -import cudf.groupby.groupby -from .cuda import cuda_serialize, cuda_deserialize -from ..utils import log_errors - - -# all (de-)serializtion code lives in the cudf codebase -# here we ammend the returned headers with `is_gpu` for -# UCX buffer consumption -@cuda_serialize.register((cudf.DataFrame, cudf.Series, cudf.groupby.groupby._Groupby)) -def serialize_cudf_dataframe(x): - with log_errors(): - header, frames = x.serialize() - return header, frames - - -@cuda_deserialize.register((cudf.DataFrame, cudf.Series, cudf.groupby.groupby._Groupby)) -def deserialize_cudf_dataframe(header, frames): - with log_errors(): - cudf_typ = pickle.loads(header["type"]) - cudf_obj = cudf_typ.deserialize(header, frames) - return cudf_obj From 241d0d44d8a108bf2412696455dd7486bdfbff17 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Mon, 27 Jan 2020 16:25:03 -0600 Subject: [PATCH 0645/1550] Add CI documentation build (#3411) --- .github/workflows/ci-docs.yaml | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 .github/workflows/ci-docs.yaml diff --git a/.github/workflows/ci-docs.yaml b/.github/workflows/ci-docs.yaml new file mode 100644 index 00000000000..e80e07b9c33 --- /dev/null +++ b/.github/workflows/ci-docs.yaml @@ -0,0 +1,28 @@ +name: Documentation CI + +on: [push, pull_request] + +jobs: + build: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v1 + + - name: Set up Python 3.7 + uses: actions/setup-python@v1 + with: + python-version: 3.7 + + - name: Install Distributed + run: | + python -m pip install --upgrade pip + pip install -e . + + - name: Install doc dependencies + run: pip install -r docs/requirements.txt + + - name: Build docs + run: | + cd docs + make html From 6cf1afe012435216bb5a1f5e6fa3f9c8258dde00 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Mon, 27 Jan 2020 19:06:18 -0600 Subject: [PATCH 0646/1550] Ignore no-worker state in TaskProgress (#3407) --- distributed/dashboard/components/scheduler.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/distributed/dashboard/components/scheduler.py b/distributed/dashboard/components/scheduler.py index 8eec6b8b772..2c9953e97e3 100644 --- a/distributed/dashboard/components/scheduler.py +++ b/distributed/dashboard/components/scheduler.py @@ -1418,12 +1418,13 @@ def update(self): } for tp in self.scheduler.task_prefixes.values(): - if any(tp.active_states.values()): - state["memory"][tp.name] = tp.active_states["memory"] - state["erred"][tp.name] = tp.active_states["erred"] - state["released"][tp.name] = tp.active_states["released"] - state["processing"][tp.name] = tp.active_states["processing"] - state["waiting"][tp.name] = tp.active_states["waiting"] + active_states = tp.active_states + if any(active_states.get(s) for s in state.keys()): + state["memory"][tp.name] = active_states["memory"] + state["erred"][tp.name] = active_states["erred"] + state["released"][tp.name] = active_states["released"] + state["processing"][tp.name] = active_states["processing"] + state["waiting"][tp.name] = active_states["waiting"] state["all"] = { k: sum(v[k] for v in state.values()) for k in state["memory"] From 457281ba826df106737ed96c9327124cc85f6d29 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Tue, 28 Jan 2020 12:29:29 -0600 Subject: [PATCH 0647/1550] DOC: Update changelog for 2.10.0 (#3421) [ci skip] --- docs/source/changelog.rst | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index 1bb19e70330..cd311309a30 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -1,6 +1,16 @@ Changelog ========= +2.10.0 - 2020-01-28 +------------------- + +- Fixed ``ZeroDivisionError`` in dashboard when no workers were present (:pr:`3407`) `James Bourbeau`_ +- Respect the ``dashboard-prefix`` when redirecting from the root (:pr:`3387`) `Chrysostomos Nanakos`_ +- Allow enabling / disabling work-stealing after the cluster has started (:pr:`3410`) `John Kirkham`_ +- Support ``*args`` and ``**kwargs`` in offload (:pr:`3392`) `Matthew Rocklin`_ +- Add lifecycle hooks to SchedulerPlugin (:pr:`3391`) `Matthew Rocklin`_ + + 2.9.3 - 2020-01-17 ------------------ @@ -1501,3 +1511,4 @@ significantly without many new features. .. _`Benedikt Reinartz`: https://github.com/filmor .. _`Markus Mohrhard`: https://github.com/mmohrhard .. _`Mana Borwornpadungkitti`: https://github.com/potpath +.. _`Chrysostomos Nanakos`: https://github.com/cnanakos From 6083181cdcf4a8bc23697910d6d7f5712bfd8b47 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Wed, 29 Jan 2020 08:55:30 -0600 Subject: [PATCH 0648/1550] Add Mac OS build to CI (#3358) --- .travis.yml | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 56c2588ff5f..c0c30316c9a 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,4 +1,4 @@ -language: python +language: generic # sudo shouldn't be required, but currently tests fail when run in a container # on travis instead of a vm. See https://github.com/dask/distributed/pull/1563. sudo: required @@ -13,8 +13,15 @@ matrix: fast_finish: true include: - os: linux + language: python python: 3.6 env: LINT=true + - os: osx + env: PYTHON=3.7 TESTS=true PACKAGES="scikit-learn python-snappy python-blosc" TORNADO=6 + if: type != pull_request OR commit_message =~ test-osx # Skip on PRs unless the commit message contains "test-osx" + + allow_failures: + - os: osx install: - if [[ $TESTS == true ]]; then source continuous_integration/travis/install.sh ; fi From fecf588c800f3230f08fe63d3f4f57360c3dd7cf Mon Sep 17 00:00:00 2001 From: Igor Gotlibovych Date: Wed, 29 Jan 2020 23:02:39 +0000 Subject: [PATCH 0649/1550] Support version checking with older versions of Dask (#3390) The test fails of one of the components' version info is None (we treat it as UNKNOWN). In addition, the function fails to report if the `client` is missing a package. --- distributed/tests/test_versions.py | 119 +++++++++++++++++++++++++++++ distributed/versions.py | 61 +++++++++------ 2 files changed, 155 insertions(+), 25 deletions(-) create mode 100644 distributed/tests/test_versions.py diff --git a/distributed/tests/test_versions.py b/distributed/tests/test_versions.py new file mode 100644 index 00000000000..25087df795a --- /dev/null +++ b/distributed/tests/test_versions.py @@ -0,0 +1,119 @@ +import re + +import pytest + +from distributed.versions import get_versions, error_message +from distributed import Client, Worker +from distributed.utils_test import gen_cluster + + +# if one of the nodes reports this version, there's a mismatch +mismatched_version = get_versions() +mismatched_version["packages"]["distributed"] = "0.0.0.dev0" + +# for really old versions, the `package` key is missing - version is UNKNOWN +key_err_version = {} + +# if no key is available for one package, we assume it's MISSING +missing_version = get_versions() +del missing_version["packages"]["distributed"] + +# if a node doesn't report any version info, we treat them as UNKNOWN +# the happens if the node is pre-32cb96e, i.e. <=2.9.1 +unknown_version = None + + +@pytest.fixture +def kwargs_matching(): + return dict( + scheduler=get_versions(), + workers={f"worker-{i}": get_versions() for i in range(3)}, + client=get_versions(), + ) + + +def test_versions_match(kwargs_matching): + assert error_message(**kwargs_matching) == "" + + +@pytest.fixture(params=["client", "scheduler", "worker-1"]) +def node(request): + """Node affected by version mismatch.""" + return request.param + + +@pytest.fixture(params=["MISMATCHED", "MISSING", "KEY_ERROR", "NONE"]) +def effect(request): + """Specify type of mismatch.""" + return request.param + + +@pytest.fixture +def kwargs_not_matching(kwargs_matching, node, effect): + affected_version = { + "MISMATCHED": mismatched_version, + "MISSING": missing_version, + "KEY_ERROR": key_err_version, + "NONE": unknown_version, + }[effect] + kwargs = kwargs_matching + if node in kwargs["workers"]: + kwargs["workers"][node] = affected_version + else: + kwargs[node] = affected_version + return kwargs + + +@pytest.fixture +def pattern(effect): + """Pattern to match in the right-hand column.""" + return { + "MISMATCHED": r"0\.0\.0\.dev0", + "MISSING": "MISSING", + "KEY_ERROR": "UNKNOWN", + "NONE": "UNKNOWN", + }[effect] + + +def test_version_mismatch(node, effect, kwargs_not_matching, pattern): + msg = error_message(**kwargs_not_matching) + + assert "Mismatched versions found" in msg + assert "distributed" in msg + assert re.search(node + r"\s+\|\s+" + pattern, msg) + + +def test_scheduler_mismatched_irrelevant_package(kwargs_matching): + """An irrelevant package on the scheduler can have any version.""" + kwargs_matching["scheduler"]["packages"]["numpy"] = "0.0.0" + assert "numpy" in kwargs_matching["client"]["packages"] + + assert error_message(**kwargs_matching) == "" + + +def test_scheduler_additional_irrelevant_package(kwargs_matching): + """An irrelevant package on the scheduler does not need to be present elsewhere.""" + kwargs_matching["scheduler"]["packages"]["pyspark"] = "0.0.0" + + assert error_message(**kwargs_matching) == "" + + +@gen_cluster() +async def test_version_warning_in_cluster(s, a, b): + s.workers[a.address].versions["packages"]["dask"] = "0.0.0" + + with pytest.warns(None) as record: + async with Client(s.address, asynchronous=True) as client: + pass + + assert record + assert any("dask" in str(r.message) for r in record) + assert any("0.0.0" in str(r.message) for r in record) + assert any(a.address in str(r.message) for r in record) + + async with Worker(s.address) as w: + assert any("This Worker" in line.message for line in w.logs) + assert any("dask" in line.message for line in w.logs) + assert any( + "0.0.0" in line.message and a.address in line.message for line in w.logs + ) diff --git a/distributed/versions.py b/distributed/versions.py index 0b97a6f7ac0..a7022c830f7 100644 --- a/distributed/versions.py +++ b/distributed/versions.py @@ -26,9 +26,15 @@ ] +# only these scheduler packages will be checked for version mismatch +scheduler_relevant_packages = set(pkg for pkg, _ in required_packages) | set( + ["lz4", "blosc"] +) + + def get_versions(packages=None): """ - Return basic information on our software installation, and out installed versions of packages. + Return basic information on our software installation, and our installed versions of packages. """ if packages is None: packages = [] @@ -98,34 +104,39 @@ def get_package_info(pkgs): def error_message(scheduler, workers, client, client_name="client"): - # we care about the required & optional packages matching - try: - client_versions = client["packages"] - versions = [("scheduler", scheduler["packages"])] - versions.extend((w, d["packages"]) for w, d in sorted(workers.items())) - except KeyError: - return ( - "Version mismatch for dask.distributed. " - "The scheduler has version >= 1.28.0 " - "but some other component is less than this" - ) + from .utils import asciitable - mismatched = defaultdict(list) - for name, vers in versions: - for pkg, cv in client_versions.items(): - v = vers.get(pkg, "MISSING") - if cv != v: - mismatched[pkg].append((name, v)) + nodes = {**{client_name: client}, **{"scheduler": scheduler}, **workers} - if mismatched: - from .utils import asciitable + # Hold all versions, e.g. versions["scheduler"]["distributed"] = 2.9.3 + node_packages = defaultdict(dict) - errs = [] - for pkg, versions in sorted(mismatched.items()): - rows = [(client_name, client_versions[pkg])] - rows.extend(versions) - errs.append("%s\n%s" % (pkg, asciitable(["", "version"], rows))) + # Collect all package versions + packages = set() + for node, info in nodes.items(): + if info is None or not (isinstance(info, dict)) or "packages" not in info: + node_packages[node] = defaultdict(lambda: "UNKNOWN") + else: + node_packages[node] = defaultdict(lambda: "MISSING") + for pkg, version in info["packages"].items(): + node_packages[node][pkg] = version + packages.add(pkg) + + errs = [] + for pkg in sorted(packages): + versions = set( + node_packages[node][pkg] + for node in nodes + if node != "scheduler" or pkg in scheduler_relevant_packages + ) + if len(versions) <= 1: + continue + rows = [ + (node_name, node_packages[node_name][pkg]) for node_name in nodes.keys() + ] + errs.append("%s\n%s" % (pkg, asciitable(["", "version"], rows))) + if errs: return "Mismatched versions found\n" "\n" "%s" % ("\n\n".join(errs)) else: return "" From 8eadf5e8fa56673aef9063fa91c52f4715638b6d Mon Sep 17 00:00:00 2001 From: Pierre Glaser Date: Thu, 30 Jan 2020 01:09:32 +0100 Subject: [PATCH 0650/1550] Make _get_ip return an IP address when defaulting (#3418) * default to fully-qualified domain names in _get_ip * make _get_ip return an IP address (not a hostname) --- distributed/utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/distributed/utils.py b/distributed/utils.py index c2e32d849c4..efc948cff19 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -138,13 +138,15 @@ def _get_ip(host, port, family): ip = sock.getsockname()[0] return ip except EnvironmentError as e: - # XXX Should first try getaddrinfo() on socket.gethostname() and getfqdn() warnings.warn( "Couldn't detect a suitable IP address for " "reaching %r, defaulting to hostname: %s" % (host, e), RuntimeWarning, ) - return socket.gethostname() + addr_info = socket.getaddrinfo( + socket.gethostname(), port, family, socket.SOCK_DGRAM, socket.IPPROTO_UDP + )[0] + return addr_info[4][0] finally: sock.close() From 3e7bbbdb47c88b656e08ceab70e65a1877a7bee4 Mon Sep 17 00:00:00 2001 From: fjetter Date: Thu, 30 Jan 2020 17:28:22 +0100 Subject: [PATCH 0651/1550] Allow memory monitor to evict data more aggressively (#3424) --- distributed/worker.py | 59 +++++++++++++++++++++++++++---------------- 1 file changed, 37 insertions(+), 22 deletions(-) diff --git a/distributed/worker.py b/distributed/worker.py index 7062aba5b87..e429cb75a4b 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -2584,36 +2584,44 @@ async def memory_monitor(self): memory = proc.memory_info().rss frac = memory / self.memory_limit - # Pause worker threads if above 80% memory use - if self.memory_pause_fraction and frac > self.memory_pause_fraction: - # Try to free some memory while in paused state - self._throttled_gc.collect() - if not self.paused: + def check_pause(memory): + frac = memory / self.memory_limit + # Pause worker threads if above 80% memory use + if self.memory_pause_fraction and frac > self.memory_pause_fraction: + # Try to free some memory while in paused state + self._throttled_gc.collect() + if not self.paused: + logger.warning( + "Worker is at %d%% memory usage. Pausing worker. " + "Process memory: %s -- Worker memory limit: %s", + int(frac * 100), + format_bytes(memory), + format_bytes(self.memory_limit) + if self.memory_limit is not None + else "None", + ) + self.paused = True + elif self.paused: logger.warning( - "Worker is at %d%% memory usage. Pausing worker. " + "Worker is at %d%% memory usage. Resuming worker. " "Process memory: %s -- Worker memory limit: %s", int(frac * 100), - format_bytes(proc.memory_info().rss), + format_bytes(memory), format_bytes(self.memory_limit) if self.memory_limit is not None else "None", ) - self.paused = True - elif self.paused: - logger.warning( - "Worker is at %d%% memory usage. Resuming worker. " - "Process memory: %s -- Worker memory limit: %s", - int(frac * 100), - format_bytes(proc.memory_info().rss), - format_bytes(self.memory_limit) - if self.memory_limit is not None - else "None", - ) - self.paused = False - self.ensure_computing() + self.paused = False + self.ensure_computing() + check_pause(memory) # Dump data to disk if above 70% if self.memory_spill_fraction and frac > self.memory_spill_fraction: + logger.debug( + "Worker is at %d%% memory usage. Start spilling data to disk.", + int(frac * 100), + ) + start = time() target = self.memory_limit * self.memory_target_fraction count = 0 need = memory - target @@ -2624,7 +2632,7 @@ async def memory_monitor(self): "to store to disk. Perhaps some other process " "is leaking memory? Process memory: %s -- " "Worker memory limit: %s", - format_bytes(proc.memory_info().rss), + format_bytes(memory), format_bytes(self.memory_limit) if self.memory_limit is not None else "None", @@ -2634,7 +2642,13 @@ async def memory_monitor(self): del k, v total += weight count += 1 - await asyncio.sleep(0) + # If the current buffer is filled with a lot of small values, + # evicting one at a time is very slow and the worker might + # generate new data faster than it is able to evict. Therefore, + # only pass on control if we spent at least 0.5s evicting + if time() - start > 0.5: + await asyncio.sleep(0) + start = time() memory = proc.memory_info().rss if total > need and memory > target: # Issue a GC to ensure that the evicted data is actually @@ -2642,6 +2656,7 @@ async def memory_monitor(self): # before trying to evict even more data. self._throttled_gc.collect() memory = proc.memory_info().rss + check_pause(memory) if count: logger.debug( "Moved %d pieces of data data and %s to disk", From a298fdaceca133b414b8182cb5e2e331fab06585 Mon Sep 17 00:00:00 2001 From: Jacob Tomlinson Date: Thu, 30 Jan 2020 17:36:58 +0000 Subject: [PATCH 0652/1550] Add dashboard_link property to Client (#3429) --- distributed/client.py | 74 ++++++++++++++++++-------------- distributed/tests/test_client.py | 3 +- 2 files changed, 43 insertions(+), 34 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 7a39ec4b235..5bde92a4bb1 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -752,6 +752,22 @@ def asynchronous(self): """ return self._asynchronous and self.loop is IOLoop.current() + @property + def dashboard_link(self): + scheduler, info = self._get_scheduler_info() + try: + return self.cluster.dashboard_link + except AttributeError: + protocol, rest = scheduler.address.split("://") + + port = info["services"]["dashboard"] + if protocol == "inproc": + host = "localhost" + else: + host = rest.split(":")[0] + + return format_dashboard_link(host, port) + def sync(self, func, *args, asynchronous=None, callback_timeout=None, **kwargs): if ( asynchronous @@ -767,6 +783,29 @@ def sync(self, func, *args, asynchronous=None, callback_timeout=None, **kwargs): self.loop, func, *args, callback_timeout=callback_timeout, **kwargs ) + def _get_scheduler_info(self): + from .scheduler import Scheduler + + if ( + self.cluster + and hasattr(self.cluster, "scheduler") + and isinstance(self.cluster.scheduler, Scheduler) + ): + info = self.cluster.scheduler.identity() + scheduler = self.cluster.scheduler + elif ( + self._loop_runner.is_started() + and self.scheduler + and not (self.asynchronous and self.loop is IOLoop.current()) + ): + info = sync(self.loop, self.scheduler.identity) + scheduler = self.scheduler + else: + info = self._scheduler_identity + scheduler = self.scheduler + + return scheduler, info + def __repr__(self): # Note: avoid doing I/O here... info = self._scheduler_identity @@ -796,25 +835,7 @@ def __repr__(self): return "<%s: not connected>" % (self.__class__.__name__,) def _repr_html_(self): - from .scheduler import Scheduler - - if ( - self.cluster - and hasattr(self.cluster, "scheduler") - and isinstance(self.cluster.scheduler, Scheduler) - ): - info = self.cluster.scheduler.identity() - scheduler = self.cluster.scheduler - elif ( - self._loop_runner.is_started() - and self.scheduler - and not (self.asynchronous and self.loop is IOLoop.current()) - ): - info = sync(self.loop, self.scheduler.identity) - scheduler = self.scheduler - else: - info = self._scheduler_identity - scheduler = self.scheduler + scheduler, info = self._get_scheduler_info() text = ( '

          Client

          \n' @@ -826,22 +847,9 @@ def _repr_html_(self): text += "
        • Scheduler: not connected
        • \n" if info and "dashboard" in info["services"]: - try: - address = self.cluster.dashboard_link - except AttributeError: - protocol, rest = scheduler.address.split("://") - - port = info["services"]["dashboard"] - if protocol == "inproc": - host = "localhost" - else: - host = rest.split(":")[0] - - address = format_dashboard_link(host, port) - text += ( "
        • Dashboard: %(web)s\n" - % {"web": address} + % {"web": self.dashboard_link} ) text += "
        \n" diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 77180eefa35..a205452b2f8 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -5249,8 +5249,9 @@ def test_dashboard_link(loop, monkeypatch): with dask.config.set( {"distributed.dashboard.link": "{scheme}://foo-{USER}:{port}/status"} ): - text = c._repr_html_() link = "http://foo-myusername:12355/status" + assert link == c.dashboard_link + text = c._repr_html_() assert link in text From 12a4f2d38f131ddb63849d1be0805f414fa6d816 Mon Sep 17 00:00:00 2001 From: Anderson Banihirwe Date: Fri, 31 Jan 2020 12:48:36 -0700 Subject: [PATCH 0653/1550] Remove object from class hierarchy (#3432) --- distributed/_concurrent_futures_thread.py | 2 +- distributed/actor.py | 4 ++-- distributed/batched.py | 2 +- distributed/client.py | 6 +++--- distributed/comm/inproc.py | 4 ++-- distributed/comm/tcp.py | 2 +- distributed/comm/tests/test_comms.py | 2 +- distributed/core.py | 8 ++++---- distributed/counter.py | 4 ++-- distributed/dashboard/components/__init__.py | 2 +- distributed/dashboard/core.py | 2 +- distributed/dashboard/scheduler.py | 2 +- distributed/dashboard/worker.py | 2 +- distributed/deploy/cluster.py | 2 +- distributed/deploy/old_ssh.py | 2 +- distributed/deploy/spec.py | 2 +- .../deploy/tests/test_slow_adaptive.py | 2 +- distributed/deploy/utils_test.py | 2 +- distributed/diagnostics/plugin.py | 4 ++-- distributed/diagnostics/progressbar.py | 4 ++-- distributed/diskutils.py | 4 ++-- distributed/lock.py | 4 ++-- distributed/locket.py | 8 ++++---- distributed/metrics.py | 2 +- distributed/nanny.py | 2 +- distributed/node.py | 2 +- distributed/process.py | 4 ++-- distributed/protocol/cupy.py | 2 +- distributed/protocol/serialize.py | 6 +++--- distributed/protocol/tests/test_serialize.py | 10 +++++----- distributed/publish.py | 2 +- distributed/pubsub.py | 10 +++++----- distributed/pytest_resourceleaks.py | 4 ++-- distributed/queues.py | 4 ++-- distributed/recreate_exceptions.py | 4 ++-- distributed/scheduler.py | 10 +++++----- distributed/security.py | 2 +- distributed/system_monitor.py | 2 +- distributed/tests/test_actor.py | 20 +++++++++---------- distributed/tests/test_batched.py | 2 +- distributed/tests/test_client.py | 14 ++++++------- distributed/tests/test_core.py | 2 +- distributed/tests/test_steal.py | 2 +- distributed/tests/test_utils_perf.py | 2 +- distributed/tests/test_worker.py | 10 +++++----- distributed/utils.py | 4 ++-- distributed/utils_comm.py | 2 +- distributed/utils_perf.py | 6 +++--- distributed/utils_test.py | 2 +- distributed/variable.py | 4 ++-- distributed/worker.py | 2 +- docs/source/actors.rst | 2 +- docs/source/adaptive.rst | 4 ++-- docs/source/serialization.rst | 2 +- 54 files changed, 111 insertions(+), 111 deletions(-) diff --git a/distributed/_concurrent_futures_thread.py b/distributed/_concurrent_futures_thread.py index 02ff7c649aa..b26da12cb7a 100644 --- a/distributed/_concurrent_futures_thread.py +++ b/distributed/_concurrent_futures_thread.py @@ -50,7 +50,7 @@ def _python_exit(): atexit.register(_python_exit) -class _WorkItem(object): +class _WorkItem: def __init__(self, future, fn, args, kwargs): self.future = future self.fn = fn diff --git a/distributed/actor.py b/distributed/actor.py index e7e4afaacf0..37f43b69358 100644 --- a/distributed/actor.py +++ b/distributed/actor.py @@ -178,7 +178,7 @@ def client(self): return self._future.client -class ProxyRPC(object): +class ProxyRPC: """ An rpc-like object that uses the scheduler's rpc to connect to a worker """ @@ -196,7 +196,7 @@ async def func(**msg): return func -class ActorFuture(object): +class ActorFuture: """ Future to an actor's method call Whenever you call a method on an Actor you get an ActorFuture immediately diff --git a/distributed/batched.py b/distributed/batched.py index e066fcf7588..13c241d1e1b 100644 --- a/distributed/batched.py +++ b/distributed/batched.py @@ -12,7 +12,7 @@ logger = logging.getLogger(__name__) -class BatchedSend(object): +class BatchedSend: """ Batch messages in batches on a stream This takes an IOStream and an interval (in ms) and ensures that we send no diff --git a/distributed/client.py b/distributed/client.py index 5bde92a4bb1..706101d2fe4 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -412,7 +412,7 @@ def __await__(self): return self.result().__await__() -class FutureState(object): +class FutureState: """A Future's internal state. This is shared between all Futures with the same key and client. @@ -4142,7 +4142,7 @@ async def _first_completed(futures): return result -class as_completed(object): +class as_completed: """ Return futures in the order in which they complete @@ -4480,7 +4480,7 @@ def fire_and_forget(obj): ) -class get_task_stream(object): +class get_task_stream: """ Collect task stream within a context block diff --git a/distributed/comm/inproc.py b/distributed/comm/inproc.py index e46c2804ed1..c0191f024f6 100644 --- a/distributed/comm/inproc.py +++ b/distributed/comm/inproc.py @@ -25,7 +25,7 @@ ) -class Manager(object): +class Manager: """ An object coordinating listeners and their addresses. """ @@ -87,7 +87,7 @@ class QueueEmpty(Exception): pass -class Queue(object): +class Queue: """ A single-reader, single-writer, non-threadsafe, peekable queue. """ diff --git a/distributed/comm/tcp.py b/distributed/comm/tcp.py index ce6b7fa0b44..7003053ce06 100644 --- a/distributed/comm/tcp.py +++ b/distributed/comm/tcp.py @@ -324,7 +324,7 @@ def _expect_tls_context(connection_args): return ctx -class RequireEncryptionMixin(object): +class RequireEncryptionMixin: def _check_encryption(self, address, connection_args): if not self.encrypted and connection_args.get("require_encryption"): # XXX Should we have a dedicated SecurityError class? diff --git a/distributed/comm/tests/test_comms.py b/distributed/comm/tests/test_comms.py index 470c667b989..b486912f281 100644 --- a/distributed/comm/tests/test_comms.py +++ b/distributed/comm/tests/test_comms.py @@ -1070,7 +1070,7 @@ def _raise_eoferror(): raise EOFError -class _EOFRaising(object): +class _EOFRaising: def __reduce__(self): return _raise_eoferror, () diff --git a/distributed/core.py b/distributed/core.py index 81cd7adf8e4..3dad1223030 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -60,7 +60,7 @@ def _raise(*args, **kwargs): LOG_PDB = dask.config.get("distributed.admin.pdb-on-err") -class Server(object): +class Server: """ Dask Distributed Server Superclass for endpoints in a distributed cluster, such as Worker @@ -569,7 +569,7 @@ def addr_from_args(addr=None, ip=None, port=None): return normalize_address(addr) -class rpc(object): +class rpc: """ Conveniently interact with a remote server >>> remote = rpc(address) # doctest: +SKIP @@ -728,7 +728,7 @@ def __repr__(self): return "" % (self.address, len(self.comms)) -class PooledRPCCall(object): +class PooledRPCCall: """ The result of ConnectionPool()('host:port') See Also: @@ -777,7 +777,7 @@ def __repr__(self): return "" % (self.addr,) -class ConnectionPool(object): +class ConnectionPool: """ A maximum sized pool of Comm objects. This provides a connect method that mirrors the normal distributed.connect diff --git a/distributed/counter.py b/distributed/counter.py index f41961e87ac..ebc8cda6104 100644 --- a/distributed/counter.py +++ b/distributed/counter.py @@ -11,7 +11,7 @@ pass else: - class Digest(object): + class Digest: def __init__(self, loop=None, intervals=(5, 60, 3600)): self.intervals = intervals self.components = [TDigest() for i in self.intervals] @@ -39,7 +39,7 @@ def size(self): return sum(d.size() for d in self.components) -class Counter(object): +class Counter: def __init__(self, loop=None, intervals=(5, 60, 3600)): self.intervals = intervals self.components = [defaultdict(lambda: 0) for i in self.intervals] diff --git a/distributed/dashboard/components/__init__.py b/distributed/dashboard/components/__init__.py index a66be2eced6..bb8269083e9 100644 --- a/distributed/dashboard/components/__init__.py +++ b/distributed/dashboard/components/__init__.py @@ -42,7 +42,7 @@ profile_interval = parse_timedelta(profile_interval, default="ms") -class DashboardComponent(object): +class DashboardComponent: """ Base class for Dask.distributed UI dashboard components. This class must have two attributes, ``root`` and ``source``, and one diff --git a/distributed/dashboard/core.py b/distributed/dashboard/core.py index 82cfe92da17..9b919917a67 100644 --- a/distributed/dashboard/core.py +++ b/distributed/dashboard/core.py @@ -16,7 +16,7 @@ raise ImportError("Dask needs bokeh >= 0.13.0") -class BokehServer(object): +class BokehServer: server_kwargs = {} def listen(self, addr): diff --git a/distributed/dashboard/scheduler.py b/distributed/dashboard/scheduler.py index 2c0520161b3..a030ba434f7 100644 --- a/distributed/dashboard/scheduler.py +++ b/distributed/dashboard/scheduler.py @@ -251,7 +251,7 @@ def get(self): self.write(result) -class _PrometheusCollector(object): +class _PrometheusCollector: def __init__(self, server): self.server = server diff --git a/distributed/dashboard/worker.py b/distributed/dashboard/worker.py index 5a34a261bf1..db29480666b 100644 --- a/distributed/dashboard/worker.py +++ b/distributed/dashboard/worker.py @@ -35,7 +35,7 @@ } -class _PrometheusCollector(object): +class _PrometheusCollector: def __init__(self, server): self.worker = server self.logger = logging.getLogger("distributed.dask_worker") diff --git a/distributed/deploy/cluster.py b/distributed/deploy/cluster.py index 1b304b0a53e..ad071a214be 100644 --- a/distributed/deploy/cluster.py +++ b/distributed/deploy/cluster.py @@ -21,7 +21,7 @@ logger = logging.getLogger(__name__) -class Cluster(object): +class Cluster: """ Superclass for cluster objects This class contains common functionality for Dask Cluster manager classes. diff --git a/distributed/deploy/old_ssh.py b/distributed/deploy/old_ssh.py index 30f6f819224..86d49c9cf15 100644 --- a/distributed/deploy/old_ssh.py +++ b/distributed/deploy/old_ssh.py @@ -335,7 +335,7 @@ def start_worker( return merge(cmd_dict, {"thread": thread}) -class SSHCluster(object): +class SSHCluster: def __init__( self, scheduler_addr, diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index 96279d15323..537fa3201f4 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -92,7 +92,7 @@ async def __aexit__(self, *args, **kwargs): await self.close() -class NoOpAwaitable(object): +class NoOpAwaitable: """An awaitable object that always returns None. Useful to return from a method that can be called in both asynchronous and diff --git a/distributed/deploy/tests/test_slow_adaptive.py b/distributed/deploy/tests/test_slow_adaptive.py index 09113fe3b23..e7021fc854a 100644 --- a/distributed/deploy/tests/test_slow_adaptive.py +++ b/distributed/deploy/tests/test_slow_adaptive.py @@ -6,7 +6,7 @@ from distributed.metrics import time -class SlowWorker(object): +class SlowWorker: def __init__(self, *args, delay=0, **kwargs): self.worker = Worker(*args, **kwargs) self.delay = delay diff --git a/distributed/deploy/utils_test.py b/distributed/deploy/utils_test.py index 2bb55c7da08..fd6ba03aae9 100644 --- a/distributed/deploy/utils_test.py +++ b/distributed/deploy/utils_test.py @@ -3,7 +3,7 @@ import pytest -class ClusterTest(object): +class ClusterTest: Cluster = None kwargs = {} diff --git a/distributed/diagnostics/plugin.py b/distributed/diagnostics/plugin.py index 1d218fe5ac8..12e7ad6ec3f 100644 --- a/distributed/diagnostics/plugin.py +++ b/distributed/diagnostics/plugin.py @@ -3,7 +3,7 @@ logger = logging.getLogger(__name__) -class SchedulerPlugin(object): +class SchedulerPlugin: """ Interface to extend the Scheduler The scheduler operates by triggering and responding to events like @@ -86,7 +86,7 @@ def remove_client(self, scheduler=None, client=None, **kwargs): """ Run when a client disconnects """ -class WorkerPlugin(object): +class WorkerPlugin: """ Interface to extend the Worker A worker plugin enables custom code to run at different stages of the Workers' diff --git a/distributed/diagnostics/progressbar.py b/distributed/diagnostics/progressbar.py index 01dc9bbea39..ab7800c2125 100644 --- a/distributed/diagnostics/progressbar.py +++ b/distributed/diagnostics/progressbar.py @@ -24,7 +24,7 @@ def get_scheduler(scheduler): return coerce_to_address(scheduler) -class ProgressBar(object): +class ProgressBar: def __init__(self, keys, scheduler=None, interval="100ms", complete=True): self.scheduler = get_scheduler(scheduler) @@ -207,7 +207,7 @@ def _draw_bar(self, remaining, all, **kwargs): ) -class MultiProgressBar(object): +class MultiProgressBar: def __init__( self, keys, diff --git a/distributed/diskutils.py b/distributed/diskutils.py index 075ec7750c8..64124b753a8 100644 --- a/distributed/diskutils.py +++ b/distributed/diskutils.py @@ -30,7 +30,7 @@ def safe_unlink(path): logger.error("Failed to remove %r", str(e)) -class WorkDir(object): +class WorkDir: """ A temporary work directory inside a WorkSpace. """ @@ -102,7 +102,7 @@ def _finalize(cls, workspace, lock_path, lock_file, dir_path): safe_unlink(lock_path) -class WorkSpace(object): +class WorkSpace: """ An on-disk workspace that tracks disposable work directories inside it. If a process crashes or another event left stale directories behind, diff --git a/distributed/lock.py b/distributed/lock.py index c230a8e861c..48d538915f0 100644 --- a/distributed/lock.py +++ b/distributed/lock.py @@ -12,7 +12,7 @@ logger = logging.getLogger(__name__) -class LockExtension(object): +class LockExtension: """ An extension for the scheduler to manage Locks This adds the following routes to the scheduler @@ -73,7 +73,7 @@ def release(self, stream=None, name=None, id=None): del self.events[name] -class Lock(object): +class Lock: """ Distributed Centralized Lock Parameters diff --git a/distributed/locket.py b/distributed/locket.py index 1ed7b023085..65a10f195f7 100644 --- a/distributed/locket.py +++ b/distributed/locket.py @@ -114,7 +114,7 @@ def _acquire_non_blocking(acquire, timeout, retry_period, path): time.sleep(retry_period) -class _LockSet(object): +class _LockSet: def __init__(self, locks): self._locks = locks @@ -136,7 +136,7 @@ def release(self): lock.release() -class _ThreadLock(object): +class _ThreadLock: def __init__(self, path): self._path = path self._lock = threading.Lock() @@ -156,7 +156,7 @@ def release(self): self._lock.release() -class _LockFile(object): +class _LockFile: def __init__(self, path): self._path = path self._file = None @@ -181,7 +181,7 @@ def release(self): self._file = None -class _Locker(object): +class _Locker: """ A lock wrapper to always apply the given *timeout* and *retry_period* to acquire() calls. diff --git a/distributed/metrics.py b/distributed/metrics.py index fefdfeb2e4c..f28e9f2ac7f 100755 --- a/distributed/metrics.py +++ b/distributed/metrics.py @@ -36,7 +36,7 @@ def wrapper(): net_io_counters = _psutil_caller("net_io_counters") -class _WindowsTime(object): +class _WindowsTime: """ Combine time.time() and time.perf_counter() to get an absolute clock with fine resolution. diff --git a/distributed/nanny.py b/distributed/nanny.py index 945f33041d3..9c95dd4a07a 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -450,7 +450,7 @@ async def close(self, comm=None, timeout=5, report=None): await ServerNode.close(self) -class WorkerProcess(object): +class WorkerProcess: def __init__( self, worker_kwargs, diff --git a/distributed/node.py b/distributed/node.py index edee3e2dd7b..4e26defeb08 100644 --- a/distributed/node.py +++ b/distributed/node.py @@ -12,7 +12,7 @@ from .utils import DequeHandler, TimeoutError -class Node(object): +class Node: """ Base class for nodes in a distributed cluster. """ diff --git a/distributed/process.py b/distributed/process.py index 5899c853385..b070342b340 100644 --- a/distributed/process.py +++ b/distributed/process.py @@ -40,13 +40,13 @@ def _call_and_set_future(loop, future, func, *args, **kwargs): _loop_add_callback(loop, future.set_result, res) -class _ProcessState(object): +class _ProcessState: is_alive = False pid = None exitcode = None -class AsyncProcess(object): +class AsyncProcess: """ A coroutine-compatible multiprocessing.Process-alike. All normally blocking methods are wrapped in Tornado coroutines. diff --git a/distributed/protocol/cupy.py b/distributed/protocol/cupy.py index 26a5accc6af..087de6f9663 100644 --- a/distributed/protocol/cupy.py +++ b/distributed/protocol/cupy.py @@ -5,7 +5,7 @@ from .cuda import cuda_serialize, cuda_deserialize -class PatchedCudaArrayInterface(object): +class PatchedCudaArrayInterface: """This class do two things: 1) Makes sure that __cuda_array_interface__['strides'] behaves as specified in the protocol. diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index 8d1d37a283e..ddab6130765 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -268,7 +268,7 @@ def deserialize(header, frames, deserializers=None): return loads(header, frames) -class Serialize(object): +class Serialize: """ Mark an object that should be serialized Example @@ -301,7 +301,7 @@ def __hash__(self): to_serialize = Serialize -class Serialized(object): +class Serialized: """ An object that is already serialized into header and frames @@ -484,7 +484,7 @@ def register_serialization(cls, serialize, deserialize): Examples -------- - >>> class Human(object): + >>> class Human: ... def __init__(self, name): ... self.name = name diff --git a/distributed/protocol/tests/test_serialize.py b/distributed/protocol/tests/test_serialize.py index 6ba70f676a1..b5a202f1520 100644 --- a/distributed/protocol/tests/test_serialize.py +++ b/distributed/protocol/tests/test_serialize.py @@ -26,7 +26,7 @@ from distributed.comm.utils import to_frames, from_frames -class MyObj(object): +class MyObj: def __init__(self, data): self.data = data @@ -151,7 +151,7 @@ def test_inter_worker_comms(c, s, a, b): assert o2.data == 123 -class Empty(object): +class Empty: def __getstate__(self): raise Exception("Not picklable") @@ -213,7 +213,7 @@ class BadException(Exception): def __setstate__(self): return Exception("Sneaky deserialization code") - class MyClass(object): + class MyClass: def __getstate__(self): raise BadException() @@ -258,7 +258,7 @@ def test_err_on_bad_deserializer(): yield from_frames(frames, deserializers=["msgpack"]) -class MyObject(object): +class MyObject: def __init__(self, **kwargs): self.__dict__.update(kwargs) @@ -348,7 +348,7 @@ def check(dask_worker): def test_serialize_raises(): - class Foo(object): + class Foo: pass @dask_serialize.register(Foo) diff --git a/distributed/publish.py b/distributed/publish.py index c899b9fbaaa..758e5ccc34b 100644 --- a/distributed/publish.py +++ b/distributed/publish.py @@ -3,7 +3,7 @@ from .utils import log_errors, tokey -class PublishExtension(object): +class PublishExtension: """ An extension for the scheduler to manage collections * publish-list diff --git a/distributed/pubsub.py b/distributed/pubsub.py index 9de133ddb47..3c8b140b362 100644 --- a/distributed/pubsub.py +++ b/distributed/pubsub.py @@ -14,7 +14,7 @@ logger = logging.getLogger(__name__) -class PubSubSchedulerExtension(object): +class PubSubSchedulerExtension: """ Extend Dask's scheduler with routes to handle PubSub machinery """ def __init__(self, scheduler): @@ -117,7 +117,7 @@ def handle_message(self, name=None, msg=None, worker=None, client=None): ) -class PubSubWorkerExtension(object): +class PubSubWorkerExtension: """ Extend Dask's Worker with routes to handle PubSub machinery """ def __init__(self, worker): @@ -170,7 +170,7 @@ def cleanup(self): del self.publish_to_scheduler[name] -class PubSubClientExtension(object): +class PubSubClientExtension: """ Extend Dask's Client with handlers to handle PubSub machinery """ def __init__(self, client): @@ -199,7 +199,7 @@ def cleanup(self): self.client.scheduler_comm.send(msg) -class Pub(object): +class Pub: """ Publish data with Publish-Subscribe pattern This allows clients and workers to directly communicate data between each @@ -349,7 +349,7 @@ def __repr__(self): __str__ = __repr__ -class Sub(object): +class Sub: """ Subscribe to a Publish/Subscribe topic See Also diff --git a/distributed/pytest_resourceleaks.py b/distributed/pytest_resourceleaks.py index 0119a425722..348472892d6 100644 --- a/distributed/pytest_resourceleaks.py +++ b/distributed/pytest_resourceleaks.py @@ -92,7 +92,7 @@ def decorate(cls): return decorate -class ResourceChecker(object): +class ResourceChecker: def on_start_test(self): pass @@ -260,7 +260,7 @@ def format(self, before, after): return "\n".join(lines) -class LeakChecker(object): +class LeakChecker: def __init__(self, checkers, grace_delay, mark_failed, max_retries): self.checkers = checkers self.grace_delay = grace_delay diff --git a/distributed/queues.py b/distributed/queues.py index 6d1fc76571b..9f5db0af68e 100644 --- a/distributed/queues.py +++ b/distributed/queues.py @@ -14,7 +14,7 @@ logger = logging.getLogger(__name__) -class QueueExtension(object): +class QueueExtension: """ An extension for the scheduler to manage queues This adds the following routes to the scheduler @@ -139,7 +139,7 @@ def qsize(self, stream=None, name=None, client=None): return self.queues[name].qsize() -class Queue(object): +class Queue: """ Distributed Queue This allows multiple clients to share futures or small bits of data between diff --git a/distributed/recreate_exceptions.py b/distributed/recreate_exceptions.py index 9138c1fca5a..4aaa851ee23 100644 --- a/distributed/recreate_exceptions.py +++ b/distributed/recreate_exceptions.py @@ -7,7 +7,7 @@ logger = logging.getLogger(__name__) -class ReplayExceptionScheduler(object): +class ReplayExceptionScheduler: """ A plugin for the scheduler to recreate exceptions locally This adds the following routes to the scheduler @@ -50,7 +50,7 @@ def cause_of_failure(self, *args, keys=(), **kwargs): } -class ReplayExceptionClient(object): +class ReplayExceptionClient: """ A plugin for the client allowing replay of remote exceptions locally diff --git a/distributed/scheduler.py b/distributed/scheduler.py index ab2e615cef2..89c938f0dd5 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -92,7 +92,7 @@ ALL_TASK_STATES = {"released", "waiting", "no-worker", "processing", "erred", "memory"} -class ClientState(object): +class ClientState: """ A simple object holding information about a client. @@ -128,7 +128,7 @@ def __str__(self): return self.client_key -class WorkerState(object): +class WorkerState: """ A simple object holding information about a worker. @@ -324,7 +324,7 @@ def ncores(self): return self.nthreads -class TaskState(object): +class TaskState: """ A simple object holding information about a task. @@ -683,7 +683,7 @@ def validate(self): pdb.set_trace() -class TaskGroup(object): +class TaskGroup: """ Collection tracking all tasks within a group Keys often have a structure like ``("x-123", 0)`` @@ -754,7 +754,7 @@ def __len__(self): return sum(self.states.values()) -class TaskPrefix(object): +class TaskPrefix: """ Collection tracking all tasks within a group Keys often have a structure like ``("x-123", 0)`` diff --git a/distributed/security.py b/distributed/security.py index 6b7d87b2715..f3430ac7b3e 100644 --- a/distributed/security.py +++ b/distributed/security.py @@ -13,7 +13,7 @@ __all__ = ("Security",) -class Security(object): +class Security: """Security configuration for a Dask cluster. Default values are loaded from Dask's configuration files, and can be diff --git a/distributed/system_monitor.py b/distributed/system_monitor.py index 5b3bed3f98d..cf305869a8c 100644 --- a/distributed/system_monitor.py +++ b/distributed/system_monitor.py @@ -5,7 +5,7 @@ from .metrics import time -class SystemMonitor(object): +class SystemMonitor: def __init__(self, n=10000): self.proc = psutil.Process() diff --git a/distributed/tests/test_actor.py b/distributed/tests/test_actor.py index fd6bf0335e1..de69db5685a 100644 --- a/distributed/tests/test_actor.py +++ b/distributed/tests/test_actor.py @@ -11,7 +11,7 @@ from distributed.metrics import time -class Counter(object): +class Counter: n = 0 def __init__(self): @@ -26,7 +26,7 @@ def add(self, x): return self.n -class List(object): +class List: L = [] def __init__(self, dummy=None): @@ -36,7 +36,7 @@ def append(self, x): self.L.append(x) -class ParameterServer(object): +class ParameterServer: def __init__(self): self.data = {} @@ -156,7 +156,7 @@ def test_linear_access(c, s, a, b): @gen_cluster(client=True) def test_exceptions_create(c, s, a, b): - class Foo(object): + class Foo: x = 0 def __init__(self): @@ -170,7 +170,7 @@ def __init__(self): @gen_cluster(client=True) def test_exceptions_method(c, s, a, b): - class Foo(object): + class Foo: def throw(self): 1 / 0 @@ -349,7 +349,7 @@ def add(n, counter): @gen_cluster(client=True, nthreads=[("127.0.0.1", 5)] * 2) def test_thread_safety(c, s, a, b): - class Unsafe(object): + class Unsafe: def __init__(self): self.n = 0 @@ -378,7 +378,7 @@ def test_Actors_create_dependencies(c, s, a, b): @gen_cluster(client=True) def test_load_balance(c, s, a, b): - class Foo(object): + class Foo: def __init__(self, x): pass @@ -396,7 +396,7 @@ def __init__(self, x): @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 5) def test_load_balance_map(c, s, *workers): - class Foo(object): + class Foo: def __init__(self, x, y=None): pass @@ -510,7 +510,7 @@ def check(dask_worker): config={"distributed.worker.profile.interval": "1ms"}, ) def test_actors_in_profile(c, s, a): - class Sleeper(object): + class Sleeper: def sleep(self, time): sleep(time) @@ -530,7 +530,7 @@ def sleep(self, time): def test_waiter(c, s, a, b): from tornado.locks import Event - class Waiter(object): + class Waiter: def __init__(self): self.event = Event() diff --git a/distributed/tests/test_batched.py b/distributed/tests/test_batched.py index 07dd32f4c68..74efba810d3 100644 --- a/distributed/tests/test_batched.py +++ b/distributed/tests/test_batched.py @@ -12,7 +12,7 @@ from distributed.protocol import to_serialize -class EchoServer(object): +class EchoServer: count = 0 async def handle_comm(self, comm): diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index a205452b2f8..16e660492f8 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -1130,7 +1130,7 @@ def test_scatter_hash(c, s, a, b): def test_scatter_tokenize_local(c, s, a, b): from dask.base import normalize_token - class MyObj(object): + class MyObj: pass L = [] @@ -1846,7 +1846,7 @@ def f(x, y=10): assert result == 100 + 1 + 200 -class BadlySerializedObject(object): +class BadlySerializedObject: def __getstate__(self): return 1 @@ -1854,7 +1854,7 @@ def __setstate__(self, state): raise TypeError("hello!") -class FatallySerializedObject(object): +class FatallySerializedObject: def __getstate__(self): return 1 @@ -3018,7 +3018,7 @@ def test_replicate_workers(c, s, *workers): s.validate_state() -class CountSerialization(object): +class CountSerialization: def __init__(self): self.n = 0 @@ -4472,7 +4472,7 @@ class MyException(Exception): @gen_cluster(client=True) def test_robust_unserializable(c, s, a, b): - class Foo(object): + class Foo: def __getstate__(self): raise MyException() @@ -4488,7 +4488,7 @@ def __getstate__(self): @gen_cluster(client=True) def test_robust_undeserializable(c, s, a, b): - class Foo(object): + class Foo: def __getstate__(self): return 1 @@ -4508,7 +4508,7 @@ def __setstate__(self, state): @gen_cluster(client=True) def test_robust_undeserializable_function(c, s, a, b): - class Foo(object): + class Foo: def __getstate__(self): return 1 diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index d423c6ab6c3..0a9c48bc870 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -49,7 +49,7 @@ def echo(comm, x): return x -class CountedObject(object): +class CountedObject: """ A class which counts the number of live instances. """ diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index de63e542807..9c4fef57d2a 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -623,7 +623,7 @@ def long(delay): @gen_cluster(client=True, nthreads=[("127.0.0.1", 5)] * 2) def test_cleanup_repeated_tasks(c, s, a, b): - class Foo(object): + class Foo: pass s.extensions["stealing"]._pc.callback_time = 20 diff --git a/distributed/tests/test_utils_perf.py b/distributed/tests/test_utils_perf.py index 4256548900c..a1591df0280 100644 --- a/distributed/tests/test_utils_perf.py +++ b/distributed/tests/test_utils_perf.py @@ -11,7 +11,7 @@ from distributed.utils_test import captured_logger, run_for -class RandomTimer(object): +class RandomTimer: """ A mock timer producing random (but monotonic) values. """ diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 11ab461ae25..6c1a0805817 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -86,7 +86,7 @@ def test_identity(): @gen_cluster(client=True) def test_worker_bad_args(c, s, a, b): - class NoReprObj(object): + class NoReprObj: """ This object cannot be properly represented as a string. """ def __str__(self): @@ -833,7 +833,7 @@ def test_worker_dir(c, s, a, b): @gen_cluster(client=True) def test_dataframe_attribute_error(c, s, a, b): - class BadSize(object): + class BadSize: def __init__(self, data): self.data = data @@ -847,7 +847,7 @@ def __sizeof__(self): @gen_cluster(client=True) def test_fail_write_to_disk(c, s, a, b): - class Bad(object): + class Bad: def __getstate__(self): raise TypeError() @@ -876,7 +876,7 @@ def test_fail_write_many_to_disk(c, s, a): yield gen.sleep(0.1) assert not a.paused - class Bad(object): + class Bad: def __init__(self, x): pass @@ -1094,7 +1094,7 @@ def test_robust_to_bad_sizeof_estimates(c, s, a): memory = psutil.Process().memory_info().rss a.memory_limit = memory / 0.7 + 400e6 - class BadAccounting(object): + class BadAccounting: def __init__(self, data): self.data = data diff --git a/distributed/utils.py b/distributed/utils.py index efc948cff19..086555643ea 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -349,7 +349,7 @@ def f(): return result[0] -class LoopRunner(object): +class LoopRunner: """ A helper to start and stop an IO loop in a controlled way. Several loop runners can associate safely to the same IO loop. @@ -1061,7 +1061,7 @@ def import_file(path): return loaded -class itemgetter(object): +class itemgetter: """A picklable itemgetter. Examples diff --git a/distributed/utils_comm.py b/distributed/utils_comm.py index 792e73227a9..3d10ba51038 100644 --- a/distributed/utils_comm.py +++ b/distributed/utils_comm.py @@ -96,7 +96,7 @@ async def gather_from_workers(who_has, rpc, close=True, serializers=None, who=No return (results, bad_keys, list(missing_workers)) -class WrappedKey(object): +class WrappedKey: """ Interface for a key in a dask graph. Subclasses must have .key attribute that refers to a key in a dask graph. diff --git a/distributed/utils_perf.py b/distributed/utils_perf.py index f21e96d7353..3b97dd46327 100644 --- a/distributed/utils_perf.py +++ b/distributed/utils_perf.py @@ -12,7 +12,7 @@ logger = _logger = logging.getLogger(__name__) -class ThrottledGC(object): +class ThrottledGC: """Wrap gc.collect to protect against excessively repeated calls. Allows to run throttled garbage collection in the workers as a @@ -67,7 +67,7 @@ def collect(self): ) -class FractionalTimer(object): +class FractionalTimer: """ An object that measures runtimes, accumulates them and computes a running fraction of the recent runtimes over the corresponding @@ -128,7 +128,7 @@ def running_fraction(self): return self._running_fraction -class GCDiagnosis(object): +class GCDiagnosis: """ An object that hooks itself into the gc callbacks to collect timing and memory statistics, and log interesting info. diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 1650e0426ef..e16983b1879 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -343,7 +343,7 @@ def run_for(duration, timer=time): _varying_key_gen = itertools.count() -class _ModuleSlot(object): +class _ModuleSlot: def __init__(self, modname, slotname): self.modname = modname self.slotname = slotname diff --git a/distributed/variable.py b/distributed/variable.py index 677e2997b32..bfd6ca250e2 100644 --- a/distributed/variable.py +++ b/distributed/variable.py @@ -19,7 +19,7 @@ logger = logging.getLogger(__name__) -class VariableExtension(object): +class VariableExtension: """ An extension for the scheduler to manage queues This adds the following routes to the scheduler @@ -114,7 +114,7 @@ async def delete(self, stream=None, name=None, client=None): del self.variables[name] -class Variable(object): +class Variable: """ Distributed Global Variable This allows multiple clients to share futures and data between each other diff --git a/distributed/worker.py b/distributed/worker.py index e429cb75a4b..77711058682 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -3417,7 +3417,7 @@ def get_msg_safe_str(msg): ignoring them. """ - class Repr(object): + class Repr: def __init__(self, f, val): self._f = f self._val = val diff --git a/docs/source/actors.rst b/docs/source/actors.rst index b8bbebc743a..85d1300e502 100644 --- a/docs/source/actors.rst +++ b/docs/source/actors.rst @@ -197,7 +197,7 @@ will run on the Worker's event loop thread rather than a separate thread. .. code-block:: python - def Waiter(object): + def Waiter: def __init__(self): self.event = tornado.locks.Event() diff --git a/docs/source/adaptive.rst b/docs/source/adaptive.rst index 774ae21e4a6..f07246588cd 100644 --- a/docs/source/adaptive.rst +++ b/docs/source/adaptive.rst @@ -59,7 +59,7 @@ the correct times. .. code-block:: python - class MyCluster(object): + class MyCluster: async def scale_up(self, n, **kwargs): """ Bring the total count of workers up to ``n`` @@ -110,7 +110,7 @@ We reproduce the full body of the implementation below as an example: from marathon import MarathonClient, MarathonApp from marathon.models.container import MarathonContainer - class MarathonCluster(object): + class MarathonCluster: def __init__(self, scheduler, executable='dask-worker', docker_image='mrocklin/dask-distributed', diff --git a/docs/source/serialization.rst b/docs/source/serialization.rst index 25b2ae49476..ec315cc321e 100644 --- a/docs/source/serialization.rst +++ b/docs/source/serialization.rst @@ -154,7 +154,7 @@ register them with Dask. .. code-block:: python - class Human(object): + class Human: def __init__(self, name): self.name = name From 7c9da106b2edd543d3b420ad8270261c6a04ec5e Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 3 Feb 2020 09:09:37 -0800 Subject: [PATCH 0654/1550] Replace tornado.locks with asyncio for Events/Locks/Conditions/Semaphore (#3397) --- distributed/client.py | 66 ++++++++++++++-------------- distributed/comm/inproc.py | 3 +- distributed/comm/tests/test_comms.py | 4 +- distributed/core.py | 3 +- distributed/deploy/spec.py | 3 +- distributed/lock.py | 6 +-- distributed/nanny.py | 5 +-- distributed/pubsub.py | 43 +++++++++--------- distributed/queues.py | 35 ++++----------- distributed/tests/test_nanny.py | 3 +- distributed/tests/test_pubsub.py | 2 + distributed/tests/test_variable.py | 32 ++++++++++---- distributed/variable.py | 40 ++++++++++------- docs/source/actors.rst | 2 +- 14 files changed, 124 insertions(+), 123 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 706101d2fe4..93f395e2dc4 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -1,3 +1,4 @@ +import asyncio import atexit from collections import defaultdict from collections.abc import Iterator @@ -8,6 +9,7 @@ import errno from functools import partial import html +from inspect import isawaitable import itertools import json import logging @@ -37,12 +39,7 @@ except ImportError: single_key = first from tornado import gen -from tornado.locks import Event, Condition, Semaphore from tornado.ioloop import IOLoop -from tornado.queues import Queue - -import asyncio -from asyncio import iscoroutine from .batched import BatchedSend from .utils_comm import ( @@ -431,7 +428,7 @@ def _get_event(self): # (https://github.com/tornadoweb/tornado/issues/2189) event = self._event if event is None: - event = self._event = Event() + event = self._event = asyncio.Event() return event def cancel(self): @@ -470,7 +467,7 @@ def reset(self): self._event.clear() async def wait(self, timeout=None): - await self._get_event().wait(timeout) + await asyncio.wait_for(self._get_event().wait(), timeout) def __repr__(self): return "<%s: %s>" % (self.__class__.__name__, self.status) @@ -630,10 +627,6 @@ def __init__( self._deserializers = deserializers self.direct_to_workers = direct_to_workers - self._gather_semaphore = Semaphore(5) - self._gather_keys = None - self._gather_future = None - # Communication self.scheduler_comm = None @@ -678,6 +671,10 @@ def __init__( self._loop_runner = LoopRunner(loop=loop, asynchronous=asynchronous) self.loop = self._loop_runner.loop + self._gather_semaphore = asyncio.Semaphore(5, loop=self.loop.asyncio_loop) + self._gather_keys = None + self._gather_future = None + if heartbeat_interval is None: heartbeat_interval = dask.config.get("distributed.client.heartbeat") heartbeat_interval = parse_timedelta(heartbeat_interval, default="ms") @@ -1190,7 +1187,9 @@ async def _handle_report(self): try: handler = self._stream_handlers[op] - handler(**msg) + result = handler(**msg) + if isawaitable(result): + await result except Exception as e: logger.exception(e) if breakout: @@ -1246,6 +1245,8 @@ def _handle_error(self, exception=None): async def _close(self, fast=False): """ Send close signal and wait until scheduler completes """ + if self.status == "closed": + return self.status = "closing" for pc in self._periodic_callbacks.values(): @@ -1260,8 +1261,6 @@ async def _close(self, fast=False): pass if self.get == dask.config.get("get", None): del dask.config.config["get"] - if self.status == "closed": - return if ( self.scheduler_comm @@ -1353,7 +1352,7 @@ def close(self, timeout=no_default): if self._start_arg is None: with ignoring(AttributeError): f = self.cluster.close() - if iscoroutine(f): + if asyncio.iscoroutine(f): async def _(): await f @@ -1373,6 +1372,7 @@ async def _shutdown(self): await self.cluster.close() else: with ignoring(CommClosedError): + self.status = "closing" await self.scheduler.terminate(close_workers=True) def shutdown(self): @@ -1808,12 +1808,11 @@ async def _gather_remote(self, direct, local_worker): few. In controls access using a Tornado semaphore, and picks up keys from other requests made recently. """ - await self._gather_semaphore.acquire() - keys = list(self._gather_keys) - self._gather_keys = None # clear state, these keys are being sent off - self._gather_future = None + async with self._gather_semaphore: + keys = list(self._gather_keys) + self._gather_keys = None # clear state, these keys are being sent off + self._gather_future = None - try: if direct or local_worker: # gather directly from workers who_has = await retry_operation(self.scheduler.who_has, keys=keys) data2, missing_keys, missing_workers = await gather_from_workers( @@ -1828,8 +1827,6 @@ async def _gather_remote(self, direct, local_worker): else: # ask scheduler to gather data for us response = await retry_operation(self.scheduler.gather, keys=keys) - finally: - self._gather_semaphore.release() return response @@ -2919,10 +2916,12 @@ async def _restart(self, timeout=no_default): if timeout == no_default: timeout = self._timeout * 2 self._send_to_scheduler({"op": "restart", "timeout": timeout}) - self._restart_event = Event() + self._restart_event = asyncio.Event() try: - await self._restart_event.wait(self.loop.time() + timeout) - except gen.TimeoutError: + await asyncio.wait_for( + self._restart_event.wait(), self.loop.time() + timeout + ) + except TimeoutError: logger.error("Restart timed out after %f seconds", timeout) pass self.generation += 1 @@ -4136,7 +4135,7 @@ async def _first_completed(futures): See Also: _as_completed """ - q = Queue() + q = asyncio.Queue() await _as_completed(futures, q) result = await q.get() return result @@ -4207,7 +4206,7 @@ def __init__(self, futures=None, loop=None, with_results=False, raise_errors=Tru self.queue = pyQueue() self.lock = threading.Lock() self.loop = loop or default_client().loop - self.condition = Condition() + self.condition = asyncio.Condition(loop=self.loop.asyncio_loop) self.thread_condition = threading.Condition() self.with_results = with_results self.raise_errors = raise_errors @@ -4215,11 +4214,6 @@ def __init__(self, futures=None, loop=None, with_results=False, raise_errors=Tru if futures: self.update(futures) - def _notify(self): - self.condition.notify() - with self.thread_condition: - self.thread_condition.notify() - async def _track_future(self, future): try: await _wait(future) @@ -4238,7 +4232,10 @@ async def _track_future(self, future): self.queue.put_nowait((future, result)) else: self.queue.put_nowait(future) - self._notify() + async with self.condition: + self.condition.notify() + with self.thread_condition: + self.thread_condition.notify() def update(self, futures): """ Add multiple futures to the collection. @@ -4305,7 +4302,8 @@ async def __anext__(self): while self.queue.empty(): if not self.futures: raise StopAsyncIteration - await self.condition.wait() + async with self.condition: + await self.condition.wait() return self._get_and_raise() diff --git a/distributed/comm/inproc.py b/distributed/comm/inproc.py index c0191f024f6..0642cce7381 100644 --- a/distributed/comm/inproc.py +++ b/distributed/comm/inproc.py @@ -7,7 +7,6 @@ import weakref import warnings -from tornado import locks from tornado.concurrent import Future from tornado.ioloop import IOLoop @@ -298,7 +297,7 @@ async def connect(self, address, deserialize=True, **connection_args): s2c_q=Queue(), c_loop=IOLoop.current(), c_addr=self.manager.new_address(), - conn_event=locks.Event(), + conn_event=asyncio.Event(), ) listener.connect_threadsafe(conn_req) # Wait for connection acknowledgement diff --git a/distributed/comm/tests/test_comms.py b/distributed/comm/tests/test_comms.py index b486912f281..150251f3d59 100644 --- a/distributed/comm/tests/test_comms.py +++ b/distributed/comm/tests/test_comms.py @@ -9,7 +9,7 @@ import pkg_resources import pytest -from tornado import ioloop, locks, queues +from tornado import ioloop, queues from tornado.concurrent import Future import distributed @@ -901,7 +901,7 @@ async def handle_comm(comm): async def check_connector_deserialize(addr, deserialize, in_value, check_out): - done = locks.Event() + done = asyncio.Event() async def handle_comm(comm): await comm.write(in_value) diff --git a/distributed/core.py b/distributed/core.py index 3dad1223030..78dd618e8e2 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -13,7 +13,6 @@ from toolz import merge from tornado import gen from tornado.ioloop import IOLoop -from tornado.locks import Event from .comm import ( connect, @@ -134,7 +133,7 @@ def __init__( self.events = None self.event_counts = None self._ongoing_coroutines = weakref.WeakSet() - self._event_finished = Event() + self._event_finished = asyncio.Event() self.listeners = [] self.io_loop = io_loop or IOLoop.current() diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index 537fa3201f4..17b1af28148 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -6,7 +6,6 @@ import weakref import dask -from tornado.locks import Event from tornado import gen from .adaptive import Adaptive @@ -42,7 +41,7 @@ def __init__(self, scheduler=None, name=None): self.external_address = None self.lock = asyncio.Lock() self.status = "created" - self._event_finished = Event() + self._event_finished = asyncio.Event() def __await__(self): async def _(): diff --git a/distributed/lock.py b/distributed/lock.py index 48d538915f0..3eceba5ce7a 100644 --- a/distributed/lock.py +++ b/distributed/lock.py @@ -1,9 +1,7 @@ +import asyncio from collections import defaultdict, deque import logging import uuid -import asyncio - -import tornado.locks from .client import _get_global_client from .utils import log_errors, TimeoutError @@ -40,7 +38,7 @@ async def acquire(self, stream=None, name=None, id=None, timeout=None): result = True else: while name in self.ids: - event = tornado.locks.Event() + event = asyncio.Event() self.events[name].append(event) future = event.wait() if timeout is not None: diff --git a/distributed/nanny.py b/distributed/nanny.py index 9c95dd4a07a..5e67b51b52b 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -12,7 +12,6 @@ import dask from dask.system import CPU_COUNT from tornado.ioloop import IOLoop -from tornado.locks import Event from tornado import gen from .comm import get_address_host, unparse_host_port @@ -507,8 +506,8 @@ async def start(self): ) self.process.daemon = dask.config.get("distributed.worker.daemon", default=True) self.process.set_exit_callback(self._on_exit) - self.running = Event() - self.stopped = Event() + self.running = asyncio.Event() + self.stopped = asyncio.Event() self.status = "starting" try: await self.process.start() diff --git a/distributed/pubsub.py b/distributed/pubsub.py index 3c8b140b362..355fee7ae59 100644 --- a/distributed/pubsub.py +++ b/distributed/pubsub.py @@ -1,14 +1,12 @@ +import asyncio from collections import defaultdict, deque -import datetime import logging import threading import weakref -import tornado.locks -from tornado import gen - from .core import CommClosedError -from .utils import sync, TimeoutError +from .metrics import time +from .utils import sync, TimeoutError, ignoring from .protocol.serialize import to_serialize logger = logging.getLogger(__name__) @@ -148,9 +146,9 @@ def remove_subscriber(self, name=None, address=None): def publish_scheduler(self, name=None, publish=None): self.publish_to_scheduler[name] = publish - def handle_message(self, name=None, msg=None): + async def handle_message(self, name=None, msg=None): for sub in self.subscribers.get(name, []): - sub._put(msg) + await sub._put(msg) def trigger_cleanup(self): self.worker.loop.add_callback(self.cleanup) @@ -180,9 +178,9 @@ def __init__(self, client): self.subscribers = defaultdict(weakref.WeakSet) self.client.extensions["pubsub"] = self # TODO: circular reference - def handle_message(self, name=None, msg=None): + async def handle_message(self, name=None, msg=None): for sub in self.subscribers[name]: - sub._put(msg) + await sub._put(msg) if not self.subscribers[name]: self.client.scheduler_comm.send( @@ -374,7 +372,7 @@ def __init__(self, name, worker=None, client=None): self.loop = self.client.loop self.name = name self.buffer = deque() - self.condition = tornado.locks.Condition() + self.condition = asyncio.Condition(loop=self.loop.asyncio_loop) if self.worker: pubsub = self.worker.extensions["pubsub"] @@ -393,20 +391,24 @@ def __init__(self, name, worker=None, client=None): weakref.finalize(self, pubsub.trigger_cleanup) async def _get(self, timeout=None): - if timeout is not None: - timeout = datetime.timedelta(seconds=timeout) - start = datetime.datetime.now() + start = time() while not self.buffer: if timeout is not None: - timeout2 = timeout - (datetime.datetime.now() - start) - if timeout2.total_seconds() < 0: + timeout2 = timeout - (time() - start) + if timeout2 < 0: raise TimeoutError() else: timeout2 = None + + async def _(): + await self.condition.acquire() + await self.condition.wait() + try: - await self.condition.wait(timeout=timeout2) - except gen.TimeoutError: - raise TimeoutError("Timed out waiting on Sub") + await asyncio.wait_for(_(), timeout2) + finally: + with ignoring(RuntimeError): # Python 3.6 fails here sometimes + self.condition.release() return self.buffer.popleft() @@ -431,9 +433,10 @@ def __iter__(self): def __aiter__(self): return self - def _put(self, msg): + async def _put(self, msg): self.buffer.append(msg) - self.condition.notify() + async with self.condition: + self.condition.notify() def __repr__(self): return "".format(self.name) diff --git a/distributed/queues.py b/distributed/queues.py index 9f5db0af68e..6cdab880aa6 100644 --- a/distributed/queues.py +++ b/distributed/queues.py @@ -1,14 +1,10 @@ +import asyncio from collections import defaultdict -import datetime import logging import uuid -import tornado.queues -from tornado.locks import Event -from tornado import gen - from .client import Future, _get_global_client, Client -from .utils import tokey, sync, thread_state, TimeoutError +from .utils import tokey, sync, thread_state from .worker import get_client logger = logging.getLogger(__name__) @@ -50,7 +46,7 @@ def __init__(self, scheduler): def create(self, stream=None, name=None, client=None, maxsize=0): logger.debug("Queue name: {}".format(name)) if name not in self.queues: - self.queues[name] = tornado.queues.Queue(maxsize=maxsize) + self.queues[name] = asyncio.Queue(maxsize=maxsize) self.client_refcount[name] = 1 else: self.client_refcount[name] += 1 @@ -77,12 +73,7 @@ async def put( self.scheduler.client_desires_keys(keys=[key], client="queue-%s" % name) else: record = {"type": "msgpack", "value": data} - if timeout is not None: - timeout = datetime.timedelta(seconds=timeout) - try: - await self.queues[name].put(record, timeout=timeout) - except gen.TimeoutError: - raise TimeoutError("Timed out waiting for Queue") + await asyncio.wait_for(self.queues[name].put(record), timeout=timeout) def future_release(self, name=None, key=None, client=None): self.future_refcount[name, key] -= 1 @@ -126,12 +117,7 @@ def process(record): out = [process(o) for o in out] return out else: - if timeout is not None: - timeout = datetime.timedelta(seconds=timeout) - try: - record = await self.queues[name].get(timeout=timeout) - except gen.TimeoutError: - raise TimeoutError("Timed out waiting for Queue") + record = await asyncio.wait_for(self.queues[name].get(), timeout=timeout) record = process(record) return record @@ -171,7 +157,7 @@ class Queue: def __init__(self, name=None, client=None, maxsize=0): self.client = client or _get_global_client() self.name = name or "queue-" + uuid.uuid4().hex - self._event_started = Event() + self._event_started = asyncio.Event() if self.client.asynchronous or getattr( thread_state, "on_event_loop_thread", False ): @@ -232,12 +218,9 @@ def qsize(self, **kwargs): return self.client.sync(self._qsize, **kwargs) async def _get(self, timeout=None, batch=False): - try: - resp = await self.client.scheduler.queue_get( - timeout=timeout, name=self.name, batch=batch - ) - except gen.TimeoutError: - raise TimeoutError("Timed out waiting for Queue") + resp = await self.client.scheduler.queue_get( + timeout=timeout, name=self.name, batch=batch + ) def process(d): if d["type"] == "Future": diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index 2ddc3b7e5db..0091a6126f1 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -12,7 +12,6 @@ from toolz import valmap, first from tornado import gen from tornado.ioloop import IOLoop -from tornado.locks import Event import dask from distributed.diagnostics import SchedulerPlugin @@ -453,7 +452,7 @@ async def test_nanny_closes_cleanly(cleanup): @pytest.mark.asyncio async def test_lifetime(cleanup): counter = 0 - event = Event() + event = asyncio.Event() class Plugin(SchedulerPlugin): def add_worker(self, **kwargs): diff --git a/distributed/tests/test_pubsub.py b/distributed/tests/test_pubsub.py index 847b0b88bf0..2e372dea88b 100644 --- a/distributed/tests/test_pubsub.py +++ b/distributed/tests/test_pubsub.py @@ -124,6 +124,8 @@ def test_timeouts(c, s, a, b): yield sub.get(timeout=0.1) stop = time() assert stop - start < 1 + with pytest.raises(TimeoutError): + yield sub.get(timeout=0.01) @gen_cluster(client=True) diff --git a/distributed/tests/test_variable.py b/distributed/tests/test_variable.py index 962b7a40e42..6e3b3bcdad6 100644 --- a/distributed/tests/test_variable.py +++ b/distributed/tests/test_variable.py @@ -5,9 +5,11 @@ import pytest from tornado import gen +from tornado.ioloop import IOLoop from distributed import Client, Variable, worker_client, Nanny, wait, TimeoutError from distributed.metrics import time +from distributed.compatibility import WINDOWS from distributed.utils_test import gen_cluster, inc, div from distributed.utils_test import client, cluster_fixture, loop # noqa: F401 @@ -83,20 +85,34 @@ def test_hold_futures(s, a, b): def test_timeout(c, s, a, b): v = Variable("v") - start = time() + start = IOLoop.current().time() + with pytest.raises(TimeoutError): + yield v.get(timeout=0.2) + stop = IOLoop.current().time() + + if WINDOWS: # timing is weird with asyncio and Windows + assert 0.1 < stop - start < 2.0 + else: + assert 0.2 < stop - start < 2.0 + with pytest.raises(TimeoutError): - yield v.get(timeout=0.1) - stop = time() - assert 0.1 < stop - start < 2.0 + yield v.get(timeout=0.01) def test_timeout_sync(client): v = Variable("v") - start = time() + start = IOLoop.current().time() + with pytest.raises(TimeoutError): + v.get(timeout=0.2) + stop = IOLoop.current().time() + + if WINDOWS: + assert 0.1 < stop - start < 2.0 + else: + assert 0.2 < stop - start < 2.0 + with pytest.raises(TimeoutError): - v.get(timeout=0.1) - stop = time() - assert 0.1 < stop - start < 2.0 + yield v.get(timeout=0.01) @gen_cluster(client=True) diff --git a/distributed/variable.py b/distributed/variable.py index bfd6ca250e2..9024ab03d8b 100644 --- a/distributed/variable.py +++ b/distributed/variable.py @@ -3,17 +3,13 @@ import logging import uuid -import tornado.locks -from tornado import gen - try: from cytoolz import merge except ImportError: from toolz import merge from .client import Future, _get_global_client, Client -from .metrics import time -from .utils import tokey, log_errors, TimeoutError +from .utils import tokey, log_errors, TimeoutError, ignoring from .worker import get_client logger = logging.getLogger(__name__) @@ -33,8 +29,8 @@ def __init__(self, scheduler): self.scheduler = scheduler self.variables = dict() self.waiting = defaultdict(set) - self.waiting_conditions = defaultdict(tornado.locks.Condition) - self.started = tornado.locks.Condition() + self.waiting_conditions = defaultdict(asyncio.Condition) + self.started = asyncio.Condition() self.scheduler.handlers.update( {"variable_set": self.set, "variable_get": self.get} @@ -45,7 +41,7 @@ def __init__(self, scheduler): self.scheduler.extensions["variables"] = self - def set(self, stream=None, name=None, key=None, data=None, client=None): + async def set(self, stream=None, name=None, key=None, data=None, client=None): if key is not None: record = {"type": "Future", "value": key} self.scheduler.client_desires_keys(keys=[key], client="variable-%s" % name) @@ -59,34 +55,44 @@ def set(self, stream=None, name=None, key=None, data=None, client=None): if old["type"] == "Future" and old["value"] != key: asyncio.ensure_future(self.release(old["value"], name)) if name not in self.variables: - self.started.notify_all() + async with self.started: + self.started.notify_all() self.variables[name] = record async def release(self, key, name): while self.waiting[key, name]: - await self.waiting_conditions[name].wait() + async with self.waiting_conditions[name]: + await self.waiting_conditions[name].wait() self.scheduler.client_releases_keys(keys=[key], client="variable-%s" % name) del self.waiting[key, name] - def future_release(self, name=None, key=None, token=None, client=None): + async def future_release(self, name=None, key=None, token=None, client=None): self.waiting[key, name].remove(token) if not self.waiting[key, name]: - self.waiting_conditions[name].notify_all() + async with self.waiting_conditions[name]: + self.waiting_conditions[name].notify_all() async def get(self, stream=None, name=None, client=None, timeout=None): - start = time() + start = self.scheduler.loop.time() while name not in self.variables: if timeout is not None: - left = timeout - (time() - start) + left = timeout - (self.scheduler.loop.time() - start) else: left = None if left and left < 0: raise TimeoutError() try: - await self.started.wait(timeout=left) - except gen.TimeoutError: - raise TimeoutError("Timed out waiting for Variable.get") + + async def _(): # Python 3.6 is odd and requires special help here + await self.started.acquire() + await self.started.wait() + + await asyncio.wait_for(_(), timeout=left) + finally: + with ignoring(RuntimeError): # Python 3.6 loses lock on finally clause + self.started.release() + record = self.variables[name] if record["type"] == "Future": key = record["value"] diff --git a/docs/source/actors.rst b/docs/source/actors.rst index 85d1300e502..370837629f3 100644 --- a/docs/source/actors.rst +++ b/docs/source/actors.rst @@ -199,7 +199,7 @@ will run on the Worker's event loop thread rather than a separate thread. def Waiter: def __init__(self): - self.event = tornado.locks.Event() + self.event = asyncio.Event() async def set(self): self.event.set() From d9c481815e2dbfe4f4691a0f3c6071be6cf3a471 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=B8ren=20Fuglede=20J=C3=B8rgensen?= Date: Mon, 3 Feb 2020 19:06:18 +0100 Subject: [PATCH 0655/1550] Add documentation of parameters in coordination primitives (#3434) --- distributed/lock.py | 8 ++++++-- distributed/pubsub.py | 8 +++++++- distributed/queues.py | 12 ++++++++++++ distributed/variable.py | 9 +++++++++ 4 files changed, 34 insertions(+), 3 deletions(-) diff --git a/distributed/lock.py b/distributed/lock.py index 3eceba5ce7a..3c893a419c2 100644 --- a/distributed/lock.py +++ b/distributed/lock.py @@ -76,9 +76,13 @@ class Lock: Parameters ---------- - name: string + name: string (optional) Name of the lock to acquire. Choosing the same name allows two - disconnected processes to coordinate a lock. + disconnected processes to coordinate a lock. If not given, a random + name will be generated. + client: Client (optional) + Client to use for communication with the scheduler. If not given, the + default global client will be used. Examples -------- diff --git a/distributed/pubsub.py b/distributed/pubsub.py index 355fee7ae59..ca4e06c44d1 100644 --- a/distributed/pubsub.py +++ b/distributed/pubsub.py @@ -238,7 +238,13 @@ class Pub: Parameters ---------- name: object (msgpack serializable) - The name of the group of Pubs and Subs on which to participate + The name of the group of Pubs and Subs on which to participate. + worker: Worker (optional) + The worker to be used for publishing data. Defaults to the value of + ```get_worker()```. If given, ``client`` must be ``None``. + client: Client (optional) + Client used for communication with the scheduler. Defaults to + the value of ``get_client()``. If given, ``worker`` must be ``None``. Examples -------- diff --git a/distributed/queues.py b/distributed/queues.py index 6cdab880aa6..81262703ad4 100644 --- a/distributed/queues.py +++ b/distributed/queues.py @@ -141,6 +141,18 @@ class Queue: This object is experimental and has known issues in Python 2 + Parameters + ---------- + name: string (optional) + Name used by other clients and the scheduler to identify the queue. If + not given, a random name will be generated. + client: Client (optional) + Client used for communication with the scheduler. Defaults to the + value of ``_get_global_client()``. + maxsize: int (optional) + Number of items allowed in the queue. If 0 (the default), the queue + size is unbounded. + Examples -------- >>> from dask.distributed import Client, Queue # doctest: +SKIP diff --git a/distributed/variable.py b/distributed/variable.py index 9024ab03d8b..fc4cc396dab 100644 --- a/distributed/variable.py +++ b/distributed/variable.py @@ -136,6 +136,15 @@ class Variable: This object is experimental and has known issues in Python 2 + Parameters + ---------- + name: string (optional) + Name used by other clients and the scheduler to identify the variable. + If not given, a random name will be generated. + client: Client (optional) + Client used for communication with the scheduler. Defaults to the + value of ``_get_global_client()``. + Examples -------- >>> from dask.distributed import Client, Variable # doctest: +SKIP From eb8de64072a2ada03e195b5c2ea218e46abb1068 Mon Sep 17 00:00:00 2001 From: Cyril Shcherbin Date: Mon, 3 Feb 2020 19:11:21 +0100 Subject: [PATCH 0656/1550] Call pip as a module to avoid warnings (#3433) (#3436) --- .github/workflows/ci-docs.yaml | 4 ++-- .github/workflows/ci-windows.yaml | 2 +- .travis.yml | 6 ++--- continuous_integration/travis/install.sh | 24 +++++++++---------- distributed/client.py | 2 +- distributed/dashboard/proxy.py | 4 ++-- .../dashboard/tests/test_scheduler_bokeh.py | 2 +- distributed/protocol/keras.py | 2 +- distributed/worker.py | 4 +++- docs/source/install.rst | 2 +- docs/source/quickstart.rst | 2 +- 11 files changed, 28 insertions(+), 26 deletions(-) diff --git a/.github/workflows/ci-docs.yaml b/.github/workflows/ci-docs.yaml index e80e07b9c33..780e2a251fd 100644 --- a/.github/workflows/ci-docs.yaml +++ b/.github/workflows/ci-docs.yaml @@ -17,10 +17,10 @@ jobs: - name: Install Distributed run: | python -m pip install --upgrade pip - pip install -e . + python -m pip install -e . - name: Install doc dependencies - run: pip install -r docs/requirements.txt + run: python -m pip install -r docs/requirements.txt - name: Build docs run: | diff --git a/.github/workflows/ci-windows.yaml b/.github/workflows/ci-windows.yaml index ecbf29f5a3d..992f65d6435 100644 --- a/.github/workflows/ci-windows.yaml +++ b/.github/workflows/ci-windows.yaml @@ -24,7 +24,7 @@ jobs: - name: Install distributed from source shell: bash -l {0} - run: pip install -q --no-deps -e . + run: python -m pip install -q --no-deps -e . - name: Run tests shell: bash -l {0} diff --git a/.travis.yml b/.travis.yml index c0c30316c9a..d00894dd3d6 100644 --- a/.travis.yml +++ b/.travis.yml @@ -29,11 +29,11 @@ install: script: - if [[ $TESTS == true ]]; then source continuous_integration/travis/run_tests.sh ; fi - - if [[ $LINT == true ]]; then pip install flake8 ; flake8 distributed ; fi - - if [[ $LINT == true ]]; then pip install black ; black distributed --check; fi + - if [[ $LINT == true ]]; then python -m pip install flake8 ; flake8 distributed ; fi + - if [[ $LINT == true ]]; then python -m pip install black ; black distributed --check; fi after_success: - - if [[ $COVERAGE == true ]]; then coverage report; pip install -q coveralls ; coveralls ; fi + - if [[ $COVERAGE == true ]]; then coverage report; python -m pip install -q coveralls ; coveralls ; fi notifications: email: false diff --git a/continuous_integration/travis/install.sh b/continuous_integration/travis/install.sh index 8c34f38d276..8eaed19df81 100644 --- a/continuous_integration/travis/install.sh +++ b/continuous_integration/travis/install.sh @@ -58,28 +58,28 @@ conda install -q \ conda install -c defaults -c conda-forge libunwind zstandard asyncssh conda install --no-deps -c defaults -c numba -c conda-forge stacktrace -pip install -q "pytest>=4" pytest-repeat pytest-faulthandler pytest-asyncio +python -m pip install -q "pytest>=4" pytest-repeat pytest-faulthandler pytest-asyncio -pip install -q git+https://github.com/dask/dask.git --upgrade --no-deps -pip install -q git+https://github.com/joblib/joblib.git --upgrade --no-deps -pip install -q git+https://github.com/intake/filesystem_spec.git --upgrade --no-deps -pip install -q git+https://github.com/dask/s3fs.git --upgrade --no-deps -pip install -q git+https://github.com/dask/zict.git --upgrade --no-deps -pip install -q sortedcollections msgpack --no-deps -pip install -q keras --upgrade --no-deps -pip install -q asyncssh +python -m pip install -q git+https://github.com/dask/dask.git --upgrade --no-deps +python -m pip install -q git+https://github.com/joblib/joblib.git --upgrade --no-deps +python -m pip install -q git+https://github.com/intake/filesystem_spec.git --upgrade --no-deps +python -m pip install -q git+https://github.com/dask/s3fs.git --upgrade --no-deps +python -m pip install -q git+https://github.com/dask/zict.git --upgrade --no-deps +python -m pip install -q sortedcollections msgpack --no-deps +python -m pip install -q keras --upgrade --no-deps +python -m pip install -q asyncssh if [[ $CRICK == true ]]; then conda install -q cython - pip install -q git+https://github.com/jcrist/crick.git + python -m pip install -q git+https://github.com/jcrist/crick.git fi; # Install distributed -pip install --no-deps -e . +python -m pip install --no-deps -e . # For debugging echo -e "--\n--Conda Environment\n--" conda list echo -e "--\n--Pip Environment\n--" -pip list --format=columns +python -m pip list --format=columns diff --git a/distributed/client.py b/distributed/client.py index 93f395e2dc4..fee50564963 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -4513,7 +4513,7 @@ class get_task_stream: A common way to do this is to upload the file as a gist, and then serve it on https://raw.githack.com :: - $ pip install gist + $ python -m pip install gist $ gist task-stream.html https://gist.github.com/8a5b3c74b10b413f612bb5e250856ceb diff --git a/distributed/dashboard/proxy.py b/distributed/dashboard/proxy.py index 89f9f87aae6..3e76ba11c0e 100644 --- a/distributed/dashboard/proxy.py +++ b/distributed/dashboard/proxy.py @@ -70,7 +70,7 @@ def proxy(self, port, proxied_path): logger.info( "To route to workers diagnostics web server " "please install jupyter-server-proxy: " - "pip install jupyter-server-proxy" + "python -m pip install jupyter-server-proxy" ) class GlobalProxyHandler(web.RequestHandler): @@ -94,7 +94,7 @@ def get(self, port, host, proxied_path):

         conda install jupyter-server-proxy -c conda-forge 

        -

         pip install jupyter-server-proxy

        +

         python -m pip install jupyter-server-proxy

        The link above should work though if your workers are on a diff --git a/distributed/dashboard/tests/test_scheduler_bokeh.py b/distributed/dashboard/tests/test_scheduler_bokeh.py index 6594ce2142f..4ef90e48b8e 100644 --- a/distributed/dashboard/tests/test_scheduler_bokeh.py +++ b/distributed/dashboard/tests/test_scheduler_bokeh.py @@ -607,7 +607,7 @@ def test_proxy_to_workers(c, s, a, b): if proxy_exists: assert b"Crossfilter" in response_proxy.body else: - assert b"pip install jupyter-server-proxy" in response_proxy.body + assert b"python -m pip install jupyter-server-proxy" in response_proxy.body assert response_direct.code == 200 assert b"Crossfilter" in response_direct.body diff --git a/distributed/protocol/keras.py b/distributed/protocol/keras.py index 020ce1cae3b..121aa0c4700 100644 --- a/distributed/protocol/keras.py +++ b/distributed/protocol/keras.py @@ -9,7 +9,7 @@ def serialize_keras_model(model): if keras.__version__ < "1.2.0": raise ImportError( - "Need Keras >= 1.2.0. Try pip install keras --upgrade --no-deps" + "Need Keras >= 1.2.0. Try python -m pip install keras --upgrade --no-deps" ) header = model._updated_config() diff --git a/distributed/worker.py b/distributed/worker.py index 77711058682..e1ae8317148 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -547,7 +547,9 @@ def __init__( try: from zict import Buffer, File, Func except ImportError: - raise ImportError("Please `pip install zict` for spill-to-disk workers") + raise ImportError( + "Please `python -m pip install zict` for spill-to-disk workers" + ) path = os.path.join(self.local_directory, "storage") storage = Func( partial(serialize_bytelist, on_error="raise"), diff --git a/docs/source/install.rst b/docs/source/install.rst index 7cf4199eecd..db1bf316400 100644 --- a/docs/source/install.rst +++ b/docs/source/install.rst @@ -18,7 +18,7 @@ Pip Or install distributed with ``pip``:: - pip install dask distributed --upgrade + python -m pip install dask distributed --upgrade Source ------ diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index 4437f77a1ea..0172c376746 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -6,7 +6,7 @@ Install :: - $ pip install dask distributed --upgrade + $ python -m pip install dask distributed --upgrade See :doc:`installation ` document for more information. From 812847960a40c2f31bfdb9b4b9af3a79a0c8c443 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 3 Feb 2020 16:57:03 -0800 Subject: [PATCH 0657/1550] Avoid loop= keyword in asyncio coordination primitives (#3437) --- distributed/client.py | 12 ++++++++++-- distributed/pubsub.py | 9 ++++++++- distributed/tests/test_client.py | 14 ++++++++++++++ 3 files changed, 32 insertions(+), 3 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index fee50564963..516185cef23 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -671,7 +671,6 @@ def __init__( self._loop_runner = LoopRunner(loop=loop, asynchronous=asynchronous) self.loop = self._loop_runner.loop - self._gather_semaphore = asyncio.Semaphore(5, loop=self.loop.asyncio_loop) self._gather_keys = None self._gather_future = None @@ -981,6 +980,8 @@ async def _start(self, timeout=no_default, **kwargs): address = self.cluster.scheduler_address + self._gather_semaphore = asyncio.Semaphore(5) + if self.scheduler is None: self.scheduler = self.rpc(address) self.scheduler_comm = None @@ -4206,7 +4207,6 @@ def __init__(self, futures=None, loop=None, with_results=False, raise_errors=Tru self.queue = pyQueue() self.lock = threading.Lock() self.loop = loop or default_client().loop - self.condition = asyncio.Condition(loop=self.loop.asyncio_loop) self.thread_condition = threading.Condition() self.with_results = with_results self.raise_errors = raise_errors @@ -4214,6 +4214,14 @@ def __init__(self, futures=None, loop=None, with_results=False, raise_errors=Tru if futures: self.update(futures) + @property + def condition(self): + try: + return self._condition + except AttributeError: + self._condition = asyncio.Condition() + return self._condition + async def _track_future(self, future): try: await _wait(future) diff --git a/distributed/pubsub.py b/distributed/pubsub.py index ca4e06c44d1..cdff73ffeca 100644 --- a/distributed/pubsub.py +++ b/distributed/pubsub.py @@ -378,7 +378,6 @@ def __init__(self, name, worker=None, client=None): self.loop = self.client.loop self.name = name self.buffer = deque() - self.condition = asyncio.Condition(loop=self.loop.asyncio_loop) if self.worker: pubsub = self.worker.extensions["pubsub"] @@ -396,6 +395,14 @@ def __init__(self, name, worker=None, client=None): weakref.finalize(self, pubsub.trigger_cleanup) + @property + def condition(self): + try: + return self._condition + except AttributeError: + self._condition = asyncio.Condition() + return self._condition + async def _get(self, timeout=None): start = time() while not self.buffer: diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 16e660492f8..ae03edd4faa 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -5920,3 +5920,17 @@ async def test_performance_report(c, s, a, b): assert "bokeh" in data assert "random" in data assert "Dask Performance Report" in data + + +@pytest.mark.asyncio +async def test_client_gather_semaphor_loop(cleanup): + async with Scheduler(port=0) as s: + async with Client(s.address, asynchronous=True) as c: + assert c._gather_semaphore._loop is c.loop.asyncio_loop + + +@gen_cluster(client=True) +def test_as_completed_condition_loop(c, s, a, b): + seq = c.map(inc, range(5)) + ac = as_completed(seq) + assert ac.condition._loop == c.loop.asyncio_loop From d134345c5c026c7472fbc0dcb0ce907ac2075e05 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Tue, 4 Feb 2020 00:19:34 -0600 Subject: [PATCH 0658/1550] Ensure scheduler updates task and worker states after successful worker data deletion (#3401) --- distributed/scheduler.py | 55 +++++++++++++++++++--------------------- 1 file changed, 26 insertions(+), 29 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 89c938f0dd5..8e5fdd17a34 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -2952,6 +2952,30 @@ async def proxy(self, comm=None, msg=None, worker=None, serializers=None): ) return d[worker] + async def _delete_worker_data(self, worker_address, keys): + """ Delete data from a worker and update the corresponding worker/task states + + Parameters + ---------- + worker_address: str + Worker address to delete keys from + keys: List[str] + List of keys to delete on the specified worker + """ + await retry_operation( + self.rpc(addr=worker_address).delete_data, keys=list(keys), report=False, + ) + + ws = self.workers[worker_address] + tasks = {self.tasks[key] for key in keys} + ws.has_what -= tasks + for ts in tasks: + ts.who_has.remove(ws) + ws.nbytes -= ts.get_nbytes() + self.log_event( + ws.address, {"action": "remove-worker-data", "keys": keys}, + ) + async def rebalance(self, comm=None, keys=None, workers=None): """ Rebalance keys so that each worker stores roughly equal bytes @@ -3068,19 +3092,9 @@ async def rebalance(self, comm=None, keys=None, workers=None): ) await asyncio.gather( - *( - retry_operation( - self.rpc(addr=r).delete_data, keys=v, report=False - ) - for r, v in to_senders.items() - ) + *(self._delete_worker_data(r, v) for r, v in to_senders.items()) ) - for sender, recipient, ts in msgs: - ts.who_has.remove(sender) - sender.has_what.remove(ts) - sender.nbytes -= ts.get_nbytes() - return {"status": "OK"} async def replicate( @@ -3142,28 +3156,11 @@ async def replicate( await asyncio.gather( *( - retry_operation( - self.rpc(addr=ws.address).delete_data, - keys=[ts.key for ts in tasks], - report=False, - ) + self._delete_worker_data(ws.address, [t.key for t in tasks]) for ws, tasks in del_worker_tasks.items() ) ) - for ws, tasks in del_worker_tasks.items(): - ws.has_what -= tasks - for ts in tasks: - ts.who_has.remove(ws) - ws.nbytes -= ts.get_nbytes() - self.log_event( - ws.address, - { - "action": "replicate-remove", - "keys": [ts.key for ts in tasks], - }, - ) - # Copy not-yet-filled data while tasks: gathers = defaultdict(dict) From 3d454d719cda99e2fc0808d820a1a02e96a6260d Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Tue, 4 Feb 2020 12:01:12 -0600 Subject: [PATCH 0659/1550] Update worker_kwargs description in LocalCluster constructor [skip ci] (#3438) --- distributed/deploy/local.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/distributed/deploy/local.py b/distributed/deploy/local.py index 8eb55c54997..d1744ed32c0 100644 --- a/distributed/deploy/local.py +++ b/distributed/deploy/local.py @@ -49,8 +49,6 @@ class LocalCluster(SpecCluster): asynchronous: bool (False by default) Set to True if using this cluster within async/await functions or within Tornado gen.coroutines. This should remain False for normal use. - worker_kwargs: dict - Extra worker arguments, will be passed to the Worker constructor. blocked_handlers: List[str] A list of strings specifying a blacklist of handlers to disallow on the Scheduler, like ``['feed', 'run_function']`` @@ -68,6 +66,9 @@ class LocalCluster(SpecCluster): Network interface to use. Defaults to lo/localhost worker_class: Worker Worker class used to instantiate workers from. + **worker_kwargs: + Extra worker arguments. Any additional keyword arguments will be passed + to the ``Worker`` class constructor. Examples -------- From f87e802102da07fce7c9958e5926d18a143b0fd8 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Wed, 5 Feb 2020 12:05:00 +0100 Subject: [PATCH 0660/1550] Checks for command parameters in ssh2 (#3078) * Allow specification of worker type in SSHCLuster * Default the worker_module to None and check further down * Avoid ssh2 parameters from getting overwritten by superclass * Check for a command's parameters in cli_keywords * Pass ssh2's worker_module to cli_keywords * Removed duplicate worker_module from ssh2 This was a result of a merge conflict that passed unnoticed * Fix utils.py formatting * Add tests for cli_keywords command options * Add new command_has_keyword function, simplify cli_keywords * Avoid ssh attributes from getting overwritten by superclass * Pass ssh command parameters to cli_keywords * Add test for cli_keywords command options * Remove check for non-Worker class attribute "nprocs" The "nprocs" is a command-line argument, and not an attribute from the Worker class, thus it cannot be asserted for * Improve documentation for cmd= in cli_keywords Co-authored-by: Jacob Tomlinson --- distributed/deploy/ssh.py | 10 ++--- distributed/deploy/tests/test_ssh.py | 25 ++++++++++++- distributed/utils.py | 55 +++++++++++++++++++++++++--- 3 files changed, 78 insertions(+), 12 deletions(-) diff --git a/distributed/deploy/ssh.py b/distributed/deploy/ssh.py index 1f49f187a14..a7b3526bcba 100644 --- a/distributed/deploy/ssh.py +++ b/distributed/deploy/ssh.py @@ -73,6 +73,8 @@ def __init__( loop=None, name=None, ): + super().__init__() + self.address = address self.scheduler = scheduler self.worker_module = worker_module @@ -80,8 +82,6 @@ def __init__( self.kwargs = kwargs self.name = name - super().__init__() - async def start(self): import asyncssh # import now to avoid adding to module startup time @@ -98,7 +98,7 @@ async def start(self): "--name", str(self.name), ] - + cli_keywords(self.kwargs, cls=_Worker) + + cli_keywords(self.kwargs, cls=_Worker, cmd=self.worker_module) ) ) @@ -131,12 +131,12 @@ class Scheduler(Process): """ def __init__(self, address: str, connect_options: dict, kwargs: dict): + super().__init__() + self.address = address self.kwargs = kwargs self.connect_options = connect_options - super().__init__() - async def start(self): import asyncssh # import now to avoid adding to module startup time diff --git a/distributed/deploy/tests/test_ssh.py b/distributed/deploy/tests/test_ssh.py index 376b0eae3a4..eff7bf05a11 100644 --- a/distributed/deploy/tests/test_ssh.py +++ b/distributed/deploy/tests/test_ssh.py @@ -32,7 +32,12 @@ async def test_keywords(): ["127.0.0.1"] * 3, connect_options=dict(known_hosts=None), asynchronous=True, - worker_options={"nthreads": 2, "memory_limit": "2 GiB", "death_timeout": "5s"}, + worker_options={ + "nprocs": 2, # nprocs checks custom arguments with cli_keywords + "nthreads": 2, + "memory_limit": "2 GiB", + "death_timeout": "5s", + }, scheduler_options={"idle_timeout": "5s", "port": 0}, ) as cluster: async with Client(cluster, asynchronous=True) as client: @@ -74,3 +79,21 @@ def f(x): async with Client(cluster, asynchronous=True) as client: result = await client.submit(f, 1) assert result == 101 + + +@pytest.mark.asyncio +async def test_unimplemented_options(): + with pytest.raises(Exception): + async with SSHCluster( + ["127.0.0.1"] * 3, + connect_kwargs=dict(known_hosts=None), + asynchronous=True, + worker_kwargs={ + "nthreads": 2, + "memory_limit": "2 GiB", + "death_timeout": "5s", + "unimplemented_option": 2, + }, + scheduler_kwargs={"idle_timeout": "5s", "port": 0}, + ) as cluster: + assert cluster diff --git a/distributed/utils.py b/distributed/utils.py index 086555643ea..09c9c62cd1e 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -1,6 +1,7 @@ import asyncio from asyncio import TimeoutError import atexit +import click from collections import deque, OrderedDict, UserDict from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager @@ -1223,6 +1224,32 @@ def has_keyword(func, keyword): return keyword in inspect.signature(func).parameters +@functools.lru_cache(1000) +def command_has_keyword(cmd, k): + if cmd is not None: + if isinstance(cmd, str): + try: + from importlib import import_module + + cmd = import_module(cmd) + except ImportError: + raise ImportError("Module for command %s is not available" % cmd) + + if isinstance(getattr(cmd, "main"), click.core.Command): + cmd = cmd.main + if isinstance(cmd, click.core.Command): + cmd_params = set( + [ + p.human_readable_name + for p in cmd.params + if isinstance(p, click.core.Option) + ] + ) + return k in cmd_params + + return False + + # from bokeh.palettes import viridis # palette = viridis(18) palette = [ @@ -1324,7 +1351,7 @@ def _repr_html_(self): return "\n".join(summaries) -def cli_keywords(d: dict, cls=None): +def cli_keywords(d: dict, cls=None, cmd=None): """ Convert a kwargs dictionary into a list of CLI keywords Parameters @@ -1333,6 +1360,12 @@ def cli_keywords(d: dict, cls=None): The keywords to convert cls: callable The callable that consumes these terms to check them for validity + cmd: string or object + A string with the name of a module, or the module containing a + click-generated command with a "main" function, or the function itself. + It may be used to parse a module's custom arguments (i.e., arguments that + are not part of Worker class), such as nprocs from dask-worker CLI or + enable_nvlink from dask-cuda-worker CLI. Examples -------- @@ -1345,12 +1378,22 @@ def cli_keywords(d: dict, cls=None): ... ValueError: Class distributed.worker.Worker does not support keyword x """ - if cls: + if cls or cmd: for k in d: - if not has_keyword(cls, k): - raise ValueError( - "Class %s does not support keyword %s" % (typename(cls), k) - ) + if not has_keyword(cls, k) and not command_has_keyword(cmd, k): + if cls and cmd: + raise ValueError( + "Neither class %s or module %s support keyword %s" + % (typename(cls), typename(cmd), k) + ) + elif cls: + raise ValueError( + "Class %s does not support keyword %s" % (typename(cls), k) + ) + else: + raise ValueError( + "Module %s does not support keyword %s" % (typename(cmd), k) + ) def convert_value(v): out = str(v) From ef5317feeae72cd23dd25f22ce278a3a91212e49 Mon Sep 17 00:00:00 2001 From: jakirkham Date: Wed, 5 Feb 2020 07:45:12 -0800 Subject: [PATCH 0661/1550] Fix name of Numba serialization test (#3447) --- distributed/protocol/tests/test_numba.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/protocol/tests/test_numba.py b/distributed/protocol/tests/test_numba.py index 69ea73310d4..78e2b1859b4 100644 --- a/distributed/protocol/tests/test_numba.py +++ b/distributed/protocol/tests/test_numba.py @@ -6,7 +6,7 @@ @pytest.mark.parametrize("dtype", ["u1", "u4", "u8", "f4"]) -def test_serialize_cupy(dtype): +def test_serialize_numba(dtype): if not cuda.is_available(): pytest.skip("CUDA is not available") From 88c354f68761b60c11ba5613806891f78810da0c Mon Sep 17 00:00:00 2001 From: jakirkham Date: Wed, 5 Feb 2020 10:58:31 -0800 Subject: [PATCH 0662/1550] Adjust `numba.cuda` import and add check (#3446) * Import `numba.cuda` instead of just `numba` Appears that an error will be raised when accessing `numba.cuda` unless `numba.cuda` is imported as well. So go ahead and import `numba.cuda` too. * Check that CUDA is available in CuPy/Numba test --- distributed/protocol/tests/test_cupy.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/distributed/protocol/tests/test_cupy.py b/distributed/protocol/tests/test_cupy.py index 57d26ae679b..f7feb4da5e2 100644 --- a/distributed/protocol/tests/test_cupy.py +++ b/distributed/protocol/tests/test_cupy.py @@ -17,12 +17,15 @@ def test_serialize_cupy(size, dtype): @pytest.mark.parametrize("dtype", ["u1", "u4", "u8", "f4"]) def test_serialize_cupy_from_numba(dtype): - numba = pytest.importorskip("numba") + cuda = pytest.importorskip("numba.cuda") np = pytest.importorskip("numpy") + if not cuda.is_available(): + pytest.skip("CUDA is not available") + size = 10 x_np = np.arange(size, dtype=dtype) - x = numba.cuda.to_device(x_np) + x = cuda.to_device(x_np) header, frames = serialize(x, serializers=("cuda", "dask", "pickle")) header["type-serialized"] = pickle.dumps(cupy.ndarray) From 105d040fb9cabfa3581dcdce73165c77cb22de9a Mon Sep 17 00:00:00 2001 From: Alex Adamson Date: Wed, 5 Feb 2020 17:07:15 -0500 Subject: [PATCH 0663/1550] Ensure __causes__s of exceptions raised on workers are serialized (#3430) --- distributed/core.py | 10 ++++++++++ distributed/tests/test_worker.py | 26 ++++++++++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/distributed/core.py b/distributed/core.py index 78dd618e8e2..ac9be6728fc 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -10,6 +10,7 @@ import weakref import dask +import tblib from toolz import merge from tornado import gen from tornado.ioloop import IOLoop @@ -981,6 +982,14 @@ def coerce_to_address(o): return normalize_address(o) +def collect_causes(e): + causes = [] + while e.__cause__ is not None: + causes.append(e.__cause__) + e = e.__cause__ + return causes + + def error_message(e, status="error"): """ Produce message to send back given an exception has occurred @@ -997,6 +1006,7 @@ def error_message(e, status="error"): clean_exception: deserialize and unpack message into exception/traceback """ MAX_ERROR_LEN = dask.config.get("distributed.admin.max-error-length") + tblib.pickling_support.install(e, *collect_causes(e)) tb = get_traceback() e2 = truncate_exception(e, MAX_ERROR_LEN) try: diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 6c1a0805817..55dc7faf417 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -372,6 +372,32 @@ def __str__(self): assert len(msg["text"]) > 10100 # default + 100 +@gen_cluster(client=True) +def test_chained_error_message(c, s, a, b): + def chained_exception_fn(): + class MyException(Exception): + def __init__(self, msg): + self.msg = msg + + def __str__(self): + return "MyException(%s)" % self.msg + + exception = MyException("Foo") + inner_exception = MyException("Bar") + + try: + raise inner_exception + except Exception as e: + raise exception from e + + f = c.submit(chained_exception_fn) + try: + yield f + except Exception as e: + assert e.__cause__ is not None + assert "Bar" in str(e.__cause__) + + @gen_cluster() def test_gather(s, a, b): b.data["x"] = 1 From 9af811d8f9858c63b9586bcfb78ce2dec8f5d6b3 Mon Sep 17 00:00:00 2001 From: jakirkham Date: Thu, 6 Feb 2020 21:19:02 -0800 Subject: [PATCH 0664/1550] Rerun `black` on the code base (#3444) --- distributed/scheduler.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 8e5fdd17a34..5cfc2d7b9cb 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -2963,7 +2963,7 @@ async def _delete_worker_data(self, worker_address, keys): List of keys to delete on the specified worker """ await retry_operation( - self.rpc(addr=worker_address).delete_data, keys=list(keys), report=False, + self.rpc(addr=worker_address).delete_data, keys=list(keys), report=False ) ws = self.workers[worker_address] @@ -2972,9 +2972,7 @@ async def _delete_worker_data(self, worker_address, keys): for ts in tasks: ts.who_has.remove(ws) ws.nbytes -= ts.get_nbytes() - self.log_event( - ws.address, {"action": "remove-worker-data", "keys": keys}, - ) + self.log_event(ws.address, {"action": "remove-worker-data", "keys": keys}) async def rebalance(self, comm=None, keys=None, workers=None): """ Rebalance keys so that each worker stores roughly equal bytes From 7bb23d5f3c586afdd5e03d9754ebe194b8491bb0 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Fri, 7 Feb 2020 12:56:24 -0600 Subject: [PATCH 0665/1550] Always add new TaskGroup to TaskPrefix (#3322) --- distributed/dashboard/scheduler.py | 6 ++- .../tests/test_scheduler_bokeh_html.py | 2 +- distributed/scheduler.py | 28 +++++++----- distributed/tests/test_scheduler.py | 45 ++++++++++++++++++- distributed/utils.py | 4 +- 5 files changed, 69 insertions(+), 16 deletions(-) diff --git a/distributed/dashboard/scheduler.py b/distributed/dashboard/scheduler.py index a030ba434f7..17e150e8df9 100644 --- a/distributed/dashboard/scheduler.py +++ b/distributed/dashboard/scheduler.py @@ -292,7 +292,11 @@ def collect(self): yield CounterMetricFamily( "dask_scheduler_tasks_forgotten", - "Total number of processed tasks no longer in memory and already removed from the scheduler job queue.", + ( + "Total number of processed tasks no longer in memory and already " + "removed from the scheduler job queue. Note task groups on the " + "scheduler which have all tasks in the forgotten state are not included." + ), value=task_counter.get("forgotten", 0.0), ) diff --git a/distributed/dashboard/tests/test_scheduler_bokeh_html.py b/distributed/dashboard/tests/test_scheduler_bokeh_html.py index 39da4730a28..de71b12a0d1 100644 --- a/distributed/dashboard/tests/test_scheduler_bokeh_html.py +++ b/distributed/dashboard/tests/test_scheduler_bokeh_html.py @@ -171,7 +171,7 @@ async def fetch_metrics(): active_metrics, forgotten_tasks = await fetch_metrics() assert active_metrics.keys() == expected assert sum(active_metrics.values()) == 0.0 - assert sum(forgotten_tasks) == 1.0 + assert sum(forgotten_tasks) == 0.0 @gen_cluster( diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 5cfc2d7b9cb..9dcaae16397 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -835,7 +835,7 @@ def duration(self): @property def types(self): - return set.union(*[tg.types for tg in self.groups]) + return set().union(*[tg.types for tg in self.groups]) class _StateLegacyMapping(Mapping): @@ -1976,20 +1976,21 @@ def new_task(self, key, spec, state): """ Create a new task, and associated states """ ts = TaskState(key, spec) ts._state = state - try: - tg = self.task_groups[ts.group_key] - except KeyError: - tg = self.task_groups[ts.group_key] = TaskGroup(ts.group_key) - tg.add(ts) prefix_key = key_split(key) try: tp = self.task_prefixes[prefix_key] except KeyError: - tp = TaskPrefix(prefix_key) - tp.groups.append(tg) - self.task_prefixes[prefix_key] = tp + tp = self.task_prefixes[prefix_key] = TaskPrefix(prefix_key) ts.prefix = tp - tg.prefix = tp + + group_key = ts.group_key + try: + tg = self.task_groups[group_key] + except KeyError: + tg = self.task_groups[group_key] = TaskGroup(group_key) + tg.prefix = tp + tp.groups.append(tg) + tg.add(ts) self.tasks[key] = ts return ts @@ -4642,6 +4643,13 @@ def transition(self, key, finish, *args, **kwargs): if ts.state == "forgotten": del self.tasks[ts.key] + if ts.state == "forgotten": + # Remove TaskGroup if all tasks are in the forgotten state + tg = ts.group + if not any(tg.states.get(s) for s in ALL_TASK_STATES): + ts.prefix.groups.remove(tg) + del self.task_groups[tg.name] + return recommendations except Exception as e: logger.exception("Error transitioning %r from %r to %r", key, start, finish) diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 3ce681ea546..5c4d8cbc23e 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -1774,6 +1774,8 @@ async def test_task_groups(c, s, a, b): await c.replicate(y) assert tg.nbytes_in_memory == y.nbytes + assert "array" in str(tg.types) + assert "array" in str(tp.types) del y @@ -1782,8 +1784,9 @@ async def test_task_groups(c, s, a, b): assert tg.nbytes_in_memory == 0 assert tg.states["forgotten"] == 5 - assert "array" in str(tg.types) - assert "array" in str(tp.types) + # Ensure TaskGroup is removed once all tasks are in forgotten state + assert tg.name not in s.task_groups + assert sys.getrefcount(tg) == 2 @gen_cluster(client=True) @@ -1795,6 +1798,44 @@ async def test_task_prefix(c, s, a, b): assert s.task_prefixes["sum-aggregate"].states["memory"] == 1 + a = da.arange(101, chunks=(20,)) + b = (a + 1).sum().persist() + b = await b + + assert s.task_prefixes["sum-aggregate"].states["memory"] == 2 + + +@gen_cluster(client=True) +async def test_task_group_non_tuple_key(c, s, a, b): + da = pytest.importorskip("dask.array") + np = pytest.importorskip("numpy") + x = da.arange(100, chunks=(20,)) + y = (x + 1).sum().persist() + y = await y + + assert s.task_prefixes["sum"].states["released"] == 4 + assert "sum" not in s.task_groups + + f = c.submit(np.sum, [1, 2, 3]) + await f + + assert s.task_prefixes["sum"].states["released"] == 4 + assert s.task_prefixes["sum"].states["memory"] == 1 + assert "sum" in s.task_groups + + +@gen_cluster(client=True) +async def test_task_unique_groups(c, s, a, b): + """ This test ensure that task groups remain unique when using submit + """ + x = c.submit(sum, [1, 2]) + y = c.submit(len, [1, 2]) + z = c.submit(sum, [3, 4]) + await asyncio.wait([x, y, z]) + + assert s.task_prefixes["len"].states["memory"] == 1 + assert s.task_prefixes["sum"].states["memory"] == 2 + class BrokenComm(Comm): peer_address = None diff --git a/distributed/utils.py b/distributed/utils.py index 09c9c62cd1e..a771a3b280d 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -650,11 +650,11 @@ def key_split_group(x): elif x[0] == "<": return x.strip("<>").split()[0].split(".")[-1] else: - return "" + return key_split(x) elif typ is bytes: return key_split_group(x.decode()) else: - return "" + return key_split(x) @contextmanager From fdf1ece3081f716f880e6b1aa7765211a40821af Mon Sep 17 00:00:00 2001 From: jakirkham Date: Fri, 7 Feb 2020 11:54:22 -0800 Subject: [PATCH 0666/1550] Support serializing/deserializing `rmm.DeviceBuffer`s (#3442) * Serialize and deserialize RMM `DeviceBuffer`'s * Register RMM serializers and deserializers * Test RMM serialization/deserialization * Test serializing RMM `DeviceBuffer` from Numba * Test deserializing a CuPy array from RMM data * Test deserializing a Numba array from RMM data * Fix some minor formatting issues * Drop unneeded `cuda.as_cuda_array` call This should already be deserialized as a Numba array. So go ahead and treat it as such. * Use namespace with `DeviceBuffer` * Assert that `arr` is a `rmm.DeviceBuffer` If we got here, we should already have an `rmm.DeviceBuffer` as that is what we would have allocated. If that's not the case, something very wrong has happened. So just `assert` that is true (erroring otherwise). * Drop unneeded Numba-backed RMM serialization test As RMM is used preferentially for allocations when available, there shouldn't be a case where we need to serialize an RMM `DeviceBuffer` from a Numba array. So drop this test. * Drop unused import --- distributed/protocol/__init__.py | 6 ++++++ distributed/protocol/rmm.py | 23 +++++++++++++++++++++++ distributed/protocol/tests/test_cupy.py | 20 ++++++++++++++++++++ distributed/protocol/tests/test_numba.py | 24 ++++++++++++++++++++++++ distributed/protocol/tests/test_rmm.py | 22 ++++++++++++++++++++++ 5 files changed, 95 insertions(+) create mode 100644 distributed/protocol/rmm.py create mode 100644 distributed/protocol/tests/test_rmm.py diff --git a/distributed/protocol/__init__.py b/distributed/protocol/__init__.py index 30ae3935498..84ee9420c78 100644 --- a/distributed/protocol/__init__.py +++ b/distributed/protocol/__init__.py @@ -80,6 +80,12 @@ def _register_numba(): from . import numba +@cuda_serialize.register_lazy("rmm") +@cuda_deserialize.register_lazy("rmm") +def _register_rmm(): + from . import rmm + + @cuda_serialize.register_lazy("cudf") @cuda_deserialize.register_lazy("cudf") def _register_cudf(): diff --git a/distributed/protocol/rmm.py b/distributed/protocol/rmm.py new file mode 100644 index 00000000000..cdf22f8218f --- /dev/null +++ b/distributed/protocol/rmm.py @@ -0,0 +1,23 @@ +import rmm +from .cuda import cuda_serialize, cuda_deserialize + + +# Used for RMM 0.11.0+ otherwise Numba serializers used +if hasattr(rmm, "DeviceBuffer"): + + @cuda_serialize.register(rmm.DeviceBuffer) + def serialize_rmm_device_buffer(x): + header = x.__cuda_array_interface__.copy() + frames = [x] + return header, frames + + @cuda_deserialize.register(rmm.DeviceBuffer) + def deserialize_rmm_device_buffer(header, frames): + (arr,) = frames + + # We should already have `DeviceBuffer` + # as RMM is used preferably for allocations + # when it is available (as it is here). + assert isinstance(arr, rmm.DeviceBuffer) + + return arr diff --git a/distributed/protocol/tests/test_cupy.py b/distributed/protocol/tests/test_cupy.py index f7feb4da5e2..d2965d3af3f 100644 --- a/distributed/protocol/tests/test_cupy.py +++ b/distributed/protocol/tests/test_cupy.py @@ -32,3 +32,23 @@ def test_serialize_cupy_from_numba(dtype): y = deserialize(header, frames, deserializers=("cuda", "dask", "pickle", "error")) assert (x_np == cupy.asnumpy(y)).all() + + +@pytest.mark.parametrize("size", [0, 3, 10]) +def test_serialize_cupy_from_rmm(size): + np = pytest.importorskip("numpy") + rmm = pytest.importorskip("rmm") + + x_np = np.arange(size, dtype="u1") + + x_np_desc = x_np.__array_interface__ + (x_np_ptr, _) = x_np_desc["data"] + (x_np_size,) = x_np_desc["shape"] + x = rmm.DeviceBuffer(ptr=x_np_ptr, size=x_np_size) + + header, frames = serialize(x, serializers=("cuda", "dask", "pickle")) + header["type-serialized"] = pickle.dumps(cupy.ndarray) + + y = deserialize(header, frames, deserializers=("cuda", "dask", "pickle", "error")) + + assert (x_np == cupy.asnumpy(y)).all() diff --git a/distributed/protocol/tests/test_numba.py b/distributed/protocol/tests/test_numba.py index 78e2b1859b4..21b722fb1b0 100644 --- a/distributed/protocol/tests/test_numba.py +++ b/distributed/protocol/tests/test_numba.py @@ -1,4 +1,5 @@ from distributed.protocol import serialize, deserialize +import pickle import pytest cuda = pytest.importorskip("numba.cuda") @@ -20,3 +21,26 @@ def test_serialize_numba(dtype): x.copy_to_host(hx) y.copy_to_host(hy) assert (hx == hy).all() + + +@pytest.mark.parametrize("size", [0, 3, 10]) +def test_serialize_numba_from_rmm(size): + np = pytest.importorskip("numpy") + rmm = pytest.importorskip("rmm") + + if not cuda.is_available(): + pytest.skip("CUDA is not available") + + x_np = np.arange(size, dtype="u1") + + x_np_desc = x_np.__array_interface__ + (x_np_ptr, _) = x_np_desc["data"] + (x_np_size,) = x_np_desc["shape"] + x = rmm.DeviceBuffer(ptr=x_np_ptr, size=x_np_size) + + header, frames = serialize(x, serializers=("cuda", "dask", "pickle")) + header["type-serialized"] = pickle.dumps(cuda.devicearray.DeviceNDArray) + + y = deserialize(header, frames, deserializers=("cuda", "dask", "pickle", "error")) + + assert (x_np == y.copy_to_host()).all() diff --git a/distributed/protocol/tests/test_rmm.py b/distributed/protocol/tests/test_rmm.py new file mode 100644 index 00000000000..eff3325289e --- /dev/null +++ b/distributed/protocol/tests/test_rmm.py @@ -0,0 +1,22 @@ +from distributed.protocol import serialize, deserialize +import pytest + +numpy = pytest.importorskip("numpy") +cuda = pytest.importorskip("numba.cuda") +rmm = pytest.importorskip("rmm") + + +@pytest.mark.parametrize("size", [0, 3, 10]) +def test_serialize_rmm_device_buffer(size): + if not hasattr(rmm, "DeviceBuffer"): + pytest.skip("RMM pre-0.11.0 does not have DeviceBuffer") + + x_np = numpy.arange(size, dtype="u1") + x = rmm.DeviceBuffer(size=size) + cuda.to_device(x_np, to=cuda.as_cuda_array(x)) + + header, frames = serialize(x, serializers=("cuda", "dask", "pickle")) + y = deserialize(header, frames, deserializers=("cuda", "dask", "pickle", "error")) + y_np = y.copy_to_host() + + assert (x_np == y_np).all() From 2a1ed3831c69c6acd256336c6f439f15d77fee95 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sun, 9 Feb 2020 16:27:46 -0800 Subject: [PATCH 0667/1550] Skip test_open_close_many_workers on Python 3.6 (#3459) This test has caused intermittent failures in the past. Previously we had marked it as xfail, but it would still cause CI to break because it would cause things to hang. In #3419 it was observed that the failure seems to only occur on Python 3.6. This commit changes the universal xfail to a skipif for Python 3.6 and below. We still don't know what causes the failure, other than that GC seems to take up all of the CPU time Fixes #3419 --- distributed/tests/test_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index ae03edd4faa..9879c9ff408 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -3581,7 +3581,7 @@ def test_reconnect_timeout(c, s): @pytest.mark.slow @pytest.mark.skipif(WINDOWS, reason="num_fds not supported on windows") -@pytest.mark.xfail(reason="TODO: intermittent failures") +@pytest.mark.skipif(sys.version_info < (3, 7), reason="TODO: intermittent failures") @pytest.mark.parametrize("worker,count,repeat", [(Worker, 100, 5), (Nanny, 10, 20)]) def test_open_close_many_workers(loop, worker, count, repeat): psutil = pytest.importorskip("psutil") From 0af1f4ff4f7c49cb365045011e0f565c0516b1e1 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 10 Feb 2020 07:21:27 -0800 Subject: [PATCH 0668/1550] Include code and summary in performance report (#3462) --- distributed/client.py | 17 +++++++++----- distributed/scheduler.py | 38 ++++++++++++++++++++++++++++++-- distributed/tests/test_client.py | 1 + 3 files changed, 49 insertions(+), 7 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 516185cef23..36c487dc334 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -9,7 +9,7 @@ import errno from functools import partial import html -from inspect import isawaitable +import inspect import itertools import json import logging @@ -1189,7 +1189,7 @@ async def _handle_report(self): try: handler = self._stream_handlers[op] result = handler(**msg) - if isawaitable(result): + if inspect.isawaitable(result): await result except Exception as e: logger.exception(e) @@ -4593,8 +4593,13 @@ async def __aenter__(self): self.start = time() await get_client().get_task_stream(start=0, stop=0) # ensure plugin - async def __aexit__(self, typ, value, traceback): - data = await get_client().scheduler.performance_report(start=self.start) + async def __aexit__(self, typ, value, traceback, code=None): + if not code: + frame = inspect.currentframe().f_back + code = inspect.getsource(frame) + data = await get_client().scheduler.performance_report( + start=self.start, code=code + ) with open(self.filename, "w") as f: f.write(data) @@ -4602,7 +4607,9 @@ def __enter__(self): get_client().sync(self.__aenter__) def __exit__(self, typ, value, traceback): - get_client().sync(self.__aexit__, type, value, traceback) + frame = inspect.currentframe().f_back + code = inspect.getsource(frame) + get_client().sync(self.__aexit__, type, value, traceback, code=code) @contextmanager diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 9dcaae16397..b332b92349d 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -59,6 +59,8 @@ key_split_group, empty_context, tmpfile, + format_bytes, + format_time, TimeoutError, ) from .utils_comm import scatter_to_workers, gather_from_workers, retry_operation @@ -4982,7 +4984,7 @@ async def get_profile_metadata( return {"counts": counts, "keys": keys} - async def performance_report(self, comm=None, start=None): + async def performance_report(self, comm=None, start=None, code=""): # Profiles compute, scheduler, workers = await asyncio.gather( *[ @@ -5021,8 +5023,39 @@ def profile_to_figure(state): bandwidth_types = BandwidthTypes(self, sizing_mode="stretch_both") bandwidth_types.update() - from bokeh.models import Panel, Tabs + from bokeh.models import Panel, Tabs, Div + # HTML + html = """ +

        Dask Performance Report

        + + Select different tabs on the top for additional information + +

        Duration: {time}

        + +

        Scheduler Information

        +
          +
        • Address: {address}
        • +
        • Workers: {nworkers}
        • +
        • Threads: {threads}
        • +
        • Memory: {memory}
        • +
        + +

        Calling Code

        +
        +{code}
        +        
        + """.format( + time=format_time(time() - start), + address=self.address, + nworkers=len(self.workers), + threads=sum(w.nthreads for w in self.workers.values()), + memory=format_bytes(sum(w.memory_limit for w in self.workers.values())), + code=code, + ) + html = Div(text=html) + + html = Panel(child=html, title="Summary") compute = Panel(child=compute, title="Worker Profile (compute)") workers = Panel(child=workers, title="Worker Profile (administrative)") scheduler = Panel(child=scheduler, title="Scheduler Profile (administrative)") @@ -5034,6 +5067,7 @@ def profile_to_figure(state): tabs = Tabs( tabs=[ + html, task_stream, compute, workers, diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 9879c9ff408..8abbd89386d 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -5920,6 +5920,7 @@ async def test_performance_report(c, s, a, b): assert "bokeh" in data assert "random" in data assert "Dask Performance Report" in data + assert "x = da.random" in data @pytest.mark.asyncio From 9d79de1abf28e3784ce03f6622fd07ab1131e0be Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Mon, 10 Feb 2020 09:36:11 -0600 Subject: [PATCH 0669/1550] Workaround RecursionError on profile data (#3455) --- distributed/comm/utils.py | 8 +++++++- distributed/protocol/tests/test_serialize.py | 15 +++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/distributed/comm/utils.py b/distributed/comm/utils.py index 4862aace207..b75663a14f2 100644 --- a/distributed/comm/utils.py +++ b/distributed/comm/utils.py @@ -1,4 +1,5 @@ import logging +import math import socket import dask @@ -37,7 +38,12 @@ def _to_frames(): logger.exception(e) raise - if FRAME_OFFLOAD_THRESHOLD and sizeof(msg) > FRAME_OFFLOAD_THRESHOLD: + try: + msg_size = sizeof(msg) + except RecursionError: + msg_size = math.inf + + if FRAME_OFFLOAD_THRESHOLD and msg_size > FRAME_OFFLOAD_THRESHOLD: return await offload(_to_frames) else: return _to_frames() diff --git a/distributed/protocol/tests/test_serialize.py b/distributed/protocol/tests/test_serialize.py index b5a202f1520..caf1bbe0ad5 100644 --- a/distributed/protocol/tests/test_serialize.py +++ b/distributed/protocol/tests/test_serialize.py @@ -359,3 +359,18 @@ def dumps(f): deserialize(*serialize(Foo())) assert "Hello-123" in str(info.value) + + +@pytest.mark.asyncio +async def test_profile_nested_sizeof(): + # https://github.com/dask/distributed/issues/1674 + n = 500 + original = outer = {} + inner = {} + + for i in range(n): + outer["children"] = inner + outer, inner = inner, {} + + msg = {"data": original} + frames = await to_frames(msg) From e9cdc9e9c2ac4fb04bb2f9cf79d88a85043d4650 Mon Sep 17 00:00:00 2001 From: Julia Signell Date: Mon, 10 Feb 2020 17:14:04 -0500 Subject: [PATCH 0670/1550] Add total row to workers plot in dashboard (#3464) --- distributed/dashboard/components/scheduler.py | 15 +++++++++++++-- .../dashboard/tests/test_scheduler_bokeh.py | 11 +++++++---- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/distributed/dashboard/components/scheduler.py b/distributed/dashboard/components/scheduler.py index 2c9953e97e3..c70e41ca436 100644 --- a/distributed/dashboard/components/scheduler.py +++ b/distributed/dashboard/components/scheduler.py @@ -1556,7 +1556,7 @@ def __init__(self, scheduler, width=800, **kwargs): point_policy="follow_mouse", tooltips="""
        - @worker: + Worker (@name): @memory_percent
        """, @@ -1585,7 +1585,7 @@ def __init__(self, scheduler, width=800, **kwargs): point_policy="follow_mouse", tooltips="""
        - @worker: + Worker (@name): @cpu
        """, @@ -1641,6 +1641,17 @@ def update(self): data["cpu_fraction"][-1] = ws.metrics["cpu"] / 100.0 / ws.nthreads data["nthreads"][-1] = ws.nthreads + for name in self.names + self.extra_names: + if name == "name": + data[name].insert( + 0, "Total ({nworkers})".format(nworkers=len(data[name])) + ) + continue + try: + data[name].insert(0, sum(data[name])) + except TypeError: + data[name].insert(0, None) + self.source.data.update(data) diff --git a/distributed/dashboard/tests/test_scheduler_bokeh.py b/distributed/dashboard/tests/test_scheduler_bokeh.py index 4ef90e48b8e..4977ee8fa76 100644 --- a/distributed/dashboard/tests/test_scheduler_bokeh.py +++ b/distributed/dashboard/tests/test_scheduler_bokeh.py @@ -300,10 +300,13 @@ def test_WorkerTable(c, s, a, b): for L in wt.source.data.values() for v in L ), {type(v).__name__ for L in wt.source.data.values() for v in L} - assert all(len(v) == 2 for v in wt.source.data.values()) + + assert all(len(v) == 3 for v in wt.source.data.values()) + assert wt.source.data["name"][0] == "Total (2)" nthreads = wt.source.data["nthreads"] assert all(nthreads) + assert nthreads[0] == nthreads[1] + nthreads[2] @gen_cluster(client=True) @@ -334,7 +337,7 @@ def metric_address(worker): assert name in data assert all(data.values()) - assert all(len(v) == 2 for v in data.values()) + assert all(len(v) == 3 for v in data.values()) my_index = data["address"].index(a.address), data["address"].index(b.address) assert [data["metric_port"][i] for i in my_index] == [a.port, b.port] assert [data["metric_address"][i] for i in my_index] == [a.address, b.address] @@ -359,7 +362,7 @@ def metric_port(worker): assert "metric_a" in data assert "metric_b" in data assert all(data.values()) - assert all(len(v) == 2 for v in data.values()) + assert all(len(v) == 3 for v in data.values()) my_index = data["address"].index(a.address), data["address"].index(b.address) assert [data["metric_a"][i] for i in my_index] == [a.port, None] assert [data["metric_b"][i] for i in my_index] == [None, b.port] @@ -379,7 +382,7 @@ def metric_port(worker): assert "metric_a" in data assert all(data.values()) - assert all(len(v) == 2 for v in data.values()) + assert all(len(v) == 3 for v in data.values()) my_index = data["address"].index(a.address), data["address"].index(b.address) assert [data["metric_a"][i] for i in my_index] == [a.port, None] From e3c97aec94c6df505c59c8a69210339f4a008ac6 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Tue, 11 Feb 2020 07:25:14 -0600 Subject: [PATCH 0671/1550] Update minimum tblib version to 1.6.0 (#3451) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 545eba40c4d..87e148bb244 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ dask >= 2.9.0 msgpack psutil >= 5.0 sortedcontainers !=2.0.0, !=2.0.1 -tblib +tblib >= 1.6.0 toolz >= 0.7.4 tornado >= 5 zict >= 0.1.3 From f561e9646a121f74d1aecc6d1ee31baeabffad49 Mon Sep 17 00:00:00 2001 From: rockwellw Date: Tue, 11 Feb 2020 08:09:48 -0800 Subject: [PATCH 0672/1550] Update comparison logic for worker state (#3321) --- distributed/scheduler.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index b332b92349d..b56c08eb38b 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -275,6 +275,12 @@ def __init__( self.extra = extra or {} + def __hash__(self): + return hash((self.name, self.host)) + + def __eq__(self, other): + return type(self) == type(other) and hash(self) == hash(other) + @property def host(self): return get_address_host(self.address) @@ -2603,7 +2609,7 @@ def handle_release_data(self, key=None, worker=None, client=None, **msg): if ts is None: return ws = self.workers[worker] - if ts.processing_on is not ws: + if ts.processing_on != ws: return r = self.stimulus_missing_data(key=key, ensure=False, **msg) self.transitions(r) @@ -4062,7 +4068,7 @@ def transition_processing_memory( if ws is None: return {key: "released"} - if ws is not ts.processing_on: # someone else has this task + if ws != ts.processing_on: # someone else has this task logger.info( "Unexpected worker completed task, likely due to" " work stealing. Expected: %s, Got: %s, Key: %s", From 386a9d836b05b2e3f8daacdeb97f8081e987272b Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Wed, 12 Feb 2020 09:45:24 -0600 Subject: [PATCH 0673/1550] Minor gen.Return cleanup (#3469) --- distributed/actor.py | 2 -- distributed/client.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/distributed/actor.py b/distributed/actor.py index 37f43b69358..69172bf23ec 100644 --- a/distributed/actor.py +++ b/distributed/actor.py @@ -1,5 +1,4 @@ import asyncio -from tornado import gen import functools import threading from queue import Queue @@ -169,7 +168,6 @@ async def get_actor_attribute_from_worker(): attribute=key, actor=self.key ) return x["result"] - raise gen.Return(x["result"]) return self._sync(get_actor_attribute_from_worker) diff --git a/distributed/client.py b/distributed/client.py index 36c487dc334..dbc4cd5bc11 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -1909,7 +1909,7 @@ async def _scatter( isinstance(k, (bytes, str)) for k in data ): d = await self._scatter(keymap(tokey, data), workers, broadcast) - raise gen.Return({k: d[tokey(k)] for k in data}) + return {k: d[tokey(k)] for k in data} if isinstance(data, type(range(0))): data = list(data) From 51f1a22fb8742c90f9870d224a9336235957a3df Mon Sep 17 00:00:00 2001 From: Dustin Tindall Date: Wed, 12 Feb 2020 09:47:53 -0600 Subject: [PATCH 0674/1550] Update locality.rst (#3470) --- docs/source/locality.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/locality.rst b/docs/source/locality.rst index 4ee6568395d..caab601f191 100644 --- a/docs/source/locality.rst +++ b/docs/source/locality.rst @@ -73,7 +73,7 @@ used. allow_other_workers=True) Additionally the ``scatter`` function supports a ``broadcast=`` keyword -argument to enforce that the all data is sent to all workers rather than +argument to enforce that all the data is sent to all workers rather than round-robined. If new workers arrive they will not automatically receive this data. From 346b2dbc56d682aa96d0b3f072f7411ec7d695a0 Mon Sep 17 00:00:00 2001 From: Jacob Tomlinson Date: Wed, 12 Feb 2020 17:42:22 +0000 Subject: [PATCH 0675/1550] Split dashboard host on additional slashes to handle inproc (#3466) --- distributed/deploy/cluster.py | 2 +- distributed/tests/test_client.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/distributed/deploy/cluster.py b/distributed/deploy/cluster.py index ad071a214be..c616f13c826 100644 --- a/distributed/deploy/cluster.py +++ b/distributed/deploy/cluster.py @@ -199,7 +199,7 @@ def dashboard_link(self): except KeyError: return "" else: - host = self.scheduler_address.split("://")[1].split(":")[0] + host = self.scheduler_address.split("://")[1].split("/")[0].split(":")[0] return format_dashboard_link(host, port) def _widget_status(self): diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 8abbd89386d..4f075d582f3 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -5255,6 +5255,13 @@ def test_dashboard_link(loop, monkeypatch): assert link in text +@pytest.mark.asyncio +async def test_dashboard_link_inproc(cleanup): + async with Client(processes=False, asynchronous=True) as c: + with dask.config.set({"distributed.dashboard.link": "{host}"}): + assert "/" not in c.dashboard_link + + @gen_test() def test_client_timeout_2(): with dask.config.set({"distributed.comm.timeouts.connect": "10ms"}): From 7c1b4dfdc1dc9d690ee1731a8fc94f5665620de7 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 12 Feb 2020 09:43:09 -0800 Subject: [PATCH 0676/1550] Change default multiprocessing behavior to spawn (#3461) --- distributed/deploy/tests/test_ssh.py | 4 ++-- distributed/distributed.yaml | 2 +- distributed/nanny.py | 1 + 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/distributed/deploy/tests/test_ssh.py b/distributed/deploy/tests/test_ssh.py index eff7bf05a11..af6bf1566f2 100644 --- a/distributed/deploy/tests/test_ssh.py +++ b/distributed/deploy/tests/test_ssh.py @@ -38,14 +38,14 @@ async def test_keywords(): "memory_limit": "2 GiB", "death_timeout": "5s", }, - scheduler_options={"idle_timeout": "5s", "port": 0}, + scheduler_options={"idle_timeout": "10s", "port": 0}, ) as cluster: async with Client(cluster, asynchronous=True) as client: assert ( await client.run_on_scheduler( lambda dask_scheduler: dask_scheduler.idle_timeout ) - ) == 5 + ) == 10 d = client.scheduler_info()["workers"] assert all(v["nthreads"] == 2 for v in d.values()) diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index ee38750f8ee..487e72e215e 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -39,7 +39,7 @@ distributed: worker: blocked-handlers: [] - multiprocessing-method: forkserver + multiprocessing-method: spawn use-file-locking: True connections: # Maximum concurrent connections for data outgoing: 50 # This helps to control network saturation diff --git a/distributed/nanny.py b/distributed/nanny.py index 5e67b51b52b..ff653ba096c 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -597,6 +597,7 @@ async def kill(self, timeout=2, executor_wait=True): "executor_wait": executor_wait, } ) + await asyncio.sleep(0) # otherwise we get broken pipe errors self.child_stop_q.close() while process.is_alive() and loop.time() < deadline: From cf051bf7050a00ff7c5f738283b2634fbb5b5178 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 12 Feb 2020 09:43:32 -0800 Subject: [PATCH 0677/1550] Clear old docs (#3458) * Remove local-cluster doc * remove adaptive and prometheus docs * remove dashboard docs --- docs/source/adaptive.rst | 175 -------------------------- docs/source/conf.py | 7 ++ docs/source/index.rst | 5 - docs/source/local-cluster.rst | 50 -------- docs/source/prometheus.rst | 42 ------- docs/source/web.rst | 229 ---------------------------------- 6 files changed, 7 insertions(+), 501 deletions(-) delete mode 100644 docs/source/adaptive.rst delete mode 100644 docs/source/local-cluster.rst delete mode 100644 docs/source/prometheus.rst delete mode 100644 docs/source/web.rst diff --git a/docs/source/adaptive.rst b/docs/source/adaptive.rst deleted file mode 100644 index f07246588cd..00000000000 --- a/docs/source/adaptive.rst +++ /dev/null @@ -1,175 +0,0 @@ -Adaptive Deployments -==================== - -It is possible to grow and shrink Dask clusters based on current use. This -allows you to run Dask permanently on your cluster and have it only take up -resources when necessary. Dask contains the logic about when to grow and -shrink but relies on external cluster managers to launch and kill -``dask-worker`` jobs. This page describes the policies about adaptively -resizing Dask clusters based on load, how to connect these policies to a -particular job scheduler, and an example implementation. - -Dynamically scaling a Dask cluster up and down requires tight integration with -an external cluster management system that can deploy ``dask-worker`` jobs -throughout the cluster. Several systems are in wide use today, including -common examples like SGE, SLURM, Torque, Condor, LSF, Yarn, Mesos, Marathon, -Kubernetes, etc... These systems can be quite different from each other, but -all are used to manage distributed services throughout different kinds of -clusters. - -The large number of relevant systems, the challenges of rigorously testing -each, and finite development time precludes the systematic inclusion of all -solutions within the dask/distributed repository. Instead, we include a -generic interface that can be extended by someone with basic understanding of -their cluster management tool. We encourage these as third party modules. - - -Policies --------- - -We control the number of workers based on current load and memory use. The -scheduler checks itself periodically to determine if more or fewer workers are -needed. - -If there are excess unclaimed tasks, or if the memory of the current workers is -more nearing full then the scheduler tries to increase the number of workers by -a fixed factor, defaulting to 2. This causes exponential growth while growth -is useful. - -If there are idle workers and if the memory of the current workers is nearing -empty then we gracefully retire the idle workers with the least amount of data -in memory. We first move these results to the surviving workers and then -remove the idle workers from the cluster. This shrinks the cluster while -gracefully preserving intermediate results, shrinking the cluster when excess -size is not useful. - - -Adaptive class interface ------------------------- - -The ``distributed.deploy.Adaptive`` class contains the logic about when to ask -for new workers, and when to close idle ones. This class requires both a -scheduler and a cluster object. - -The cluster object must support two methods, ``scale_up(n, **kwargs)``, which -takes in a target number of total workers for the cluster and -``scale_down(workers)``, which takes in a list of addresses to remove from the -cluster. The Adaptive class will call these methods with the correct values at -the correct times. - -.. code-block:: python - - class MyCluster: - async def scale_up(self, n, **kwargs): - """ - Bring the total count of workers up to ``n`` - - This function/coroutine should bring the total number of workers up to - the number ``n``. - - This can be implemented either as a function or as a Tornado coroutine. - """ - raise NotImplementedError() - - async def scale_down(self, workers): - """ - Remove ``workers`` from the cluster - - Given a list of worker addresses this function should remove those - workers from the cluster. This may require tracking which jobs are - associated to which worker address. - - This can be implemented either as a function or as a Tornado coroutine. - """ - - from distributed.deploy import Adaptive - - scheduler = Scheduler() - cluster = MyCluster() - adapative_cluster = Adaptive(scheduler, cluster) - scheduler.start() - -Implementing these ``scale_up`` and ``scale_down`` functions depends strongly -on the cluster management system. See :doc:`LocalCluster ` for -an example. - - -Marathon: an example --------------------- - -We now present an example project that implements this cluster interface backed -by the Marathon cluster management tool on Mesos. Full source code and testing -apparatus is available here: http://github.com/mrocklin/dask-marathon - -The implementation is small. It uses the Marathon HTTP API through the -`marathon Python client library `_. -We reproduce the full body of the implementation below as an example: - -.. code-block:: python - - from marathon import MarathonClient, MarathonApp - from marathon.models.container import MarathonContainer - - class MarathonCluster: - def __init__(self, scheduler, - executable='dask-worker', - docker_image='mrocklin/dask-distributed', - marathon_address='http://localhost:8080', - name=None, **kwargs): - self.scheduler = scheduler - - # Create Marathon App to run dask-worker - args = [executable, scheduler.address, - '--name', '$MESOS_TASK_ID'] # use Mesos task ID as worker name - if 'mem' in kwargs: - args.extend(['--memory-limit', - str(int(kwargs['mem'] * 0.6 * 1e6))]) - kwargs['cmd'] = ' '.join(args) - container = MarathonContainer({'image': docker_image}) - - app = MarathonApp(instances=0, container=container, **kwargs) - - # Connect and register app - self.client = MarathonClient(marathon_address) - self.app = self.client.create_app(name or 'dask-%s' % uuid.uuid4(), app) - - def scale_up(self, instances): - self.marathon_client.scale_app(self.app.id, instances=instances) - - def scale_down(self, workers): - for w in workers: - self.marathon_client.kill_task(self.app.id, - self.scheduler.worker_info[w]['name'], - scale=True) - -Subclassing Adaptive --------------------- - -The default behaviors of ``Adaptive`` controlling when to scale up or down, and -by how much, may not be appropriate for your cluster manager or workload. For -example, you may have tasks that require a worker with more memory than usual. -This means we need to pass through some additional keyword arguments to -``cluster.scale_up`` call. - -.. code-block:: python - - from distributed.deploy import Adaptive - - class MyAdaptive(Adaptive): - def get_scale_up_kwargs(self): - kwargs = super(Adaptive, self).get_scale_up_kwargs() - # resource_restrictions maps task keys to a dict of restrictions - restrictions = self.scheduler.resource_restrictions.values() - memory_restrictions = [x.get('memory') for x in restrictions - if 'memory' in x] - - if memory_restrictions: - kwargs['memory'] = max(memory_restrictions) - - return kwargs - - -So if there are any tasks that are waiting to be run on a worker with enough -memory, the ``kwargs`` dictionary passed to ``cluster.scale_up`` will include -a key and value for ``'memory'`` (your ``Cluster.scale_up`` method needs to be -able to support this). diff --git a/docs/source/conf.py b/docs/source/conf.py index 9bd0ce6867e..f8ab5a31797 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -393,6 +393,13 @@ ("setup.html", "https://docs.dask.org/en/latest/setup.html"), ("ec2.html", "https://docs.dask.org/en/latest/setup/cloud.html"), ("configuration.html", "https://docs.dask.org/en/latest/configuration.html"), + ( + "local-cluster.html", + "https://docs.dask.org/en/latest/setup/single-distributed.html", + ), + ("adaptive.html", "https://docs.dask.org/en/latest/setup/adaptive.html"), + ("prometheus.html", "https://docs.dask.org/en/latest/setup/prometheus.html"), + ("web.html", "https://docs.dask.org/en/latest/diagnostics-distributed.html"), ] diff --git a/docs/source/index.rst b/docs/source/index.rst index 47419e014ec..3cbdd18792a 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -106,19 +106,14 @@ Contents :caption: Additional Features actors - adaptive asynchronous configuration - local-cluster ipython - Joblib Integration prometheus publish - queues resources task-launch tls - web .. toctree:: :maxdepth: 1 diff --git a/docs/source/local-cluster.rst b/docs/source/local-cluster.rst deleted file mode 100644 index d596ccaed24..00000000000 --- a/docs/source/local-cluster.rst +++ /dev/null @@ -1,50 +0,0 @@ -Local Cluster -============= - -For convenience you can start a local cluster from your Python session. - -.. code-block:: python - - >>> from distributed import Client, LocalCluster - >>> cluster = LocalCluster() - LocalCluster("127.0.0.1:8786", workers=8, nthreads=8) - >>> client = Client(cluster) - - -You can dynamically scale this cluster up and down: - -.. code-block:: python - - >>> worker = cluster.add_worker() - >>> cluster.remove_worker(worker) - -Alternatively, a ``LocalCluster`` is made for you automatically if you create -an ``Client`` with no arguments: - -.. code-block:: python - - >>> from distributed import Client - >>> client = Client() - >>> client - - -.. note:: - - Within a Python script you need to start a local cluster in the - ``if __name__ == '__main__'`` block: - - .. code-block:: python - - if __name__ == '__main__': - cluster = LocalCluster() - client = Client(cluster) - # Your code follows here - -API ---- - -.. currentmodule:: distributed.deploy.local - -.. autoclass:: LocalCluster - :members: - diff --git a/docs/source/prometheus.rst b/docs/source/prometheus.rst deleted file mode 100644 index 097335ee0d7..00000000000 --- a/docs/source/prometheus.rst +++ /dev/null @@ -1,42 +0,0 @@ -Prometheus Monitoring ------------------------ - -Prometheus_ is a widely popular tool for monitoring and alerting a wide variety of systems. Dask.distributed exposes -scheduler and worker metrics in a prometheus text based format. Metrics are available at ``http://scheduler-address:8787/metrics``. - -.. _Prometheus: https://prometheus.io - -Available metrics are as following - -+---------------------------------------------+------------------------------------------------+-----------+--------+ -| Metric name | Description | Scheduler | Worker | -+=========================+===================+================================================+===========+========+ -| python_gc_objects_collected_total | Objects collected during gc. | Yes | Yes | -+---------------------------------------------+------------------------------------------------+-----------+--------+ -| python_gc_objects_uncollectable_total | Uncollectable object found during GC. | Yes | Yes | -+---------------------------------------------+------------------------------------------------+-----------+--------+ -| python_gc_collections_total | Number of times this generation was collected. | Yes | Yes | -+---------------------------------------------+------------------------------------------------+-----------+--------+ -| python_info | Python platform information. | Yes | Yes | -+---------------------------------------------+------------------------------------------------+-----------+--------+ -| dask_scheduler_workers | Number of workers connected. | Yes | | -+---------------------------------------------+------------------------------------------------+-----------+--------+ -| dask_scheduler_clients | Number of clients connected. | Yes | | -+---------------------------------------------+------------------------------------------------+-----------+--------+ -| dask_scheduler_tasks | Number of tasks at scheduler. | Yes | | -+---------------------------------------------+------------------------------------------------+-----------+--------+ -| dask_worker_tasks | Number of tasks at worker. | | Yes | -+---------------------------------------------+------------------------------------------------+-----------+--------+ -| dask_worker_connections | Number of task connections to other workers. | | Yes | -+---------------------------------------------+------------------------------------------------+-----------+--------+ -| dask_worker_threads | Number of worker threads. | | Yes | -+---------------------------------------------+------------------------------------------------+-----------+--------+ -| dask_worker_latency_seconds | Latency of worker connection. | | Yes | -+---------------------------------------------+------------------------------------------------+-----------+--------+ -| dask_worker_tick_duration_median_seconds | Median tick duration at worker. | | Yes | -+---------------------------------------------+------------------------------------------------+-----------+--------+ -| dask_worker_task_duration_median_seconds | Median task runtime at worker. | | Yes | -+---------------------------------------------+------------------------------------------------+-----------+--------+ -| dask_worker_transfer_bandwidth_median_bytes | Bandwidth for transfer at worker in Bytes. | | Yes | -+---------------------------------------------+------------------------------------------------+-----------+--------+ - diff --git a/docs/source/web.rst b/docs/source/web.rst deleted file mode 100644 index 6a5b58fac5e..00000000000 --- a/docs/source/web.rst +++ /dev/null @@ -1,229 +0,0 @@ -Web Interface -============= - -.. raw:: html - - - -Information about the current state of the network helps to track progress, -identify performance issues, and debug failures. - -Dask.distributed includes a web interface to help deliver this information over -a normal web page in real time. This web interface is launched by default -wherever the scheduler is launched if the scheduler machine has Bokeh_ -installed (``conda install bokeh -c bokeh``). - -These diagnostic pages are: - -* Main Scheduler pages at ``http://scheduler-address:8787``. These pages, - particularly the ``/status`` page are the main page that most people - associate with Dask. These pages are served from a separate standalone - Bokeh server application running in a separate process. - -The available pages are ``http://scheduler-address:8787//`` where ```` is one of - -- ``status``: a stream of recently run tasks, progress bars, resource use -- ``tasks``: a larger stream of the last 100k tasks -- ``workers``: basic information about workers and their current load -- ``health``: basic health check, returns ``ok`` if service is running - -.. _Bokeh: http://bokeh.pydata.org/en/latest/ - -Plots ------ - -Example Computation -~~~~~~~~~~~~~~~~~~~ - -The following plots show a trace of the following computation: - -.. code-block:: python - - from distributed import Client - from time import sleep - import random - - def inc(x): - sleep(random.random() / 10) - return x + 1 - - def dec(x): - sleep(random.random() / 10) - return x - 1 - - def add(x, y): - sleep(random.random() / 10) - return x + y - - - client = Client('127.0.0.1:8786') - - incs = client.map(inc, range(100)) - decs = client.map(dec, range(100)) - adds = client.map(add, incs, decs) - total = client.submit(sum, adds) - - del incs, decs, adds - total.result() - -Progress -~~~~~~~~ - -The interface shows the progress of the various computations as well as the -exact number completed. - -.. image:: https://raw.githubusercontent.com/dask/dask-org/master/images/bokeh-progress.gif - :alt: Resources view of Dask web interface - -Each bar is assigned a color according to the function being run. Each bar -has a few components. On the left the lighter shade is the number of tasks -that have both completed and have been released from memory. The darker shade -to the right corresponds to the tasks that are completed and whose data still -reside in memory. If errors occur then they appear as a black colored block -to the right. - -Typical computations may involve dozens of kinds of functions. We handle this -visually with the following approaches: - -1. Functions are ordered by the number of total tasks -2. The colors are assigned in a round-robin fashion from a standard palette -3. The progress bars shrink horizontally to make space for more functions -4. Only the largest functions (in terms of number of tasks) are displayed - -.. image:: https://raw.githubusercontent.com/dask/dask-org/master/images/bokeh-progress-large.gif - :alt: Progress bar plot of Dask web interface - -Counts of tasks processing, waiting for dependencies, processing, etc.. are -displayed in the title bar. - -Memory Use -~~~~~~~~~~ - -The interface shows the relative memory use of each function with a horizontal -bar sorted by function name. - -.. image:: https://raw.githubusercontent.com/dask/dask-org/master/images/bokeh-memory-use.gif - :alt: Memory use plot of Dask web interface - -The title shows the number of total bytes in use. Hovering over any bar -tells you the specific function and how many bytes its results are actively -taking up in memory. This does not count data that has been released. - -Task Stream -~~~~~~~~~~~ - -The task stream plot shows when tasks complete on which workers. Worker cores -are on the y-axis and time is on the x-axis. As a worker completes a task its -start and end times are recorded and a rectangle is added to this plot -accordingly. - -.. image:: https://raw.githubusercontent.com/dask/dask-org/master/images/bokeh-task-stream.gif - :alt: Task stream plot of Dask web interface - -The colors signifying the following: - -1. Serialization (gray) -2. Communication between workers (red) -3. Disk I/O (orange) -4. Error (black) -5. Execution times (colored by task: purple, green, yellow, etc) - - -If data transfer occurs between workers a *red* bar appears preceding the -task bar showing the duration of the transfer. If an error occurs than a -*black* bar replaces the normal color. This plot show the last 1000 tasks. -It resets if there is a delay greater than 10 seconds. - -For a full history of the last 100,000 tasks see the ``tasks/`` page. - -Resources -~~~~~~~~~ - -The resources plot show the average CPU and Memory use over time as well as -average network traffic. More detailed information on a per-worker basis is -available in the ``workers/`` page. - -.. image:: https://raw.githubusercontent.com/dask/dask-org/master/images/bokeh-resources.gif - :alt: Resources view of Dask web interface - -Per-worker resources -~~~~~~~~~~~~~~~~~~~~ - -The ``workers/`` page shows per-worker resources, the main ones being CPU and -memory use. Custom metrics can be registered and displayed in this page. Here -is an example showing how to display GPU utilization and GPU memory use: - -.. code-block:: python - - import subprocess - - def nvidia_data(name): - def dask_function(dask_worker): - cmd = 'nvidia-smi --query-gpu={} --format=csv,noheader'.format(name) - result = subprocess.check_output(cmd.split()) - return result.strip().decode() - return dask_function - - def register_metrics(dask_worker): - for name in ['utilization.gpu', 'utilization.memory']: - dask_worker.metrics[name] = nvidia_data(name) - - client.run(register_metrics) - -Connecting to Web Interface ---------------------------- - -Default -~~~~~~~ - -By default, ``dask-scheduler`` prints out the address of the web interface:: - - INFO - Bokeh UI at: http://10.129.39.91:8787/status - ... - INFO - Starting Bokeh server on port 8787 with applications at paths ['/status', '/tasks'] - -The machine hosting the scheduler runs an HTTP server serving at that address. - - -Troubleshooting ---------------- - -Some clusters restrict the ports that are visible to the outside world. These -ports may include the default port for the web interface, ``8787``. There are -a few ways to handle this: - -1. Open port ``8787`` to the outside world. Often this involves asking your - cluster administrator. -2. Use a different port that is publicly accessible using the - ``--dashboard-address :8787`` option on the ``dask-scheduler`` command. -3. Use fancier techniques, like `Port Forwarding`_ - -Running distributed on a remote machine can cause issues with viewing the web -UI -- this depends on the remote machines network configuration. - -.. _`Port Forwarding`: https://en.wikipedia.org/wiki/Port_forwarding - - -Port Forwarding -~~~~~~~~~~~~~~~ - -If you have SSH access then one way to gain access to a blocked port is through -SSH port forwarding. A typical use case looks like the following: - -.. code:: bash - - local$ ssh -L 8000:localhost:8787 user@remote - remote$ dask-scheduler # now, the web UI is visible at localhost:8000 - remote$ # continue to set up dask if needed -- add workers, etc - -It is then possible to go to ``localhost:8000`` and see Dask Web UI. This same approach is -not specific to dask.distributed, but can be used by any service that operates over a -network, such as Jupyter notebooks. For example, if we chose to do this we could -forward port 8888 (the default Jupyter port) to port 8001 with -``ssh -L 8001:localhost:8888 user@remote``. From 49328dcd0e13556c7b3bae8c3d542786cbb07703 Mon Sep 17 00:00:00 2001 From: condoratberlin <49398997+condoratberlin@users.noreply.github.com> Date: Wed, 12 Feb 2020 18:50:37 +0100 Subject: [PATCH 0678/1550] Change default value of local_directory from empty string to None (#3441) Fixes #3440 --- distributed/cli/dask_worker.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/distributed/cli/dask_worker.py b/distributed/cli/dask_worker.py index e76bed2e9bc..5188333b75c 100755 --- a/distributed/cli/dask_worker.py +++ b/distributed/cli/dask_worker.py @@ -151,12 +151,12 @@ ) @click.option("--pid-file", type=str, default="", help="File to write the process PID") @click.option( - "--local-directory", default="", type=str, help="Directory to place worker files" + "--local-directory", default=None, type=str, help="Directory to place worker files" ) @click.option( "--resources", type=str, - default="", + default=None, help='Resources for task constraints like "GPU=2 MEM=10e9". ' "Resources are applied separately to each worker process " "(only relevant when starting multiple worker processes with '--nprocs').", @@ -164,7 +164,7 @@ @click.option( "--scheduler-file", type=str, - default="", + default=None, help="Filename to JSON encoded scheduler information. " "Use with dask-scheduler --scheduler-file", ) @@ -180,7 +180,7 @@ @click.option( "--lifetime", type=str, - default="", + default=None, help="If provided, shut down the worker after this duration.", ) @click.option( From ca88aa7327820fc934b2254b884721cea94991e5 Mon Sep 17 00:00:00 2001 From: kaelgreco Date: Thu, 13 Feb 2020 08:43:26 -0800 Subject: [PATCH 0679/1550] Add last seen column to worker table and highlight errant workers (#3468) --- distributed/dashboard/scheduler.py | 4 +++- distributed/dashboard/templates/task.html | 4 ++-- distributed/dashboard/templates/worker-table.html | 6 ++++-- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/distributed/dashboard/scheduler.py b/distributed/dashboard/scheduler.py index 17e150e8df9..836cefbbd6c 100644 --- a/distributed/dashboard/scheduler.py +++ b/distributed/dashboard/scheduler.py @@ -50,12 +50,14 @@ from .proxy import GlobalProxyHandler from .utils import RequestHandler, redirect from ..diagnostics.websocket import WebsocketPlugin +from ..metrics import time from ..utils import log_errors, format_time from ..scheduler import ALL_TASK_STATES ns = { - func.__name__: func for func in [format_bytes, format_time, datetime.fromtimestamp] + func.__name__: func + for func in [format_bytes, format_time, datetime.fromtimestamp, time] } rel_path_statics = {"rel_path_statics": "../../"} diff --git a/distributed/dashboard/templates/task.html b/distributed/dashboard/templates/task.html index 8c292da4e43..bcc0d17c0a8 100644 --- a/distributed/dashboard/templates/task.html +++ b/distributed/dashboard/templates/task.html @@ -122,9 +122,9 @@

        Transition Log

        Recommended Action - {% for key, start, finish, recommendations, time in scheduler.story(Task) %} + {% for key, start, finish, recommendations, transition_time in scheduler.story(Task) %} - {{ fromtimestamp(time) }} + {{ fromtimestamp(transition_time) }} {{key}} {{ start }} {{ finish }} diff --git a/distributed/dashboard/templates/worker-table.html b/distributed/dashboard/templates/worker-table.html index c12061fab46..87512ee3860 100644 --- a/distributed/dashboard/templates/worker-table.html +++ b/distributed/dashboard/templates/worker-table.html @@ -1,4 +1,4 @@ - +
        @@ -10,9 +10,10 @@ + {% for ws in worker_list %} - + 60 else ""}}> @@ -27,6 +28,7 @@ {% end %} + {% end %}
        Worker Name In-memory Services Logs Last seen
        {{ws.address}} {{ ws.name if ws.name is not None else "" }} {{ ws.nthreads }} logs {{ format_time(time() - ws.last_seen) }}
        From 04de4b2adc1f0c94ab86e9aa46a4c67382a7eaea Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Thu, 13 Feb 2020 16:10:34 -0600 Subject: [PATCH 0680/1550] Add support for Python 3.8 (#3249) --- .github/workflows/ci-windows.yaml | 15 ++++++++++++- .travis.yml | 1 + continuous_integration/environment.yml | 1 - continuous_integration/travis/install.sh | 19 +++++++++++------ distributed/__init__.py | 2 +- distributed/client.py | 8 ++++--- distributed/comm/ucx.py | 8 ++----- distributed/core.py | 2 +- distributed/deploy/tests/test_local.py | 7 ++++++ distributed/protocol/tests/test_pickle.py | 6 ++++++ distributed/tests/test_as_completed.py | 2 +- distributed/tests/test_client.py | 2 +- distributed/tests/test_client_executor.py | 2 +- distributed/tests/test_failed_workers.py | 3 +-- distributed/tests/test_steal.py | 5 +++++ distributed/tests/test_stress.py | 3 +-- distributed/tests/test_worker.py | 5 +++++ distributed/utils.py | 26 ++++++++++++++++++----- requirements.txt | 3 ++- setup.py | 1 + 20 files changed, 88 insertions(+), 33 deletions(-) diff --git a/.github/workflows/ci-windows.yaml b/.github/workflows/ci-windows.yaml index 992f65d6435..707c50ffc76 100644 --- a/.github/workflows/ci-windows.yaml +++ b/.github/workflows/ci-windows.yaml @@ -7,7 +7,7 @@ jobs: runs-on: windows-latest strategy: matrix: - python-version: ["3.6", "3.7"] + python-version: ["3.6", "3.7", "3.8"] steps: - name: Checkout source @@ -22,6 +22,19 @@ jobs: activate-environment: testenv auto-activate-base: false + - name: Install tornado + shell: bash -l {0} + run: | + if [[ "${{ matrix.python-version }}" = "3.8" ]]; then + conda install -c conda-forge tornado=6 + else + conda install -c conda-forge tornado=5 + fi + + - name: List packages in environment + shell: bash -l {0} + run: conda list + - name: Install distributed from source shell: bash -l {0} run: python -m pip install -q --no-deps -e . diff --git a/.travis.yml b/.travis.yml index d00894dd3d6..e8f2afc5057 100644 --- a/.travis.yml +++ b/.travis.yml @@ -8,6 +8,7 @@ env: matrix: - PYTHON=3.6 TESTS=true COVERAGE=true PACKAGES="scikit-learn lz4" TORNADO=5 CRICK=true - PYTHON=3.7 TESTS=true PACKAGES="scikit-learn python-snappy python-blosc" TORNADO=6 + - PYTHON=3.8 TESTS=true PACKAGES="scikit-learn python-snappy python-blosc" TORNADO=6 matrix: fast_finish: true diff --git a/continuous_integration/environment.yml b/continuous_integration/environment.yml index f6651254af2..f69d919e879 100644 --- a/continuous_integration/environment.yml +++ b/continuous_integration/environment.yml @@ -21,7 +21,6 @@ dependencies: - requests - toolz - tblib - - tornado=5 - zict - fsspec - pip diff --git a/continuous_integration/travis/install.sh b/continuous_integration/travis/install.sh index 8eaed19df81..b631ac3bc6c 100644 --- a/continuous_integration/travis/install.sh +++ b/continuous_integration/travis/install.sh @@ -27,7 +27,8 @@ conda create -q -n test-environment python=$PYTHON source activate test-environment # Install dependencies -conda install -q \ +conda install -c conda-forge -q \ + asyncssh \ bokeh \ click \ coverage \ @@ -48,15 +49,20 @@ conda install -q \ python=$PYTHON \ requests \ scipy \ - tblib \ + tblib>=1.5.0 \ toolz \ tornado=$TORNADO \ + zstandard \ $PACKAGES -# For low-level profiler, install libunwind and stacktrace from conda-forge -# For stacktrace we use --no-deps to avoid upgrade of python -conda install -c defaults -c conda-forge libunwind zstandard asyncssh -conda install --no-deps -c defaults -c numba -c conda-forge stacktrace +# stacktrace is not currently avaiable for Python 3.8. +# Remove the version check block below when it is avaiable. +if [[ $PYTHON != 3.8 ]]; then + # For low-level profiler, install libunwind and stacktrace from conda-forge + # For stacktrace we use --no-deps to avoid upgrade of python + conda install -c defaults -c conda-forge libunwind + conda install --no-deps -c defaults -c numba -c conda-forge stacktrace +fi; python -m pip install -q "pytest>=4" pytest-repeat pytest-faulthandler pytest-asyncio @@ -67,7 +73,6 @@ python -m pip install -q git+https://github.com/dask/s3fs.git --upgrade --no-dep python -m pip install -q git+https://github.com/dask/zict.git --upgrade --no-deps python -m pip install -q sortedcollections msgpack --no-deps python -m pip install -q keras --upgrade --no-deps -python -m pip install -q asyncssh if [[ $CRICK == true ]]; then conda install -q cython diff --git a/distributed/__init__.py b/distributed/__init__.py index 9238d57ccc9..be750f9daed 100644 --- a/distributed/__init__.py +++ b/distributed/__init__.py @@ -24,7 +24,7 @@ from .queues import Queue from .scheduler import Scheduler from .threadpoolexecutor import rejoin -from .utils import sync, TimeoutError +from .utils import sync, TimeoutError, CancelledError from .variable import Variable from .worker import Worker, get_worker, get_client, secede, Reschedule from .worker_client import local_client, worker_client diff --git a/distributed/client.py b/distributed/client.py index dbc4cd5bc11..ce820d2c6e6 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -2,7 +2,7 @@ import atexit from collections import defaultdict from collections.abc import Iterator -from concurrent.futures import ThreadPoolExecutor, CancelledError +from concurrent.futures import ThreadPoolExecutor from concurrent.futures._base import DoneAndNotDoneFutures from contextlib import contextmanager import copy @@ -82,6 +82,7 @@ has_keyword, format_dashboard_link, TimeoutError, + CancelledError, ) from . import versions as version_module @@ -1248,6 +1249,7 @@ async def _close(self, fast=False): """ Send close signal and wait until scheduler completes """ if self.status == "closed": return + self.status = "closing" for pc in self._periodic_callbacks.values(): @@ -1273,7 +1275,7 @@ async def _close(self, fast=False): # Give the scheduler 'stream-closed' message 100ms to come through # This makes the shutdown slightly smoother and quieter - with ignoring(AttributeError, CancelledError, TimeoutError): + with ignoring(AttributeError, asyncio.CancelledError, TimeoutError): await asyncio.wait_for( asyncio.shield(self._handle_scheduler_coroutine), 0.1 ) @@ -1310,7 +1312,7 @@ async def _close(self, fast=False): del self.coroutines[:] if not fast: - with ignoring(TimeoutError): + with ignoring(TimeoutError, asyncio.CancelledError): await asyncio.wait_for(asyncio.gather(*coroutines), 2) with ignoring(AttributeError): diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index 175d628a0f6..629a179e43e 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -8,7 +8,6 @@ import ucp import logging -import concurrent import dask import numpy as np @@ -17,7 +16,7 @@ from .core import Comm, Connector, Listener, CommClosedError from .registry import Backend, backends from .utils import ensure_concrete_host, to_frames, from_frames -from ..utils import ensure_ip, get_ip, get_ipv6, nbytes, log_errors +from ..utils import ensure_ip, get_ip, get_ipv6, nbytes, log_errors, CancelledError import dask import numpy as np @@ -170,10 +169,7 @@ async def read(self, deserializers=("cuda", "dask", "pickle", "error")): await self.ep.recv(is_cudas) sizes = np.empty(nframes[0], dtype=np.uint64) await self.ep.recv(sizes) - except ( - ucp.exceptions.UCXBaseException, - concurrent.futures._base.CancelledError, - ): + except (ucp.exceptions.UCXBaseException, CancelledError): self.abort() raise CommClosedError("While reading, the connection was closed") else: diff --git a/distributed/core.py b/distributed/core.py index ac9be6728fc..5768f0f4d8e 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -1,6 +1,5 @@ import asyncio from collections import defaultdict, deque -from concurrent.futures import CancelledError from functools import partial from inspect import isawaitable import logging @@ -35,6 +34,7 @@ PeriodicCallback, parse_timedelta, has_keyword, + CancelledError, ) from . import protocol diff --git a/distributed/deploy/tests/test_local.py b/distributed/deploy/tests/test_local.py index 98d04c78d17..8ca780a4eb2 100644 --- a/distributed/deploy/tests/test_local.py +++ b/distributed/deploy/tests/test_local.py @@ -7,9 +7,11 @@ from threading import Lock import unittest import weakref +from distutils.version import LooseVersion from tornado.ioloop import IOLoop from tornado import gen +import tornado import pytest from dask.system import CPU_COUNT @@ -451,6 +453,11 @@ async def test_scale_up_and_down(): assert len(cluster.workers) == 1 +@pytest.mark.xfail( + sys.version_info >= (3, 8) and LooseVersion(tornado.version) < "6.0.3", + reason="Known issue with Python 3.8 and Tornado < 6.0.3. See https://github.com/tornadoweb/tornado/pull/2683.", + strict=True, +) def test_silent_startup(): code = """if 1: from time import sleep diff --git a/distributed/protocol/tests/test_pickle.py b/distributed/protocol/tests/test_pickle.py index 0ba776e2758..681992ef844 100644 --- a/distributed/protocol/tests/test_pickle.py +++ b/distributed/protocol/tests/test_pickle.py @@ -2,6 +2,7 @@ import gc from operator import add import weakref +import sys import pytest @@ -23,6 +24,11 @@ def test_pickle_numpy(): assert (loads(dumps(x)) == x).all() +@pytest.mark.xfail( + sys.version_info[:2] == (3, 8), + reason="Sporadic failure on Python 3.8", + strict=False, +) def test_pickle_functions(): def make_closure(): value = 1 diff --git a/distributed/tests/test_as_completed.py b/distributed/tests/test_as_completed.py index 45833b302e1..d0249b121d6 100644 --- a/distributed/tests/test_as_completed.py +++ b/distributed/tests/test_as_completed.py @@ -1,4 +1,3 @@ -from concurrent.futures import CancelledError from collections.abc import Iterator from operator import add import queue @@ -9,6 +8,7 @@ from tornado import gen from distributed.client import _as_completed, as_completed, _first_completed +from distributed.utils import CancelledError from distributed.utils_test import gen_cluster, inc, throws from distributed.utils_test import client, cluster_fixture, loop # noqa: F401 diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 4f075d582f3..392aec73be8 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -1,6 +1,5 @@ import asyncio from collections import deque -from concurrent.futures import CancelledError import gc import logging from operator import add @@ -38,6 +37,7 @@ profile, performance_report, TimeoutError, + CancelledError, ) from distributed.comm import CommClosedError from distributed.client import ( diff --git a/distributed/tests/test_client_executor.py b/distributed/tests/test_client_executor.py index 40639998852..1024990216d 100644 --- a/distributed/tests/test_client_executor.py +++ b/distributed/tests/test_client_executor.py @@ -2,7 +2,6 @@ import time from concurrent.futures import ( - CancelledError, TimeoutError, Future, wait, @@ -15,6 +14,7 @@ from toolz import take from distributed import Client +from distributed.utils import CancelledError from distributed.utils_test import ( slowinc, slowadd, diff --git a/distributed/tests/test_failed_workers.py b/distributed/tests/test_failed_workers.py index 3cc055b5246..cf0387c1cd2 100644 --- a/distributed/tests/test_failed_workers.py +++ b/distributed/tests/test_failed_workers.py @@ -1,4 +1,3 @@ -from concurrent.futures import CancelledError import os import random from time import sleep @@ -12,7 +11,7 @@ from distributed.comm import CommClosedError from distributed.client import wait from distributed.metrics import time -from distributed.utils import sync, ignoring +from distributed.utils import sync, ignoring, CancelledError from distributed.utils_test import ( gen_cluster, cluster, diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 9c4fef57d2a..a6a19332f5f 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -621,6 +621,11 @@ def long(delay): ) <= 1 +@pytest.mark.xfail( + sys.version_info[:2] == (3, 8), + reason="Sporadic failure on Python 3.8", + strict=False, +) @gen_cluster(client=True, nthreads=[("127.0.0.1", 5)] * 2) def test_cleanup_repeated_tasks(c, s, a, b): class Foo: diff --git a/distributed/tests/test_stress.py b/distributed/tests/test_stress.py index 5275bc47fd8..ab996e2b30d 100644 --- a/distributed/tests/test_stress.py +++ b/distributed/tests/test_stress.py @@ -1,4 +1,3 @@ -from concurrent.futures import CancelledError from operator import add import random import sys @@ -12,7 +11,7 @@ from distributed import Client, wait, Nanny from distributed.config import config from distributed.metrics import time -from distributed.utils import All, ignoring +from distributed.utils import All, ignoring, CancelledError from distributed.utils_test import ( gen_cluster, cluster, diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 55dc7faf417..0bc2cf10988 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1141,6 +1141,11 @@ def f(n): @pytest.mark.slow +@pytest.mark.xfail( + sys.version_info[:2] == (3, 8), + reason="Sporadic failure on Python 3.8", + strict=False, +) @gen_cluster( nthreads=[("127.0.0.1", 2)], client=True, diff --git a/distributed/utils.py b/distributed/utils.py index a771a3b280d..429a53cddde 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -3,7 +3,7 @@ import atexit import click from collections import deque, OrderedDict, UserDict -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import ThreadPoolExecutor, CancelledError # noqa: F401 from contextlib import contextmanager import functools from hashlib import md5 @@ -1212,11 +1212,27 @@ def reset_logger_locks(): is_kernel_and_no_running_loop = True if not is_kernel_and_no_running_loop: - import tornado.platform.asyncio - asyncio.set_event_loop_policy( - tornado.platform.asyncio.AnyThreadEventLoopPolicy() - ) + # TODO: Use tornado's AnyThreadEventLoopPolicy, instead of class below, + # once tornado > 6.0.3 is available. + if WINDOWS and hasattr(asyncio, "WindowsSelectorEventLoopPolicy"): + # WindowsProactorEventLoopPolicy is not compatible with tornado 6 + # fallback to the pre-3.8 default of Selector + # https://github.com/tornadoweb/tornado/issues/2608 + BaseEventLoopPolicy = asyncio.WindowsSelectorEventLoopPolicy + else: + BaseEventLoopPolicy = asyncio.DefaultEventLoopPolicy + + class AnyThreadEventLoopPolicy(BaseEventLoopPolicy): + def get_event_loop(self): + try: + return super().get_event_loop() + except (RuntimeError, AssertionError): + loop = self.new_event_loop() + self.set_event_loop(loop) + return loop + + asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy()) @functools.lru_cache(1000) diff --git a/requirements.txt b/requirements.txt index 87e148bb244..49b4d21940b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,8 @@ psutil >= 5.0 sortedcontainers !=2.0.0, !=2.0.1 tblib >= 1.6.0 toolz >= 0.7.4 -tornado >= 5 +tornado >= 5;python_version<'3.8' +tornado >= 6.0.3;python_version>='3.8' zict >= 0.1.3 pyyaml setuptools diff --git a/setup.py b/setup.py index e8c419cb147..155ae0c0274 100755 --- a/setup.py +++ b/setup.py @@ -48,6 +48,7 @@ "Programming Language :: Python", "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", "Topic :: Scientific/Engineering", "Topic :: System :: Distributed Computing", ], From 3f23aa30a50c1a5d382e467a17b74d0c80709a17 Mon Sep 17 00:00:00 2001 From: jakirkham Date: Fri, 14 Feb 2020 12:45:11 -0800 Subject: [PATCH 0681/1550] Register Dask cuDF serializers (#3478) As cuDF is gaining support for serializing using the Dask protocol as well, make sure to register it's serializers there as well. cuDF versions lacking this support will continue to behave the same (falling back to pickle). However cuDF versions with this support will bypass pickle. So only pay the cost of a host-to-device transfer. --- distributed/protocol/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/distributed/protocol/__init__.py b/distributed/protocol/__init__.py index 84ee9420c78..6830f375e35 100644 --- a/distributed/protocol/__init__.py +++ b/distributed/protocol/__init__.py @@ -88,5 +88,7 @@ def _register_rmm(): @cuda_serialize.register_lazy("cudf") @cuda_deserialize.register_lazy("cudf") +@dask_serialize.register_lazy("cudf") +@dask_deserialize.register_lazy("cudf") def _register_cudf(): from cudf.comm import serialize From 288c9577eb7a51a6a662499b686b948059da86ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakub=20Ber=C3=A1nek?= Date: Fri, 14 Feb 2020 21:50:30 +0100 Subject: [PATCH 0682/1550] Do not duplicate messages in scheduler report (#3477) --- distributed/scheduler.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index b56c08eb38b..e6e6adf7ced 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -2439,13 +2439,10 @@ def report(self, msg, ts=None, client=None): If the message contains a key then we only send the message to those comms that care about the key. """ + comms = set() if client is not None: try: - comm = self.client_comms[client] - comm.send(msg) - except CommClosedError: - if self.status == "running": - logger.critical("Tried writing to closed comm: %s", msg) + comms.add(self.client_comms[client]) except KeyError: pass @@ -2453,14 +2450,14 @@ def report(self, msg, ts=None, client=None): ts = self.tasks.get(msg["key"]) if ts is None: # Notify all clients - comms = self.client_comms.values() + comms |= set(self.client_comms.values()) else: # Notify clients interested in key - comms = [ + comms |= { self.client_comms[c.client_key] for c in ts.who_wants if c.client_key in self.client_comms - ] + } for c in comms: try: c.send(msg) From c1f265125ab068a54a4c67501250550dcfff688e Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sat, 15 Feb 2020 17:49:44 -0800 Subject: [PATCH 0683/1550] Remove --verbose flag from CI runs (#3484) Historically we included the verbose flag because if things hung it was useful to see which test in particular was causing the hang. This hasn't been so important recently. It's somewhat annoying to scroll through all of the tests one by one. This seems more important today than the hung test case. * Give test_workspace_concurrency more time This is more important now that we're using spawn --- .github/workflows/ci-windows.yaml | 2 +- continuous_integration/travis/run_tests.sh | 2 +- distributed/tests/test_diskutils.py | 3 ++- setup.cfg | 2 +- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci-windows.yaml b/.github/workflows/ci-windows.yaml index 707c50ffc76..75c4b294e88 100644 --- a/.github/workflows/ci-windows.yaml +++ b/.github/workflows/ci-windows.yaml @@ -43,4 +43,4 @@ jobs: shell: bash -l {0} env: PYTHONFAULTHANDLER: 1 - run: py.test -m "not avoid_travis" distributed --verbose -r s --timeout-method=thread --timeout=300 --durations=20 + run: py.test -m "not avoid_travis" distributed -r s --timeout-method=thread --timeout=300 --durations=20 diff --git a/continuous_integration/travis/run_tests.sh b/continuous_integration/travis/run_tests.sh index dbc0b21ff03..14c3db7750a 100644 --- a/continuous_integration/travis/run_tests.sh +++ b/continuous_integration/travis/run_tests.sh @@ -1,4 +1,4 @@ -export PYTEST_OPTIONS="--verbose -r s --timeout-method=thread --timeout=300 --durations=20" +export PYTEST_OPTIONS="-r s --timeout-method=thread --timeout=300 --durations=20" if [[ $RUNSLOW != false ]]; then export PYTEST_OPTIONS="$PYTEST_OPTIONS --runslow" fi diff --git a/distributed/tests/test_diskutils.py b/distributed/tests/test_diskutils.py index 86b472e184a..f69485cfa46 100644 --- a/distributed/tests/test_diskutils.py +++ b/distributed/tests/test_diskutils.py @@ -272,12 +272,13 @@ def _test_workspace_concurrency(tmpdir, timeout, max_procs): return n_created, n_purged +@pytest.mark.slow def test_workspace_concurrency(tmpdir): if WINDOWS: raise pytest.xfail.Exception("TODO: unknown failure on windows") if sys.version_info < (3, 7): raise pytest.xfail.Exception("TODO: unknown failure on Python 3.6") - _test_workspace_concurrency(tmpdir, 2.0, 6) + _test_workspace_concurrency(tmpdir, 5.0, 6) @pytest.mark.slow diff --git a/setup.cfg b/setup.cfg index 042a8b86f35..764ac7ad02c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -36,7 +36,7 @@ tag_prefix = parentdir_prefix = distributed- [tool:pytest] -addopts = -rsx -v --durations=10 +addopts = -rsx --durations=10 minversion = 3.2 markers = slow: marks tests as slow (deselect with '-m "not slow"') From ba8dad3fc76c54e788e54b8b425d7debf08d12a6 Mon Sep 17 00:00:00 2001 From: Chris Roat <1053153+chrisroat@users.noreply.github.com> Date: Sun, 16 Feb 2020 09:21:21 -0800 Subject: [PATCH 0684/1550] Propose fix for collection based resources docs (#3480) Besides a missing end brace, the dask keys need to be flattened before they can be turned into a tuple (they can contain lists, which are not hashable as dict keys). --- docs/source/resources.rst | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/source/resources.rst b/docs/source/resources.rst index bd4d8b5e81c..f9449dbd8be 100644 --- a/docs/source/resources.rst +++ b/docs/source/resources.rst @@ -96,12 +96,13 @@ delayed objects. You can pass a dictionary mapping keys of the collection to resource requirements during compute or persist calls. .. code-block:: python - + from dask import core + x = dd.read_csv(...) y = x.map_partitions(func1) z = y.map_partitions(func2) - z.compute(resources={tuple(y.__dask_keys__()): {'GPU': 1}) + z.compute(resources={tuple(core.flatten(y.__dask_keys__())): {'GPU': 1}}) In some cases (such as the case above) the keys for ``y`` may be optimized away before execution. You can avoid that either by requiring them as an explicit @@ -110,4 +111,4 @@ output, or by passing the ``optimize_graph=False`` keyword. .. code-block:: python - z.compute(resources={tuple(y.__dask_keys__()): {'GPU': 1}, optimize_graph=False) + z.compute(resources={tuple(core.flatten(y.__dask_keys__())): {'GPU': 1}}, optimize_graph=False) From c9b8fc35e2209855c45eb7b70f585193b65516c1 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Sun, 16 Feb 2020 11:24:06 -0600 Subject: [PATCH 0685/1550] Update NumPy array serialization to handle non-contiguous slices (#3474) --- distributed/protocol/numpy.py | 17 +++++++++++++---- distributed/protocol/tests/test_numpy.py | 1 + 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/distributed/protocol/numpy.py b/distributed/protocol/numpy.py index 9a1f493c333..a2c9c2933e6 100644 --- a/distributed/protocol/numpy.py +++ b/distributed/protocol/numpy.py @@ -44,12 +44,21 @@ def serialize_numpy_ndarray(x): else: dt = (0, x.dtype.str) - # Only serialize non-broadcasted data for arrays with zero strided axes + # Only serialize broadcastable data for arrays with zero strided axes + broadcast_to = None if 0 in x.strides: broadcast_to = x.shape - x = x[tuple(slice(None) if s != 0 else slice(1) for s in x.strides)] - else: - broadcast_to = None + strides = x.strides + writeable = x.flags.writeable + x = x[tuple(slice(None) if s != 0 else slice(1) for s in strides)] + if not x.flags.c_contiguous and not x.flags.f_contiguous: + # Broadcasting can only be done with contiguous arrays + x = np.ascontiguousarray(x) + x = np.lib.stride_tricks.as_strided( + x, + strides=[j if i != 0 else i for i, j in zip(strides, x.strides)], + writeable=writeable, + ) if not x.shape: # 0d array diff --git a/distributed/protocol/tests/test_numpy.py b/distributed/protocol/tests/test_numpy.py index 432b749e27e..e6dfd9764e2 100644 --- a/distributed/protocol/tests/test_numpy.py +++ b/distributed/protocol/tests/test_numpy.py @@ -268,6 +268,7 @@ def test_large_numpy_array(): np.broadcast_to(np.arange(10), (20, 10)), # Some strides are 0 np.broadcast_to(1, (3, 4, 2)), # All strides are 0 np.broadcast_to(np.arange(100)[:1], 5), # x.base is larger than x + np.broadcast_to(np.arange(5), (4, 5))[:, ::-1], ], ) @pytest.mark.parametrize("writeable", [True, False]) From 9408ebcb3f9ef1420640febf67b48c48fdcd2dd8 Mon Sep 17 00:00:00 2001 From: jakirkham Date: Sun, 16 Feb 2020 11:38:54 -0800 Subject: [PATCH 0686/1550] Serialize 1-D, contiguous, `uint8` CUDA frames (#3475) * Space things out to improve readability * Assign `frames` for clarity * Check CUDA contiguous arrays the same way * Use keyword arguments with CUDA array constructors Should make it a little clearer what the arguments relate to. Also makes them less dependent on changes in the signature of the constructors. * Always serialize the CUDA array's `strides` * Make sure CuPy always uses the `strides` provided * Drop workarounds for unknown `strides` As we always include the actual `strides` in the `header`, we can rely on this when deserializing the data. So drop the workarounds added for C-contiguous data where `strides` could be `None` as we still have the exact `strides` in that case. * Convert CUDA arrays into 1-D contiguous arrays * Avoid copying C/F contiguous CUDA arrays * Cast CUDA array data to `uint8` before serializing As we will ultimately read any data in as `uint8`, go ahead and cast to `uint8` as part of serialization. This is a good first order check to make sure that we are able to serialize the dat To simplify handling of the data a bit, go ahead and cast it to `uint8` before serializing. This makes contiguity checks trivial. * Test serialization of F-contiguous CUDA arrays --- distributed/protocol/cupy.py | 28 ++++++++++++++---------- distributed/protocol/numba.py | 24 ++++++++++---------- distributed/protocol/tests/test_cupy.py | 9 +++++--- distributed/protocol/tests/test_numba.py | 7 ++++-- 4 files changed, 40 insertions(+), 28 deletions(-) diff --git a/distributed/protocol/cupy.py b/distributed/protocol/cupy.py index 087de6f9663..3ba5ca51597 100644 --- a/distributed/protocol/cupy.py +++ b/distributed/protocol/cupy.py @@ -6,21 +6,15 @@ class PatchedCudaArrayInterface: - """This class do two things: - 1) Makes sure that __cuda_array_interface__['strides'] - behaves as specified in the protocol. - 2) Makes sure that the cuda context is active + """This class does one thing: + 1) Makes sure that the cuda context is active when deallocating the base cuda array. Notice, this is only needed when the array to deserialize isn't a native cupy array. """ def __init__(self, ary): - cai = ary.__cuda_array_interface__ - cai_cupy_vsn = cupy.ndarray(0).__cuda_array_interface__["version"] - if cai.get("strides") is None and cai_cupy_vsn < 2: - cai.pop("strides", None) - self.__cuda_array_interface__ = cai + self.__cuda_array_interface__ = ary.__cuda_array_interface__ # Save a ref to ary so it won't go out of scope self.base = ary @@ -39,11 +33,18 @@ def __del__(self): @cuda_serialize.register(cupy.ndarray) def serialize_cupy_ndarray(x): # Making sure `x` is behaving - if not x.flags.c_contiguous: + if not (x.flags["C_CONTIGUOUS"] or x.flags["F_CONTIGUOUS"]): x = cupy.array(x, copy=True) header = x.__cuda_array_interface__.copy() - return header, [x] + header["strides"] = tuple(x.strides) + frames = [ + cupy.ndarray( + shape=(x.nbytes,), dtype=cupy.dtype("u1"), memptr=x.data, strides=(1,) + ) + ] + + return header, frames @cuda_deserialize.register(cupy.ndarray) @@ -52,6 +53,9 @@ def deserialize_cupy_array(header, frames): if not isinstance(frame, cupy.ndarray): frame = PatchedCudaArrayInterface(frame) arr = cupy.ndarray( - header["shape"], dtype=header["typestr"], memptr=cupy.asarray(frame).data + shape=header["shape"], + dtype=header["typestr"], + memptr=cupy.asarray(frame).data, + strides=header["strides"], ) return arr diff --git a/distributed/protocol/numba.py b/distributed/protocol/numba.py index 9b33660e2bd..3d2b4879c3b 100644 --- a/distributed/protocol/numba.py +++ b/distributed/protocol/numba.py @@ -6,13 +6,21 @@ @cuda_serialize.register(numba.cuda.devicearray.DeviceNDArray) def serialize_numba_ndarray(x): # Making sure `x` is behaving - if not x.is_c_contiguous(): + if not (x.flags["C_CONTIGUOUS"] or x.flags["F_CONTIGUOUS"]): shape = x.shape t = numba.cuda.device_array(shape, dtype=x.dtype) t.copy_to_device(x) x = t + header = x.__cuda_array_interface__.copy() - return header, [x] + header["strides"] = tuple(x.strides) + frames = [ + numba.cuda.cudadrv.devicearray.DeviceNDArray( + shape=(x.nbytes,), strides=(1,), dtype=np.dtype("u1"), gpu_data=x.gpu_data, + ) + ] + + return header, frames @cuda_deserialize.register(numba.cuda.devicearray.DeviceNDArray) @@ -21,16 +29,10 @@ def deserialize_numba_ndarray(header, frames): shape = header["shape"] strides = header["strides"] - # Starting with __cuda_array_interface__ version 2, strides can be None, - # meaning the array is C-contiguous, so we have to calculate it. - if strides is None: - itemsize = np.dtype(header["typestr"]).itemsize - strides = tuple((np.cumprod((1,) + shape[:0:-1]) * itemsize).tolist()) - arr = numba.cuda.devicearray.DeviceNDArray( - shape, - strides, - np.dtype(header["typestr"]), + shape=shape, + strides=strides, + dtype=np.dtype(header["typestr"]), gpu_data=numba.cuda.as_cuda_array(frame).gpu_data, ) return arr diff --git a/distributed/protocol/tests/test_cupy.py b/distributed/protocol/tests/test_cupy.py index d2965d3af3f..4b3ea27cc9c 100644 --- a/distributed/protocol/tests/test_cupy.py +++ b/distributed/protocol/tests/test_cupy.py @@ -3,12 +3,15 @@ import pytest cupy = pytest.importorskip("cupy") +numpy = pytest.importorskip("numpy") -@pytest.mark.parametrize("size", [0, 10]) +@pytest.mark.parametrize("shape", [(0,), (5,), (4, 6), (10, 11), (2, 3, 5)]) @pytest.mark.parametrize("dtype", ["u1", "u4", "u8", "f4"]) -def test_serialize_cupy(size, dtype): - x = cupy.arange(size, dtype=dtype) +@pytest.mark.parametrize("order", ["C", "F"]) +def test_serialize_cupy(shape, dtype, order): + x = cupy.arange(numpy.product(shape), dtype=dtype) + x = cupy.ndarray(shape, dtype=x.dtype, memptr=x.data, order=order) header, frames = serialize(x, serializers=("cuda", "dask", "pickle")) y = deserialize(header, frames, deserializers=("cuda", "dask", "pickle", "error")) diff --git a/distributed/protocol/tests/test_numba.py b/distributed/protocol/tests/test_numba.py index 21b722fb1b0..4f3a9e7116e 100644 --- a/distributed/protocol/tests/test_numba.py +++ b/distributed/protocol/tests/test_numba.py @@ -6,12 +6,15 @@ np = pytest.importorskip("numpy") +@pytest.mark.parametrize("shape", [(0,), (5,), (4, 6), (10, 11), (2, 3, 5)]) @pytest.mark.parametrize("dtype", ["u1", "u4", "u8", "f4"]) -def test_serialize_numba(dtype): +@pytest.mark.parametrize("order", ["C", "F"]) +def test_serialize_numba(shape, dtype, order): if not cuda.is_available(): pytest.skip("CUDA is not available") - ary = np.arange(100, dtype=dtype) + ary = np.arange(np.product(shape), dtype=dtype) + ary = np.ndarray(shape, dtype=ary.dtype, buffer=ary.data, order=order) x = cuda.to_device(ary) header, frames = serialize(x, serializers=("cuda", "dask", "pickle")) y = deserialize(header, frames, deserializers=("cuda", "dask", "pickle", "error")) From cc7ecdf2abb76a7ed21d0cb4c9a9c92559f638c4 Mon Sep 17 00:00:00 2001 From: Brett Naul Date: Sun, 16 Feb 2020 17:32:13 -0800 Subject: [PATCH 0687/1550] Check exact equality for worker state (#3483) --- distributed/scheduler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index e6e6adf7ced..cba399318cc 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -276,10 +276,10 @@ def __init__( self.extra = extra or {} def __hash__(self): - return hash((self.name, self.host)) + return hash(self.address) def __eq__(self, other): - return type(self) == type(other) and hash(self) == hash(other) + return type(self) == type(other) and self.address == other.address @property def host(self): From e11674a987f98571f92f11fe74cc4588adab4af8 Mon Sep 17 00:00:00 2001 From: jakirkham Date: Tue, 18 Feb 2020 14:11:50 -0800 Subject: [PATCH 0688/1550] Register cuML serializers (#3485) --- distributed/protocol/__init__.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/distributed/protocol/__init__.py b/distributed/protocol/__init__.py index 6830f375e35..b82461ef054 100644 --- a/distributed/protocol/__init__.py +++ b/distributed/protocol/__init__.py @@ -92,3 +92,11 @@ def _register_rmm(): @dask_deserialize.register_lazy("cudf") def _register_cudf(): from cudf.comm import serialize + + +@cuda_serialize.register_lazy("cuml") +@cuda_deserialize.register_lazy("cuml") +@dask_serialize.register_lazy("cuml") +@dask_deserialize.register_lazy("cuml") +def _register_cuml(): + from cuml.comm import serialize From a04a6321a46870f8432da3ee3789bc7cd7c27bbb Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Tue, 18 Feb 2020 16:47:34 -0600 Subject: [PATCH 0689/1550] Msgpack 1.0 compatibility (#3494) --- continuous_integration/environment.yml | 2 +- continuous_integration/travis/install.sh | 3 ++- distributed/protocol/tests/test_numpy.py | 2 +- distributed/protocol/tests/test_protocol.py | 4 +++- distributed/protocol/utils.py | 1 + requirements.txt | 2 +- 6 files changed, 9 insertions(+), 5 deletions(-) diff --git a/continuous_integration/environment.yml b/continuous_integration/environment.yml index f69d919e879..8f8e425dcab 100644 --- a/continuous_integration/environment.yml +++ b/continuous_integration/environment.yml @@ -14,7 +14,7 @@ dependencies: - ipywidgets - joblib - jupyter_client - - msgpack-python + - msgpack-python>=0.6.0 - prometheus_client - psutil - pytest diff --git a/continuous_integration/travis/install.sh b/continuous_integration/travis/install.sh index b631ac3bc6c..68b842aa033 100644 --- a/continuous_integration/travis/install.sh +++ b/continuous_integration/travis/install.sh @@ -40,6 +40,7 @@ conda install -c conda-forge -q \ ipywidgets \ joblib \ jupyter_client \ + msgpack-python>=0.6.0 \ netcdf4 \ paramiko \ prometheus_client \ @@ -71,7 +72,7 @@ python -m pip install -q git+https://github.com/joblib/joblib.git --upgrade --no python -m pip install -q git+https://github.com/intake/filesystem_spec.git --upgrade --no-deps python -m pip install -q git+https://github.com/dask/s3fs.git --upgrade --no-deps python -m pip install -q git+https://github.com/dask/zict.git --upgrade --no-deps -python -m pip install -q sortedcollections msgpack --no-deps +python -m pip install -q sortedcollections --no-deps python -m pip install -q keras --upgrade --no-deps if [[ $CRICK == true ]]; then diff --git a/distributed/protocol/tests/test_numpy.py b/distributed/protocol/tests/test_numpy.py index e6dfd9764e2..99a298d9694 100644 --- a/distributed/protocol/tests/test_numpy.py +++ b/distributed/protocol/tests/test_numpy.py @@ -196,7 +196,7 @@ def test_compress_numpy(): frames = dumps({"x": to_serialize(x)}) assert sum(map(nbytes, frames)) < x.nbytes - header = msgpack.loads(frames[2], raw=False, use_list=False) + header = msgpack.loads(frames[2], raw=False, use_list=False, strict_map_key=False) try: import blosc # noqa: F401 except ImportError: diff --git a/distributed/protocol/tests/test_protocol.py b/distributed/protocol/tests/test_protocol.py index bf16aecf2f4..d3536933a96 100644 --- a/distributed/protocol/tests/test_protocol.py +++ b/distributed/protocol/tests/test_protocol.py @@ -166,7 +166,9 @@ def test_loads_without_deserialization_avoids_compression(): def eq_frames(a, b): if b"headers" in a: - return msgpack.loads(a, use_list=False) == msgpack.loads(b, use_list=False) + return msgpack.loads(a, use_list=False, strict_map_key=False) == msgpack.loads( + b, use_list=False, strict_map_key=False + ) else: return a == b diff --git a/distributed/protocol/utils.py b/distributed/protocol/utils.py index 68de0bebd32..e5b9247e77f 100644 --- a/distributed/protocol/utils.py +++ b/distributed/protocol/utils.py @@ -9,6 +9,7 @@ msgpack_opts = { ("max_%s_len" % x): 2 ** 31 - 1 for x in ["str", "bin", "array", "map", "ext"] } +msgpack_opts["strict_map_key"] = False try: msgpack.loads(msgpack.dumps(""), raw=False, **msgpack_opts) diff --git a/requirements.txt b/requirements.txt index 49b4d21940b..3f827e250e6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ click >= 6.6 cloudpickle >= 0.2.2 dask >= 2.9.0 -msgpack +msgpack >= 0.6.0 psutil >= 5.0 sortedcontainers !=2.0.0, !=2.0.1 tblib >= 1.6.0 From 54657393afb49ef3c6cb41ab739f0d2b0d22621f Mon Sep 17 00:00:00 2001 From: jakirkham Date: Tue, 18 Feb 2020 19:54:58 -0800 Subject: [PATCH 0690/1550] Suppress cuML `ImportError` (#3499) * Suppress cuML `ImportError` If cuML is present, but `cuml.comm` does not yet exist, make sure to suppress the `ImportError`. After all there is nothing to do here in this case and we don't want to raise unnecessary errors. * Use `ignoring` instead of `suppress` --- distributed/protocol/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/distributed/protocol/__init__.py b/distributed/protocol/__init__.py index b82461ef054..c83d33c4868 100644 --- a/distributed/protocol/__init__.py +++ b/distributed/protocol/__init__.py @@ -99,4 +99,5 @@ def _register_cudf(): @dask_serialize.register_lazy("cuml") @dask_deserialize.register_lazy("cuml") def _register_cuml(): - from cuml.comm import serialize + with ignoring(ImportError): + from cuml.comm import serialize From b5e95ed77d5df7b07f43f02fc39881d483f061a4 Mon Sep 17 00:00:00 2001 From: jakirkham Date: Tue, 18 Feb 2020 19:56:36 -0800 Subject: [PATCH 0691/1550] Add dask serialization of CUDA objects (#3482) * Run `isort` on CUDA protocol `import`s * Align CuPy serialize/deserialize function names * Prefix CUDA serializers with `cuda_` This should make room for Dask serializers to also be specified and added. * Add Dask serializers for RMM `DeviceBuffer`s To make TCP a bit more performant with RMM, provide Dask serializers to allow going to and from host memory. * Add Dask serializers for Numba `DeviceNDArray`s * Add Dask serializers for CuPy `ndarray`s * Parametrize serializers in CUDA object tests To make sure that different CUDA objects can use different serialization protocols, test with each one individual and ensure it completes. In particular test both "cuda" and "dask". Where supported also test "pickle", but skip it when it is not (like with Numba). * Check frames are the expected type To make sure Dask can handle transmission of the frames serialized, test they match the type expected by the protocol used. With "cuda" ensure we get something that supports `__cuda_array_interface__`. With "dask" make sure we get a `memoryview`. --- distributed/protocol/__init__.py | 6 +++++ distributed/protocol/cupy.py | 27 ++++++++++++++++--- distributed/protocol/numba.py | 33 +++++++++++++++++++++--- distributed/protocol/rmm.py | 28 +++++++++++++++++--- distributed/protocol/tests/test_cupy.py | 12 ++++++--- distributed/protocol/tests/test_numba.py | 12 ++++++--- distributed/protocol/tests/test_rmm.py | 12 ++++++--- 7 files changed, 111 insertions(+), 19 deletions(-) diff --git a/distributed/protocol/__init__.py b/distributed/protocol/__init__.py index c83d33c4868..212051427f5 100644 --- a/distributed/protocol/__init__.py +++ b/distributed/protocol/__init__.py @@ -70,18 +70,24 @@ def _register_torch(): @cuda_serialize.register_lazy("cupy") @cuda_deserialize.register_lazy("cupy") +@dask_serialize.register_lazy("cupy") +@dask_deserialize.register_lazy("cupy") def _register_cupy(): from . import cupy @cuda_serialize.register_lazy("numba") @cuda_deserialize.register_lazy("numba") +@dask_serialize.register_lazy("numba") +@dask_deserialize.register_lazy("numba") def _register_numba(): from . import numba @cuda_serialize.register_lazy("rmm") @cuda_deserialize.register_lazy("rmm") +@dask_serialize.register_lazy("rmm") +@dask_deserialize.register_lazy("rmm") def _register_rmm(): from . import rmm diff --git a/distributed/protocol/cupy.py b/distributed/protocol/cupy.py index 3ba5ca51597..40bf6efda4f 100644 --- a/distributed/protocol/cupy.py +++ b/distributed/protocol/cupy.py @@ -2,7 +2,14 @@ Efficient serialization GPU arrays. """ import cupy -from .cuda import cuda_serialize, cuda_deserialize + +from .cuda import cuda_deserialize, cuda_serialize +from .serialize import dask_deserialize, dask_serialize + +try: + from .rmm import dask_deserialize_rmm_device_buffer as dask_deserialize_cuda_buffer +except ImportError: + from .numba import dask_deserialize_numba_array as dask_deserialize_cuda_buffer class PatchedCudaArrayInterface: @@ -31,7 +38,7 @@ def __del__(self): @cuda_serialize.register(cupy.ndarray) -def serialize_cupy_ndarray(x): +def cuda_serialize_cupy_ndarray(x): # Making sure `x` is behaving if not (x.flags["C_CONTIGUOUS"] or x.flags["F_CONTIGUOUS"]): x = cupy.array(x, copy=True) @@ -48,7 +55,7 @@ def serialize_cupy_ndarray(x): @cuda_deserialize.register(cupy.ndarray) -def deserialize_cupy_array(header, frames): +def cuda_deserialize_cupy_ndarray(header, frames): (frame,) = frames if not isinstance(frame, cupy.ndarray): frame = PatchedCudaArrayInterface(frame) @@ -59,3 +66,17 @@ def deserialize_cupy_array(header, frames): strides=header["strides"], ) return arr + + +@dask_serialize.register(cupy.ndarray) +def dask_serialize_cupy_ndarray(x): + header, frames = cuda_serialize_cupy_ndarray(x) + frames = [memoryview(cupy.asnumpy(f)) for f in frames] + return header, frames + + +@dask_deserialize.register(cupy.ndarray) +def dask_deserialize_cupy_ndarray(header, frames): + frames = [dask_deserialize_cuda_buffer(header, frames)] + arr = cuda_deserialize_cupy_ndarray(header, frames) + return arr diff --git a/distributed/protocol/numba.py b/distributed/protocol/numba.py index 3d2b4879c3b..1070c080e61 100644 --- a/distributed/protocol/numba.py +++ b/distributed/protocol/numba.py @@ -1,10 +1,17 @@ -import numpy as np import numba.cuda -from .cuda import cuda_serialize, cuda_deserialize +import numpy as np + +from .cuda import cuda_deserialize, cuda_serialize +from .serialize import dask_deserialize, dask_serialize + +try: + from .rmm import dask_deserialize_rmm_device_buffer +except ImportError: + dask_deserialize_rmm_device_buffer = None @cuda_serialize.register(numba.cuda.devicearray.DeviceNDArray) -def serialize_numba_ndarray(x): +def cuda_serialize_numba_ndarray(x): # Making sure `x` is behaving if not (x.flags["C_CONTIGUOUS"] or x.flags["F_CONTIGUOUS"]): shape = x.shape @@ -24,7 +31,7 @@ def serialize_numba_ndarray(x): @cuda_deserialize.register(numba.cuda.devicearray.DeviceNDArray) -def deserialize_numba_ndarray(header, frames): +def cuda_deserialize_numba_ndarray(header, frames): (frame,) = frames shape = header["shape"] strides = header["strides"] @@ -36,3 +43,21 @@ def deserialize_numba_ndarray(header, frames): gpu_data=numba.cuda.as_cuda_array(frame).gpu_data, ) return arr + + +@dask_serialize.register(numba.cuda.devicearray.DeviceNDArray) +def dask_serialize_numba_ndarray(x): + header, frames = cuda_serialize_numba_ndarray(x) + frames = [memoryview(f.copy_to_host()) for f in frames] + return header, frames + + +@dask_deserialize.register(numba.cuda.devicearray.DeviceNDArray) +def dask_deserialize_numba_array(header, frames): + if dask_deserialize_rmm_device_buffer: + frames = [dask_deserialize_rmm_device_buffer(header, frames)] + else: + frames = [numba.cuda.to_device(np.asarray(memoryview(f))) for f in frames] + + arr = cuda_deserialize_numba_ndarray(header, frames) + return arr diff --git a/distributed/protocol/rmm.py b/distributed/protocol/rmm.py index cdf22f8218f..ae2db0d528b 100644 --- a/distributed/protocol/rmm.py +++ b/distributed/protocol/rmm.py @@ -1,18 +1,22 @@ +import numba +import numba.cuda +import numpy import rmm -from .cuda import cuda_serialize, cuda_deserialize +from .cuda import cuda_deserialize, cuda_serialize +from .serialize import dask_deserialize, dask_serialize # Used for RMM 0.11.0+ otherwise Numba serializers used if hasattr(rmm, "DeviceBuffer"): @cuda_serialize.register(rmm.DeviceBuffer) - def serialize_rmm_device_buffer(x): + def cuda_serialize_rmm_device_buffer(x): header = x.__cuda_array_interface__.copy() frames = [x] return header, frames @cuda_deserialize.register(rmm.DeviceBuffer) - def deserialize_rmm_device_buffer(header, frames): + def cuda_deserialize_rmm_device_buffer(header, frames): (arr,) = frames # We should already have `DeviceBuffer` @@ -21,3 +25,21 @@ def deserialize_rmm_device_buffer(header, frames): assert isinstance(arr, rmm.DeviceBuffer) return arr + + @dask_serialize.register(rmm.DeviceBuffer) + def dask_serialize_rmm_device_buffer(x): + header = x.__cuda_array_interface__.copy() + frames = [numba.cuda.as_cuda_array(x).copy_to_host().data] + return header, frames + + @dask_deserialize.register(rmm.DeviceBuffer) + def dask_deserialize_rmm_device_buffer(header, frames): + (frame,) = frames + + arr = numpy.asarray(memoryview(frame)) + ptr = arr.__array_interface__["data"][0] + size = arr.nbytes + + buf = rmm.DeviceBuffer(ptr=ptr, size=size) + + return buf diff --git a/distributed/protocol/tests/test_cupy.py b/distributed/protocol/tests/test_cupy.py index 4b3ea27cc9c..5470266fce5 100644 --- a/distributed/protocol/tests/test_cupy.py +++ b/distributed/protocol/tests/test_cupy.py @@ -9,11 +9,17 @@ @pytest.mark.parametrize("shape", [(0,), (5,), (4, 6), (10, 11), (2, 3, 5)]) @pytest.mark.parametrize("dtype", ["u1", "u4", "u8", "f4"]) @pytest.mark.parametrize("order", ["C", "F"]) -def test_serialize_cupy(shape, dtype, order): +@pytest.mark.parametrize("serializers", [("cuda",), ("dask",), ("pickle",)]) +def test_serialize_cupy(shape, dtype, order, serializers): x = cupy.arange(numpy.product(shape), dtype=dtype) x = cupy.ndarray(shape, dtype=x.dtype, memptr=x.data, order=order) - header, frames = serialize(x, serializers=("cuda", "dask", "pickle")) - y = deserialize(header, frames, deserializers=("cuda", "dask", "pickle", "error")) + header, frames = serialize(x, serializers=serializers) + y = deserialize(header, frames, deserializers=serializers) + + if serializers[0] == "cuda": + assert all(hasattr(f, "__cuda_array_interface__") for f in frames) + elif serializers[0] == "dask": + assert all(isinstance(f, memoryview) for f in frames) assert (x == y).all() diff --git a/distributed/protocol/tests/test_numba.py b/distributed/protocol/tests/test_numba.py index 4f3a9e7116e..61213640715 100644 --- a/distributed/protocol/tests/test_numba.py +++ b/distributed/protocol/tests/test_numba.py @@ -9,15 +9,21 @@ @pytest.mark.parametrize("shape", [(0,), (5,), (4, 6), (10, 11), (2, 3, 5)]) @pytest.mark.parametrize("dtype", ["u1", "u4", "u8", "f4"]) @pytest.mark.parametrize("order", ["C", "F"]) -def test_serialize_numba(shape, dtype, order): +@pytest.mark.parametrize("serializers", [("cuda",), ("dask",)]) +def test_serialize_numba(shape, dtype, order, serializers): if not cuda.is_available(): pytest.skip("CUDA is not available") ary = np.arange(np.product(shape), dtype=dtype) ary = np.ndarray(shape, dtype=ary.dtype, buffer=ary.data, order=order) x = cuda.to_device(ary) - header, frames = serialize(x, serializers=("cuda", "dask", "pickle")) - y = deserialize(header, frames, deserializers=("cuda", "dask", "pickle", "error")) + header, frames = serialize(x, serializers=serializers) + y = deserialize(header, frames, deserializers=serializers) + + if serializers[0] == "cuda": + assert all(hasattr(f, "__cuda_array_interface__") for f in frames) + elif serializers[0] == "dask": + assert all(isinstance(f, memoryview) for f in frames) hx = np.empty_like(ary) hy = np.empty_like(ary) diff --git a/distributed/protocol/tests/test_rmm.py b/distributed/protocol/tests/test_rmm.py index eff3325289e..8176f4d22f7 100644 --- a/distributed/protocol/tests/test_rmm.py +++ b/distributed/protocol/tests/test_rmm.py @@ -7,7 +7,8 @@ @pytest.mark.parametrize("size", [0, 3, 10]) -def test_serialize_rmm_device_buffer(size): +@pytest.mark.parametrize("serializers", [("cuda",), ("dask",), ("pickle",)]) +def test_serialize_rmm_device_buffer(size, serializers): if not hasattr(rmm, "DeviceBuffer"): pytest.skip("RMM pre-0.11.0 does not have DeviceBuffer") @@ -15,8 +16,13 @@ def test_serialize_rmm_device_buffer(size): x = rmm.DeviceBuffer(size=size) cuda.to_device(x_np, to=cuda.as_cuda_array(x)) - header, frames = serialize(x, serializers=("cuda", "dask", "pickle")) - y = deserialize(header, frames, deserializers=("cuda", "dask", "pickle", "error")) + header, frames = serialize(x, serializers=serializers) + y = deserialize(header, frames, deserializers=serializers) y_np = y.copy_to_host() + if serializers[0] == "cuda": + assert all(hasattr(f, "__cuda_array_interface__") for f in frames) + elif serializers[0] == "dask": + assert all(isinstance(f, memoryview) for f in frames) + assert (x_np == y_np).all() From 2a05299a934a2557b985dee93da1e0eff8689178 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Wed, 19 Feb 2020 11:45:15 -0600 Subject: [PATCH 0692/1550] bump version to 2.11.0 --- docs/source/changelog.rst | 58 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index cd311309a30..498668d3c88 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -1,6 +1,55 @@ Changelog ========= +2.11.0 - 2020-02-19 +------------------- + +- Add dask serialization of CUDA objects (:pr:`3482`) `John Kirkham`_ +- Suppress cuML ``ImportError`` (:pr:`3499`) `John Kirkham`_ +- Msgpack 1.0 compatibility (:pr:`3494`) `James Bourbeau`_ +- Register cuML serializers (:pr:`3485`) `John Kirkham`_ +- Check exact equality for worker state (:pr:`3483`) `Brett Naul`_ +- Serialize 1-D, contiguous, ``uint8`` CUDA frames (:pr:`3475`) `John Kirkham`_ +- Update NumPy array serialization to handle non-contiguous slices (:pr:`3474`) `James Bourbeau`_ +- Propose fix for collection based resources docs (:pr:`3480`) `Chris Roat`_ +- Remove ``--verbose`` flag from CI runs (:pr:`3484`) `Matthew Rocklin`_ +- Do not duplicate messages in scheduler report (:pr:`3477`) `Jakub Beránek`_ +- Register Dask cuDF serializers (:pr:`3478`) `John Kirkham`_ +- Add support for Python 3.8 (:pr:`3249`) `James Bourbeau`_ +- Add last seen column to worker table and highlight errant workers (:pr:`3468`) `kaelgreco`_ +- Change default value of ``local_directory`` from empty string to ``None`` (:pr:`3441`) `condoratberlin`_ +- Clear old docs (:pr:`3458`) `Matthew Rocklin`_ +- Change default multiprocessing behavior to spawn (:pr:`3461`) `Matthew Rocklin`_ +- Split dashboard host on additional slashes to handle inproc (:pr:`3466`) `Jacob Tomlinson`_ +- Update ``locality.rst`` (:pr:`3470`) `Dustin Tindall`_ +- Minor ``gen.Return`` cleanup (:pr:`3469`) `James Bourbeau`_ +- Update comparison logic for worker state (:pr:`3321`) `rockwellw`_ +- Update minimum ``tblib`` version to 1.6.0 (:pr:`3451`) `James Bourbeau`_ +- Add total row to workers plot in dashboard (:pr:`3464`) `Julia Signell`_ +- Workaround ``RecursionError`` on profile data (:pr:`3455`) `Tom Augspurger`_ +- Include code and summary in performance report (:pr:`3462`) `Matthew Rocklin`_ +- Skip ``test_open_close_many_workers`` on Python 3.6 (:pr:`3459`) `Matthew Rocklin`_ +- Support serializing/deserializing ``rmm.DeviceBuffer`` s (:pr:`3442`) `John Kirkham`_ +- Always add new ``TaskGroup`` to ``TaskPrefix`` (:pr:`3322`) `James Bourbeau`_ +- Rerun ``black`` on the code base (:pr:`3444`) `John Kirkham`_ +- Ensure ``__causes__`` s of exceptions raised on workers are serialized (:pr:`3430`) `Alex Adamson`_ +- Adjust ``numba.cuda`` import and add check (:pr:`3446`) `John Kirkham`_ +- Fix name of Numba serialization test (:pr:`3447`) `John Kirkham`_ +- Checks for command parameters in ``ssh2`` (:pr:`3078`) `Peter Andreas Entschev`_ +- Update ``worker_kwargs`` description in ``LocalCluster`` constructor (:pr:`3438`) `James Bourbeau`_ +- Ensure scheduler updates task and worker states after successful worker data deletion (:pr:`3401`) `James Bourbeau`_ +- Avoid ``loop=`` keyword in asyncio coordination primitives (:pr:`3437`) `Matthew Rocklin`_ +- Call pip as a module to avoid warnings (:pr:`3436`) `Cyril Shcherbin`_ +- Add documentation of parameters in coordination primitives (:pr:`3434`) `Søren Fuglede Jørgensen`_ +- Replace ``tornado.locks`` with asyncio for Events/Locks/Conditions/Semaphore (:pr:`3397`) `Matthew Rocklin`_ +- Remove object from class hierarchy (:pr:`3432`) `Anderson Banihirwe`_ +- Add ``dashboard_link`` property to ``Client`` (:pr:`3429`) `Jacob Tomlinson`_ +- Allow memory monitor to evict data more aggressively (:pr:`3424`) `fjetter`_ +- Make ``_get_ip`` return an IP address when defaulting (:pr:`3418`) `Pierre Glaser`_ +- Support version checking with older versions of Dask (:pr:`3390`) `Igor Gotlibovych`_ +- Add Mac OS build to CI (:pr:`3358`) `James Bourbeau`_ + + 2.10.0 - 2020-01-28 ------------------- @@ -1512,3 +1561,12 @@ significantly without many new features. .. _`Markus Mohrhard`: https://github.com/mmohrhard .. _`Mana Borwornpadungkitti`: https://github.com/potpath .. _`Chrysostomos Nanakos`: https://github.com/cnanakos +.. _`Chris Roat`: https://github.com/chrisroat +.. _`Jakub Beránek`: https://github.com/Kobzol +.. _`kaelgreco`: https://github.com/kaelgreco +.. _`Dustin Tindall`: https://github.com/dustindall +.. _`Julia Signell`: https://github.com/jsignell +.. _`Alex Adamson`: https://github.com/aadamson +.. _`Cyril Shcherbin`: https://github.com/shcherbin +.. _`Søren Fuglede Jørgensen`: https://github.com/fuglede +.. _`Igor Gotlibovych`: https://github.com/ig248 From 83f8febd32d99b58b75ecd3da710dd8a25618867 Mon Sep 17 00:00:00 2001 From: Jacob Tomlinson Date: Wed, 19 Feb 2020 20:06:48 +0000 Subject: [PATCH 0693/1550] Stop keep alives when worker reconnecting to the scheduler (#3493) --- distributed/worker.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/distributed/worker.py b/distributed/worker.py index e1ae8317148..cb32be8111e 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -382,7 +382,6 @@ def __init__( self.executed_count = 0 self.long_running = set() - self.batched_stream = None self.recent_messages_log = deque( maxlen=dask.config.get("distributed.comm.recent-messages-log-length") ) @@ -573,6 +572,7 @@ def __init__( self.actor_executor = ThreadPoolExecutor( 1, thread_name_prefix="Dask-Actor-Threads" ) + self.batched_stream = BatchedSend(interval="2ms", loop=self.loop) self.name = name self.scheduler_delay = 0 self.stream_comms = dict() @@ -650,6 +650,13 @@ def __init__( pc = PeriodicCallback(self.heartbeat, 1000, io_loop=self.io_loop) self.periodic_callbacks["heartbeat"] = pc + pc = PeriodicCallback( + lambda: self.batched_stream.send({"op": "keep-alive"}), + 60000, + io_loop=self.io_loop, + ) + self.periodic_callbacks["keep-alive"] = pc + self._address = contact_address if self.memory_limit: @@ -797,6 +804,7 @@ def identity(self, comm=None): ##################### async def _register_with_scheduler(self): + self.periodic_callbacks["keep-alive"].stop() self.periodic_callbacks["heartbeat"].stop() start = time() if self.contact_address is None: @@ -863,15 +871,8 @@ async def _register_with_scheduler(self): logger.info(" Registered to: %26s", self.scheduler.address) logger.info("-" * 49) - self.batched_stream = BatchedSend(interval="2ms", loop=self.loop) self.batched_stream.start(comm) - pc = PeriodicCallback( - lambda: self.batched_stream.send({"op": "keep-alive"}), - 60000, - io_loop=self.io_loop, - ) - self.periodic_callbacks["keep-alive"] = pc - pc.start() + self.periodic_callbacks["keep-alive"].start() self.periodic_callbacks["heartbeat"].start() self.loop.add_callback(self.handle_scheduler, comm) @@ -1112,7 +1113,11 @@ async def close( for k, v in self.services.items(): v.stop() - if self.batched_stream and not self.batched_stream.comm.closed(): + if ( + self.batched_stream + and self.batched_stream.comm + and not self.batched_stream.comm.closed() + ): self.batched_stream.send({"op": "close-stream"}) if self.batched_stream: From 6ea63bdf4bb04ef11fdb06019d84a161681f761b Mon Sep 17 00:00:00 2001 From: Jacob Tomlinson Date: Thu, 20 Feb 2020 15:36:12 +0000 Subject: [PATCH 0694/1550] Rename logs to get_logs (#3473) * Rename get_logs to logs * Update distributed/node.py Co-Authored-By: James Bourbeau * Revert changes * Rename logs to get_logs Co-authored-by: James Bourbeau --- distributed/deploy/cluster.py | 13 +++++++++---- distributed/scheduler.py | 1 + 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/distributed/deploy/cluster.py b/distributed/deploy/cluster.py index c616f13c826..81f3d578fb2 100644 --- a/distributed/deploy/cluster.py +++ b/distributed/deploy/cluster.py @@ -1,6 +1,7 @@ import asyncio import logging import threading +import warnings from dask.utils import format_bytes @@ -159,11 +160,11 @@ def sync(self, func, *args, asynchronous=None, callback_timeout=None, **kwargs): else: return sync(self.loop, func, *args, **kwargs) - async def _logs(self, scheduler=True, workers=True): + async def _get_logs(self, scheduler=True, workers=True): logs = Logs() if scheduler: - L = await self.scheduler_comm.logs() + L = await self.scheduler_comm.get_logs() logs["Scheduler"] = Log("\n".join(line for level, line in L)) if workers: @@ -173,7 +174,7 @@ async def _logs(self, scheduler=True, workers=True): return logs - def logs(self, scheduler=True, workers=True): + def get_logs(self, scheduler=True, workers=True): """ Return logs for the scheduler and workers Parameters @@ -190,7 +191,11 @@ def logs(self, scheduler=True, workers=True): A dictionary of logs, with one item for the scheduler and one for each worker """ - return self.sync(self._logs, scheduler=scheduler, workers=workers) + return self.sync(self._get_logs, scheduler=scheduler, workers=workers) + + def logs(self, *args, **kwargs): + warnings.warn("logs is deprecated, use get_logs instead", DeprecationWarning) + return self.get_logs(*args, **kwargs) @property def dashboard_link(self): diff --git a/distributed/scheduler.py b/distributed/scheduler.py index cba399318cc..c78c4b1b218 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1267,6 +1267,7 @@ def __init__( "call_stack": self.get_call_stack, "profile": self.get_profile, "performance_report": self.performance_report, + "get_logs": self.get_logs, "logs": self.get_logs, "worker_logs": self.get_worker_logs, "nbytes": self.get_nbytes, From 0bed9fe57fa6c0f9416b337aa816024a6bb31acf Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Sat, 22 Feb 2020 13:27:45 +0100 Subject: [PATCH 0695/1550] Remove `import ucp` from the top of `ucx.py` (#3510) This is needed to ensure Dask configurations will be propagated to UCX upon importing. Since `ucx.py` is imported upon `import distributed`, Dask configurations passed to `Nanny(..., config=ucx_config)` won't be read by UCX since it has already been loaded. --- distributed/comm/ucx.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index 629a179e43e..330dde8a2d3 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -5,8 +5,6 @@ .. _UCX: https://github.com/openucx/ucx """ -import ucp - import logging import dask @@ -26,7 +24,9 @@ # In order to avoid double init when forking/spawning new processes (multiprocess), -# we make sure only to import and initialize UCX once at first use. +# we make sure only to import and initialize UCX once at first use. This is also +# required to ensure Dask configuration gets propagated to UCX, which needs +# variables to be set before being imported. ucp = None cuda_array = None From 1868dfe18e50aea20d3242c6955523b57c4f4e50 Mon Sep 17 00:00:00 2001 From: Darren Weber Date: Mon, 24 Feb 2020 08:53:40 -0800 Subject: [PATCH 0696/1550] Revise develop-docs: conda env example (#3406) --- docs/source/develop.rst | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/docs/source/develop.rst b/docs/source/develop.rst index f7a10c64471..68e8385ca5e 100644 --- a/docs/source/develop.rst +++ b/docs/source/develop.rst @@ -22,6 +22,26 @@ and install it from source:: cd distributed python setup.py install +Using conda, for example:: + + git clone git@github.com:{your-fork}/distributed.git + cd distributed + conda create -y -n distributed python=3.6 + conda activate distributed + python -m pip install -U -r requirements.txt + python -m pip install -U -r dev-requirements.txt + python -m pip install -e . + +To keep a fork in sync with the upstream source:: + + cd distributed + git remote add upstream git@github.com:dask/distributed.git + git remote -v + git fetch -a upstream + git checkout master + git pull upstream master + git push origin master + Test ---- From 56fd9b8eaef22896acca01c97bc7061ef2f114fb Mon Sep 17 00:00:00 2001 From: Benjamin Zaitlen Date: Tue, 25 Feb 2020 10:25:56 -0500 Subject: [PATCH 0697/1550] RMM/UCX Config Flags (#3515) * initial setup and test of ucx config flags in dask * add rmm, rework ucx config a bit, add keys in distributed.yaml * move ucx conf to solitary file * rework ucx config ingestion * lint * dgx check * simplify dask.ucx flag scrubbing * move scrub function back to ucx * remove tcp-over-ucx flag and do not assume tls when only using net-devices --- distributed/comm/tests/test_ucx.py | 2 +- distributed/comm/tests/test_ucx_config.py | 84 +++++++++++++++++++++++ distributed/comm/ucx.py | 69 +++++++++++++++++-- distributed/distributed.yaml | 7 +- 4 files changed, 154 insertions(+), 8 deletions(-) create mode 100644 distributed/comm/tests/test_ucx_config.py diff --git a/distributed/comm/tests/test_ucx.py b/distributed/comm/tests/test_ucx.py index 7725bfa2432..ead799f8158 100644 --- a/distributed/comm/tests/test_ucx.py +++ b/distributed/comm/tests/test_ucx.py @@ -10,7 +10,7 @@ from distributed.protocol import to_serialize from distributed.deploy.local import LocalCluster from dask.dataframe.utils import assert_eq -from distributed.utils_test import gen_test, loop, inc, cleanup # noqa: 401 +from distributed.utils_test import gen_test, loop, inc, cleanup, popen # noqa: 401 from .test_comms import check_deserialize diff --git a/distributed/comm/tests/test_ucx_config.py b/distributed/comm/tests/test_ucx_config.py new file mode 100644 index 00000000000..5b45a844093 --- /dev/null +++ b/distributed/comm/tests/test_ucx_config.py @@ -0,0 +1,84 @@ +import pytest +from time import sleep + +import dask +from dask.utils import format_bytes +from distributed import Client +from distributed.utils_test import gen_test, loop, inc, cleanup, popen # noqa: 401 +from distributed.utils import get_ip +from distributed.comm.ucx import _scrub_ucx_config + +try: + HOST = get_ip() +except Exception: + HOST = "127.0.0.1" + +ucp = pytest.importorskip("ucp") +rmm = pytest.importorskip("rmm") + + +@pytest.mark.asyncio +async def test_ucx_config(cleanup): + + ucx = {"nvlink": True, "infiniband": True, "net-devices": ""} + + with dask.config.set(ucx=ucx): + ucx_config = _scrub_ucx_config() + assert ucx_config.get("TLS") == "rc,tcp,sockcm,cuda_copy,cuda_ipc" + assert ucx_config.get("NET_DEVICES") is None + + ucx = {"nvlink": False, "infiniband": True, "net-devices": "mlx5_0:1"} + + with dask.config.set(ucx=ucx): + ucx_config = _scrub_ucx_config() + assert ucx_config.get("TLS") == "rc,tcp,sockcm,cuda_copy" + assert ucx_config.get("NET_DEVICES") == "mlx5_0:1" + + ucx = { + "nvlink": False, + "infiniband": True, + "net-devices": "all", + "MEMTYPE_CACHE": "y", + } + + with dask.config.set(ucx=ucx): + ucx_config = _scrub_ucx_config() + assert ucx_config.get("TLS") == "rc,tcp,sockcm,cuda_copy" + assert ucx_config.get("MEMTYPE_CACHE") == "y" + + +def test_ucx_config_w_env_var(cleanup, loop, monkeypatch): + size = "1000.00 MB" + monkeypatch.setenv("DASK_RMM__POOL_SIZE", size) + + dask.config.refresh() + + port = "13339" + sched_addr = "ucx://%s:%s" % (HOST, port) + + with popen( + ["dask-scheduler", "--no-dashboard", "--protocol", "ucx", "--port", port] + ) as sched: + with popen( + [ + "dask-worker", + sched_addr, + "--no-dashboard", + "--protocol", + "ucx", + "--no-nanny", + ] + ) as w: + with Client(sched_addr, loop=loop, timeout=10) as c: + while not c.scheduler_info()["workers"]: + sleep(0.1) + + # configured with 1G pool + rmm_usage = c.run_on_scheduler(rmm.get_info) + assert size == format_bytes(rmm_usage.free) + + # configured with 1G pool + worker_addr = list(c.scheduler_info()["workers"])[0] + worker_rmm_usage = c.run(rmm.get_info) + rmm_usage = worker_rmm_usage[worker_addr] + assert size == format_bytes(rmm_usage.free) diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index 330dde8a2d3..5ea2d16ec45 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -14,11 +14,15 @@ from .core import Comm, Connector, Listener, CommClosedError from .registry import Backend, backends from .utils import ensure_concrete_host, to_frames, from_frames -from ..utils import ensure_ip, get_ip, get_ipv6, nbytes, log_errors, CancelledError - -import dask -import numpy as np - +from ..utils import ( + ensure_ip, + get_ip, + get_ipv6, + nbytes, + log_errors, + CancelledError, + parse_bytes, +) logger = logging.getLogger(__name__) @@ -39,7 +43,11 @@ def init_once(): import ucp as _ucp ucp = _ucp - ucp.init(options=dask.config.get("ucx"), env_takes_precedence=True) + + # remove/process dask.ucx flags for valid ucx options + ucx_config = _scrub_ucx_config() + + ucp.init(options=ucx_config, env_takes_precedence=True) # Find the function, `cuda_array()`, to use when allocating new CUDA arrays try: @@ -61,6 +69,13 @@ def cuda_array(n): "In order to send/recv CUDA arrays, Numba or RMM is required" ) + pool_size_str = dask.config.get("rmm.pool-size") + if pool_size_str is not None: + pool_size = parse_bytes(pool_size_str) + rmm.reinitialize( + pool_allocator=True, managed_memory=False, initial_pool_size=pool_size + ) + class UCX(Comm): """Comm object using UCP. @@ -328,3 +343,45 @@ def get_local_address_for(self, loc): backends["ucx"] = UCXBackend() + + +def _scrub_ucx_config(): + """Function to scrub dask config options for valid UCX config options""" + + # configuration of UCX can happen in two ways: + # 1) high level on/off flags which correspond to UCX configuration + # 2) explicity defined UCX configuration flags + + # import does not initialize ucp -- this will occur outside this function + from ucp import get_config + + options = {} + + # if any of the high level flags are set, as long as they are not Null/None, + # we assume we should configure basic TLS settings for UCX + if any([dask.config.get("ucx.nvlink"), dask.config.get("ucx.infiniband")]): + tls = "tcp,sockcm,cuda_copy" + tls_priority = "sockcm" + + if dask.config.get("ucx.infiniband"): + tls = "rc," + tls + if dask.config.get("ucx.nvlink"): + tls = tls + ",cuda_ipc" + + options = {"TLS": tls, "SOCKADDR_TLS_PRIORITY": tls_priority} + + net_devices = dask.config.get("ucx.net-devices") + if net_devices is not None and net_devices != "": + options["NET_DEVICES"] = net_devices + + # ANY UCX options defined in config will overwrite high level dask.ucx flags + valid_ucx_keys = list(get_config().keys()) + for k, v in dask.config.get("ucx").items(): + if k in valid_ucx_keys: + options[k] = v + else: + logger.debug( + "Key: %s with value: %s not a valid UCX configuration option" % (k, v) + ) + + return options diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index 487e72e215e..417b7bd5be8 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -132,5 +132,10 @@ distributed: log-length: 10000 # default length of logs to keep in memory log-format: '%(name)s - %(levelname)s - %(message)s' pdb-on-err: False # enter debug mode on scheduling error +rmm: + pool-size: null +ucx: + nvlink: null + infiniband: null + net-devices: null -ucx: {} From fc3e8d68e9e867e4cdd0d90cef86ca46f758f48e Mon Sep 17 00:00:00 2001 From: Lucas Rademaker <44430780+lr4d@users.noreply.github.com> Date: Thu, 27 Feb 2020 22:59:08 +0100 Subject: [PATCH 0698/1550] make work stealing callback time configurable (#3523) --- distributed/distributed.yaml | 1 + distributed/stealing.py | 11 +++++++++-- distributed/tests/test_steal.py | 18 ++++++++++++++++++ 3 files changed, 28 insertions(+), 2 deletions(-) diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index 417b7bd5be8..ed21507e041 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -19,6 +19,7 @@ distributed: idle-timeout: null # Shut down after this duration, like "1h" or "30 minutes" transition-log-length: 100000 work-stealing: True # workers should steal tasks from each other + work-stealing-interval: 100ms # Callback time for work stealing worker-ttl: null # like '60s'. Time to live for workers. They must heartbeat faster than this pickle: True # Is the scheduler allowed to deserialize arbitrary bytestrings preload: [] diff --git a/distributed/stealing.py b/distributed/stealing.py index e3537f647bf..b14a2a8de6d 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -6,7 +6,7 @@ import dask from .core import CommClosedError from .diagnostics.plugin import SchedulerPlugin -from .utils import log_errors, PeriodicCallback +from .utils import log_errors, parse_timedelta, PeriodicCallback try: from cytoolz import topk @@ -40,8 +40,15 @@ def __init__(self, scheduler): for worker in scheduler.workers: self.add_worker(worker=worker) + # `callback_time` is in milliseconds + callback_time = 1000 * parse_timedelta( + dask.config.get("distributed.scheduler.work-stealing-interval"), + default="ms", + ) pc = PeriodicCallback( - callback=self.balance, callback_time=100, io_loop=self.scheduler.loop + callback=self.balance, + callback_time=callback_time, + io_loop=self.scheduler.loop, ) self._pc = pc self.scheduler.periodic_callbacks["stealing"] = pc diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index a6a19332f5f..b017bff4371 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -9,6 +9,7 @@ from toolz import sliding_window, concat from tornado import gen +import dask from distributed import Nanny, Worker, wait, worker_client from distributed.config import config from distributed.metrics import time @@ -676,3 +677,20 @@ def test_lose_task(c, s, a, b): out = log.getvalue() assert "Error" not in out + + +@gen_cluster(client=True) +def test_worker_stealing_interval(c, s, a, b): + from distributed.scheduler import WorkStealing + + ws = WorkStealing(s) + assert ws._pc.callback_time == 100 + + with dask.config.set({"distributed.scheduler.work-stealing-interval": "500ms"}): + ws = WorkStealing(s) + assert ws._pc.callback_time == 500 + + # Default unit is `ms` + with dask.config.set({"distributed.scheduler.work-stealing-interval": 2}): + ws = WorkStealing(s) + assert ws._pc.callback_time == 2 From 0d7a31adaabd801a189fa529c6b7670fe98395b1 Mon Sep 17 00:00:00 2001 From: Davis Bennett Date: Fri, 28 Feb 2020 11:15:44 -0500 Subject: [PATCH 0699/1550] fix typo in docstring (#3528) --- distributed/deploy/adaptive_core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/deploy/adaptive_core.py b/distributed/deploy/adaptive_core.py index 44a708aca38..dfd82ea33ba 100644 --- a/distributed/deploy/adaptive_core.py +++ b/distributed/deploy/adaptive_core.py @@ -13,7 +13,7 @@ class AdaptiveCore: The core logic for adaptive deployments, with none of the cluster details This class controls our adaptive scaling behavior. It is intended to be - sued as a super-class or mixin. It expects the following state and methods: + used as a super-class or mixin. It expects the following state and methods: **State** From 3b915a2adeddaf991590fd0192178836a7594fcf Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sat, 29 Feb 2020 18:09:47 -0500 Subject: [PATCH 0700/1550] Add try-except around getting source code in performance report (#3505) See https://github.com/dask/distributed/issues/1674#issuecomment-589028369 Co-authored-by: James Bourbeau --- distributed/client.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index ce820d2c6e6..4e84ea278d2 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -4597,8 +4597,11 @@ async def __aenter__(self): async def __aexit__(self, typ, value, traceback, code=None): if not code: - frame = inspect.currentframe().f_back - code = inspect.getsource(frame) + try: + frame = inspect.currentframe().f_back + code = inspect.getsource(frame) + except Exception: + code = "" data = await get_client().scheduler.performance_report( start=self.start, code=code ) @@ -4609,8 +4612,11 @@ def __enter__(self): get_client().sync(self.__aenter__) def __exit__(self, typ, value, traceback): - frame = inspect.currentframe().f_back - code = inspect.getsource(frame) + try: + frame = inspect.currentframe().f_back + code = inspect.getsource(frame) + except Exception: + code = "" get_client().sync(self.__aexit__, type, value, traceback, code=code) From f49b4b0166699755bcca47fd4c86573a7d4ce72f Mon Sep 17 00:00:00 2001 From: jakirkham Date: Sun, 1 Mar 2020 07:29:01 -0800 Subject: [PATCH 0701/1550] Use 'temporary-directory' from dask.config for Nanny's directory (#3531) Make sure to respect the `temporary-directory` config value in `Nanny` when determining an appropriate temporary directory for things like spilling. --- distributed/nanny.py | 8 +++++++- distributed/tests/test_nanny.py | 10 ++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/distributed/nanny.py b/distributed/nanny.py index ff653ba096c..676291da3f8 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -66,7 +66,7 @@ def __init__( ncores=None, loop=None, local_dir=None, - local_directory="dask-worker-space", + local_directory=None, services=None, name=None, memory_limit="auto", @@ -150,6 +150,12 @@ def __init__( warnings.warn("The local_dir keyword has moved to local_directory") local_directory = local_dir + if local_directory is None: + local_directory = dask.config.get("temporary-directory") or os.getcwd() + if not os.path.exists(local_directory): + os.mkdir(local_directory) + local_directory = os.path.join(local_directory, "dask-worker-space") + self.local_directory = local_directory self.services = services diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index 0091a6126f1..c80974d9970 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -402,6 +402,16 @@ def test_data_types(c, s): yield w.close() +@gen_cluster(nthreads=[]) +def test_local_directory(s): + with tmpfile() as fn: + with dask.config.set(temporary_directory=fn): + w = yield Nanny(s.address) + assert w.local_directory.startswith(fn) + assert "dask-worker-space" in w.local_directory + yield w.close() + + def _noop(x): """Define here because closures aren't pickleable.""" pass From 953314f64d780f68848b42f6478e343129adef11 Mon Sep 17 00:00:00 2001 From: jakirkham Date: Sun, 1 Mar 2020 13:21:12 -0800 Subject: [PATCH 0702/1550] Mark `bool` as MessagePack serializable (#3535) As `bool` values can be serialized by MessagePack (see code below), mark them as such in `_is_msgpack_serializable`. ```python In [1]: import msgpack In [2]: (msgpack.dumps(False), msgpack.dumps(True)) Out[2]: (b'\xc2', b'\xc3') In [3]: (msgpack.loads(b'\xc2'), msgpack.loads(b'\xc3')) Out[3]: (False, True) ``` --- distributed/protocol/serialize.py | 1 + 1 file changed, 1 insertion(+) diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index ddab6130765..b9c6c33d318 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -548,6 +548,7 @@ def _is_msgpack_serializable(v): typ = type(v) return ( typ is str + or typ is bool or typ is int or typ is float or isinstance(v, dict) From 5fd58327b0cb23c401178f8388548b4ebd74c93f Mon Sep 17 00:00:00 2001 From: jakirkham Date: Sun, 1 Mar 2020 18:19:21 -0800 Subject: [PATCH 0703/1550] Mark `None` as MessagePack serializable (#3537) --- distributed/protocol/serialize.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index b9c6c33d318..3f3207ab58f 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -547,7 +547,8 @@ def _deserialize_bytes(header, frames): def _is_msgpack_serializable(v): typ = type(v) return ( - typ is str + v is None + or typ is str or typ is bool or typ is int or typ is float From b9936bfe01d87f89cd60d2ea7abf21749f031781 Mon Sep 17 00:00:00 2001 From: jakirkham Date: Mon, 2 Mar 2020 13:41:39 -0800 Subject: [PATCH 0704/1550] Use `makedirs` when constructing `local_directory` (#3538) --- distributed/nanny.py | 2 +- distributed/worker.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/distributed/nanny.py b/distributed/nanny.py index 676291da3f8..ec5397efb93 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -153,7 +153,7 @@ def __init__( if local_directory is None: local_directory = dask.config.get("temporary-directory") or os.getcwd() if not os.path.exists(local_directory): - os.mkdir(local_directory) + os.makedirs(local_directory) local_directory = os.path.join(local_directory, "dask-worker-space") self.local_directory = local_directory diff --git a/distributed/worker.py b/distributed/worker.py index cb32be8111e..185db2e193b 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -492,7 +492,7 @@ def __init__( if local_directory is None: local_directory = dask.config.get("temporary-directory") or os.getcwd() if not os.path.exists(local_directory): - os.mkdir(local_directory) + os.makedirs(local_directory) local_directory = os.path.join(local_directory, "dask-worker-space") with warn_on_duration( From 72213c9b14fe7f828371f187a01b1d8c9b773ae1 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Mon, 2 Mar 2020 15:47:43 -0600 Subject: [PATCH 0705/1550] Update heartbeat CommClosedError error handling (#3529) * Update worker heartbeat error catching logic * Run black --- distributed/tests/test_worker.py | 25 ++++++++++++++++++++++++- distributed/worker.py | 6 ++++-- 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 0bc2cf10988..b6da294c749 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -31,7 +31,7 @@ wait, ) from distributed.compatibility import WINDOWS -from distributed.core import rpc +from distributed.core import rpc, CommClosedError from distributed.scheduler import Scheduler from distributed.metrics import time from distributed.worker import Worker, error_message, logger, parse_memory_limit @@ -1629,3 +1629,26 @@ async def test_update_latency(cleanup): if w.digests is not None: assert w.digests["latency"].size() > 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("reconnect", [True, False]) +async def test_heartbeat_comm_closed(cleanup, monkeypatch, reconnect): + with captured_logger("distributed.worker", level=logging.WARNING) as logger: + async with await Scheduler() as s: + + def bad_heartbeat_worker(*args, **kwargs): + raise CommClosedError() + + async with await Worker(s.address, reconnect=reconnect) as w: + # Trigger CommClosedError during worker heartbeat + monkeypatch.setattr( + w.scheduler, "heartbeat_worker", bad_heartbeat_worker + ) + + await w.heartbeat() + if reconnect: + assert w.status == "running" + else: + assert w.status == "closed" + assert "Heartbeat to scheduler failed" in logger.getvalue() diff --git a/distributed/worker.py b/distributed/worker.py index 185db2e193b..a5a39fe22b3 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -913,14 +913,16 @@ async def heartbeat(self): ) self.bandwidth_workers.clear() self.bandwidth_types.clear() + except CommClosedError: + logger.warning("Heartbeat to scheduler failed") + if not self.reconnect: + await self.close(report=False) except IOError as e: # Scheduler is gone. Respect distributed.comm.timeouts.connect if "Timed out trying to connect" in str(e): await self.close(report=False) else: raise e - except CommClosedError: - logger.warning("Heartbeat to scheduler failed") finally: self.heartbeat_active = False else: From 0140fc61835745ee41305d944ac037649e096059 Mon Sep 17 00:00:00 2001 From: Benjamin Zaitlen Date: Mon, 2 Mar 2020 19:56:41 -0500 Subject: [PATCH 0706/1550] Fix/more ucx config options (#3539) * add tcp and cuda_copy config flags * update tests * raise error if no transport methods are set --- distributed/comm/tests/test_ucx_config.py | 32 ++++++++++++++++++++--- distributed/comm/ucx.py | 20 ++++++++++++-- distributed/distributed.yaml | 8 +++--- 3 files changed, 52 insertions(+), 8 deletions(-) diff --git a/distributed/comm/tests/test_ucx_config.py b/distributed/comm/tests/test_ucx_config.py index 5b45a844093..5746cc80454 100644 --- a/distributed/comm/tests/test_ucx_config.py +++ b/distributed/comm/tests/test_ucx_config.py @@ -20,18 +20,30 @@ @pytest.mark.asyncio async def test_ucx_config(cleanup): - ucx = {"nvlink": True, "infiniband": True, "net-devices": ""} + ucx = { + "nvlink": True, + "infiniband": True, + "net-devices": "", + "tcp": True, + "cuda_copy": True, + } with dask.config.set(ucx=ucx): ucx_config = _scrub_ucx_config() assert ucx_config.get("TLS") == "rc,tcp,sockcm,cuda_copy,cuda_ipc" assert ucx_config.get("NET_DEVICES") is None - ucx = {"nvlink": False, "infiniband": True, "net-devices": "mlx5_0:1"} + ucx = { + "nvlink": False, + "infiniband": True, + "net-devices": "mlx5_0:1", + "tcp": True, + "cuda_copy": False, + } with dask.config.set(ucx=ucx): ucx_config = _scrub_ucx_config() - assert ucx_config.get("TLS") == "rc,tcp,sockcm,cuda_copy" + assert ucx_config.get("TLS") == "rc,tcp,sockcm" assert ucx_config.get("NET_DEVICES") == "mlx5_0:1" ucx = { @@ -39,6 +51,8 @@ async def test_ucx_config(cleanup): "infiniband": True, "net-devices": "all", "MEMTYPE_CACHE": "y", + "tcp": True, + "cuda_copy": True, } with dask.config.set(ucx=ucx): @@ -46,6 +60,18 @@ async def test_ucx_config(cleanup): assert ucx_config.get("TLS") == "rc,tcp,sockcm,cuda_copy" assert ucx_config.get("MEMTYPE_CACHE") == "y" + ucx = { + "nvlink": False, + "infiniband": False, + "net-devices": "all", + "MEMTYPE_CACHE": "y", + "tcp": False, + "cuda_copy": True, + } + with dask.config.set(ucx=ucx): + with raises(ValueError): + ucx_config = _scrub_ucx_config() + def test_ucx_config_w_env_var(cleanup, loop, monkeypatch): size = "1000.00 MB" diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index 5ea2d16ec45..9484c9c08b9 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -359,10 +359,22 @@ def _scrub_ucx_config(): # if any of the high level flags are set, as long as they are not Null/None, # we assume we should configure basic TLS settings for UCX - if any([dask.config.get("ucx.nvlink"), dask.config.get("ucx.infiniband")]): - tls = "tcp,sockcm,cuda_copy" + if any( + [ + dask.config.get("ucx.tcp"), + dask.config.get("ucx.nvlink"), + dask.config.get("ucx.infiniband"), + ] + ): + tls = "tcp,sockcm" tls_priority = "sockcm" + # CUDA COPY can optionally be used with ucx -- we rely on the user + # to define when messages will include CUDA objects. Note: + # defining only the Infiniband flag will not enable cuda_copy + if any([dask.config.get("ucx.nvlink"), dask.config.get("ucx.cuda_copy")]): + tls = tls + ",cuda_copy" + if dask.config.get("ucx.infiniband"): tls = "rc," + tls if dask.config.get("ucx.nvlink"): @@ -373,6 +385,10 @@ def _scrub_ucx_config(): net_devices = dask.config.get("ucx.net-devices") if net_devices is not None and net_devices != "": options["NET_DEVICES"] = net_devices + else: + raise ValueError( + "UCX Dask config not set. Please define at least one: ucx.tcp, ucx.nvlink, ucx.infiniband" + ) # ANY UCX options defined in config will overwrite high level dask.ucx flags valid_ucx_keys = list(get_config().keys()) diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index ed21507e041..05f27604328 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -136,7 +136,9 @@ distributed: rmm: pool-size: null ucx: - nvlink: null - infiniband: null - net-devices: null + tcp: null # enable tcp + nvlink: null # enable cuda_ipc + infiniband: null # enable Infiniband + cuda_copy: null # enable cuda-copy + net-devices: null # define which Infiniband device to use From 6a66df019cceec14cbf7397c24cb2766967cd704 Mon Sep 17 00:00:00 2001 From: jakirkham Date: Tue, 3 Mar 2020 03:36:50 -0800 Subject: [PATCH 0707/1550] Use `pytest.raises` in `test_ucx_config.py` (#3541) Should fix a linting error we are seeing in `master`. ref: https://travis-ci.org/dask/distributed/jobs/657551806#L460 --- distributed/comm/tests/test_ucx_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/comm/tests/test_ucx_config.py b/distributed/comm/tests/test_ucx_config.py index 5746cc80454..695f2e7575d 100644 --- a/distributed/comm/tests/test_ucx_config.py +++ b/distributed/comm/tests/test_ucx_config.py @@ -69,7 +69,7 @@ async def test_ucx_config(cleanup): "cuda_copy": True, } with dask.config.set(ucx=ucx): - with raises(ValueError): + with pytest.raises(ValueError): ucx_config = _scrub_ucx_config() From b049bd71f8ef28adb96aa0cdd91254242c38ea2c Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Tue, 3 Mar 2020 08:26:51 -0600 Subject: [PATCH 0708/1550] DOC: update to async await (#3543) --- docs/source/develop.rst | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/docs/source/develop.rst b/docs/source/develop.rst index 68e8385ca5e..8d0a02fd73d 100644 --- a/docs/source/develop.rst +++ b/docs/source/develop.rst @@ -108,20 +108,20 @@ using the ``@gen_cluster`` style of test, e.g. from distributed import Client, Future, Scheduler, Worker @gen_cluster(client=True) - def test_submit(c, s, a, b): + async def test_submit(c, s, a, b): assert isinstance(c, Client) assert isinstance(s, Scheduler) assert isinstance(a, Worker) assert isinstance(b, Worker) - + future = c.submit(inc, 1) assert isinstance(future, Future) assert future.key in c.futures - + # result = future.result() # This synchronous API call would block - result = yield future + result = await future assert result == 2 - + assert future.key in s.tasks assert future.key in a.data or future.key in b.data @@ -131,8 +131,8 @@ you and cleans them up after the test. It also allows you to directly inspect the state of every element of the cluster directly. However, you can not use the normal synchronous API (doing so will cause the test to wait forever) and instead you need to use the coroutine API, where all blocking functions are -prepended with an underscore (``_``). Beware, it is a common mistake to use -the blocking interface within these tests. +prepended with an underscore (``_``) and awaited with ``await``. +Beware, it is a common mistake to use the blocking interface within these tests. If you want to test the normal synchronous API you can use the ``client`` pytest fixture style test, which sets up a scheduler and workers for you in @@ -166,7 +166,7 @@ also add the ``s, a, b`` fixtures as well. In this style of test you do not have access to the scheduler or workers. The variables ``s, a, b`` are now dictionaries holding a ``multiprocessing.Process`` object and a port integer. However, you can now -use the normal synchronous API (never use yield in this style of test) and you +use the normal synchronous API (never use ``await`` in this style of test) and you can close processes easily by terminating them. Typically for most user-facing functions you will find both kinds of tests. From 52a56c50623b6e95c3ee84cdf094b185043dbe17 Mon Sep 17 00:00:00 2001 From: jakirkham Date: Tue, 3 Mar 2020 09:58:14 -0800 Subject: [PATCH 0709/1550] Support using other serializers with `register_generic` (#3536) * Build `ObjectDictSerializer` Makes a class to handle serialization and deserialization of objects with `__dict__`s. Should make it easier to specify other serializers one can use with such objects. Note if we are unable to use the serializer specified, we fallback to pickling things. * Use `ObjectDictSerializer` for "dask" serialization * Make `register_generic` handle other serializers Allow users to call `register_generic` with other serializers in mind. By default still use `dask_serialize` and `dask_deserialize`. Though allow these to be subbed out by other `*_serialize` and `*_deserialize` functions that will also dispatch based on type to serialize and deserialize with the expected serialization mode. * Ensure `dict` can deserialized with "cuda" As nested serialization can run into `dict` objects in the process, make sure that they can be deserialized with "cuda" as well to produce the original object. --- distributed/protocol/cuda.py | 7 +- distributed/protocol/serialize.py | 103 +++++++++++++++++------------- 2 files changed, 63 insertions(+), 47 deletions(-) diff --git a/distributed/protocol/cuda.py b/distributed/protocol/cuda.py index 51cb3ea42fa..aa638f70c0d 100644 --- a/distributed/protocol/cuda.py +++ b/distributed/protocol/cuda.py @@ -1,7 +1,7 @@ import dask from . import pickle -from .serialize import register_serialization_family +from .serialize import ObjectDictSerializer, register_serialization_family from dask.utils import typename cuda_serialize = dask.utils.Dispatch("cuda_serialize") @@ -29,3 +29,8 @@ def cuda_loads(header, frames): register_serialization_family("cuda", cuda_dumps, cuda_loads) + + +cuda_object_with_dict_serializer = ObjectDictSerializer("cuda") + +cuda_deserialize.register(dict)(cuda_object_with_dict_serializer.deserialize) diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index 3f3207ab58f..c462568cc40 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -560,59 +560,69 @@ def _is_msgpack_serializable(v): ) -def serialize_object_with_dict(est): - header = { - "serializer": "dask", - "type-serialized": pickle.dumps(type(est)), - "simple": {}, - "complex": {}, - } - frames = [] - - if isinstance(est, dict): - d = est - else: - d = est.__dict__ +class ObjectDictSerializer: + def __init__(self, serializer): + self.serializer = serializer + + def serialize(self, est): + header = { + "serializer": self.serializer, + "type-serialized": pickle.dumps(type(est)), + "simple": {}, + "complex": {}, + } + frames = [] - for k, v in d.items(): - if _is_msgpack_serializable(v): - header["simple"][k] = v + if isinstance(est, dict): + d = est else: - if isinstance(v, dict): - h, f = serialize_object_with_dict(v) - else: - h, f = serialize(v) - header["complex"][k] = { - "header": h, - "start": len(frames), - "stop": len(frames) + len(f), - } - frames += f - return header, frames + d = est.__dict__ + for k, v in d.items(): + if _is_msgpack_serializable(v): + header["simple"][k] = v + else: + if isinstance(v, dict): + h, f = self.serialize(v) + else: + h, f = serialize(v, serializers=(self.serializer, "pickle")) + header["complex"][k] = { + "header": h, + "start": len(frames), + "stop": len(frames) + len(f), + } + frames += f + return header, frames + + def deserialize(self, header, frames): + cls = pickle.loads(header["type-serialized"]) + if issubclass(cls, dict): + dd = obj = {} + else: + obj = object.__new__(cls) + dd = obj.__dict__ + dd.update(header["simple"]) + for k, d in header["complex"].items(): + h = d["header"] + f = frames[d["start"] : d["stop"]] + v = deserialize(h, f) + dd[k] = v -def deserialize_object_with_dict(header, frames): - cls = pickle.loads(header["type-serialized"]) - if issubclass(cls, dict): - dd = obj = {} - else: - obj = object.__new__(cls) - dd = obj.__dict__ - dd.update(header["simple"]) - for k, d in header["complex"].items(): - h = d["header"] - f = frames[d["start"] : d["stop"]] - v = deserialize(h, f) - dd[k] = v + return obj - return obj +dask_object_with_dict_serializer = ObjectDictSerializer("dask") -dask_deserialize.register(dict)(deserialize_object_with_dict) +dask_deserialize.register(dict)(dask_object_with_dict_serializer.deserialize) -def register_generic(cls): - """ Register dask_(de)serialize to traverse through __dict__ +def register_generic( + cls, + serializer_name="dask", + serialize_func=dask_serialize, + deserialize_func=dask_deserialize, +): + """ Register (de)serialize to traverse through __dict__ Normally when registering new classes for Dask's custom serialization you need to manage headers and frames, which can be tedious. If all you want @@ -643,5 +653,6 @@ def register_generic(cls): dask_serialize dask_deserialize """ - dask_serialize.register(cls)(serialize_object_with_dict) - dask_deserialize.register(cls)(deserialize_object_with_dict) + object_with_dict_serializer = ObjectDictSerializer(serializer_name) + serialize_func.register(cls)(object_with_dict_serializer.serialize) + deserialize_func.register(cls)(object_with_dict_serializer.deserialize) From 384080422f5bd54d16ba23161a1c2d18f74ff299 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Tue, 3 Mar 2020 18:58:40 +0100 Subject: [PATCH 0710/1550] Use UCX default configuration instead of raising (#3544) * Use UCX default configuration instead of raising * Remove UCX test raising ValueError --- distributed/comm/tests/test_ucx_config.py | 12 ------------ distributed/comm/ucx.py | 7 ++----- 2 files changed, 2 insertions(+), 17 deletions(-) diff --git a/distributed/comm/tests/test_ucx_config.py b/distributed/comm/tests/test_ucx_config.py index 695f2e7575d..c2e86ed0b49 100644 --- a/distributed/comm/tests/test_ucx_config.py +++ b/distributed/comm/tests/test_ucx_config.py @@ -60,18 +60,6 @@ async def test_ucx_config(cleanup): assert ucx_config.get("TLS") == "rc,tcp,sockcm,cuda_copy" assert ucx_config.get("MEMTYPE_CACHE") == "y" - ucx = { - "nvlink": False, - "infiniband": False, - "net-devices": "all", - "MEMTYPE_CACHE": "y", - "tcp": False, - "cuda_copy": True, - } - with dask.config.set(ucx=ucx): - with pytest.raises(ValueError): - ucx_config = _scrub_ucx_config() - def test_ucx_config_w_env_var(cleanup, loop, monkeypatch): size = "1000.00 MB" diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index 9484c9c08b9..7295b11bb48 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -358,7 +358,8 @@ def _scrub_ucx_config(): options = {} # if any of the high level flags are set, as long as they are not Null/None, - # we assume we should configure basic TLS settings for UCX + # we assume we should configure basic TLS settings for UCX, otherwise we + # leave UCX to its default configuration if any( [ dask.config.get("ucx.tcp"), @@ -385,10 +386,6 @@ def _scrub_ucx_config(): net_devices = dask.config.get("ucx.net-devices") if net_devices is not None and net_devices != "": options["NET_DEVICES"] = net_devices - else: - raise ValueError( - "UCX Dask config not set. Please define at least one: ucx.tcp, ucx.nvlink, ucx.infiniband" - ) # ANY UCX options defined in config will overwrite high level dask.ucx flags valid_ucx_keys = list(get_config().keys()) From d8d0d4e71023ac6c1507b443b90d7805e2bf7ad2 Mon Sep 17 00:00:00 2001 From: Stan Seibert Date: Tue, 3 Mar 2020 14:56:45 -0600 Subject: [PATCH 0711/1550] Allow tasks with restrictions to be stolen (#3069) Addresses stealing tasks with resource restrictions, as mentioned in #1851. If a task has hard restrictions, do not just give up on stealing. Instead, use the restrictions to determine which workers can steal it before attempting to execute a steal operation. A follow up PR will be needed to address the issue of long-running tasks not being stolen because the scheduler has no information about their runtime. Supercedes #2740 --- distributed/stealing.py | 70 ++++++++++++++++++++++++++++----- distributed/tests/test_steal.py | 55 +++++++++++++++++++++++++- distributed/worker.py | 2 +- 3 files changed, 115 insertions(+), 12 deletions(-) diff --git a/distributed/stealing.py b/distributed/stealing.py index b14a2a8de6d..4fbb753e131 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -4,6 +4,7 @@ from time import time import dask +from .comm.addressing import get_address_host from .core import CommClosedError from .diagnostics.plugin import SchedulerPlugin from .utils import log_errors, parse_timedelta, PeriodicCallback @@ -128,11 +129,6 @@ def steal_time_ratio(self, ts): For example a result of zero implies a task without dependencies. level: The location within a stealable list to place this value """ - if not ts.loose_restrictions and ( - ts.host_restrictions or ts.worker_restrictions or ts.resource_restrictions - ): - return None, None # don't steal - if not ts.dependencies: # no dependencies fast path return 0, 0 @@ -258,7 +254,7 @@ def move_task_confirm(self, key=None, worker=None, state=None): self.scheduler.check_idle_saturated(victim) # Victim was waiting, has given up task, enact steal - elif state in ("waiting", "ready"): + elif state in ("waiting", "ready", "constrained"): self.remove_key_from_stealable(ts) ts.processing_on = thief duration = victim.processing.pop(ts) @@ -360,14 +356,23 @@ def maybe_move_task(level, ts, sat, idl, duration, cost_multiplier): i += 1 if not idle: break - idl = idle[i % len(idle)] + + if _has_restrictions(ts): + thieves = [ws for ws in idle if _can_steal(ws, ts, sat)] + else: + thieves = idle + if not thieves: + break + thief = thieves[i % len(thieves)] duration = sat.processing.get(ts) if duration is None: stealable.discard(ts) continue - maybe_move_task(level, ts, sat, idl, duration, cost_multiplier) + maybe_move_task( + level, ts, sat, thief, duration, cost_multiplier + ) if self.cost_multipliers[level] < 20: # don't steal from public at cost stealable = self.stealable_all[level] @@ -388,10 +393,18 @@ def maybe_move_task(level, ts, sat, idl, duration, cost_multiplier): continue i += 1 - idl = idle[i % len(idle)] + if _has_restrictions(ts): + thieves = [ws for ws in idle if _can_steal(ws, ts, sat)] + else: + thieves = idle + if not thieves: + continue + thief = thieves[i % len(thieves)] duration = sat.processing[ts] - maybe_move_task(level, ts, sat, idl, duration, cost_multiplier) + maybe_move_task( + level, ts, sat, thief, duration, cost_multiplier + ) if log: self.log.append(log) @@ -422,4 +435,41 @@ def story(self, *keys): return out +def _has_restrictions(ts): + """Determine whether the given task has restrictions and whether these + restrictions are strict. + """ + return not ts.loose_restrictions and ( + ts.host_restrictions or ts.worker_restrictions or ts.resource_restrictions + ) + + +def _can_steal(thief, ts, victim): + """Determine whether worker ``thief`` can steal task ``ts`` from worker + ``victim``. + + Assumes that `ts` has some restrictions. + """ + if ( + ts.host_restrictions + and get_address_host(thief.address) not in ts.host_restrictions + ): + return False + elif ts.worker_restrictions and thief.address not in ts.worker_restrictions: + return False + + if victim.resources is None: + return True + + for resource, value in victim.resources.items(): + try: + supplied = thief.resources[resource] + except KeyError: + return False + else: + if supplied < value: + return False + return True + + fast_tasks = {"shuffle-split"} diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index b017bff4371..71f408749a1 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -224,6 +224,32 @@ def test_dont_steal_worker_restrictions(c, s, a, b): assert len(b.task_state) == 0 +@gen_cluster( + client=True, nthreads=[("127.0.0.1", 1), ("127.0.0.1", 2), ("127.0.0.1", 2)] +) +def test_steal_worker_restrictions(c, s, wa, wb, wc): + future = c.submit(slowinc, 1, delay=0.1, workers={wa.address, wb.address}) + yield future + + ntasks = 100 + futures = c.map(slowinc, range(ntasks), delay=0.1, workers={wa.address, wb.address}) + + while sum(len(w.task_state) for w in [wa, wb, wc]) < ntasks: + yield gen.sleep(0.01) + + assert 0 < len(wa.task_state) < ntasks + assert 0 < len(wb.task_state) < ntasks + assert len(wc.task_state) == 0 + + s.extensions["stealing"].balance() + + yield gen.sleep(0.1) + + assert 0 < len(wa.task_state) < ntasks + assert 0 < len(wb.task_state) < ntasks + assert len(wc.task_state) == 0 + + @pytest.mark.skipif( not sys.platform.startswith("linux"), reason="Need 127.0.0.2 to mean localhost" ) @@ -245,6 +271,34 @@ def test_dont_steal_host_restrictions(c, s, a, b): assert len(b.task_state) == 0 +@pytest.mark.skipif( + not sys.platform.startswith("linux"), reason="Need 127.0.0.2 to mean localhost" +) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1), ("127.0.0.2", 2)]) +def test_steal_host_restrictions(c, s, wa, wb): + future = c.submit(slowinc, 1, delay=0.10, workers=wa.address) + yield future + + ntasks = 100 + futures = c.map(slowinc, range(ntasks), delay=0.1, workers="127.0.0.1") + while len(wa.task_state) < ntasks: + yield gen.sleep(0.01) + assert len(wa.task_state) == ntasks + assert len(wb.task_state) == 0 + + wc = yield Worker(s.address, ncores=1) + + start = time() + while not wc.task_state or len(wa.task_state) == ntasks: + yield gen.sleep(0.01) + assert time() < start + 3 + + yield gen.sleep(0.1) + assert 0 < len(wa.task_state) < ntasks + assert len(wb.task_state) == 0 + assert 0 < len(wc.task_state) < ntasks + + @gen_cluster( client=True, nthreads=[("127.0.0.1", 1, {"resources": {"A": 2}}), ("127.0.0.1", 1)] ) @@ -265,7 +319,6 @@ def test_dont_steal_resource_restrictions(c, s, a, b): assert len(b.task_state) == 0 -@pytest.mark.skip(reason="no stealing of resources") @gen_cluster( client=True, nthreads=[("127.0.0.1", 1, {"resources": {"A": 2}})], timeout=3 ) diff --git a/distributed/worker.py b/distributed/worker.py index a5a39fe22b3..aa71a16640b 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -2150,7 +2150,7 @@ def steal_request(self, key): response = {"op": "steal-response", "key": key, "state": state} self.batched_stream.send(response) - if state in ("ready", "waiting"): + if state in ("ready", "waiting", "constrained"): self.release_key(key) def release_key(self, key, cause=None, reason=None, report=True): From 15550952aabfecf9e6a6bbccb7dea82d72857a4c Mon Sep 17 00:00:00 2001 From: jakirkham Date: Tue, 3 Mar 2020 18:20:47 -0800 Subject: [PATCH 0712/1550] Serialize sparse arrays (#3545) * Serialize SciPy sparse matrices As SciPy sparse matrices just consist of a handful of NumPy arrays as attributes (which Dask already knows how to serialize), we can just ask Dask to comb through them for us and serialize their components. * Special case scipy.sparse's `dok_matrix` As `dok_matrix` is a subclass of `spmatrix` and `dict`, this confuses Dask's `register_generic` machinery. Not to mention this doesn't actually contain any NumPy `ndarray`s. Instead it just stores coordinates as `tuple`s in keys and data in values. Ideally we would just pack the dictionary into the header and move on. However as the data included is not MessagePack serializable, this strategy does not work in practice. So simply convert the `dok_matrix` to a `coo_matrix`, which has very similar layout and is easier for us to serialize. When deserializing, just extract the `coo_matrix` and convert it back to a `dok_matrix`. This let's us bypass the oddities of the `dok_matrix` while still having reasonably efficient serialization. * Register SciPy serialization * Test serializing SciPy sparse matrices * Serialize CuPy sparse matrices As CuPy sparse matrices just consist of a handful of CuPy arrays as attributes (which Dask already knows how to serialize), we can just ask Dask to comb through them for us and serialize their components. * Run `isort` on `test_cupy` * Test serializing CuPy sparse matrices --- distributed/protocol/__init__.py | 6 ++++ distributed/protocol/cupy.py | 19 +++++++++++- distributed/protocol/scipy.py | 30 +++++++++++++++++++ distributed/protocol/tests/test_cupy.py | 35 +++++++++++++++++++++- distributed/protocol/tests/test_scipy.py | 37 ++++++++++++++++++++++++ 5 files changed, 125 insertions(+), 2 deletions(-) create mode 100644 distributed/protocol/scipy.py create mode 100644 distributed/protocol/tests/test_scipy.py diff --git a/distributed/protocol/__init__.py b/distributed/protocol/__init__.py index 212051427f5..c34f161a1fe 100644 --- a/distributed/protocol/__init__.py +++ b/distributed/protocol/__init__.py @@ -30,6 +30,12 @@ def _register_numpy(): from . import numpy +@dask_serialize.register_lazy("scipy") +@dask_deserialize.register_lazy("scipy") +def _register_scipy(): + from . import scipy + + @dask_serialize.register_lazy("h5py") @dask_deserialize.register_lazy("h5py") def _register_h5py(): diff --git a/distributed/protocol/cupy.py b/distributed/protocol/cupy.py index 40bf6efda4f..9245412de6e 100644 --- a/distributed/protocol/cupy.py +++ b/distributed/protocol/cupy.py @@ -4,7 +4,7 @@ import cupy from .cuda import cuda_deserialize, cuda_serialize -from .serialize import dask_deserialize, dask_serialize +from .serialize import dask_deserialize, dask_serialize, register_generic try: from .rmm import dask_deserialize_rmm_device_buffer as dask_deserialize_cuda_buffer @@ -80,3 +80,20 @@ def dask_deserialize_cupy_ndarray(header, frames): frames = [dask_deserialize_cuda_buffer(header, frames)] arr = cuda_deserialize_cupy_ndarray(header, frames) return arr + + +try: + from cupy.cusparse import MatDescriptor + from cupyx.scipy.sparse import spmatrix + + cupy_sparse_types = [MatDescriptor, spmatrix] +except ImportError: + cupy_sparse_types = [] + + +for t in cupy_sparse_types: + for n, s, d in [ + ("cuda", cuda_serialize, cuda_deserialize), + ("dask", dask_serialize, dask_deserialize), + ]: + register_generic(t, n, s, d) diff --git a/distributed/protocol/scipy.py b/distributed/protocol/scipy.py new file mode 100644 index 00000000000..9ed533bc850 --- /dev/null +++ b/distributed/protocol/scipy.py @@ -0,0 +1,30 @@ +""" +Efficient serialization of SciPy sparse matrices. +""" +import scipy + +from .serialize import dask_deserialize, dask_serialize, register_generic + +register_generic(scipy.sparse.spmatrix, "dask", dask_serialize, dask_deserialize) + + +@dask_serialize.register(scipy.sparse.dok.dok_matrix) +def serialize_scipy_sparse_dok(x): + x_coo = x.tocoo() + coo_header, coo_frames = dask_serialize(x.tocoo()) + + header = {"coo_header": coo_header} + frames = coo_frames + + return header, frames + + +@dask_deserialize.register(scipy.sparse.dok.dok_matrix) +def deserialize_scipy_sparse_dok(header, frames): + coo_header = header["coo_header"] + coo_frames = frames + x_coo = dask_deserialize(coo_header, coo_frames) + + x = x_coo.todok() + + return x diff --git a/distributed/protocol/tests/test_cupy.py b/distributed/protocol/tests/test_cupy.py index 5470266fce5..44d4b80d66d 100644 --- a/distributed/protocol/tests/test_cupy.py +++ b/distributed/protocol/tests/test_cupy.py @@ -1,8 +1,10 @@ -from distributed.protocol import serialize, deserialize import pickle + import pytest +from distributed.protocol import deserialize, serialize cupy = pytest.importorskip("cupy") +cupy_sparse = pytest.importorskip("cupyx.scipy.sparse") numpy = pytest.importorskip("numpy") @@ -61,3 +63,34 @@ def test_serialize_cupy_from_rmm(size): y = deserialize(header, frames, deserializers=("cuda", "dask", "pickle", "error")) assert (x_np == cupy.asnumpy(y)).all() + + +@pytest.mark.parametrize( + "sparse_type", + [ + cupy_sparse.coo_matrix, + cupy_sparse.csc_matrix, + cupy_sparse.csr_matrix, + cupy_sparse.dia_matrix, + ], +) +@pytest.mark.parametrize( + "dtype", + [numpy.dtype("f4"), numpy.dtype("f8"),], +) +@pytest.mark.parametrize("serializer", ["cuda", "dask",]) +def test_serialize_cupy_sparse(sparse_type, dtype, serializer): + a_host = numpy.array([[0, 1, 0], [2, 0, 3], [0, 4, 0]], dtype=dtype) + a = cupy.asarray(a_host) + + anz = a.nonzero() + acoo = cupy_sparse.coo_matrix((a[anz], anz)) + asp = sparse_type(acoo) + + header, frames = serialize(asp, serializers=[serializer]) + asp2 = deserialize(header, frames) + + a2 = asp2.todense() + a2_host = cupy.asnumpy(a2) + + assert (a_host == a2_host).all() diff --git a/distributed/protocol/tests/test_scipy.py b/distributed/protocol/tests/test_scipy.py new file mode 100644 index 00000000000..2cb5d7477e5 --- /dev/null +++ b/distributed/protocol/tests/test_scipy.py @@ -0,0 +1,37 @@ +import pytest +from distributed.protocol import deserialize, serialize + +numpy = pytest.importorskip("numpy") +scipy = pytest.importorskip("scipy") +scipy_sparse = pytest.importorskip("scipy.sparse") + + +@pytest.mark.parametrize( + "sparse_type", + [ + scipy_sparse.bsr_matrix, + scipy_sparse.coo_matrix, + scipy_sparse.csc_matrix, + scipy_sparse.csr_matrix, + scipy_sparse.dia_matrix, + scipy_sparse.dok_matrix, + scipy_sparse.lil_matrix, + ], +) +@pytest.mark.parametrize( + "dtype", + [numpy.dtype("f4"), numpy.dtype("f8"),], +) +def test_serialize_scipy_sparse(sparse_type, dtype): + a = numpy.array([[0, 1, 0], [2, 0, 3], [0, 4, 0]], dtype=dtype) + + anz = a.nonzero() + acoo = scipy_sparse.coo_matrix((a[anz], anz)) + asp = sparse_type(acoo) + + header, frames = serialize(asp, serializers=["dask"]) + asp2 = deserialize(header, frames) + + a2 = asp2.todense() + + assert (a == a2).all() From 9b5bf448af478a166069aefc9c0c1354a29ae482 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Wed, 4 Mar 2020 09:59:20 -0600 Subject: [PATCH 0713/1550] API docs for LocalCluster and SpecCluster (#3548) --- distributed/client.py | 4 ++-- docs/source/api.rst | 22 ++++++++++++++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 4e84ea278d2..287425a70f1 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -502,7 +502,7 @@ class Client(Node): It is also common to create a Client without specifying the scheduler address , like ``Client()``. In this case the Client creates a - ``LocalCluster`` in the background and connects to that. Any extra + :class:`LocalCluster` in the background and connects to that. Any extra keywords are passed from Client to LocalCluster in this case. See the LocalCluster documentation for more information. @@ -569,7 +569,7 @@ class Client(Node): See Also -------- distributed.scheduler.Scheduler: Internal scheduler - distributed.deploy.local.LocalCluster: + distributed.LocalCluster: """ _instances = weakref.WeakSet() diff --git a/docs/source/api.rst b/docs/source/api.rst index 8d739334b07..9d2f6c7f870 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -148,6 +148,28 @@ Future .. autoclass:: Future :members: +Cluster +------- + +Classes relevant for cluster creation and management. Other libraries +(like `dask-jobqueue`_, `dask-gateway`_, `dask-kubernetes`_, `dask-yarn`_ etc.) +provide additional cluster objects. + +.. _dask-jobqueue: https://jobqueue.dask.org/ +.. _dask-gateway: https://gateway.dask.org/ +.. _dask-kubernetes: https://kubernetes.dask.org/ +.. _dask-yarn: https://yarn.dask.org/en/latest/ + +.. autosummary:: + LocalCluster + SpecCluster + +.. autoclass:: LocalCluster + :members: + +.. autoclass:: SpecCluster + :members: + Other ----- From ae74b5ea18fac7e272c9f25b5d9f2775956aa943 Mon Sep 17 00:00:00 2001 From: jakirkham Date: Thu, 5 Mar 2020 15:17:37 -0800 Subject: [PATCH 0714/1550] Fix-up CuPy sparse serialization (#3556) Fix-up CuPy sparse serialization --- distributed/protocol/__init__.py | 4 +++ distributed/protocol/cupy.py | 32 +++++++++++++++--- distributed/protocol/tests/test_cupy.py | 45 ++++++++++++++----------- 3 files changed, 56 insertions(+), 25 deletions(-) diff --git a/distributed/protocol/__init__.py b/distributed/protocol/__init__.py index c34f161a1fe..bb919019e04 100644 --- a/distributed/protocol/__init__.py +++ b/distributed/protocol/__init__.py @@ -78,6 +78,10 @@ def _register_torch(): @cuda_deserialize.register_lazy("cupy") @dask_serialize.register_lazy("cupy") @dask_deserialize.register_lazy("cupy") +@cuda_serialize.register_lazy("cupyx") +@cuda_deserialize.register_lazy("cupyx") +@dask_serialize.register_lazy("cupyx") +@dask_deserialize.register_lazy("cupyx") def _register_cupy(): from . import cupy diff --git a/distributed/protocol/cupy.py b/distributed/protocol/cupy.py index 9245412de6e..3d074266245 100644 --- a/distributed/protocol/cupy.py +++ b/distributed/protocol/cupy.py @@ -1,6 +1,8 @@ """ Efficient serialization GPU arrays. """ +import copyreg + import cupy from .cuda import cuda_deserialize, cuda_serialize @@ -85,15 +87,35 @@ def dask_deserialize_cupy_ndarray(header, frames): try: from cupy.cusparse import MatDescriptor from cupyx.scipy.sparse import spmatrix - - cupy_sparse_types = [MatDescriptor, spmatrix] except ImportError: - cupy_sparse_types = [] + MatDescriptor = None + spmatrix = None + + +if MatDescriptor is not None: + + def reduce_matdescriptor(other): + # Pickling MatDescriptor errors + # xref: https://github.com/cupy/cupy/issues/3061 + return cupy.cusparse.MatDescriptor.create, () + + copyreg.pickle(MatDescriptor, reduce_matdescriptor) + + @cuda_serialize.register(MatDescriptor) + @dask_serialize.register(MatDescriptor) + def serialize_cupy_matdescriptor(x): + header, frames = {}, [] + return header, frames + + @cuda_deserialize.register(MatDescriptor) + @dask_deserialize.register(MatDescriptor) + def deserialize_cupy_matdescriptor(header, frames): + return MatDescriptor.create() -for t in cupy_sparse_types: +if spmatrix is not None: for n, s, d in [ ("cuda", cuda_serialize, cuda_deserialize), ("dask", dask_serialize, dask_deserialize), ]: - register_generic(t, n, s, d) + register_generic(spmatrix, n, s, d) diff --git a/distributed/protocol/tests/test_cupy.py b/distributed/protocol/tests/test_cupy.py index 44d4b80d66d..95cb530c4db 100644 --- a/distributed/protocol/tests/test_cupy.py +++ b/distributed/protocol/tests/test_cupy.py @@ -4,7 +4,6 @@ from distributed.protocol import deserialize, serialize cupy = pytest.importorskip("cupy") -cupy_sparse = pytest.importorskip("cupyx.scipy.sparse") numpy = pytest.importorskip("numpy") @@ -66,31 +65,37 @@ def test_serialize_cupy_from_rmm(size): @pytest.mark.parametrize( - "sparse_type", - [ - cupy_sparse.coo_matrix, - cupy_sparse.csc_matrix, - cupy_sparse.csr_matrix, - cupy_sparse.dia_matrix, - ], + "sparse_name", ["coo_matrix", "csc_matrix", "csr_matrix", "dia_matrix",], ) @pytest.mark.parametrize( "dtype", [numpy.dtype("f4"), numpy.dtype("f8"),], ) -@pytest.mark.parametrize("serializer", ["cuda", "dask",]) -def test_serialize_cupy_sparse(sparse_type, dtype, serializer): - a_host = numpy.array([[0, 1, 0], [2, 0, 3], [0, 4, 0]], dtype=dtype) - a = cupy.asarray(a_host) - - anz = a.nonzero() - acoo = cupy_sparse.coo_matrix((a[anz], anz)) - asp = sparse_type(acoo) +@pytest.mark.parametrize("serializer", ["cuda", "dask", "pickle"]) +def test_serialize_cupy_sparse(sparse_name, dtype, serializer): + scipy_sparse = pytest.importorskip("scipy.sparse") + cupy_sparse = pytest.importorskip("cupyx.scipy.sparse") - header, frames = serialize(asp, serializers=[serializer]) - asp2 = deserialize(header, frames) + scipy_sparse_type = getattr(scipy_sparse, sparse_name) + cupy_sparse_type = getattr(cupy_sparse, sparse_name) - a2 = asp2.todense() - a2_host = cupy.asnumpy(a2) + a_host = numpy.array([[0, 1, 0], [2, 0, 3], [0, 4, 0]], dtype=dtype) + asp_host = scipy_sparse_type(a_host) + if sparse_name == "dia_matrix": + # CuPy `dia_matrix` cannot be created from SciPy one + # xref: https://github.com/cupy/cupy/issues/3158 + asp_dev = cupy_sparse_type( + (asp_host.data, asp_host.offsets), + shape=asp_host.shape, + dtype=asp_host.dtype, + ) + else: + asp_dev = cupy_sparse_type(asp_host) + + header, frames = serialize(asp_dev, serializers=[serializer]) + a2sp_dev = deserialize(header, frames) + + a2sp_host = a2sp_dev.get() + a2_host = a2sp_host.todense() assert (a_host == a2_host).all() From e619fc99ad5b635f52f69ce2edafed77cfa7c898 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Fri, 6 Mar 2020 10:23:06 -0600 Subject: [PATCH 0715/1550] Update TaskGroup remove logic (#3557) --- distributed/scheduler.py | 2 +- distributed/tests/test_scheduler.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index c78c4b1b218..ab026f61d06 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4649,7 +4649,7 @@ def transition(self, key, finish, *args, **kwargs): if ts.state == "forgotten": del self.tasks[ts.key] - if ts.state == "forgotten": + if ts.state == "forgotten" and ts.group.name in self.task_groups: # Remove TaskGroup if all tasks are in the forgotten state tg = ts.group if not any(tg.states.get(s) for s in ALL_TASK_STATES): diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 5c4d8cbc23e..a5649dbfc82 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -1837,6 +1837,17 @@ async def test_task_unique_groups(c, s, a, b): assert s.task_prefixes["sum"].states["memory"] == 2 +@gen_cluster(client=True) +async def test_task_group_on_fire_and_forget(c, s, a, b): + # Regression test for https://github.com/dask/distributed/issues/3465 + with captured_logger("distributed.scheduler") as logs: + x = await c.scatter(list(range(10))) + fire_and_forget([c.submit(slowadd, i, x[i]) for i in range(len(x))]) + await asyncio.sleep(1) + + assert "Error transitioning" not in logs.getvalue() + + class BrokenComm(Comm): peer_address = None local_address = None From 15591929f3a6b7b390a7a5394e1f53fe6a6c16f4 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Fri, 6 Mar 2020 14:16:22 -0600 Subject: [PATCH 0716/1550] bump version to 2.12.0 --- docs/source/changelog.rst | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index 498668d3c88..9c7ca9b01f4 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -1,6 +1,34 @@ Changelog ========= +2.12.0 - 2020-03-06 +------------------- + +- Update ``TaskGroup`` remove logic (:pr:`3557`) `James Bourbeau`_ +- Fix-up CuPy sparse serialization (:pr:`3556`) `John Kirkham`_ +- API docs for ``LocalCluster`` and ``SpecCluster`` (:pr:`3548`) `Tom Augspurger`_ +- Serialize sparse arrays (:pr:`3545`) `John Kirkham`_ +- Allow tasks with restrictions to be stolen (:pr:`3069`) `Stan Seibert`_ +- Use UCX default configuration instead of raising (:pr:`3544`) `Peter Andreas Entschev`_ +- Support using other serializers with ``register_generic`` (:pr:`3536`) `John Kirkham`_ +- DOC: update to async await (:pr:`3543`) `Tom Augspurger`_ +- Use ``pytest.raises`` in ``test_ucx_config.py`` (:pr:`3541`) `John Kirkham`_ +- Fix/more ucx config options (:pr:`3539`) `Benjamin Zaitlen`_ +- Update heartbeat ``CommClosedError`` error handling (:pr:`3529`) `James Bourbeau`_ +- Use ``makedirs`` when constructing ``local_directory`` (:pr:`3538`) `John Kirkham`_ +- Mark ``None`` as MessagePack serializable (:pr:`3537`) `John Kirkham`_ +- Mark ``bool`` as MessagePack serializable (:pr:`3535`) `John Kirkham`_ +- Use 'temporary-directory' from ``dask.config`` for Nanny's directory (:pr:`3531`) `John Kirkham`_ +- Add try-except around getting source code in performance report (:pr:`3505`) `Matthew Rocklin`_ +- Fix typo in docstring (:pr:`3528`) `Davis Bennett`_ +- Make work stealing callback time configurable (:pr:`3523`) `Lucas Rademaker`_ +- RMM/UCX Config Flags (:pr:`3515`) `Benjamin Zaitlen`_ +- Revise develop-docs: conda env example (:pr:`3406`) `Darren Weber`_ +- Remove ``import ucp`` from the top of ``ucx.py`` (:pr:`3510`) `Peter Andreas Entschev`_ +- Rename ``logs`` to ``get_logs`` (:pr:`3473`) `Jacob Tomlinson`_ +- Stop keep alives when worker reconnecting to the scheduler (:pr:`3493`) `Jacob Tomlinson`_ + + 2.11.0 - 2020-02-19 ------------------- @@ -1570,3 +1598,7 @@ significantly without many new features. .. _`Cyril Shcherbin`: https://github.com/shcherbin .. _`Søren Fuglede Jørgensen`: https://github.com/fuglede .. _`Igor Gotlibovych`: https://github.com/ig248 +.. _`Stan Seibert`: https://github.com/seibert +.. _`Davis Bennett`: https://github.com/d-v-b +.. _`Lucas Rademaker`: https://github.com/lr4d +.. _`Darren Weber`: https://github.com/dazza-codes From 73f8ae229d558e2c25ef0baea6fb6d127e78191e Mon Sep 17 00:00:00 2001 From: Abdulelah Bin Mahfoodh Date: Mon, 9 Mar 2020 18:24:13 +0300 Subject: [PATCH 0717/1550] Add 'local_directory' option to dask-ssh (#3554) --- distributed/cli/dask_ssh.py | 10 ++++++++++ distributed/deploy/old_ssh.py | 24 +++++++++++++++++++++++- 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/distributed/cli/dask_ssh.py b/distributed/cli/dask_ssh.py index 07cbb57bf01..eb09f49cfed 100755 --- a/distributed/cli/dask_ssh.py +++ b/distributed/cli/dask_ssh.py @@ -77,6 +77,14 @@ "dask-scheduler and dask-worker commands." ), ) +@click.option( + "--local-directory", + default=None, + type=click.Path(exists=True), + help=( + "Directory to use on all cluster nodes to place workers " "and scheduler files." + ), +) @click.option( "--remote-python", default=None, type=str, help="Path to Python on remote nodes." ) @@ -126,6 +134,7 @@ def main( worker_port, nanny_port, remote_dask_worker, + local_directory, ): try: hostnames = list(hostnames) @@ -157,6 +166,7 @@ def main( worker_port, nanny_port, remote_dask_worker, + local_directory, ) import distributed diff --git a/distributed/deploy/old_ssh.py b/distributed/deploy/old_ssh.py index 86d49c9cf15..b524e2d7c45 100644 --- a/distributed/deploy/old_ssh.py +++ b/distributed/deploy/old_ssh.py @@ -209,12 +209,24 @@ def communicate(): def start_scheduler( - logdir, addr, port, ssh_username, ssh_port, ssh_private_key, remote_python=None + logdir, + addr, + port, + ssh_username, + ssh_port, + ssh_private_key, + remote_python=None, + local_directory=None, ): cmd = "{python} -m distributed.cli.dask_scheduler --port {port}".format( python=remote_python or sys.executable, port=port, logdir=logdir ) + if local_directory is not None: + cmd += " --local-directory {local_directory}".format( + local_directory=local_directory + ) + # Optionally re-direct stdout and stderr to a logfile if logdir is not None: cmd = "mkdir -p {logdir} && ".format(logdir=logdir) + cmd @@ -270,6 +282,7 @@ def start_worker( nanny_port, remote_python=None, remote_dask_worker="distributed.cli.dask_worker", + local_directory=None, ): cmd = ( @@ -303,6 +316,11 @@ def start_worker( nanny_port=nanny_port, ) + if local_directory is not None: + cmd += " --local-directory {local_directory}".format( + local_directory=local_directory + ) + # Optionally redirect stdout and stderr to a logfile if logdir is not None: cmd = "mkdir -p {logdir} && ".format(logdir=logdir) + cmd @@ -353,6 +371,7 @@ def __init__( worker_port=None, nanny_port=None, remote_dask_worker="distributed.cli.dask_worker", + local_directory=None, ): self.scheduler_addr = scheduler_addr @@ -372,6 +391,7 @@ def __init__( self.worker_port = worker_port self.nanny_port = nanny_port self.remote_dask_worker = remote_dask_worker + self.local_directory = local_directory # Generate a universal timestamp to use for log files import datetime @@ -402,6 +422,7 @@ def __init__( ssh_port, ssh_private_key, remote_python, + local_directory, ) # Start worker nodes @@ -455,6 +476,7 @@ def add_worker(self, address): self.nanny_port, self.remote_python, self.remote_dask_worker, + self.local_directory, ) ) From b809777d250152edb52b846e9c1c12a20f500878 Mon Sep 17 00:00:00 2001 From: jakirkham Date: Mon, 9 Mar 2020 12:00:20 -0700 Subject: [PATCH 0718/1550] Fix typo in Client.shutdown docstring (#3562) --- distributed/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/client.py b/distributed/client.py index 287425a70f1..679e625470f 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -1382,7 +1382,7 @@ def shutdown(self): """ Shut down the connected scheduler and workers Note, this may disrupt other clients that may be using the same - scheudler and workers. + scheduler and workers. See also -------- From 81e303afc2bae5d3696ad4c29cb189e9f55cfcb8 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Thu, 12 Mar 2020 14:10:12 -0500 Subject: [PATCH 0719/1550] Disable fast fail on GitHub Actions Windows CI (#3569) --- .github/workflows/ci-windows.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/ci-windows.yaml b/.github/workflows/ci-windows.yaml index 75c4b294e88..78db494a6fb 100644 --- a/.github/workflows/ci-windows.yaml +++ b/.github/workflows/ci-windows.yaml @@ -6,6 +6,7 @@ jobs: build: runs-on: windows-latest strategy: + fail-fast: false matrix: python-version: ["3.6", "3.7", "3.8"] From 8e8438324b4323131e941428cd48df8e52b2edd8 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Fri, 13 Mar 2020 16:40:42 -0500 Subject: [PATCH 0720/1550] Pin bokeh in CI builds (#3570) --- .github/workflows/ci-windows.yaml | 8 ++++---- continuous_integration/environment.yml | 3 +-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/.github/workflows/ci-windows.yaml b/.github/workflows/ci-windows.yaml index 78db494a6fb..3b99a8c8ec0 100644 --- a/.github/workflows/ci-windows.yaml +++ b/.github/workflows/ci-windows.yaml @@ -32,14 +32,14 @@ jobs: conda install -c conda-forge tornado=5 fi - - name: List packages in environment - shell: bash -l {0} - run: conda list - - name: Install distributed from source shell: bash -l {0} run: python -m pip install -q --no-deps -e . + - name: List packages in environment + shell: bash -l {0} + run: conda list + - name: Run tests shell: bash -l {0} env: diff --git a/continuous_integration/environment.yml b/continuous_integration/environment.yml index 8f8e425dcab..8218d721e85 100644 --- a/continuous_integration/environment.yml +++ b/continuous_integration/environment.yml @@ -1,10 +1,9 @@ name: testenv channels: - - defaults - conda-forge dependencies: - zstandard - - bokeh + - bokeh=1.4.0 - click - cloudpickle - dask From 0ad75b044731338484cbde7a60986bc6a0258483 Mon Sep 17 00:00:00 2001 From: Krishan Bhasin Date: Fri, 13 Mar 2020 22:15:25 +0000 Subject: [PATCH 0721/1550] Avoid performance_report crashing when a worker dies mid-compute (#3575) --- distributed/dashboard/components/scheduler.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/distributed/dashboard/components/scheduler.py b/distributed/dashboard/components/scheduler.py index c70e41ca436..ee037a4aabb 100644 --- a/distributed/dashboard/components/scheduler.py +++ b/distributed/dashboard/components/scheduler.py @@ -412,7 +412,10 @@ def update(self): return def name(address): - ws = self.scheduler.workers[address] + try: + ws = self.scheduler.workers[address] + except KeyError: + return address if ws.name is not None: return str(ws.name) else: From 806a7e97285c5534b3e37e912cbc060a8036c56f Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Sat, 14 Mar 2020 11:49:30 -0400 Subject: [PATCH 0722/1550] Pin `numpydoc` to avoid double escaped * (#3530) Similar to dask/dask#5961, recent changes to `numpydoc` lead to function signatures displayed using `autosummary` to have doubly-escaped `*`s. This pins `numpydoc` to version 0.8.0 to avoid the regression until a patch is merged upstream. Also small changes to avoid a sphinx timeout looking for an intersphinx inventory, and a fixed a misformatted code-block --- docs/requirements.txt | 3 ++- docs/source/conf.py | 3 ++- docs/source/resources.rst | 1 + 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index 61dd185a5b9..6bcd69b284d 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,8 +1,9 @@ +# We pin numpydoc to avoid doubly-escaped *args and **kwargs in rendered docs +numpydoc==0.8.0 tornado toolz cloudpickle dask -numpydoc sphinx dask_sphinx_theme sphinx-click diff --git a/docs/source/conf.py b/docs/source/conf.py index f8ab5a31797..9bda8cb1a14 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -39,6 +39,7 @@ numpydoc_show_class_members = False + # Add any paths that contain templates here, relative to this directory. templates_path = ["_templates"] @@ -382,7 +383,7 @@ # and the Numpy documentation. intersphinx_mapping = { "python": ("https://docs.python.org/3", None), - "numpy": ("http://docs.scipy.org/doc/numpy", None), + "numpy": ("https://docs.scipy.org/doc/numpy", None), } # Redirects diff --git a/docs/source/resources.rst b/docs/source/resources.rst index f9449dbd8be..7931b980d03 100644 --- a/docs/source/resources.rst +++ b/docs/source/resources.rst @@ -96,6 +96,7 @@ delayed objects. You can pass a dictionary mapping keys of the collection to resource requirements during compute or persist calls. .. code-block:: python + from dask import core x = dd.read_csv(...) From f2f82c6c2e8d36731cb3fb82fb1f80ea0323358e Mon Sep 17 00:00:00 2001 From: jakirkham Date: Tue, 17 Mar 2020 08:51:52 -0700 Subject: [PATCH 0723/1550] Import tlz (#3579) Import from `tlz` for optional `cytoolz` support --- distributed/cfexecutor.py | 2 +- distributed/cli/dask_worker.py | 2 +- distributed/client.py | 6 ++---- distributed/core.py | 2 +- distributed/dashboard/components/__init__.py | 1 - distributed/dashboard/components/scheduler.py | 8 ++------ distributed/dashboard/components/shared.py | 2 +- distributed/dashboard/components/worker.py | 2 +- distributed/dashboard/scheduler.py | 5 +---- .../dashboard/tests/test_scheduler_bokeh.py | 4 ++-- .../dashboard/tests/test_worker_bokeh.py | 2 +- distributed/dashboard/utils.py | 8 ++------ distributed/dashboard/worker.py | 2 +- distributed/deploy/adaptive_core.py | 2 +- distributed/deploy/old_ssh.py | 2 +- distributed/deploy/tests/test_spec_cluster.py | 2 +- distributed/diagnostics/progress.py | 2 +- distributed/diagnostics/progress_stream.py | 2 +- distributed/diagnostics/progressbar.py | 2 +- .../diagnostics/tests/test_task_stream.py | 2 +- distributed/diagnostics/tests/test_widgets.py | 2 +- distributed/profile.py | 2 +- distributed/protocol/compression.py | 3 ++- distributed/protocol/core.py | 6 +----- distributed/protocol/serialize.py | 5 +---- distributed/protocol/tests/test_serialize.py | 2 +- distributed/scheduler.py | 17 ++++++++++++----- distributed/stealing.py | 5 +---- distributed/tests/test_batched.py | 2 +- distributed/tests/test_client.py | 3 ++- distributed/tests/test_client_executor.py | 2 +- distributed/tests/test_failed_workers.py | 2 +- distributed/tests/test_ipython.py | 2 +- distributed/tests/test_nanny.py | 2 +- distributed/tests/test_profile.py | 2 +- distributed/tests/test_pubsub.py | 2 +- distributed/tests/test_scheduler.py | 2 +- distributed/tests/test_steal.py | 2 +- distributed/tests/test_stress.py | 2 +- distributed/tests/test_worker.py | 2 +- distributed/utils.py | 2 +- distributed/utils_comm.py | 2 +- distributed/utils_test.py | 2 +- distributed/variable.py | 5 +---- distributed/worker.py | 6 ++---- docs/source/efficiency.rst | 2 +- requirements.txt | 2 +- 47 files changed, 64 insertions(+), 84 deletions(-) diff --git a/distributed/cfexecutor.py b/distributed/cfexecutor.py index 985a407bdb9..545dbbced09 100644 --- a/distributed/cfexecutor.py +++ b/distributed/cfexecutor.py @@ -1,7 +1,7 @@ import concurrent.futures as cf import weakref -from toolz import merge +from tlz import merge from tornado import gen diff --git a/distributed/cli/dask_worker.py b/distributed/cli/dask_worker.py index 5188333b75c..29261b52451 100755 --- a/distributed/cli/dask_worker.py +++ b/distributed/cli/dask_worker.py @@ -22,7 +22,7 @@ ) from distributed.utils import deserialize_for_cli, import_term -from toolz import valmap +from tlz import valmap from tornado.ioloop import IOLoop, TimeoutError logger = logging.getLogger("distributed.dask_worker") diff --git a/distributed/client.py b/distributed/client.py index 679e625470f..06c6d245c07 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -30,10 +30,8 @@ from dask.compatibility import apply from dask.utils import ensure_dict, format_bytes, funcname -try: - from cytoolz import first, groupby, merge, valmap, keymap -except ImportError: - from toolz import first, groupby, merge, valmap, keymap +from tlz import first, groupby, merge, valmap, keymap + try: from dask.delayed import single_key except ImportError: diff --git a/distributed/core.py b/distributed/core.py index 5768f0f4d8e..ec1e6c5214c 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -10,7 +10,7 @@ import dask import tblib -from toolz import merge +from tlz import merge from tornado import gen from tornado.ioloop import IOLoop diff --git a/distributed/dashboard/components/__init__.py b/distributed/dashboard/components/__init__.py index bb8269083e9..f6159e83bcf 100644 --- a/distributed/dashboard/components/__init__.py +++ b/distributed/dashboard/components/__init__.py @@ -26,7 +26,6 @@ from bokeh.plotting import figure import dask from tornado import gen -import toolz from distributed.dashboard.utils import without_property_validation, BOKEH_VERSION from distributed import profile diff --git a/distributed/dashboard/components/scheduler.py b/distributed/dashboard/components/scheduler.py index ee037a4aabb..c371210c701 100644 --- a/distributed/dashboard/components/scheduler.py +++ b/distributed/dashboard/components/scheduler.py @@ -35,7 +35,8 @@ from bokeh.io import curdoc import dask from dask.utils import format_bytes, key_split -from toolz import pipe +from tlz import pipe +from tlz.curried import map, concat, groupby from tornado import escape try: @@ -63,11 +64,6 @@ from distributed.diagnostics.graph_layout import GraphLayout from distributed.diagnostics.task_stream import TaskStreamPlugin -try: - from cytoolz.curried import map, concat, groupby -except ImportError: - from toolz.curried import map, concat, groupby - if dask.config.get("distributed.dashboard.export-tool"): from distributed.dashboard.export_tool import ExportTool else: diff --git a/distributed/dashboard/components/shared.py b/distributed/dashboard/components/shared.py index 611d281dd5e..24db46385e7 100644 --- a/distributed/dashboard/components/shared.py +++ b/distributed/dashboard/components/shared.py @@ -15,7 +15,7 @@ from bokeh.plotting import figure import dask from tornado import gen -import toolz +import tlz as toolz from distributed.dashboard.components import DashboardComponent from distributed.dashboard.utils import ( diff --git a/distributed/dashboard/components/worker.py b/distributed/dashboard/components/worker.py index 440e7279e3b..a11d3047838 100644 --- a/distributed/dashboard/components/worker.py +++ b/distributed/dashboard/components/worker.py @@ -20,7 +20,7 @@ from bokeh.palettes import RdBu from bokeh.themes import Theme from dask.utils import format_bytes -from toolz import merge, partition_all +from tlz import merge, partition_all from distributed.dashboard.components import add_periodic_callback from distributed.dashboard.components.shared import ( diff --git a/distributed/dashboard/scheduler.py b/distributed/dashboard/scheduler.py index 836cefbbd6c..acaab24cd17 100644 --- a/distributed/dashboard/scheduler.py +++ b/distributed/dashboard/scheduler.py @@ -8,10 +8,7 @@ import dask from dask.utils import format_bytes -try: - from cytoolz import merge, merge_with -except ImportError: - from toolz import merge, merge_with +from tlz import merge, merge_with from tornado import escape from tornado.websocket import WebSocketHandler diff --git a/distributed/dashboard/tests/test_scheduler_bokeh.py b/distributed/dashboard/tests/test_scheduler_bokeh.py index 4977ee8fa76..f36bfd897e1 100644 --- a/distributed/dashboard/tests/test_scheduler_bokeh.py +++ b/distributed/dashboard/tests/test_scheduler_bokeh.py @@ -7,7 +7,7 @@ import pytest pytest.importorskip("bokeh") -from toolz import first +from tlz import first from tornado import gen from tornado.httpclient import AsyncHTTPClient, HTTPRequest @@ -624,7 +624,7 @@ def test_proxy_to_workers(c, s, a, b): }, ) async def test_lots_of_tasks(c, s, a, b): - import toolz + import tlz as toolz ts = TaskStream(s) ts.update() diff --git a/distributed/dashboard/tests/test_worker_bokeh.py b/distributed/dashboard/tests/test_worker_bokeh.py index b33fc3ba185..97729fce14f 100644 --- a/distributed/dashboard/tests/test_worker_bokeh.py +++ b/distributed/dashboard/tests/test_worker_bokeh.py @@ -6,7 +6,7 @@ pytest.importorskip("bokeh") import sys -from toolz import first +from tlz import first from tornado import gen from tornado.httpclient import AsyncHTTPClient diff --git a/distributed/dashboard/utils.py b/distributed/dashboard/utils.py index b47cb75d6b0..394e016a4da 100644 --- a/distributed/dashboard/utils.py +++ b/distributed/dashboard/utils.py @@ -5,7 +5,8 @@ import bokeh from bokeh.io import curdoc from tornado import web -from toolz import partition +from tlz import partition +from tlz.curried import first try: import numpy as np @@ -13,11 +14,6 @@ np = False -try: - from cytoolz.curried import first -except ImportError: - from toolz.curried import first - BOKEH_VERSION = LooseVersion(bokeh.__version__) dirname = os.path.dirname(__file__) diff --git a/distributed/dashboard/worker.py b/distributed/dashboard/worker.py index db29480666b..54b3a0a4a51 100644 --- a/distributed/dashboard/worker.py +++ b/distributed/dashboard/worker.py @@ -3,7 +3,7 @@ import os from bokeh.themes import Theme -from toolz import merge +from tlz import merge from .components.worker import ( status_doc, diff --git a/distributed/deploy/adaptive_core.py b/distributed/deploy/adaptive_core.py index dfd82ea33ba..192e244bd08 100644 --- a/distributed/deploy/adaptive_core.py +++ b/distributed/deploy/adaptive_core.py @@ -2,7 +2,7 @@ import math from tornado.ioloop import IOLoop -import toolz +import tlz as toolz from ..metrics import time from ..utils import parse_timedelta, PeriodicCallback diff --git a/distributed/deploy/old_ssh.py b/distributed/deploy/old_ssh.py index b524e2d7c45..33e69772f9b 100644 --- a/distributed/deploy/old_ssh.py +++ b/distributed/deploy/old_ssh.py @@ -12,7 +12,7 @@ from threading import Thread -from toolz import merge +from tlz import merge from tornado import gen diff --git a/distributed/deploy/tests/test_spec_cluster.py b/distributed/deploy/tests/test_spec_cluster.py index 68642cda9d2..90ce9923c69 100644 --- a/distributed/deploy/tests/test_spec_cluster.py +++ b/distributed/deploy/tests/test_spec_cluster.py @@ -8,7 +8,7 @@ from distributed.metrics import time from distributed.utils_test import loop, cleanup # noqa: F401 from distributed.utils import is_valid_xml -import toolz +import tlz as toolz import pytest diff --git a/distributed/diagnostics/progress.py b/distributed/diagnostics/progress.py index 1dcab0dc9e9..2aeba986839 100644 --- a/distributed/diagnostics/progress.py +++ b/distributed/diagnostics/progress.py @@ -3,7 +3,7 @@ import logging from timeit import default_timer -from toolz import groupby, valmap +from tlz import groupby, valmap from .plugin import SchedulerPlugin from ..utils import key_split, key_split_group, log_errors, tokey diff --git a/distributed/diagnostics/progress_stream.py b/distributed/diagnostics/progress_stream.py index e417ee8e35b..c5e74a30f34 100644 --- a/distributed/diagnostics/progress_stream.py +++ b/distributed/diagnostics/progress_stream.py @@ -1,6 +1,6 @@ import logging -from toolz import valmap, merge +from tlz import valmap, merge from .progress import AllProgress diff --git a/distributed/diagnostics/progressbar.py b/distributed/diagnostics/progressbar.py index ab7800c2125..11da7a30d3d 100644 --- a/distributed/diagnostics/progressbar.py +++ b/distributed/diagnostics/progressbar.py @@ -4,7 +4,7 @@ import sys import weakref -from toolz import valmap +from tlz import valmap from tornado.ioloop import IOLoop from .progress import format_time, Progress, MultiProgress diff --git a/distributed/diagnostics/tests/test_task_stream.py b/distributed/diagnostics/tests/test_task_stream.py index 58f1c4319f6..4639c7a7a0b 100644 --- a/distributed/diagnostics/tests/test_task_stream.py +++ b/distributed/diagnostics/tests/test_task_stream.py @@ -2,7 +2,7 @@ from time import sleep import pytest -from toolz import frequencies +from tlz import frequencies from distributed import get_task_stream from distributed.utils_test import gen_cluster, div, inc, slowinc diff --git a/distributed/diagnostics/tests/test_widgets.py b/distributed/diagnostics/tests/test_widgets.py index 03689c88c1d..c217d17e293 100644 --- a/distributed/diagnostics/tests/test_widgets.py +++ b/distributed/diagnostics/tests/test_widgets.py @@ -74,7 +74,7 @@ def record_display(*args): from operator import add import re -from toolz import valmap +from tlz import valmap from distributed.client import wait from distributed.worker import dumps_task diff --git a/distributed/profile.py b/distributed/profile.py index 1bef6450974..5bf071e20da 100644 --- a/distributed/profile.py +++ b/distributed/profile.py @@ -31,7 +31,7 @@ import threading from time import sleep -import toolz +import tlz as toolz from .metrics import time from .utils import format_time, color_of, parse_timedelta diff --git a/distributed/protocol/compression.py b/distributed/protocol/compression.py index 5e81cdbaf1f..adb3c888be6 100644 --- a/distributed/protocol/compression.py +++ b/distributed/protocol/compression.py @@ -3,11 +3,12 @@ Includes utilities for determining whether or not to compress """ +from functools import partial import logging import random import dask -from toolz import identity, partial +from tlz import identity try: import blosc diff --git a/distributed/protocol/core.py b/distributed/protocol/core.py index 3937c9c2fc8..3bb863f78c2 100644 --- a/distributed/protocol/core.py +++ b/distributed/protocol/core.py @@ -1,13 +1,9 @@ +from functools import reduce import logging import operator import msgpack -try: - from cytoolz import reduce -except ImportError: - from toolz import reduce - from .compression import compressions, maybe_compress, decompress from .serialize import serialize, deserialize, Serialize, Serialized, extract_serialize from .utils import frame_split_size, merge_frames, msgpack_opts diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index c462568cc40..c0fdb98449a 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -4,10 +4,7 @@ import dask from dask.base import normalize_token -try: - from cytoolz import valmap, get_in -except ImportError: - from toolz import valmap, get_in +from tlz import valmap, get_in import msgpack diff --git a/distributed/protocol/tests/test_serialize.py b/distributed/protocol/tests/test_serialize.py index caf1bbe0ad5..10e4c5e797d 100644 --- a/distributed/protocol/tests/test_serialize.py +++ b/distributed/protocol/tests/test_serialize.py @@ -4,7 +4,7 @@ import msgpack import numpy as np import pytest -from toolz import identity +from tlz import identity from distributed import wait from distributed.protocol import ( diff --git a/distributed/scheduler.py b/distributed/scheduler.py index ab026f61d06..d543808340d 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -19,11 +19,18 @@ import psutil import sortedcontainers -try: - from cytoolz import frequencies, merge, pluck, merge_sorted, first, merge_with -except ImportError: - from toolz import frequencies, merge, pluck, merge_sorted, first, merge_with -from toolz import valmap, second, compose, groupby +from tlz import ( + frequencies, + merge, + pluck, + merge_sorted, + first, + merge_with, + valmap, + second, + compose, + groupby, +) from tornado.ioloop import IOLoop import dask diff --git a/distributed/stealing.py b/distributed/stealing.py index 4fbb753e131..fcceba4824a 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -9,10 +9,7 @@ from .diagnostics.plugin import SchedulerPlugin from .utils import log_errors, parse_timedelta, PeriodicCallback -try: - from cytoolz import topk -except ImportError: - from toolz import topk +from tlz import topk LATENCY = 10e-3 log_2 = log(2) diff --git a/distributed/tests/test_batched.py b/distributed/tests/test_batched.py index 74efba810d3..f2b0be99ab0 100644 --- a/distributed/tests/test_batched.py +++ b/distributed/tests/test_batched.py @@ -2,7 +2,7 @@ import random import pytest -from toolz import assoc +from tlz import assoc from distributed.batched import BatchedSend from distributed.core import listen, connect, CommClosedError diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 392aec73be8..5c853916558 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -1,5 +1,6 @@ import asyncio from collections import deque +from functools import partial import gc import logging from operator import add @@ -18,7 +19,7 @@ import zipfile import pytest -from toolz import identity, isdistinct, concat, pluck, valmap, partial, first, merge +from tlz import identity, isdistinct, concat, pluck, valmap, first, merge from tornado import gen import dask diff --git a/distributed/tests/test_client_executor.py b/distributed/tests/test_client_executor.py index 1024990216d..e7e3fc24c7d 100644 --- a/distributed/tests/test_client_executor.py +++ b/distributed/tests/test_client_executor.py @@ -11,7 +11,7 @@ ) import pytest -from toolz import take +from tlz import take from distributed import Client from distributed.utils import CancelledError diff --git a/distributed/tests/test_failed_workers.py b/distributed/tests/test_failed_workers.py index cf0387c1cd2..99b1b4a42a7 100644 --- a/distributed/tests/test_failed_workers.py +++ b/distributed/tests/test_failed_workers.py @@ -3,7 +3,7 @@ from time import sleep import pytest -from toolz import partition_all, first +from tlz import partition_all, first from tornado import gen from dask import delayed diff --git a/distributed/tests/test_ipython.py b/distributed/tests/test_ipython.py index aa4a3e4092e..a6d387589e6 100644 --- a/distributed/tests/test_ipython.py +++ b/distributed/tests/test_ipython.py @@ -1,7 +1,7 @@ from unittest import mock import pytest -from toolz import first +from tlz import first import tornado from distributed import Client diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index c80974d9970..2a19bdf8742 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -9,7 +9,7 @@ import numpy as np import pytest -from toolz import valmap, first +from tlz import valmap, first from tornado import gen from tornado.ioloop import IOLoop diff --git a/distributed/tests/test_profile.py b/distributed/tests/test_profile.py index a022600d819..9f673e8caaf 100644 --- a/distributed/tests/test_profile.py +++ b/distributed/tests/test_profile.py @@ -1,7 +1,7 @@ import pytest import sys import time -from toolz import first +from tlz import first import threading from distributed.compatibility import WINDOWS diff --git a/distributed/tests/test_pubsub.py b/distributed/tests/test_pubsub.py index 2e372dea88b..639542df5ca 100644 --- a/distributed/tests/test_pubsub.py +++ b/distributed/tests/test_pubsub.py @@ -3,7 +3,7 @@ import pytest from tornado import gen -import toolz +import tlz as toolz from distributed import Pub, Sub, wait, get_worker, TimeoutError from distributed.utils_test import gen_cluster diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index a5649dbfc82..5459716ca85 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -12,7 +12,7 @@ import dask from dask import delayed -from toolz import merge, concat, valmap, first, frequencies +from tlz import merge, concat, valmap, first, frequencies from tornado import gen import pytest diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 71f408749a1..5b13d9157e8 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -6,7 +6,7 @@ import weakref import pytest -from toolz import sliding_window, concat +from tlz import sliding_window, concat from tornado import gen import dask diff --git a/distributed/tests/test_stress.py b/distributed/tests/test_stress.py index ab996e2b30d..d5e1e62c574 100644 --- a/distributed/tests/test_stress.py +++ b/distributed/tests/test_stress.py @@ -6,7 +6,7 @@ from dask import delayed import pytest -from toolz import concat, sliding_window +from tlz import concat, sliding_window from distributed import Client, wait, Nanny from distributed.config import config diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index b6da294c749..0bda344fd96 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -16,7 +16,7 @@ from dask.utils import format_bytes from dask.system import CPU_COUNT import pytest -from toolz import pluck, sliding_window, first +from tlz import pluck, sliding_window, first import tornado from tornado import gen diff --git a/distributed/utils.py b/distributed/utils.py index 429a53cddde..eb622f7b837 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -47,7 +47,7 @@ parse_timedelta, ) -import toolz +import tlz as toolz import tornado from tornado import gen from tornado.ioloop import IOLoop diff --git a/distributed/utils_comm.py b/distributed/utils_comm.py index 3d10ba51038..42404754527 100644 --- a/distributed/utils_comm.py +++ b/distributed/utils_comm.py @@ -8,7 +8,7 @@ from dask.optimization import SubgraphCallable import dask.config from dask.utils import parse_timedelta -from toolz import merge, concat, groupby, drop +from tlz import merge, concat, groupby, drop from .core import rpc from .utils import All, tokey diff --git a/distributed/utils_test.py b/distributed/utils_test.py index e16983b1879..741fc76a8dd 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -32,7 +32,7 @@ import pytest import dask -from toolz import merge, memoize, assoc +from tlz import merge, memoize, assoc from tornado import gen, queues from tornado.ioloop import IOLoop diff --git a/distributed/variable.py b/distributed/variable.py index fc4cc396dab..3c6cc931166 100644 --- a/distributed/variable.py +++ b/distributed/variable.py @@ -3,10 +3,7 @@ import logging import uuid -try: - from cytoolz import merge -except ImportError: - from toolz import merge +from tlz import merge from .client import Future, _get_global_client, Client from .utils import tokey, log_errors, TimeoutError, ignoring diff --git a/distributed/worker.py b/distributed/worker.py index aa71a16640b..247ffc99510 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -3,6 +3,7 @@ from collections import defaultdict, deque, namedtuple from collections.abc import MutableMapping from datetime import timedelta +from functools import partial import heapq from inspect import isawaitable import logging @@ -21,10 +22,7 @@ from dask.utils import format_bytes, funcname from dask.system import CPU_COUNT -try: - from cytoolz import pluck, partial, merge, first, keymap -except ImportError: - from toolz import pluck, partial, merge, first, keymap +from tlz import pluck, merge, first, keymap from tornado import gen from tornado.ioloop import IOLoop diff --git a/docs/source/efficiency.rst b/docs/source/efficiency.rst index 94a603ea9a3..ed3ad2428d5 100644 --- a/docs/source/efficiency.rst +++ b/docs/source/efficiency.rst @@ -67,7 +67,7 @@ A common solution is to batch your input into larger chunks. >>> def f_many(chunk): ... return [f(x) for x in chunk] - >>> from toolz import partition_all + >>> from tlz import partition_all >>> chunks = partition_all(1000000, seq) # Collect into groups of size 1000 >>> futures = client.map(f_many, chunks) diff --git a/requirements.txt b/requirements.txt index 3f827e250e6..4cb3ba60ae7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ msgpack >= 0.6.0 psutil >= 5.0 sortedcontainers !=2.0.0, !=2.0.1 tblib >= 1.6.0 -toolz >= 0.7.4 +toolz >= 0.8.2 tornado >= 5;python_version<'3.8' tornado >= 6.0.3;python_version>='3.8' zict >= 0.1.3 From 511427b81f599105dc1eb4d2f35fd33aa249b4b8 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Tue, 17 Mar 2020 15:47:37 -0500 Subject: [PATCH 0724/1550] Add Python version to version check (#3567) * Add Python to version checks * Use dict for system info --- distributed/tests/test_versions.py | 17 +++++++++++++++++ distributed/versions.py | 26 ++++++++++++++------------ 2 files changed, 31 insertions(+), 12 deletions(-) diff --git a/distributed/tests/test_versions.py b/distributed/tests/test_versions.py index 25087df795a..ab3547820ca 100644 --- a/distributed/tests/test_versions.py +++ b/distributed/tests/test_versions.py @@ -1,4 +1,5 @@ import re +import uuid import pytest @@ -117,3 +118,19 @@ async def test_version_warning_in_cluster(s, a, b): assert any( "0.0.0" in line.message and a.address in line.message for line in w.logs ) + + +@gen_cluster() +async def test_python_version_mismatch_warning(s, a, b): + # Set random Python version for one worker + random_version = uuid.uuid4().hex + orig = s.workers[a.address].versions["host"]["python"] = random_version + + with pytest.warns(None) as record: + async with Client(s.address, asynchronous=True) as client: + pass + + assert record + assert any("python" in str(r.message) for r in record) + assert any(random_version in str(r.message) for r in record) + assert any(a.address in str(r.message) for r in record) diff --git a/distributed/versions.py b/distributed/versions.py index a7022c830f7..403d79f6aef 100644 --- a/distributed/versions.py +++ b/distributed/versions.py @@ -51,17 +51,17 @@ def get_versions(packages=None): def get_system_info(): (sysname, nodename, release, version, machine, processor) = platform.uname() - host = [ - ("python", "%d.%d.%d.%s.%s" % sys.version_info[:]), - ("python-bits", struct.calcsize("P") * 8), - ("OS", "%s" % sysname), - ("OS-release", "%s" % release), - ("machine", "%s" % machine), - ("processor", "%s" % processor), - ("byteorder", "%s" % sys.byteorder), - ("LC_ALL", "%s" % os.environ.get("LC_ALL", "None")), - ("LANG", "%s" % os.environ.get("LANG", "None")), - ] + host = { + "python": "%d.%d.%d.%s.%s" % sys.version_info[:], + "python-bits": struct.calcsize("P") * 8, + "OS": "%s" % sysname, + "OS-release": "%s" % release, + "machine": "%s" % machine, + "processor": "%s" % processor, + "byteorder": "%s" % sys.byteorder, + "LC_ALL": "%s" % os.environ.get("LC_ALL", "None"), + "LANG": "%s" % os.environ.get("LANG", "None"), + } return host @@ -113,7 +113,6 @@ def error_message(scheduler, workers, client, client_name="client"): # Collect all package versions packages = set() - for node, info in nodes.items(): if info is None or not (isinstance(info, dict)) or "packages" not in info: node_packages[node] = defaultdict(lambda: "UNKNOWN") @@ -122,6 +121,9 @@ def error_message(scheduler, workers, client, client_name="client"): for pkg, version in info["packages"].items(): node_packages[node][pkg] = version packages.add(pkg) + # Collect Python version for each node + node_packages[node]["python"] = info["host"]["python"] + packages.add("python") errs = [] for pkg in sorted(packages): From 2acffc3172ec32e173547ee4c39a01b6c94e74a1 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 18 Mar 2020 18:04:05 -0700 Subject: [PATCH 0725/1550] Optionally compress on a frame-by-frame basis (#3586) Previously this converted a list of bytes-like objects into a list. Now we consume a single one and use map when dealing with lists. * Handle compression on a frame-by-frame basis * Set cuda serialization to False rather than None We've changed the convention so that None now means "proceed as usual" rather than "don't do anything please" --- distributed/protocol/core.py | 28 +++++++++----- distributed/protocol/cuda.py | 2 +- distributed/protocol/numpy.py | 2 +- distributed/protocol/serialize.py | 6 ++- distributed/protocol/tests/test_serialize.py | 14 +++++++ distributed/protocol/utils.py | 39 +++++++++----------- 6 files changed, 57 insertions(+), 34 deletions(-) diff --git a/distributed/protocol/core.py b/distributed/protocol/core.py index 3bb863f78c2..0947b3a6292 100644 --- a/distributed/protocol/core.py +++ b/distributed/protocol/core.py @@ -48,17 +48,27 @@ def dumps(msg, serializers=None, on_error="message", context=None): for key, (head, frames) in data.items(): if "lengths" not in head: head["lengths"] = tuple(map(nbytes, frames)) - if "compression" not in head: - frames = frame_split_size(frames) - if frames: - compression, frames = zip(*map(maybe_compress, frames)) - else: - compression = [] - head["compression"] = compression - head["count"] = len(frames) + + # Compress frames that are not yet compressed + out_compression = [] + _out_frames = [] + for frame, compression in zip( + frames, head.get("compression") or [None] * len(frames) + ): + if compression is None: # default behavior + _frames = frame_split_size(frame) + _compression, _frames = zip(*map(maybe_compress, _frames)) + out_compression.extend(_compression) + _out_frames.extend(_frames) + else: # already specified, so pass + out_compression.append(compression) + _out_frames.append(frame) + + head["compression"] = out_compression + head["count"] = len(_out_frames) header["headers"][key] = head header["keys"].append(key) - out_frames.extend(frames) + out_frames.extend(_out_frames) for key, (head, frames) in pre.items(): if "lengths" not in head: diff --git a/distributed/protocol/cuda.py b/distributed/protocol/cuda.py index aa638f70c0d..44ed6a033df 100644 --- a/distributed/protocol/cuda.py +++ b/distributed/protocol/cuda.py @@ -18,7 +18,7 @@ def cuda_dumps(x): header, frames = dumps(x) header["type-serialized"] = pickle.dumps(type(x)) header["serializer"] = "cuda" - header["compression"] = (None,) * len(frames) # no compression for gpu data + header["compression"] = (False,) * len(frames) # no compression for gpu data return header, frames diff --git a/distributed/protocol/numpy.py b/distributed/protocol/numpy.py index a2c9c2933e6..a0df77c8b37 100644 --- a/distributed/protocol/numpy.py +++ b/distributed/protocol/numpy.py @@ -88,7 +88,7 @@ def serialize_numpy_ndarray(x): header["broadcast_to"] = broadcast_to if x.nbytes > 1e5: - frames = frame_split_size([data]) + frames = frame_split_size(data) else: frames = [data] diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index c0fdb98449a..6db7ca70c13 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -166,10 +166,12 @@ def serialize(x, serializers=None, on_error="message", context=None): frames = [] lengths = [] + compressions = [] for _header, _frames in headers_frames: frames.extend(_frames) length = len(_frames) lengths.append(length) + compressions.extend(_header.get("compression") or [None] * len(_frames)) headers = [obj[0] for obj in headers_frames] headers = { @@ -178,6 +180,8 @@ def serialize(x, serializers=None, on_error="message", context=None): "frame-lengths": lengths, "type-serialized": type(x).__name__, } + if any(compression is not None for compression in compressions): + headers["compression"] = compressions return headers, frames tb = "" @@ -436,7 +440,7 @@ def replace_inner(x): def serialize_bytelist(x, **kwargs): header, frames = serialize(x, **kwargs) - frames = frame_split_size(frames) + frames = sum(map(frame_split_size, frames), []) if frames: compression, frames = zip(*map(maybe_compress, frames)) else: diff --git a/distributed/protocol/tests/test_serialize.py b/distributed/protocol/tests/test_serialize.py index 10e4c5e797d..41e2af51b70 100644 --- a/distributed/protocol/tests/test_serialize.py +++ b/distributed/protocol/tests/test_serialize.py @@ -374,3 +374,17 @@ async def test_profile_nested_sizeof(): msg = {"data": original} frames = await to_frames(msg) + + +def test_compression_numpy_list(): + class MyObj: + pass + + @dask_serialize.register(MyObj) + def _(x): + header = {"compression": [False]} + frames = [b""] + return header, frames + + header, frames = serialize([MyObj(), MyObj()]) + assert header["compression"] == [False, False] diff --git a/distributed/protocol/utils.py b/distributed/protocol/utils.py index e5b9247e77f..3af203881ff 100644 --- a/distributed/protocol/utils.py +++ b/distributed/protocol/utils.py @@ -19,9 +19,9 @@ msgpack_opts["encoding"] = "utf-8" -def frame_split_size(frames, n=BIG_BYTES_SHARD_SIZE): +def frame_split_size(frame, n=BIG_BYTES_SHARD_SIZE) -> list: """ - Split a list of frames into a list of frames of maximum size + Split a frame into a list of frames of maximum size This helps us to avoid passing around very large bytestrings. @@ -30,26 +30,21 @@ def frame_split_size(frames, n=BIG_BYTES_SHARD_SIZE): >>> frame_split_size([b'12345', b'678'], n=3) # doctest: +SKIP [b'123', b'45', b'678'] """ - if not frames: - return frames - - if max(map(nbytes, frames)) <= n: - return frames - - out = [] - for frame in frames: - if nbytes(frame) > n: - if isinstance(frame, (bytes, bytearray)): - frame = memoryview(frame) - try: - itemsize = frame.itemsize - except AttributeError: - itemsize = 1 - for i in range(0, nbytes(frame) // itemsize, n // itemsize): - out.append(frame[i : i + n // itemsize]) - else: - out.append(frame) - return out + if nbytes(frame) <= n: + return [frame] + + if nbytes(frame) > n: + if isinstance(frame, (bytes, bytearray)): + frame = memoryview(frame) + try: + itemsize = frame.itemsize + except AttributeError: + itemsize = 1 + + return [ + frame[i : i + n // itemsize] + for i in range(0, nbytes(frame) // itemsize, n // itemsize) + ] def merge_frames(header, frames): From 3c1bfa8838e385daf7e86910283969699d169d40 Mon Sep 17 00:00:00 2001 From: Julia Signell Date: Thu, 19 Mar 2020 18:37:05 -0400 Subject: [PATCH 0726/1550] Change Adaptive docs to reference adaptive_target (#3597) --- distributed/deploy/adaptive.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/distributed/deploy/adaptive.py b/distributed/deploy/adaptive.py index 0d295200018..d16f577168a 100644 --- a/distributed/deploy/adaptive.py +++ b/distributed/deploy/adaptive.py @@ -66,11 +66,11 @@ class Adaptive(AdaptiveCore): Notes ----- - Subclasses can override :meth:`Adaptive.should_scale_up` and + Subclasses can override :meth:`Adaptive.target` and :meth:`Adaptive.workers_to_close` to control when the cluster should be resized. The default implementation checks if there are too many tasks - per worker or too little memory available (see :meth:`Adaptive.needs_cpu` - and :meth:`Adaptive.needs_memory`). + per worker or too little memory available (see + :meth:`Scheduler.adaptive_target`). ''' def __init__( @@ -110,6 +110,22 @@ def observed(self): return self.cluster.observed async def target(self): + """ + Determine target number of workers that should exist. + + Notes + ----- + ``Adaptive.target`` dispatches to Scheduler.adaptive_target(), + but may be overridden in subclasses. + + Returns + ------- + Target number of workers + + See Also + -------- + Scheduler.adaptive_target + """ return await self.scheduler.adaptive_target( target_duration=self.target_duration ) From 700fa17913e3f7ad8906a8468f6e9cb680746b9f Mon Sep 17 00:00:00 2001 From: Gabriel Sailer Date: Fri, 20 Mar 2020 02:15:05 +0100 Subject: [PATCH 0727/1550] Add configuration for Adaptive arguments (#3509) --- distributed/deploy/adaptive.py | 26 +++++++++++++++++------ distributed/deploy/tests/test_adaptive.py | 12 +++++++++++ distributed/distributed.yaml | 8 ++++++- distributed/scheduler.py | 4 +++- 4 files changed, 42 insertions(+), 8 deletions(-) diff --git a/distributed/deploy/adaptive.py b/distributed/deploy/adaptive.py index d16f577168a..1c53155de15 100644 --- a/distributed/deploy/adaptive.py +++ b/distributed/deploy/adaptive.py @@ -1,6 +1,6 @@ from inspect import isawaitable import logging -import math +import dask.config from .adaptive_core import AdaptiveCore from ..utils import log_errors, parse_timedelta @@ -71,22 +71,36 @@ class Adaptive(AdaptiveCore): resized. The default implementation checks if there are too many tasks per worker or too little memory available (see :meth:`Scheduler.adaptive_target`). + The values for interval, min, max, wait_count and target_duration can be + specified in the dask config under the distributed.adaptive key. ''' def __init__( self, cluster=None, - interval="1s", - minimum=0, - maximum=math.inf, - wait_count=3, - target_duration="5s", + interval=None, + minimum=None, + maximum=None, + wait_count=None, + target_duration=None, worker_key=None, **kwargs ): self.cluster = cluster self.worker_key = worker_key self._workers_to_close_kwargs = kwargs + + if interval is None: + interval = dask.config.get("distributed.adaptive.interval") + if minimum is None: + minimum = dask.config.get("distributed.adaptive.minimum") + if maximum is None: + maximum = dask.config.get("distributed.adaptive.maximum") + if wait_count is None: + wait_count = dask.config.get("distributed.adaptive.wait-count") + if target_duration is None: + target_duration = dask.config.get("distributed.adaptive.target-duration") + self.target_duration = parse_timedelta(target_duration) super().__init__( diff --git a/distributed/deploy/tests/test_adaptive.py b/distributed/deploy/tests/test_adaptive.py index 2eddeeceff8..9c68e6ddf53 100644 --- a/distributed/deploy/tests/test_adaptive.py +++ b/distributed/deploy/tests/test_adaptive.py @@ -1,3 +1,4 @@ +import math from time import sleep import dask @@ -415,3 +416,14 @@ async def test_adapt_cores_memory(cleanup): ) assert adapt.minimum == 3 assert adapt.maximum == 5 + + +def test_adaptive_config(): + with dask.config.set( + {"distributed.adaptive.minimum": 10, "distributed.adaptive.wait-count": 8} + ): + adapt = Adaptive(interval="5s") + assert adapt.minimum == 10 + assert adapt.maximum == math.inf + assert adapt.interval == 5 + assert adapt.wait_count == 8 diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index 05f27604328..311eeaae829 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -74,6 +74,13 @@ distributed: deploy: lost-worker-timeout: 15s # Interval after which to hard-close a lost worker job + adaptive: + interval: 1s # Interval between scaling evaluations + target-duration: 5s # Time an entire graph calculation is desired to take ("1m", "30m") + minimum: 0 # Minimum number of workers + maximum: .inf # Maximum number of workers + wait-count: 3 # Number of times a worker should be suggested for removal before removing it + comm: retry: # some operations (such as gathering data) are subject to re-tries with the below parameters count: 0 # the maximum retry attempts. 0 disables re-trying. @@ -141,4 +148,3 @@ ucx: infiniband: null # enable Infiniband cuda_copy: null # enable cuda-copy net-devices: null # define which Infiniband device to use - diff --git a/distributed/scheduler.py b/distributed/scheduler.py index d543808340d..8a61ba31fca 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -5206,7 +5206,7 @@ def check_idle(self): if close: self.loop.add_callback(self.close) - def adaptive_target(self, comm=None, target_duration="5s"): + def adaptive_target(self, comm=None, target_duration=None): """ Desired number of workers based on the current workload This looks at the current running tasks and memory use, and returns a @@ -5222,6 +5222,8 @@ def adaptive_target(self, comm=None, target_duration="5s"): -------- distributed.deploy.Adaptive """ + if target_duration is None: + target_duration = dask.config.get("distributed.adaptive.target-duration") target_duration = parse_timedelta(target_duration) # CPU From 0d64f3a3c2f72543420b6f2967e8e789ad265a27 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Fri, 20 Mar 2020 03:10:06 +0100 Subject: [PATCH 0728/1550] Synchronize default CUDA stream before UCX send/recv (#3598) * Synchronize default CUDA stream before UCX send/recv * Add more clarity on UCX.write comment Co-Authored-By: Mark Harris * Add more clarity on UCX.read comment Co-Authored-By: Mark Harris Co-authored-by: Mark Harris --- distributed/comm/ucx.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index 7295b11bb48..04eecdf4482 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -35,6 +35,15 @@ cuda_array = None +def synchronize_stream(stream=0): + import numba.cuda + + ctx = numba.cuda.current_context() + cu_stream = numba.cuda.driver.drvapi.cu_stream(stream) + stream = numba.cuda.driver.Stream(ctx, cu_stream, None) + stream.synchronize() + + def init_once(): global ucp, cuda_array if ucp is not None: @@ -160,6 +169,14 @@ async def write( np.array([nbytes(f) for f in frames], dtype=np.uint64) ) # Send frames + + # It is necessary to first synchronize the default stream before start sending + # We synchronize the default stream because UCX is not stream-ordered and + # syncing the default stream will wait for other non-blocking CUDA streams. + # Note this is only sufficient if the memory being sent is not currently in use on + # non-blocking CUDA streams. + synchronize_stream(0) + for frame in frames: if nbytes(frame) > 0: await self.ep.send(frame) @@ -196,13 +213,20 @@ async def read(self, deserializers=("cuda", "dask", "pickle", "error")): frame = cuda_array(size) else: frame = np.empty(size, dtype=np.uint8) - await self.ep.recv(frame) frames.append(frame) else: if is_cuda: frames.append(cuda_array(size)) else: frames.append(b"") + + # It is necessary to first populate `frames` with CUDA arrays and synchronize + # the default stream before starting receiving to ensure buffers have been allocated + synchronize_stream(0) + for i, (is_cuda, size) in enumerate(zip(is_cudas.tolist(), sizes.tolist())): + if size > 0: + await self.ep.recv(frames[i]) + msg = await from_frames( frames, deserialize=self.deserialize, deserializers=deserializers ) From e928cc0090015a7e07c18d2d64255d33a849224a Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Fri, 20 Mar 2020 10:36:24 -0500 Subject: [PATCH 0729/1550] Fix linting errors (#3604) --- distributed/comm/ucx.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index 04eecdf4482..22d7e361e97 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -174,7 +174,7 @@ async def write( # We synchronize the default stream because UCX is not stream-ordered and # syncing the default stream will wait for other non-blocking CUDA streams. # Note this is only sufficient if the memory being sent is not currently in use on - # non-blocking CUDA streams. + # non-blocking CUDA streams. synchronize_stream(0) for frame in frames: @@ -223,7 +223,9 @@ async def read(self, deserializers=("cuda", "dask", "pickle", "error")): # It is necessary to first populate `frames` with CUDA arrays and synchronize # the default stream before starting receiving to ensure buffers have been allocated synchronize_stream(0) - for i, (is_cuda, size) in enumerate(zip(is_cudas.tolist(), sizes.tolist())): + for i, (is_cuda, size) in enumerate( + zip(is_cudas.tolist(), sizes.tolist()) + ): if size > 0: await self.ep.recv(frames[i]) From 6adf0696601fc575978e4f1c3ea6805cf323283a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Est=C3=A8ve?= Date: Fri, 20 Mar 2020 19:32:16 +0100 Subject: [PATCH 0730/1550] Remove dill from CI environments. (#3608) --- continuous_integration/environment.yml | 1 - continuous_integration/travis/install.sh | 1 - 2 files changed, 2 deletions(-) diff --git a/continuous_integration/environment.yml b/continuous_integration/environment.yml index 8218d721e85..7458c9f64ac 100644 --- a/continuous_integration/environment.yml +++ b/continuous_integration/environment.yml @@ -7,7 +7,6 @@ dependencies: - click - cloudpickle - dask - - dill - lz4 - ipykernel - ipywidgets diff --git a/continuous_integration/travis/install.sh b/continuous_integration/travis/install.sh index 68b842aa033..4ee0790f6c5 100644 --- a/continuous_integration/travis/install.sh +++ b/continuous_integration/travis/install.sh @@ -33,7 +33,6 @@ conda install -c conda-forge -q \ click \ coverage \ dask \ - dill \ flake8 \ h5py \ ipykernel \ From 5f1c6bcdea47c2d8bd67cc58cfe6c29ff4fab9ef Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Fri, 20 Mar 2020 14:20:38 -0500 Subject: [PATCH 0731/1550] Replace tornado.queues with asyncio.queues (#3607) --- distributed/comm/tests/test_comms.py | 10 +++++----- distributed/utils_test.py | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/distributed/comm/tests/test_comms.py b/distributed/comm/tests/test_comms.py index 150251f3d59..16036184755 100644 --- a/distributed/comm/tests/test_comms.py +++ b/distributed/comm/tests/test_comms.py @@ -9,7 +9,7 @@ import pkg_resources import pytest -from tornado import ioloop, queues +from tornado import ioloop from tornado.concurrent import Future import distributed @@ -78,10 +78,10 @@ def check_tls_extra(info): @pytest.mark.asyncio async def get_comm_pair(listen_addr, listen_args=None, connect_args=None, **kwargs): - q = queues.Queue() + q = asyncio.Queue() - def handle_comm(comm): - q.put(comm) + async def handle_comm(comm): + await q.put(comm) listener = listen(listen_addr, handle_comm, connection_args=listen_args, **kwargs) await listener.start() @@ -883,7 +883,7 @@ async def test_inproc_many_listeners(): async def check_listener_deserialize(addr, deserialize, in_value, check_out): - q = queues.Queue() + q = asyncio.Queue() async def handle_comm(comm): msg = await comm.read() diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 741fc76a8dd..b521826647a 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -33,7 +33,7 @@ import dask from tlz import merge, memoize, assoc -from tornado import gen, queues +from tornado import gen from tornado.ioloop import IOLoop from . import system @@ -429,7 +429,7 @@ async def readone(comm): try: q = _readone_queues[comm] except KeyError: - q = _readone_queues[comm] = queues.Queue() + q = _readone_queues[comm] = asyncio.Queue() async def background_read(): while True: From 51428dc5f6f4181157e819d628d5f8dd43de3217 Mon Sep 17 00:00:00 2001 From: Jacob Tomlinson Date: Sat, 21 Mar 2020 01:02:51 +0000 Subject: [PATCH 0732/1550] Pin openssl to 1.1.1d for Travis (#3602) --- continuous_integration/travis/install.sh | 3 +++ 1 file changed, 3 insertions(+) diff --git a/continuous_integration/travis/install.sh b/continuous_integration/travis/install.sh index 4ee0790f6c5..09d13962bd4 100644 --- a/continuous_integration/travis/install.sh +++ b/continuous_integration/travis/install.sh @@ -79,6 +79,9 @@ if [[ $CRICK == true ]]; then python -m pip install -q git+https://github.com/jcrist/crick.git fi; +# Pin openssl==1.1.1d (see https://github.com/dask/distributed/issues/3588) +conda install -c conda-forge openssl==1.1.1d + # Install distributed python -m pip install --no-deps -e . From 3f03e1cf851ba1950cff6de8e17e7749cc3569d3 Mon Sep 17 00:00:00 2001 From: Scott Sievert Date: Sat, 21 Mar 2020 02:55:06 +0000 Subject: [PATCH 0733/1550] Increase number of visible mantissas in dashboard plots (#3585) --- distributed/dashboard/components/nvml.py | 4 ++-- distributed/dashboard/components/scheduler.py | 13 ++++++------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/distributed/dashboard/components/nvml.py b/distributed/dashboard/components/nvml.py index b0c56c4ef47..00346e81e66 100644 --- a/distributed/dashboard/components/nvml.py +++ b/distributed/dashboard/components/nvml.py @@ -14,7 +14,7 @@ from tornado import escape from dask.utils import format_bytes from distributed.utils import log_errors -from distributed.dashboard.components.scheduler import BOKEH_THEME +from distributed.dashboard.components.scheduler import BOKEH_THEME, TICKS_1024 from distributed.dashboard.utils import without_property_validation, update @@ -83,7 +83,7 @@ def __init__(self, scheduler, width=600, **kwargs): ) rect.nonselection_glyph = None - memory.axis[0].ticker = BasicTicker(mantissas=[1, 256, 512], base=1024) + memory.axis[0].ticker = BasicTicker(**TICKS_1024) memory.xaxis[0].formatter = NumeralTickFormatter(format="0.0 b") memory.xaxis.major_label_orientation = -math.pi / 12 memory.x_range.start = 0 diff --git a/distributed/dashboard/components/scheduler.py b/distributed/dashboard/components/scheduler.py index c371210c701..9519d3629ff 100644 --- a/distributed/dashboard/components/scheduler.py +++ b/distributed/dashboard/components/scheduler.py @@ -78,6 +78,7 @@ ) BOKEH_THEME = Theme(os.path.join(os.path.dirname(__file__), "..", "theme.yaml")) +TICKS_1024 = {"base": 1024, "mantissas": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]} nan = float("nan") inf = float("inf") @@ -233,7 +234,7 @@ def __init__(self, scheduler, **kwargs): ) self.root.xaxis[0].formatter = NumeralTickFormatter(format="0.0 b") - self.root.xaxis.ticker = AdaptiveTicker(mantissas=[1, 256, 512], base=1024) + self.root.xaxis.ticker = AdaptiveTicker(**TICKS_1024) self.root.xaxis.major_label_orientation = -math.pi / 12 self.root.xaxis.minor_tick_line_alpha = 0 @@ -296,7 +297,7 @@ def __init__(self, scheduler, **kwargs): ) fig.x_range.start = 0 fig.xaxis[0].formatter = NumeralTickFormatter(format="0.0 b") - fig.xaxis.ticker = AdaptiveTicker(mantissas=[1, 256, 512], base=1024) + fig.xaxis.ticker = AdaptiveTicker(**TICKS_1024) rect.nonselection_glyph = None fig.xaxis.minor_tick_line_alpha = 0 @@ -379,9 +380,7 @@ def __init__(self, scheduler, **kwargs): location=(0, 0), ) color_bar.formatter = NumeralTickFormatter(format="0.0 b") - color_bar.ticker = AdaptiveTicker( - mantissas=[1, 64, 128, 256, 512], base=1024 - ) + color_bar.ticker = AdaptiveTicker(**TICKS_1024) fig.add_layout(color_bar, "right") fig.toolbar.logo = None @@ -464,7 +463,7 @@ def __init__(self, scheduler, **kwargs): source=self.source, x="name", top="nbytes", width=0.9, color="color" ) fig.yaxis[0].formatter = NumeralTickFormatter(format="0.0 b") - fig.yaxis.ticker = AdaptiveTicker(mantissas=[1, 256, 512], base=1024) + fig.yaxis.ticker = AdaptiveTicker(**TICKS_1024) fig.xaxis.major_label_orientation = -math.pi / 12 rect.nonselection_glyph = None @@ -593,7 +592,7 @@ def __init__(self, scheduler, width=600, **kwargs): ) rect.nonselection_glyph = None - nbytes.axis[0].ticker = BasicTicker(mantissas=[1, 256, 512], base=1024) + nbytes.axis[0].ticker = BasicTicker(**TICKS_1024) nbytes.xaxis[0].formatter = NumeralTickFormatter(format="0.0 b") nbytes.xaxis.major_label_orientation = -math.pi / 12 nbytes.x_range.start = 0 From 2a795df67d3521dbfbf49a650238bab7266a86ee Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sat, 21 Mar 2020 13:09:07 -0700 Subject: [PATCH 0734/1550] Make Listeners awaitable (#3611) --- distributed/comm/core.py | 7 ++++++ distributed/comm/tests/test_comms.py | 34 ++++++++++++---------------- distributed/comm/tests/test_ucx.py | 3 +-- distributed/core.py | 3 +-- distributed/tests/test_batched.py | 3 +-- 5 files changed, 24 insertions(+), 26 deletions(-) diff --git a/distributed/comm/core.py b/distributed/comm/core.py index e801242bb40..b4c93644a2c 100644 --- a/distributed/comm/core.py +++ b/distributed/comm/core.py @@ -162,6 +162,13 @@ async def __aenter__(self): async def __aexit__(self, *exc): self.stop() + def __await__(self): + async def _(): + await self.start() + return self + + return _().__await__() + class Connector(ABC): @abstractmethod diff --git a/distributed/comm/tests/test_comms.py b/distributed/comm/tests/test_comms.py index 16036184755..2e5602a9ac5 100644 --- a/distributed/comm/tests/test_comms.py +++ b/distributed/comm/tests/test_comms.py @@ -83,8 +83,9 @@ async def get_comm_pair(listen_addr, listen_args=None, connect_args=None, **kwar async def handle_comm(comm): await q.put(comm) - listener = listen(listen_addr, handle_comm, connection_args=listen_args, **kwargs) - await listener.start() + listener = await listen( + listen_addr, handle_comm, connection_args=listen_args, **kwargs + ) comm = await connect( listener.contact_address, connection_args=connect_args, **kwargs @@ -221,8 +222,7 @@ async def handle_comm(comm): await comm.write(msg) await comm.close() - listener = tcp.TCPListener("localhost", handle_comm) - await listener.start() + listener = await tcp.TCPListener("localhost", handle_comm) host, port = listener.get_host_port() assert host in ("localhost", "127.0.0.1", "::1") assert port > 0 @@ -269,8 +269,7 @@ async def handle_comm(comm): server_ctx = get_server_ssl_context() client_ctx = get_client_ssl_context() - listener = tcp.TLSListener("localhost", handle_comm, ssl_context=server_ctx) - await listener.start() + listener = await tcp.TLSListener("localhost", handle_comm, ssl_context=server_ctx) host, port = listener.get_host_port() assert host in ("localhost", "127.0.0.1", "::1") assert port > 0 @@ -361,8 +360,7 @@ async def handle_comm(comm): await comm.write(msg) await comm.close() - listener = inproc.InProcListener(listener_addr, handle_comm) - await listener.start() + listener = await inproc.InProcListener(listener_addr, handle_comm) assert ( listener.listen_address == listener.contact_address @@ -468,8 +466,7 @@ async def handle_comm(comm): listen_args = listen_args or {"xxx": "bar"} connect_args = connect_args or {"xxx": "foo"} - listener = listen(addr, handle_comm, connection_args=listen_args) - await listener.start() + listener = await listen(addr, handle_comm, connection_args=listen_args) # Check listener properties bound_addr = listener.listen_address @@ -647,8 +644,9 @@ async def handle_comm(comm): await comm.close() # Listener refuses a connector not signed by the CA - listener = listen("tls://", handle_comm, connection_args={"ssl_context": serv_ctx}) - await listener.start() + listener = await listen( + "tls://", handle_comm, connection_args={"ssl_context": serv_ctx} + ) with pytest.raises(EnvironmentError) as excinfo: comm = await connect( @@ -678,10 +676,9 @@ async def handle_comm(comm): await comm.close() # Connector refuses a listener not signed by the CA - listener = listen( + listener = await listen( "tls://", handle_comm, connection_args={"ssl_context": bad_serv_ctx} ) - await listener.start() with pytest.raises(EnvironmentError) as excinfo: await connect( @@ -705,8 +702,7 @@ async def check_comm_closed_implicit( async def handle_comm(comm): await comm.close() - listener = listen(addr, handle_comm, connection_args=listen_args) - await listener.start() + listener = await listen(addr, handle_comm, connection_args=listen_args) contact_addr = listener.contact_address comm = await connect(contact_addr, connection_args=connect_args) @@ -785,8 +781,7 @@ async def handle_comm(comm): else: await comm.close() - listener = listen("inproc://", handle_comm) - await listener.start() + listener = await listen("inproc://", handle_comm) contact_addr = listener.contact_address comm = await connect(contact_addr) @@ -854,8 +849,7 @@ async def handle_comm(comm): N = 100 for i in range(N): - listener = listen(addr, handle_comm) - await listener.start() + listener = await listen(addr, handle_comm) listeners.append(listener) assert len(set(l.listen_address for l in listeners)) == N diff --git a/distributed/comm/tests/test_ucx.py b/distributed/comm/tests/test_ucx.py index ead799f8158..84da6e4f1aa 100644 --- a/distributed/comm/tests/test_ucx.py +++ b/distributed/comm/tests/test_ucx.py @@ -96,8 +96,7 @@ async def handle_comm(comm): await comm.close() assert comm.closed - listener = ucx.UCXListener(address, handle_comm) - await listener.start() + listener = await ucx.UCXListener(address, handle_comm) host, port = listener.get_host_port() assert host.count(".") == 3 assert port > 0 diff --git a/distributed/core.py b/distributed/core.py index ec1e6c5214c..dd40fa7a4d0 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -312,13 +312,12 @@ async def listen(self, port_or_addr=None, listen_args=None): else: addr = port_or_addr assert isinstance(addr, str) - listener = listen( + listener = await listen( addr, self.handle_comm, deserialize=self.deserialize, connection_args=listen_args, ) - await listener.start() self.listeners.append(listener) async def handle_comm(self, comm, shutting_down=shutting_down): diff --git a/distributed/tests/test_batched.py b/distributed/tests/test_batched.py index f2b0be99ab0..a288a25bbb9 100644 --- a/distributed/tests/test_batched.py +++ b/distributed/tests/test_batched.py @@ -25,8 +25,7 @@ async def handle_comm(self, comm): return async def listen(self): - listener = listen("", self.handle_comm) - await listener.start() + listener = await listen("", self.handle_comm) self.address = listener.contact_address self.stop = listener.stop From 7deb3b059f9ed335e2c405d0fd801c806215f9f9 Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Sun, 22 Mar 2020 20:46:05 +0100 Subject: [PATCH 0735/1550] Add backoff to comm connect attempts. (#3496) Closes #3487 --- distributed/comm/core.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/distributed/comm/core.py b/distributed/comm/core.py index b4c93644a2c..26ecdd1c54f 100644 --- a/distributed/comm/core.py +++ b/distributed/comm/core.py @@ -209,6 +209,10 @@ def _raise(error): ) raise IOError(msg) + backoff = 0.01 + if timeout and timeout / 20 < backoff: + backoff = timeout / 20 + # This starts a thread while True: try: @@ -228,8 +232,10 @@ def _raise(error): except EnvironmentError as e: error = str(e) if time() < deadline: - await asyncio.sleep(0.01) - logger.debug("sleeping on connect") + logger.debug("Could not connect, waiting before retrying") + await asyncio.sleep(backoff) + backoff *= 1.5 + backoff = min(backoff, 1) # wait at most one second else: _raise(error) else: From e1f871a3a807ef4c86f6cb8909c41cf7ec9d429e Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 23 Mar 2020 12:44:43 -0700 Subject: [PATCH 0736/1550] Add str/repr methods to as_completed (#3618) --- distributed/client.py | 5 +++++ distributed/tests/test_as_completed.py | 19 +++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/distributed/client.py b/distributed/client.py index 06c6d245c07..24f25a79a09 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -4281,6 +4281,11 @@ def count(self): with self.lock: return len(self.futures) + len(self.queue.queue) + def __repr__(self): + return "".format( + len(self.futures), len(self.queue.queue) + ) + def __iter__(self): return self diff --git a/distributed/tests/test_as_completed.py b/distributed/tests/test_as_completed.py index d0249b121d6..c9780d196cd 100644 --- a/distributed/tests/test_as_completed.py +++ b/distributed/tests/test_as_completed.py @@ -1,3 +1,4 @@ +import asyncio from collections.abc import Iterator from operator import add import queue @@ -8,6 +9,7 @@ from tornado import gen from distributed.client import _as_completed, as_completed, _first_completed +from distributed.metrics import time from distributed.utils import CancelledError from distributed.utils_test import gen_cluster, inc, throws from distributed.utils_test import client, cluster_fixture, loop # noqa: F401 @@ -232,6 +234,23 @@ def test_as_completed_with_results_no_raise(client): assert dd[z][0] == 2 +@gen_cluster(client=True) +async def test_str(c, s, a, b): + futures = c.map(inc, range(3)) + ac = as_completed(futures) + assert "waiting=3" in str(ac) + assert "waiting=3" in repr(ac) + assert "done=0" in str(ac) + assert "done=0" in repr(ac) + + await ac.__anext__() + + start = time() + while "done=2" not in str(ac): + await asyncio.sleep(0.01) + assert time() < start + 2 + + @gen_cluster(client=True) def test_as_completed_with_results_no_raise_async(c, s, a, b): x = c.submit(throws, 1) From e22f2fbf2d3fbe4a79c218d2b487bb0fc4ebc12f Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 23 Mar 2020 13:26:36 -0700 Subject: [PATCH 0737/1550] Support async Listener.stop functions (#3613) --- distributed/comm/core.py | 5 ++++- distributed/core.py | 8 +++++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/distributed/comm/core.py b/distributed/comm/core.py index 26ecdd1c54f..6ef26568853 100644 --- a/distributed/comm/core.py +++ b/distributed/comm/core.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod, abstractproperty import asyncio +import inspect import logging import weakref @@ -160,7 +161,9 @@ async def __aenter__(self): return self async def __aexit__(self, *exc): - self.stop() + future = self.stop() + if inspect.isawaitable(future): + await future def __await__(self): async def _(): diff --git a/distributed/core.py b/distributed/core.py index dd40fa7a4d0..5bff3276e99 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -1,7 +1,7 @@ import asyncio from collections import defaultdict, deque from functools import partial -from inspect import isawaitable +import inspect import logging import threading import traceback @@ -405,7 +405,7 @@ async def handle_comm(self, comm, shutting_down=shutting_down): logger.debug("Calling into handler %s", handler.__name__) try: result = handler(comm, **msg) - if isawaitable(result): + if inspect.isawaitable(result): result = asyncio.ensure_future(result) self._ongoing_coroutines.add(result) result = await result @@ -495,7 +495,9 @@ def close(self): for pc in self.periodic_callbacks.values(): pc.stop() for listener in self.listeners: - self.listener.stop() + future = self.listener.stop() + if inspect.isawaitable(future): + yield future for i in range(20): # let comms close naturally for a second if not self._comms: break From 4d4d935f46619ad9d96d64190274201ebe894eec Mon Sep 17 00:00:00 2001 From: Florian Jetter Date: Mon, 23 Mar 2020 21:46:05 +0100 Subject: [PATCH 0738/1550] Ensure that we don't steal blacklisted fast tasks (#3591) --- distributed/stealing.py | 4 +- distributed/tests/test_steal.py | 69 +++++++++++++++++++++++++-------- 2 files changed, 54 insertions(+), 19 deletions(-) diff --git a/distributed/stealing.py b/distributed/stealing.py index fcceba4824a..38524f1722c 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -84,7 +84,7 @@ def transition( if start == "processing": self.remove_key_from_stealable(ts) if finish == "memory": - for tts in self.stealable_unknown_durations.pop(ts.prefix, ()): + for tts in self.stealable_unknown_durations.pop(ts.prefix.name, ()): if tts not in self.in_flight and tts.state == "processing": self.put_key_in_stealable(tts) else: @@ -132,7 +132,7 @@ def steal_time_ratio(self, ts): nbytes = sum(dep.get_nbytes() for dep in ts.dependencies) transfer_time = nbytes / self.scheduler.bandwidth + LATENCY - split = ts.prefix + split = ts.prefix.name if split in fast_tasks: return None, None ws = ts.processing_on diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 5b13d9157e8..0ed9051cc95 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -1,32 +1,29 @@ import itertools -from operator import mul import random import sys -from time import sleep import weakref - -import pytest -from tlz import sliding_window, concat -from tornado import gen +from operator import mul +from time import sleep import dask +import pytest from distributed import Nanny, Worker, wait, worker_client from distributed.config import config from distributed.metrics import time from distributed.scheduler import key_split from distributed.system import MEMORY_LIMIT from distributed.utils_test import ( - slowinc, - slowadd, - inc, + captured_logger, gen_cluster, + inc, + nodebug_setup_module, + nodebug_teardown_module, + slowadd, slowidentity, - captured_logger, + slowinc, ) -from distributed.utils_test import nodebug_setup_module, nodebug_teardown_module - -import pytest - +from tlz import concat, sliding_window +from tornado import gen # Most tests here are timing-dependent setup_module = nodebug_setup_module @@ -145,23 +142,61 @@ def test_steal_related_tasks(e, s, a, b, c): @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 10, timeout=1000) -def test_dont_steal_fast_tasks(c, s, *workers): +async def test_dont_steal_fast_tasks_compute_time(c, s, *workers): np = pytest.importorskip("numpy") x = c.submit(np.random.random, 10000000, workers=workers[0].address) def do_nothing(x, y=None): pass - yield wait(c.submit(do_nothing, 1)) + # execute and meassure runtime once + await wait(c.submit(do_nothing, 1)) futures = c.map(do_nothing, range(1000), y=x) - yield wait(futures) + await wait(futures) assert len(s.who_has[x.key]) == 1 assert len(s.has_what[workers[0].address]) == 1001 +@gen_cluster(client=True) +async def test_dont_steal_fast_tasks_blacklist(c, s, a, b): + # create a dependency + x = c.submit(slowinc, 1, workers=[b.address]) + + # If the blacklist of fast tasks is tracked somewhere else, this needs to be + # changed. This test requies *any* key which is blacklisted. + from distributed.stealing import fast_tasks + + blacklisted_key = next(iter(fast_tasks)) + + def fast_blacklisted(x, y=None): + # The task should observe a certain computation time such that we can + # ensure that it is not stolen due to the blacklisting. If it is too + # fast, the standard mechansim shouldn't allow stealing + import time + + time.sleep(0.01) + + futures = c.map( + fast_blacklisted, + range(100), + y=x, + # Submit the task to one worker but allow it to be distributed else, + # i.e. this is not a task restriction + workers=[a.address], + allow_other_workers=True, + key=blacklisted_key, + ) + + await wait(futures) + + # The +1 is the dependency we initially submitted to worker B + assert len(s.has_what[a.address]) == 101 + assert len(s.has_what[b.address]) == 1 + + @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)], timeout=20) def test_new_worker_steals(c, s, a): yield wait(c.submit(slowinc, 1, delay=0.01)) From 706de86a255671eecbff5b03f0b4b8db456b1823 Mon Sep 17 00:00:00 2001 From: jakirkham Date: Tue, 24 Mar 2020 09:59:33 -0700 Subject: [PATCH 0739/1550] Check `nbytes` and `types` before reading `data` (#3628) * Check `nbytes` and `types` before reading `data` To avoid reading `data` when it is not needed, try checking `nbytes` and `types` beforehand. If the metadata is already there, continue on without reading `data`. Otherwise fallback to the reading `data`, but do make sure to cache the results of that read to avoid doing it again. * Use `.get(...)` with `self.nbytes` as well * Restart GitHub CI Appears GitHub CI failed to checkout the code and clicking restart in the UI does not work. So pushing a dummy commit to restart it. --- distributed/worker.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/distributed/worker.py b/distributed/worker.py index 247ffc99510..191e4df085f 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1836,13 +1836,16 @@ def ensure_communicating(self): def send_task_state_to_scheduler(self, key): if key in self.data or self.actors.get(key): - try: - value = self.data[key] - except KeyError: - value = self.actors[key] - nbytes = self.nbytes[key] or sizeof(value) - typ = self.types.get(key) or type(value) - del value + nbytes = self.nbytes.get(key) + typ = self.types.get(key) + if nbytes is None or typ is None: + try: + value = self.data[key] + except KeyError: + value = self.actors[key] + nbytes = self.nbytes[key] = sizeof(value) + typ = self.types[key] = type(value) + del value try: typ_serialized = dumps_function(typ) except PicklingError: From 0f834e7984ebcdf1333bae41dfd80fa56fd51023 Mon Sep 17 00:00:00 2001 From: Florian Jetter Date: Tue, 24 Mar 2020 18:54:05 +0100 Subject: [PATCH 0740/1550] Remove dead stealing code (#3619) --- distributed/stealing.py | 29 +++++++++-------------------- 1 file changed, 9 insertions(+), 20 deletions(-) diff --git a/distributed/stealing.py b/distributed/stealing.py index 38524f1722c..0d552d1689f 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -29,8 +29,6 @@ def __init__(self, scheduler): self.stealable = dict() # { task state: (worker, level) } self.key_stealable = dict() - # { prefix: { task states } } - self.stealable_unknown_durations = defaultdict(set) self.cost_multipliers = [1 + 2 ** (i - 6) for i in range(15)] self.cost_multipliers[0] = 1 @@ -83,11 +81,7 @@ def transition( if start == "processing": self.remove_key_from_stealable(ts) - if finish == "memory": - for tts in self.stealable_unknown_durations.pop(ts.prefix.name, ()): - if tts not in self.in_flight and tts.state == "processing": - self.put_key_in_stealable(tts) - else: + if finish != "memory": self.in_flight.pop(ts, None) def put_key_in_stealable(self, ts): @@ -136,20 +130,16 @@ def steal_time_ratio(self, ts): if split in fast_tasks: return None, None ws = ts.processing_on - if ws is None: - self.stealable_unknown_durations[split].add(ts) + compute_time = ws.processing[ts] + if compute_time < 0.005: # 5ms, just give up + return None, None + cost_multiplier = transfer_time / compute_time + if cost_multiplier > 100: return None, None - else: - compute_time = ws.processing[ts] - if compute_time < 0.005: # 5ms, just give up - return None, None - cost_multiplier = transfer_time / compute_time - if cost_multiplier > 100: - return None, None - level = int(round(log(cost_multiplier) / log_2 + 6, 0)) - level = max(1, level) - return cost_multiplier, level + level = int(round(log(cost_multiplier) / log_2 + 6, 0)) + level = max(1, level) + return cost_multiplier, level def move_task_request(self, ts, victim, thief): try: @@ -418,7 +408,6 @@ def restart(self, scheduler): for s in self.stealable_all: s.clear() self.key_stealable.clear() - self.stealable_unknown_durations.clear() def story(self, *keys): keys = set(keys) From dd28d08ca22f7ae874ba10e524ed322a15b4cacd Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Tue, 24 Mar 2020 17:55:45 -0500 Subject: [PATCH 0741/1550] Ensure Client connection pool semaphore attaches to the Client's event loop (#3546) * Add Node and ConnectionPool start methods * Make ConnectionPools awaitable --- distributed/client.py | 3 +++ distributed/core.py | 13 +++++++++++-- distributed/nanny.py | 3 +++ distributed/node.py | 7 ++++++- distributed/scheduler.py | 3 +++ distributed/tests/test_client.py | 5 +++++ distributed/tests/test_core.py | 10 +++++----- distributed/tests/test_scheduler.py | 6 +++--- distributed/tests/test_utils_comm.py | 16 ++++++++-------- distributed/worker.py | 2 ++ 10 files changed, 49 insertions(+), 19 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 24f25a79a09..4065aad17e9 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -923,6 +923,9 @@ def _send_to_scheduler(self, msg): ) async def _start(self, timeout=no_default, **kwargs): + + await super().start() + if timeout == no_default: timeout = self._timeout if timeout is not None: diff --git a/distributed/core.py b/distributed/core.py index 5bff3276e99..1bf3b172b68 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -835,8 +835,6 @@ def __init__( self.connection_args = connection_args self.timeout = timeout self._n_connecting = 0 - # Invariant: semaphore._value == limit - open - _n_connecting - self.semaphore = asyncio.Semaphore(self.limit) self.server = weakref.ref(server) if server else None self._created = weakref.WeakSet() self._instances.add(self) @@ -871,6 +869,17 @@ def __call__(self, addr=None, ip=None, port=None): addr, self, serializers=self.serializers, deserializers=self.deserializers ) + def __await__(self): + async def _(): + await self.start() + return self + + return _().__await__() + + async def start(self): + # Invariant: semaphore._value == limit - open - _n_connecting + self.semaphore = asyncio.Semaphore(self.limit) + async def connect(self, addr, timeout=None): """ Get a Comm to the given address. For internal use. diff --git a/distributed/nanny.py b/distributed/nanny.py index ec5397efb93..baa77e3ce10 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -241,6 +241,9 @@ def local_dir(self): async def start(self): """ Start nanny, start local process, start watching """ + + await super().start() + await self.listen(self._start_address, listen_args=self.listen_args) self.ip = get_address_host(self.address) diff --git a/distributed/node.py b/distributed/node.py index 4e26defeb08..af15b5a409f 100644 --- a/distributed/node.py +++ b/distributed/node.py @@ -38,6 +38,9 @@ def __init__( server=self, ) + async def start(self): + await self.rpc.start() + class ServerNode(Node, Server): """ @@ -182,5 +185,7 @@ async def wait_for(future, timeout=None): future = wait_for(future, timeout=timeout) return future.__await__() - async def start(self): # subclasses should implement this + async def start(self): + # subclasses should implement their own start method whichs calls super().start() + await Node.start(self) return self diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 8a61ba31fca..cea1f9fd136 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1408,6 +1408,9 @@ def get_worker_service_addr(self, worker, service_name, protocol=False): async def start(self): """ Clear out old state and restart all running coroutines """ + + await super().start() + enable_gc_diagnosis() self.clear_task_state() diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 5c853916558..d9633876cb2 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -5943,3 +5943,8 @@ def test_as_completed_condition_loop(c, s, a, b): seq = c.map(inc, range(5)) ac = as_completed(seq) assert ac.condition._loop == c.loop.asyncio_loop + + +def test_client_connectionpool_semaphore_loop(s, a, b): + with Client(s["address"]) as c: + assert c.rpc.semaphore._loop is c.loop.asyncio_loop diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index 0a9c48bc870..76f2b285500 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -526,7 +526,7 @@ async def ping(comm, delay=0.1): for server in servers: await server.listen(0) - rpc = ConnectionPool(limit=5) + rpc = await ConnectionPool(limit=5) # Reuse connections await asyncio.gather( @@ -583,7 +583,7 @@ async def do_ping(pool, port): for server in servers: await server.listen(0) - pool = ConnectionPool(limit=limit) + pool = await ConnectionPool(limit=limit) await asyncio.gather(*[do_ping(pool, s.port) for s in servers]) @@ -605,7 +605,7 @@ async def ping(comm, delay=0.01): for server in servers: await server.listen("tls://", listen_args=listen_args) - rpc = ConnectionPool(limit=5, connection_args=connection_args) + rpc = await ConnectionPool(limit=5, connection_args=connection_args) await asyncio.gather(*[rpc(s.address).ping() for s in servers[:5]]) await asyncio.gather(*[rpc(s.address).ping() for s in servers[::2]]) @@ -625,7 +625,7 @@ async def ping(comm, delay=0.01): for server in servers: await server.listen(0) - rpc = ConnectionPool(limit=10) + rpc = await ConnectionPool(limit=10) serv = servers.pop() await asyncio.gather(*[rpc(s.address).ping() for s in servers]) await asyncio.gather(*[rpc(serv.address).ping() for i in range(3)]) @@ -758,7 +758,7 @@ async def test_connection_pool_detects_remote_close(): await server.listen("tcp://") # open a connection, use it and give it back to the pool - p = ConnectionPool(limit=10) + p = await ConnectionPool(limit=10) conn = await p.connect(server.address) await send_recv(conn, op="ping") p.reuse(server.address, conn) diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 5459716ca85..24a40dccdab 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -1887,7 +1887,7 @@ async def test_gather_failing_cnn_recover(c, s, a, b): orig_rpc = s.rpc x = await c.scatter({"x": 1}, workers=a.address) - s.rpc = FlakyConnectionPool(failing_connections=1) + s.rpc = await FlakyConnectionPool(failing_connections=1) with mock.patch("distributed.utils_comm.retry_count", 1): res = await s.gather(keys=["x"]) assert res["status"] == "OK" @@ -1898,7 +1898,7 @@ async def test_gather_failing_cnn_error(c, s, a, b): orig_rpc = s.rpc x = await c.scatter({"x": 1}, workers=a.address) - s.rpc = FlakyConnectionPool(failing_connections=10) + s.rpc = await FlakyConnectionPool(failing_connections=10) res = await s.gather(keys=["x"]) assert res["status"] == "error" assert list(res["keys"]) == ["x"] @@ -1949,7 +1949,7 @@ def reducer(x, y): z = c.submit(reducer, x, y) - s.rpc = FlakyConnectionPool(failing_connections=4) + s.rpc = await FlakyConnectionPool(failing_connections=4) with captured_logger( logging.getLogger("distributed.scheduler") diff --git a/distributed/tests/test_utils_comm.py b/distributed/tests/test_utils_comm.py index 2d0159a2d3d..7ab793e18e4 100644 --- a/distributed/tests/test_utils_comm.py +++ b/distributed/tests/test_utils_comm.py @@ -30,11 +30,11 @@ def test_subs_multiple(): @gen_cluster(client=True) -def test_gather_from_workers_permissive(c, s, a, b): - rpc = ConnectionPool() - x = yield c.scatter({"x": 1}, workers=a.address) +async def test_gather_from_workers_permissive(c, s, a, b): + rpc = await ConnectionPool() + x = await c.scatter({"x": 1}, workers=a.address) - data, missing, bad_workers = yield gather_from_workers( + data, missing, bad_workers = await gather_from_workers( {"x": [a.address], "y": [b.address]}, rpc=rpc ) @@ -68,11 +68,11 @@ async def connect(self, *args, **kwargs): @gen_cluster(client=True) -def test_gather_from_workers_permissive_flaky(c, s, a, b): - x = yield c.scatter({"x": 1}, workers=a.address) +async def test_gather_from_workers_permissive_flaky(c, s, a, b): + x = await c.scatter({"x": 1}, workers=a.address) - rpc = BrokenConnectionPool() - data, missing, bad_workers = yield gather_from_workers({"x": [a.address]}, rpc=rpc) + rpc = await BrokenConnectionPool() + data, missing, bad_workers = await gather_from_workers({"x": [a.address]}, rpc=rpc) assert missing == {"x": [a.address]} assert bad_workers == [a.address] diff --git a/distributed/worker.py b/distributed/worker.py index 191e4df085f..ba25c91d979 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1012,6 +1012,8 @@ async def start(self): return assert self.status is None, self.status + await super().start() + enable_gc_diagnosis() thread_state.on_event_loop_thread = True From 13419fb5631df548d9fc977a548b128a8751060f Mon Sep 17 00:00:00 2001 From: jakirkham Date: Wed, 25 Mar 2020 06:47:42 -0700 Subject: [PATCH 0742/1550] WIP: Include frame lengths of CUDA objects in `header["lengths"]` (#3631) * Reuse "cuda" serialization in "dask" serialization Make sure that RMM `DeviceBuffer`s use "cuda" serialization before they are passed through "dask" serialization. * Include `"lengths"` in CUDA `header`s As going through `"dask"` serialization can result in data being split for better compression, ensure the original number of bytes in the frame is stored in `header["lengths"]`. That way on `"dask"` deserialization the frames can be merged back into their original sizes before `"cuda"` serialization is performed. --- distributed/protocol/cupy.py | 1 + distributed/protocol/numba.py | 1 + distributed/protocol/rmm.py | 5 +++-- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/distributed/protocol/cupy.py b/distributed/protocol/cupy.py index 3d074266245..b3465fee424 100644 --- a/distributed/protocol/cupy.py +++ b/distributed/protocol/cupy.py @@ -47,6 +47,7 @@ def cuda_serialize_cupy_ndarray(x): header = x.__cuda_array_interface__.copy() header["strides"] = tuple(x.strides) + header["lengths"] = [x.nbytes] frames = [ cupy.ndarray( shape=(x.nbytes,), dtype=cupy.dtype("u1"), memptr=x.data, strides=(1,) diff --git a/distributed/protocol/numba.py b/distributed/protocol/numba.py index 1070c080e61..03bf4aa9f16 100644 --- a/distributed/protocol/numba.py +++ b/distributed/protocol/numba.py @@ -21,6 +21,7 @@ def cuda_serialize_numba_ndarray(x): header = x.__cuda_array_interface__.copy() header["strides"] = tuple(x.strides) + header["lengths"] = [x.nbytes] frames = [ numba.cuda.cudadrv.devicearray.DeviceNDArray( shape=(x.nbytes,), strides=(1,), dtype=np.dtype("u1"), gpu_data=x.gpu_data, diff --git a/distributed/protocol/rmm.py b/distributed/protocol/rmm.py index ae2db0d528b..76706d49d89 100644 --- a/distributed/protocol/rmm.py +++ b/distributed/protocol/rmm.py @@ -12,6 +12,7 @@ @cuda_serialize.register(rmm.DeviceBuffer) def cuda_serialize_rmm_device_buffer(x): header = x.__cuda_array_interface__.copy() + header["lengths"] = [x.nbytes] frames = [x] return header, frames @@ -28,8 +29,8 @@ def cuda_deserialize_rmm_device_buffer(header, frames): @dask_serialize.register(rmm.DeviceBuffer) def dask_serialize_rmm_device_buffer(x): - header = x.__cuda_array_interface__.copy() - frames = [numba.cuda.as_cuda_array(x).copy_to_host().data] + header, frames = cuda_serialize_rmm_device_buffer(x) + frames = [numba.cuda.as_cuda_array(f).copy_to_host().data for f in frames] return header, frames @dask_deserialize.register(rmm.DeviceBuffer) From b0c000883eb10d4801b967fe348c4d6281ca3f1d Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 25 Mar 2020 06:50:55 -0700 Subject: [PATCH 0743/1550] Add logging message when closing idle dask scheduler (#3632) ``` $ dask-scheduler --idle-timeout "5 seconds" distributed.scheduler - INFO - ----------------------------------------------- distributed.scheduler - INFO - Local Directory: /tmp/scheduler-niju4kje distributed.scheduler - INFO - ----------------------------------------------- distributed.scheduler - INFO - Clear task state distributed.scheduler - INFO - Scheduler at: tcp://192.168.0.11:8786 distributed.scheduler - INFO - dashboard at: :8787 distributed.scheduler - INFO - Scheduler closing after being idle for 5.00 s distributed.scheduler - INFO - Scheduler closing... distributed.scheduler - INFO - Scheduler closing all comms distributed.scheduler - INFO - End scheduler at 'tcp://192.168.0.11:8786' ``` --- distributed/scheduler.py | 4 ++++ distributed/tests/test_scheduler.py | 21 +++++++++++++-------- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index cea1f9fd136..f99c26d9aba 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -5207,6 +5207,10 @@ def check_idle(self): close = time() > last_task + self.idle_timeout if close: + logger.info( + "Scheduler closing after being idle for %s", + format_time(self.idle_timeout), + ) self.loop.add_callback(self.close) def adaptive_target(self, comm=None, target_duration=None): diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 24a40dccdab..1068169b200 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -1542,15 +1542,20 @@ async def test_idle_timeout(c, s, a, b): assert s.status != "closed" - start = time() - while s.status != "closed": - await gen.sleep(0.01) - assert time() < start + 3 + with captured_logger("distributed.scheduler") as logs: + start = time() + while s.status != "closed": + await gen.sleep(0.01) + assert time() < start + 3 - start = time() - while not (a.status == "closed" and b.status == "closed"): - await gen.sleep(0.01) - assert time() < start + 1 + start = time() + while not (a.status == "closed" and b.status == "closed"): + await gen.sleep(0.01) + assert time() < start + 1 + + assert "idle" in logs.getvalue() + assert "500" in logs.getvalue() + assert "ms" in logs.getvalue() @gen_cluster(client=True, config={"distributed.scheduler.bandwidth": "100 GB"}) From fbdb067ed514024aad14dbdc3790bae9344548df Mon Sep 17 00:00:00 2001 From: jakirkham Date: Wed, 25 Mar 2020 07:04:06 -0700 Subject: [PATCH 0744/1550] Drop unused line from `pack_frames_prelude` (#3634) --- distributed/protocol/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/distributed/protocol/utils.py b/distributed/protocol/utils.py index 3af203881ff..e58732b881c 100644 --- a/distributed/protocol/utils.py +++ b/distributed/protocol/utils.py @@ -92,7 +92,6 @@ def merge_frames(header, frames): def pack_frames_prelude(frames): - lengths = [len(f) for f in frames] lengths = [struct.pack("Q", len(frames))] + [ struct.pack("Q", nbytes(frame)) for frame in frames ] From 77d103f10971bea68fa733e53c00d9ea2a6a431a Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 25 Mar 2020 08:14:57 -0700 Subject: [PATCH 0745/1550] Add as_completed.clear method (#3617) --- distributed/client.py | 30 ++++++++++++++++---------- distributed/tests/test_as_completed.py | 16 +++++++++++++- 2 files changed, 34 insertions(+), 12 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 4065aad17e9..c45e0718b43 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -4236,17 +4236,18 @@ async def _track_future(self, future): except CancelledError as exc: result = exc with self.lock: - self.futures[future] -= 1 - if not self.futures[future]: - del self.futures[future] - if self.with_results: - self.queue.put_nowait((future, result)) - else: - self.queue.put_nowait(future) - async with self.condition: - self.condition.notify() - with self.thread_condition: - self.thread_condition.notify() + if future in self.futures: + self.futures[future] -= 1 + if not self.futures[future]: + del self.futures[future] + if self.with_results: + self.queue.put_nowait((future, result)) + else: + self.queue.put_nowait(future) + async with self.condition: + self.condition.notify() + with self.thread_condition: + self.thread_condition.notify() def update(self, futures): """ Add multiple futures to the collection. @@ -4380,6 +4381,13 @@ def batches(self): except StopIteration: return + def clear(self): + """ Clear out all submitted futures """ + with self.lock: + self.futures.clear() + while not self.queue.empty(): + self.queue.get() + def AsCompleted(*args, **kwargs): raise Exception("This has moved to as_completed") diff --git a/distributed/tests/test_as_completed.py b/distributed/tests/test_as_completed.py index c9780d196cd..f71c6f7492e 100644 --- a/distributed/tests/test_as_completed.py +++ b/distributed/tests/test_as_completed.py @@ -8,7 +8,7 @@ import pytest from tornado import gen -from distributed.client import _as_completed, as_completed, _first_completed +from distributed.client import _as_completed, as_completed, _first_completed, wait from distributed.metrics import time from distributed.utils import CancelledError from distributed.utils_test import gen_cluster, inc, throws @@ -273,3 +273,17 @@ def test_as_completed_with_results_no_raise_async(c, s, a, b): assert isinstance(dd[y][0], CancelledError) assert isinstance(dd[x][0][1], RuntimeError) assert dd[z][0] == 2 + + +@gen_cluster(client=True, timeout=None) +async def test_clear(c, s, a, b): + futures = c.map(inc, range(3)) + ac = as_completed(futures) + await wait(futures) + ac.clear() + with pytest.raises(StopAsyncIteration): + await ac.__anext__() + del futures + + while s.tasks: + await asyncio.sleep(0.3) From 2277379f6249ecbb132d6b3872550d7ee7665ef0 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Wed, 25 Mar 2020 19:38:34 +0100 Subject: [PATCH 0746/1550] UCX synchronize default stream only on CUDA frames (#3638) * UCX synchronize default stream only on CUDA frames * Improve check for CUDA frames in UCX comms * Further improvements to CUDA frame synchronization in UCX * Fix black formatting --- distributed/comm/ucx.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index 22d7e361e97..fc187dcc614 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -158,13 +158,12 @@ async def write( ) # Send meta data - await self.ep.send(np.array([len(frames)], dtype=np.uint64)) - await self.ep.send( - np.array( - [hasattr(f, "__cuda_array_interface__") for f in frames], - dtype=np.bool, - ) + cuda_frames = np.array( + [hasattr(f, "__cuda_array_interface__") for f in frames], + dtype=np.bool, ) + await self.ep.send(np.array([len(frames)], dtype=np.uint64)) + await self.ep.send(cuda_frames) await self.ep.send( np.array([nbytes(f) for f in frames], dtype=np.uint64) ) @@ -175,7 +174,8 @@ async def write( # syncing the default stream will wait for other non-blocking CUDA streams. # Note this is only sufficient if the memory being sent is not currently in use on # non-blocking CUDA streams. - synchronize_stream(0) + if cuda_frames.any(): + synchronize_stream(0) for frame in frames: if nbytes(frame) > 0: @@ -222,7 +222,8 @@ async def read(self, deserializers=("cuda", "dask", "pickle", "error")): # It is necessary to first populate `frames` with CUDA arrays and synchronize # the default stream before starting receiving to ensure buffers have been allocated - synchronize_stream(0) + if is_cudas.any(): + synchronize_stream(0) for i, (is_cuda, size) in enumerate( zip(is_cudas.tolist(), sizes.tolist()) ): From 2e64ae9c256069d8e5ca93b1a1b7356a8c29f3c5 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Wed, 25 Mar 2020 14:42:12 -0500 Subject: [PATCH 0747/1550] bump version to 2.13.0 --- docs/source/changelog.rst | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index 9c7ca9b01f4..12288fc4aba 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -1,6 +1,41 @@ Changelog ========= +2.13.0 - 2020-03-25 +------------------- + +- UCX synchronize default stream only on CUDA frames (:pr:`3638`) `Peter Andreas Entschev`_ +- Add ``as_completed.clear`` method (:pr:`3617`) `Matthew Rocklin`_ +- Drop unused line from ``pack_frames_prelude`` (:pr:`3634`) `John Kirkham`_ +- Add logging message when closing idle dask scheduler (:pr:`3632`) `Matthew Rocklin`_ +- Include frame lengths of CUDA objects in ``header["lengths"]`` (:pr:`3631`) `John Kirkham`_ +- Ensure ``Client`` connection pool semaphore attaches to the ``Client`` event loop (:pr:`3546`) `James Bourbeau`_ +- Remove dead stealing code (:pr:`3619`) `Florian Jetter`_ +- Check ``nbytes`` and ``types`` before reading ``data`` (:pr:`3628`) `John Kirkham`_ +- Ensure that we don't steal blacklisted fast tasks (:pr:`3591`) `Florian Jetter`_ +- Support async ``Listener.stop`` functions (:pr:`3613`) `Matthew Rocklin`_ +- Add str/repr methods to ``as_completed`` (:pr:`3618`) `Matthew Rocklin`_ +- Add backoff to comm connect attempts. (:pr:`3496`) `Matthias Urlichs`_ +- Make ``Listeners`` awaitable (:pr:`3611`) `Matthew Rocklin`_ +- Increase number of visible mantissas in dashboard plots (:pr:`3585`) `Scott Sievert`_ +- Pin openssl to 1.1.1d for Travis (:pr:`3602`) `Jacob Tomlinson`_ +- Replace ``tornado.queues`` with ``asyncio.queues`` (:pr:`3607`) `James Bourbeau`_ +- Remove ``dill`` from CI environments (:pr:`3608`) `Loïc Estève`_ +- Fix linting errors (:pr:`3604`) `James Bourbeau`_ +- Synchronize default CUDA stream before UCX send/recv (:pr:`3598`) `Peter Andreas Entschev`_ +- Add configuration for ``Adaptive`` arguments (:pr:`3509`) `Gabriel Sailer`_ +- Change ``Adaptive`` docs to reference ``adaptive_target`` (:pr:`3597`) `Julia Signell`_ +- Optionally compress on a frame-by-frame basis (:pr:`3586`) `Matthew Rocklin`_ +- Add Python version to version check (:pr:`3567`) `James Bourbeau`_ +- Import ``tlz`` (:pr:`3579`) `John Kirkham`_ +- Pin ``numpydoc`` to avoid double escaped ``*`` (:pr:`3530`) `Gil Forsyth`_ +- Avoid ``performance_report`` crashing when a worker dies mid-compute (:pr:`3575`) `Krishan Bhasin`_ +- Pin ``bokeh`` in CI builds (:pr:`3570`) `James Bourbeau`_ +- Disable fast fail on GitHub Actions Windows CI (:pr:`3569`) `James Bourbeau`_ +- Fix typo in ``Client.shutdown`` docstring (:pr:`3562`) `John Kirkham`_ +- Add ``local_directory`` option to ``dask-ssh`` (:pr:`3554`) `Abdulelah Bin Mahfoodh`_ + + 2.12.0 - 2020-03-06 ------------------- @@ -1583,6 +1618,7 @@ significantly without many new features. .. _`He Jia`: https://github.com/HerculesJack .. _`Jim Crist-Harif`: https://github.com/jcrist .. _`fjetter`: https://github.com/fjetter +.. _`Florian Jetter`: https://github.com/fjetter .. _`Patrick Sodré`: https://github.com/sodre .. _`Stephan Erb`: https://github.com/StephanErb .. _`Benedikt Reinartz`: https://github.com/filmor @@ -1602,3 +1638,6 @@ significantly without many new features. .. _`Davis Bennett`: https://github.com/d-v-b .. _`Lucas Rademaker`: https://github.com/lr4d .. _`Darren Weber`: https://github.com/dazza-codes +.. _`Matthias Urlichs`: https://github.com/smurfix +.. _`Krishan Bhasin`: https://github.com/KrishanBhasin +.. _`Abdulelah Bin Mahfoodh`: https://github.com/abduhbm From 0bf20a4312b113659900be7e647d28b3bde4015a Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Thu, 26 Mar 2020 13:08:35 -0500 Subject: [PATCH 0748/1550] Update bokeh dependency in CI builds (#3637) --- continuous_integration/environment.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/continuous_integration/environment.yml b/continuous_integration/environment.yml index 7458c9f64ac..5f09525caae 100644 --- a/continuous_integration/environment.yml +++ b/continuous_integration/environment.yml @@ -3,7 +3,7 @@ channels: - conda-forge dependencies: - zstandard - - bokeh=1.4.0 + - bokeh!=2.0.0 - click - cloudpickle - dask From 9a02d7c7c6ecf8010d5286e625a5d842e99df8ad Mon Sep 17 00:00:00 2001 From: Prasun Anand Date: Thu, 26 Mar 2020 23:40:23 +0530 Subject: [PATCH 0749/1550] Add link to contributing.md (#3621) --- CONTRIBUTING.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index ab4175a59fe..3859e21af8d 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,3 +1,5 @@ -Dask is a community maintained project. We welcome contributions in the form of bug reports, documentation, code, design proposals, and more. +Dask is a community maintained project. We welcome contributions in the form of bug reports, documentation, code, design proposals, and more. -For general information on how to contribute see https://docs.dask.org/en/latest/develop.html. +Please see https://distributed.dask.org/en/latest/develop.html for more information. + +Also for general information on how to contribute see https://docs.dask.org/en/latest/develop.html. From f244ab64e634f46ad3d26de7a2f536e0055a5214 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Fri, 27 Mar 2020 03:50:41 -0500 Subject: [PATCH 0750/1550] Don't create output Futures in Client when there are mixed Client Futures (#3643) * Don't create Futures if raising mixed Clients error * Add test --- distributed/client.py | 5 ++--- distributed/tests/test_client.py | 3 +++ 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index c45e0718b43..97c2fd60e17 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -2452,8 +2452,6 @@ def _graph_to_futures( actors = list(self._expand_key(actors)) keyset = set(keys) - flatkeys = list(map(tokey, keys)) - futures = {key: Future(key, self, inform=False) for key in keyset} values = { k: v @@ -2506,12 +2504,13 @@ def _graph_to_futures( if isinstance(retries, Number) and retries > 0: retries = {k: retries for k in dsk3} + futures = {key: Future(key, self, inform=False) for key in keyset} self._send_to_scheduler( { "op": "update-graph", "tasks": valmap(dumps_task, dsk3), "dependencies": dependencies, - "keys": list(flatkeys), + "keys": list(map(tokey, keys)), "restrictions": restrictions or {}, "loose_restrictions": loose_restrictions, "priority": priority, diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index d9633876cb2..98d0a9bf290 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -5449,6 +5449,9 @@ def test_mixing_clients(s, a, b): future = c1.submit(inc, 1) with pytest.raises(ValueError): c2.submit(inc, future) + + assert not c2.futures # Don't create Futures on second Client + yield c1.close() yield c2.close() From 926eb12a03ac750712e13ac60560c64f652661b1 Mon Sep 17 00:00:00 2001 From: Prasun Anand Date: Fri, 27 Mar 2020 20:34:54 +0530 Subject: [PATCH 0751/1550] Remove local-directory keyword (#3620) --- distributed/cli/dask_scheduler.py | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/distributed/cli/dask_scheduler.py b/distributed/cli/dask_scheduler.py index 0951b8c3d27..2394dd65dea 100755 --- a/distributed/cli/dask_scheduler.py +++ b/distributed/cli/dask_scheduler.py @@ -3,9 +3,7 @@ import gc import os import re -import shutil import sys -import tempfile import warnings import click @@ -104,9 +102,6 @@ "This may be a good way to share connection information if your " "cluster is on a shared network file system.", ) -@click.option( - "--local-directory", default="", type=str, help="Directory to place scheduler files" -) @click.option( "--preload", type=str, @@ -136,7 +131,6 @@ def main( dashboard_prefix, use_xheaders, pid_file, - local_directory, tls_ca_file, tls_cert, tls_key, @@ -194,17 +188,6 @@ def del_pid_file(): atexit.register(del_pid_file) - local_directory_created = False - if local_directory: - if not os.path.exists(local_directory): - os.mkdir(local_directory) - local_directory_created = True - else: - local_directory = tempfile.mkdtemp(prefix="scheduler-") - local_directory_created = True - if local_directory not in sys.path: - sys.path.insert(0, local_directory) - if sys.platform.startswith("linux"): import resource # module fails importing on Windows @@ -224,7 +207,6 @@ def del_pid_file(): service_kwargs={"dashboard": {"prefix": dashboard_prefix}}, **kwargs ) - logger.info("Local Directory: %26s", local_directory) logger.info("-" * 47) install_signal_handlers(loop) @@ -237,8 +219,6 @@ async def run(): loop.run_sync(run) finally: scheduler.stop() - if local_directory_created: - shutil.rmtree(local_directory) logger.info("End scheduler at %r", scheduler.address) From aa979741787248e4d45abec9affb6240ab61dd1e Mon Sep 17 00:00:00 2001 From: Gabriel Sailer Date: Fri, 27 Mar 2020 16:42:48 +0100 Subject: [PATCH 0752/1550] Add prometheus metric for suspicious tasks (#3550) --- distributed/dashboard/scheduler.py | 10 ++++++++++ distributed/scheduler.py | 7 +++++++ distributed/tests/test_scheduler.py | 14 ++++++++++++++ 3 files changed, 31 insertions(+) diff --git a/distributed/dashboard/scheduler.py b/distributed/dashboard/scheduler.py index acaab24cd17..982bc424826 100644 --- a/distributed/dashboard/scheduler.py +++ b/distributed/dashboard/scheduler.py @@ -289,6 +289,16 @@ def collect(self): sum, (tp.states for tp in self.server.task_prefixes.values()) ) + suspicious_tasks = CounterMetricFamily( + "dask_scheduler_tasks_suspicious", + "Total number of times a task has been marked suspicious", + labels=["task_prefix_name"], + ) + + for tp in self.server.task_prefixes.values(): + suspicious_tasks.add_metric([tp.name], tp.suspicious) + yield suspicious_tasks + yield CounterMetricFamily( "dask_scheduler_tasks_forgotten", ( diff --git a/distributed/scheduler.py b/distributed/scheduler.py index f99c26d9aba..3de506cfade 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -789,6 +789,11 @@ class TaskPrefix: An exponentially weighted moving average duration of all tasks with this prefix + .. attribute:: suspicious: int + + Numbers of times a task was marked as suspicious with this prefix + + See Also -------- TaskGroup @@ -805,6 +810,7 @@ def __init__(self, name): ) else: self.duration_average = None + self.suspicious = 0 @property def states(self): @@ -2190,6 +2196,7 @@ def remove_worker(self, comm=None, address=None, safe=False, close=True): recommendations[k] = "released" if not safe: ts.suspicious += 1 + ts.prefix.suspicious += 1 if ts.suspicious > self.allowed_failures: del recommendations[k] e = pickle.dumps( diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 1068169b200..46e4d0c885c 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -787,6 +787,7 @@ def test_retire_workers_no_suspicious_tasks(c, s, a, b): yield s.retire_workers(workers=[a.address]) assert all(ts.suspicious == 0 for ts in s.tasks.values()) + assert all(tp.suspicious == 0 for tp in s.task_prefixes.values()) @pytest.mark.slow @@ -1810,6 +1811,19 @@ async def test_task_prefix(c, s, a, b): assert s.task_prefixes["sum-aggregate"].states["memory"] == 2 +@gen_cluster( + client=True, Worker=Nanny, config={"distributed.scheduler.allowed-failures": 0} +) +async def test_failing_task_increments_suspicious(client, s, a, b): + future = client.submit(sys.exit, 0) + await wait(future) + + assert s.task_prefixes["exit"].suspicious == 1 + assert sum(tp.suspicious for tp in s.task_prefixes.values()) == sum( + ts.suspicious for ts in s.tasks.values() + ) + + @gen_cluster(client=True) async def test_task_group_non_tuple_key(c, s, a, b): da = pytest.importorskip("dask.array") From b8a0b8e110d37b7986d0693e9f35fdddb85ded17 Mon Sep 17 00:00:00 2001 From: Jacob Tomlinson Date: Fri, 27 Mar 2020 15:43:23 +0000 Subject: [PATCH 0753/1550] Handle exception in faulthandler (#3646) --- conftest.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/conftest.py b/conftest.py index b5db36f59d8..07adc4982f6 100644 --- a/conftest.py +++ b/conftest.py @@ -12,7 +12,10 @@ except ImportError: pass else: - faulthandler.enable() + try: + faulthandler.enable() + except Exception: + pass def pytest_addoption(parser): From 7802bf3bffe0c870bb27543e4c334820c3b8fae7 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Fri, 27 Mar 2020 12:16:30 -0500 Subject: [PATCH 0754/1550] Bump checkout GitHub action to v2 (#3649) --- .github/workflows/ci-docs.yaml | 2 +- .github/workflows/ci-windows.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci-docs.yaml b/.github/workflows/ci-docs.yaml index 780e2a251fd..c519427f140 100644 --- a/.github/workflows/ci-docs.yaml +++ b/.github/workflows/ci-docs.yaml @@ -7,7 +7,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v1 + uses: actions/checkout@v2 - name: Set up Python 3.7 uses: actions/setup-python@v1 diff --git a/.github/workflows/ci-windows.yaml b/.github/workflows/ci-windows.yaml index 3b99a8c8ec0..e0c95d0f234 100644 --- a/.github/workflows/ci-windows.yaml +++ b/.github/workflows/ci-windows.yaml @@ -12,7 +12,7 @@ jobs: steps: - name: Checkout source - uses: actions/checkout@v1 + uses: actions/checkout@v2 - name: Setup Conda Environment uses: goanpeca/setup-miniconda@v1 From 3fceec696b81f02c8082f253bf340e3f494fc42c Mon Sep 17 00:00:00 2001 From: jakirkham Date: Fri, 27 Mar 2020 13:36:04 -0700 Subject: [PATCH 0755/1550] UCX simplify receiving frames in `comm`s (#3651) * Prefix `for`-loop variables with `each_*` Should make it easier to disambiguate things like `frame` and `frames` as they are now `each_frame` and `frames`. * Allocate frames the same way in 0-length case * Always allocate frames, receive non-trivial ones * Allocate all frames to fill before receiving * Filter out non-trivial frames to transmit --- distributed/comm/ucx.py | 40 ++++++++++++++++++---------------------- 1 file changed, 18 insertions(+), 22 deletions(-) diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index fc187dcc614..a29441ec4d5 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -156,6 +156,9 @@ async def write( frames = await to_frames( msg, serializers=serializers, on_error=on_error ) + send_frames = [ + each_frame for each_frame in frames if len(each_frame) > 0 + ] # Send meta data cuda_frames = np.array( @@ -167,6 +170,7 @@ async def write( await self.ep.send( np.array([nbytes(f) for f in frames], dtype=np.uint64) ) + # Send frames # It is necessary to first synchronize the default stream before start sending @@ -177,10 +181,9 @@ async def write( if cuda_frames.any(): synchronize_stream(0) - for frame in frames: - if nbytes(frame) > 0: - await self.ep.send(frame) - return sum(map(nbytes, frames)) + for each_frame in send_frames: + await self.ep.send(each_frame) + return sum(map(nbytes, send_frames)) except (ucp.exceptions.UCXBaseException): self.abort() raise CommClosedError("While writing, the connection was closed") @@ -206,30 +209,23 @@ async def read(self, deserializers=("cuda", "dask", "pickle", "error")): raise CommClosedError("While reading, the connection was closed") else: # Recv frames - frames = [] - for is_cuda, size in zip(is_cudas.tolist(), sizes.tolist()): - if size > 0: - if is_cuda: - frame = cuda_array(size) - else: - frame = np.empty(size, dtype=np.uint8) - frames.append(frame) - else: - if is_cuda: - frames.append(cuda_array(size)) - else: - frames.append(b"") + frames = [ + cuda_array(each_size) + if is_cuda + else np.empty(each_size, dtype=np.uint8) + for is_cuda, each_size in zip(is_cudas.tolist(), sizes.tolist()) + ] + recv_frames = [ + each_frame for each_frame in frames if len(each_frame) > 0 + ] # It is necessary to first populate `frames` with CUDA arrays and synchronize # the default stream before starting receiving to ensure buffers have been allocated if is_cudas.any(): synchronize_stream(0) - for i, (is_cuda, size) in enumerate( - zip(is_cudas.tolist(), sizes.tolist()) - ): - if size > 0: - await self.ep.recv(frames[i]) + for each_frame in recv_frames: + await self.ep.recv(each_frame) msg = await from_frames( frames, deserialize=self.deserialize, deserializers=deserializers ) From eda27bee472090f5370121d7d35c201d2facd79e Mon Sep 17 00:00:00 2001 From: Gabriel Sailer Date: Fri, 27 Mar 2020 23:10:14 +0100 Subject: [PATCH 0756/1550] Introduce config for default task duration (#3642) --- distributed/distributed.yaml | 1 + distributed/scheduler.py | 6 +++++- distributed/tests/test_scheduler.py | 11 +++++++++++ 3 files changed, 17 insertions(+), 1 deletion(-) diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index 311eeaae829..17326aebd54 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -24,6 +24,7 @@ distributed: pickle: True # Is the scheduler allowed to deserialize arbitrary bytestrings preload: [] preload-argv: [] + unknown-task-duration: 500ms # Default duration for all tasks with unknown durations ("15m", "2h") default-task-durations: # How long we expect function names to run ("1h", "1s") (helps for long tasks) rechunk-split: 1us shuffle-split: 1us diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 3de506cfade..0e998e7296a 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -3700,7 +3700,7 @@ def get_comm_cost(self, ts, ws): """ return sum(dts.nbytes for dts in ts.dependencies - ws.has_what) / self.bandwidth - def get_task_duration(self, ts, default=0.5): + def get_task_duration(self, ts, default=None): """ Get the estimated computation cost of the given task (not including any communication cost). @@ -3708,6 +3708,10 @@ def get_task_duration(self, ts, default=0.5): duration = ts.prefix.duration_average if duration is None: self.unknown_durations[ts.prefix.name].add(ts) + if default is None: + default = parse_timedelta( + dask.config.get("distributed.scheduler.unknown-task-duration") + ) return default return duration diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 46e4d0c885c..a0e1b11de2b 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -2071,3 +2071,14 @@ async def test_worker_name_collision(s, a): s.validate_state() assert set(s.workers) == {a.address} assert s.aliases == {a.name: a.address} + + +@gen_cluster(client=True, config={"distributed.scheduler.unknown-task-duration": "1h"}) +async def test_unknown_task_duration_config(client, s, a, b): + future = client.submit(slowinc, 1) + while not s.tasks: + await asyncio.sleep(0.001) + assert sum(s.get_task_duration(ts) for ts in s.tasks.values()) == 3600 + assert len(s.unknown_durations) == 1 + await wait(future) + assert len(s.unknown_durations) == 0 From f765242811199801553822a99e80559465926357 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sat, 28 Mar 2020 10:15:05 -0700 Subject: [PATCH 0757/1550] Avoid diangostics time in performance report (#3654) Previously we would include all of the time taken to generate the performance report in the reported time. Now we record the time before we generate plots and use that instead. --- distributed/scheduler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 0e998e7296a..882970838df 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -5010,6 +5010,7 @@ async def get_profile_metadata( return {"counts": counts, "keys": keys} async def performance_report(self, comm=None, start=None, code=""): + stop = time() # Profiles compute, scheduler, workers = await asyncio.gather( *[ @@ -5071,7 +5072,7 @@ def profile_to_figure(state): {code} """.format( - time=format_time(time() - start), + time=format_time(stop - start), address=self.address, nworkers=len(self.workers), threads=sum(w.nthreads for w in self.workers.values()), From d7948ce499c6788b45758efc0d31103cf57c2d22 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sat, 28 Mar 2020 10:55:46 -0700 Subject: [PATCH 0758/1550] Clean up performance report test (#3655) Previously all of the text in the assertions was being included in the performance report itself (ever since we started including the surrounding frame). This made these tests pass trivially. Now we wrap the performance report in a function of its own. --- distributed/tests/test_client.py | 38 +++++++++++++++++++++----------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 98d0a9bf290..8f121c7f27c 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -5916,22 +5916,34 @@ async def f(dask_worker): assert b.foo == "bar" -@gen_cluster(client=True) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 2)] * 2) async def test_performance_report(c, s, a, b): da = pytest.importorskip("dask.array") - x = da.random.random((1000, 1000), chunks=(100, 100)) - - with tmpfile(extension="html") as fn: - async with performance_report(filename=fn): - await c.compute((x + x.T).sum()) - with open(fn) as f: - data = f.read() - - assert "bokeh" in data - assert "random" in data - assert "Dask Performance Report" in data - assert "x = da.random" in data + async def f(): + """ + We wrap this in a function so that the assertions aren't in the + performanace report itself + + Also, we want this comment to appear + """ + x = da.random.random((1000, 1000), chunks=(100, 100)) + with tmpfile(extension="html") as fn: + async with performance_report(filename=fn): + await c.compute((x + x.T).sum()) + + with open(fn) as f: + data = f.read() + return data + + data = await f() + + assert "Also, we want this comment to appear" in data + assert "bokeh" in data + assert "random" in data + assert "Dask Performance Report" in data + assert "x = da.random" in data + assert "Threads: 4" in data @pytest.mark.asyncio From 362896a87927cbca48cccf3ba6ceb495169ecaee Mon Sep 17 00:00:00 2001 From: Rami Chowdhury <460769+necaris@users.noreply.github.com> Date: Mon, 30 Mar 2020 12:24:25 -0400 Subject: [PATCH 0759/1550] Add newlines to ensure code formatting for `retire_workers`(#3661) Without the newline, the code samples don't seem to be parsed as such, and so aren't set apart or formatted properly. --- distributed/client.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/distributed/client.py b/distributed/client.py index 97c2fd60e17..f2b25b74f60 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -3489,9 +3489,11 @@ def retire_workers(self, workers=None, close_workers=True, **kwargs): Examples -------- You can get information about active workers using the following: + >>> workers = client.scheduler_info()['workers'] From that list you may want to select some workers to close + >>> client.retire_workers(workers=['tcp://address:port', ...]) See Also From 067fd1cc35ec9d5eae62c50e1bc439111f613662 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Mon, 30 Mar 2020 13:51:51 -0500 Subject: [PATCH 0760/1550] Update Python version checking (#3660) --- distributed/tests/test_versions.py | 20 +++++--------------- distributed/versions.py | 5 +---- 2 files changed, 6 insertions(+), 19 deletions(-) diff --git a/distributed/tests/test_versions.py b/distributed/tests/test_versions.py index ab3547820ca..64b94ba37b1 100644 --- a/distributed/tests/test_versions.py +++ b/distributed/tests/test_versions.py @@ -1,5 +1,5 @@ import re -import uuid +import sys import pytest @@ -120,17 +120,7 @@ async def test_version_warning_in_cluster(s, a, b): ) -@gen_cluster() -async def test_python_version_mismatch_warning(s, a, b): - # Set random Python version for one worker - random_version = uuid.uuid4().hex - orig = s.workers[a.address].versions["host"]["python"] = random_version - - with pytest.warns(None) as record: - async with Client(s.address, asynchronous=True) as client: - pass - - assert record - assert any("python" in str(r.message) for r in record) - assert any(random_version in str(r.message) for r in record) - assert any(a.address in str(r.message) for r in record) +def test_python_version(): + required = get_versions()["packages"] + assert "python" in required + assert required["python"] == ".".join(map(str, sys.version_info)) diff --git a/distributed/versions.py b/distributed/versions.py index 403d79f6aef..d800f65ec63 100644 --- a/distributed/versions.py +++ b/distributed/versions.py @@ -82,7 +82,7 @@ def version_of_package(pkg): def get_package_info(pkgs): """ get package versions for the passed required & optional packages """ - pversions = [] + pversions = [("python", ".".join(map(str, sys.version_info)))] for pkg in pkgs: if isinstance(pkg, (tuple, list)): modname, ver_f = pkg @@ -121,9 +121,6 @@ def error_message(scheduler, workers, client, client_name="client"): for pkg, version in info["packages"].items(): node_packages[node][pkg] = version packages.add(pkg) - # Collect Python version for each node - node_packages[node]["python"] = info["host"]["python"] - packages.add("python") errs = [] for pkg in sorted(packages): From 26e28c348a1101a62631a07c152c6ce0365621d0 Mon Sep 17 00:00:00 2001 From: Martin Durant Date: Tue, 31 Mar 2020 11:12:36 -0400 Subject: [PATCH 0761/1550] Write "why killed" docs [skip ci] (#3596) * Write "why killed" docs [skip ci] * try specific [skip ci] * Responses --- docs/source/faq.rst | 8 +++ docs/source/index.rst | 1 + docs/source/ipython.rst | 1 + docs/source/killed.rst | 139 ++++++++++++++++++++++++++++++++++++++++ docs/source/worker.rst | 1 + 5 files changed, 150 insertions(+) create mode 100644 docs/source/killed.rst diff --git a/docs/source/faq.rst b/docs/source/faq.rst index 5803f10096c..0c0e3e84b70 100644 --- a/docs/source/faq.rst +++ b/docs/source/faq.rst @@ -82,3 +82,11 @@ subprocess, head to `this section`_ of the supervisor documentation to see how to pass the ``$HOME`` and ``$USER`` variables through. .. _this section: http://supervisord.org/subprocess.html#subprocess-environment + + +KilledWorker, CommsClosed, etc. +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +In the case that workers disappear unexpectedly from your cluster, you may see +a range of error messages. After checking the logs of the workers affected, you +should read the section :doc:`killed`. diff --git a/docs/source/index.rst b/docs/source/index.rst index 3cbdd18792a..249f9eb9faf 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -100,6 +100,7 @@ Contents scheduling-state worker work-stealing + killed .. toctree:: :maxdepth: 1 diff --git a/docs/source/ipython.rst b/docs/source/ipython.rst index a70229d508f..df5f44eb1d1 100644 --- a/docs/source/ipython.rst +++ b/docs/source/ipython.rst @@ -37,6 +37,7 @@ This is possible with the Client.become_dask_ method:: .. _Client.become_dask: https://ipyparallel.readthedocs.io/en/latest/api/ipyparallel.html#ipyparallel.Client.become_dask +.. _ipylaunch: Launch IPython within Dask Workers ---------------------------------- diff --git a/docs/source/killed.rst b/docs/source/killed.rst new file mode 100644 index 00000000000..837ccd944b4 --- /dev/null +++ b/docs/source/killed.rst @@ -0,0 +1,139 @@ +.. _killed: + +Why did my worker die? +---------------------- + +A Dask worker can cease functioning for a number of reasons. These fall into the +following categories: + +- the worker chooses to exit +- an unrecoverable exception happens within the worker +- the worker process is shut down by some external action + +Each of these cases will be described in more detail below. The *symptoms* you will +experience when these things happen range from simply work not getting done anymore, +to various exceptions appearing when you interact with your local client, such as +``KilledWorker``, ``TimeoutError`` and ``CommClosedError``. + +Note the special case of ``KilledWorker``: this means that a particular task was +tried on a worker, and it died, and then the same task was sent to another worker, +which also died. After a configurable number of deaths (config key " +``distributed.scheduler.allowed-failures``), Dask decides to blame the +task itself, and returns this exception. Note, that it is possible for a task to be +unfairly blamed - the worker happened to die while the task was active, perhaps +due to another thread - complicating diagnosis. + +In every case, the first place to look for further information is the logs of +the given worker, which may well give a complete description of what happened. These +logs are printed by the worker to its "standard error", which may appear in the text +console from which you launched the worker, or some logging system maintained by +the cluster infrastructure. It is also helpful to watch the diagnostic dashboard to +look for memory spikes, but of course this is only possible while the worker is still +alive. + +In all cases, the scheduler will notice that the worker has gone, either because +of an explicit de-registration, or because the worker no longer produces heartbeats, +and it should be possible to reroute tasks to other workers and have the system +keep running. + +Scenarios +~~~~~~~~~ + +Worker chose to exit +'''''''''''''''''''' + +Workers may exit in normal functioning because they have been asked to, e.g., +they received a keyboard interrupt (^C), or the scheduler scaled down the cluster. +In such cases, the work that was being done by the worker will be redirected to +other workers, if there are any left. + +You should expect to see the following message at the end of the worker's log: + +:: + + distributed.dask_worker - INFO - End worker + +In these cases, there is not normally anything which you need to do, since the +behaviour is expected. + +Unrecoverable Exception +''''''''''''''''''''''' + +The worker is a python process, and like any other code, an exception may occur +which causes the process to exit. One typical example of this might be a +version mismatch between the packages of the client and worker, so that +a message sent to the worker errors while being unpacked. There are a number of +packages that need to match, not only ``dask`` and ``distributed``. + +In this case, you should expect to see the full python traceback in the worker's +log. In the event of a version mismatch, this might be complaining about a bad +import or missing attribute. However, other fatal exceptions are also possible, +such as trying to allocate more memory than the system has available, or writing +temporary files without appropriate permissions. + +To assure that you have matching versions, you should run (more recent versions +of distributed may do this automatically) + +.. code-block:: + + client.get_versions(check=True) + +For other errors, you might want to run the computation in your local client, if +possible, or try grabbing just the task that errored and using +:func:`recreate_error_locally `, +as you would for ordinary exceptions happening during task execution. + +Specifically for connectivity problems (e.g., timeout exceptions in the worker +logs), you will need to diagnose your networking infrastructure, which is more +complicated than can be described here. Commonly, it may involve logging into +the machine running the affected worker +(although you can :ref:`ipylaunch`). + +Killed by Nanny +''''''''''''''' + +The Dask "nanny" is a process which watches the worker, and restarts it if +necessary. It also tracks the worker's memory usage, and if it should cross +a given fraction of total memory, then also the worker will be restarted, +interrupting any work in progress. The log will show a message like + +:: + + Worker exceeded X memory budget. Restarting + +Where X is the memory fraction. You can set this critical fraction using +the configuration, see :ref:`memman`. If you have an external system for +watching memory usage provided by your cluster infrastructure (HPC, +kubernetes, etc.), then it may be reasonable to turn off this memory +limit. Indeed, in these cases, restarts might be handled for you too, so +you could do without the nanny at all (``--no-nanny`` CLI option or +configuration equivalent). + +Sudden Exit +''''''''''' + +The worker process may stop working without notice. This can happen due to +something internal to the worker, e.g., a memory violation (common if interfacing +with compiled code), or due to something external, e.g., the ``kill`` command, or +stopping of the container or machine on which the worker is running. + +In the best case, you may have a line in the logs from the OS saying that the +worker was shut down, such as the single word "killed" or something more descriptive. +In these cases, the fault may well be in your code, and you might be able to use the +same debugging tools as in the previous section. + +However, if the action was initiated by some outside framework, then the worker will +have no time to leave a logging message, and the death *may* have nothing to do with +what the worker was doing at the time. For example, if kubernetes decides to evict a +pod, or your ec2 instance goes down for maintenance, the worker is not at fault. +Hopefully, the system provides a reasonable message of what happened in the process +output. +However, if the memory allocation (or other resource) exceeds toleration, then it +*is* the code's fault - although you may be able to fix with better configuration +of Dask's own limits, or simply with a bigger cluster. In any case, your deployment +framework has its own logging system, and you should look there for the reason that +the dask worker was taken down. + +Specifically for memory issues, refer to the memory section of `best practices`_. + +.. _best practices: https://docs.dask.org/en/latest/best-practices.html#avoid-very-large-partitions diff --git a/docs/source/worker.rst b/docs/source/worker.rst index 5ff66b613a6..dc4e56d9ac7 100644 --- a/docs/source/worker.rst +++ b/docs/source/worker.rst @@ -142,6 +142,7 @@ thread pool. A task either errs or its result is put into memory. In either case a response is sent back to the scheduler. +.. _memman:: Memory Management ----------------- From dedcb1350e6106a21ddcdad1edaa94b346b3e9d8 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 31 Mar 2020 08:15:56 -0700 Subject: [PATCH 0762/1550] Clean up some test warnings (#3662) * remove errant yield in variable tests * logs -> get_logs * add additional awaits --- distributed/deploy/tests/test_spec_cluster.py | 10 +++++----- distributed/tests/test_client.py | 9 ++++----- distributed/tests/test_core.py | 2 +- distributed/tests/test_variable.py | 2 +- 4 files changed, 11 insertions(+), 12 deletions(-) diff --git a/distributed/deploy/tests/test_spec_cluster.py b/distributed/deploy/tests/test_spec_cluster.py index 90ce9923c69..ae24e7400e2 100644 --- a/distributed/deploy/tests/test_spec_cluster.py +++ b/distributed/deploy/tests/test_spec_cluster.py @@ -278,7 +278,7 @@ async def test_logs(cleanup): cluster.scale(2) await cluster - logs = await cluster.logs() + logs = await cluster.get_logs() assert is_valid_xml("
        " + logs._repr_html_() + "
        ") assert "Scheduler" in logs for worker in cluster.scheduler.workers: @@ -286,17 +286,17 @@ async def test_logs(cleanup): assert "Registered" in str(logs) - logs = await cluster.logs(scheduler=True, workers=False) + logs = await cluster.get_logs(scheduler=True, workers=False) assert list(logs) == ["Scheduler"] - logs = await cluster.logs(scheduler=False, workers=False) + logs = await cluster.get_logs(scheduler=False, workers=False) assert list(logs) == [] - logs = await cluster.logs(scheduler=False, workers=True) + logs = await cluster.get_logs(scheduler=False, workers=True) assert set(logs) == set(cluster.scheduler.workers) w = toolz.first(cluster.scheduler.workers) - logs = await cluster.logs(scheduler=False, workers=[w]) + logs = await cluster.get_logs(scheduler=False, workers=[w]) assert set(logs) == {w} diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 8f121c7f27c..1ce1c1dfa6f 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -5465,12 +5465,11 @@ def test_tuple_keys(c, s, a, b): @gen_cluster(client=True) -def test_multiple_scatter(c, s, a, b): - for i in range(5): - x = c.scatter(1, direct=True) +async def test_multiple_scatter(c, s, a, b): + futures = await asyncio.gather(*[c.scatter(1, direct=True) for _ in range(5)]) - x = yield x - x = yield x + x = await futures[0] + x = await futures[0] @gen_cluster(client=True) diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index 76f2b285500..b4993bb4ce8 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -781,4 +781,4 @@ async def test_connection_pool_detects_remote_close(): # check connection pool invariants hold even after it detects a closed connection # while creating conn2: p._validate() - p.close() + await p.close() diff --git a/distributed/tests/test_variable.py b/distributed/tests/test_variable.py index 6e3b3bcdad6..64765d808c7 100644 --- a/distributed/tests/test_variable.py +++ b/distributed/tests/test_variable.py @@ -112,7 +112,7 @@ def test_timeout_sync(client): assert 0.2 < stop - start < 2.0 with pytest.raises(TimeoutError): - yield v.get(timeout=0.01) + v.get(timeout=0.01) @gen_cluster(client=True) From 10e6018c737d33cc46a14778ca42ce2e0d26d434 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Tue, 31 Mar 2020 21:24:34 -0500 Subject: [PATCH 0763/1550] Replace ncores with nthreads in work stealing tests (#3615) --- distributed/tests/test_steal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 0ed9051cc95..1c9fe22e2e8 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -321,7 +321,7 @@ def test_steal_host_restrictions(c, s, wa, wb): assert len(wa.task_state) == ntasks assert len(wb.task_state) == 0 - wc = yield Worker(s.address, ncores=1) + wc = yield Worker(s.address, nthreads=1) start = time() while not wc.task_state or len(wa.task_state) == ntasks: From 2129b740c1e3f524e5ba40a0b6a77b239d4c1f94 Mon Sep 17 00:00:00 2001 From: Lucas Rademaker <44430780+lr4d@users.noreply.github.com> Date: Wed, 1 Apr 2020 08:48:31 +0200 Subject: [PATCH 0764/1550] Add Semaphore extension (#3573) The complexity of the internal structure comes in since we do not support any notion of an ephemeral key, i.e. a value which expires together with the session. In this context this translates best to the Client. Therefore, the Semaphore tracks which lease stems from which Semaphore client instance and stores its associated Client ID. If the client is lost/closed, the semaphore will release all it's acquired values eventually. This behavior is quite important for resilience: If a worker is shut down ungracefully, all acquired leases should be released eventually since otherwise we may cause a deadlock. gh3573 and gh2690 --- distributed/__init__.py | 1 + distributed/distributed.yaml | 2 + distributed/scheduler.py | 2 + distributed/semaphore.py | 341 ++++++++++++++++++++++++++++ distributed/tests/test_semaphore.py | 246 ++++++++++++++++++++ distributed/worker.py | 2 +- 6 files changed, 593 insertions(+), 1 deletion(-) create mode 100644 distributed/semaphore.py create mode 100644 distributed/tests/test_semaphore.py diff --git a/distributed/__init__.py b/distributed/__init__.py index be750f9daed..608e23a58e3 100644 --- a/distributed/__init__.py +++ b/distributed/__init__.py @@ -22,6 +22,7 @@ from .nanny import Nanny from .pubsub import Pub, Sub from .queues import Queue +from .semaphore import Semaphore from .scheduler import Scheduler from .threadpoolexecutor import rejoin from .utils import sync, TimeoutError, CancelledError diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index 17326aebd54..f4bfc7d76ab 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -38,6 +38,8 @@ distributed: ca-file: null key: null cert: null + locks: + lease-validation-interval: 10s worker: blocked-handlers: [] diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 882970838df..bc6c0ea0fd5 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -76,6 +76,7 @@ from .publish import PublishExtension from .queues import QueueExtension +from .semaphore import SemaphoreExtension from .recreate_exceptions import ReplayExceptionScheduler from .lock import LockExtension from .pubsub import PubSubSchedulerExtension @@ -96,6 +97,7 @@ QueueExtension, VariableExtension, PubSubSchedulerExtension, + SemaphoreExtension, ] ALL_TASK_STATES = {"released", "waiting", "no-worker", "processing", "erred", "memory"} diff --git a/distributed/semaphore.py b/distributed/semaphore.py new file mode 100644 index 00000000000..23f253b7ede --- /dev/null +++ b/distributed/semaphore.py @@ -0,0 +1,341 @@ +import uuid +from collections import defaultdict, deque +from functools import partial +import asyncio +import dask +from asyncio import TimeoutError +from .utils import PeriodicCallback, log_errors, parse_timedelta +from .worker import get_client +from .metrics import time +import warnings +import logging + +logger = logging.getLogger(__name__) + + +class _Watch: + def __init__(self, duration=None): + self.duration = duration + self.started_at = None + + def start(self): + self.started_at = time() + + def leftover(self): + if self.duration is None: + return None + else: + elapsed = time() - self.started_at + return max(0, self.duration - elapsed) + + +class SemaphoreExtension: + """ An extension for the scheduler to manage Semaphores + + This adds the following routes to the scheduler + + * semaphore_acquire + * semaphore_release + * semaphore_create + """ + + def __init__(self, scheduler): + self.scheduler = scheduler + self.leases = defaultdict(deque) + self.events = defaultdict(asyncio.Event) + self.max_leases = dict() + self.leases_per_client = defaultdict(partial(defaultdict, deque)) + self.scheduler.handlers.update( + { + "semaphore_create": self.create, + "semaphore_acquire": self.acquire, + "semaphore_release": self.release, + "semaphore_close": self.close, + } + ) + + self.scheduler.extensions["semaphores"] = self + self.pc_validate_leases = PeriodicCallback( + self._validate_leases, + 1000 + * parse_timedelta( + dask.config.get( + "distributed.scheduler.locks.lease-validation-interval" + ), + default="s", + ), + io_loop=self.scheduler.loop, + ) + self.pc_validate_leases.start() + self._validation_running = False + + # `comm` here is required by the handler interface + def create(self, comm=None, name=None, max_leases=None): + # We use `self.max_leases.keys()` as the point of truth to find out if a semaphore with a specific + # `name` has been created. + if name not in self.max_leases: + assert isinstance(max_leases, int), max_leases + self.max_leases[name] = max_leases + else: + if max_leases != self.max_leases[name]: + raise ValueError( + "Inconsistent max leases: %s, expected: %s" + % (max_leases, self.max_leases[name]) + ) + + async def _get_lease(self, client, name, identifier): + result = True + if len(self.leases[name]) < self.max_leases[name]: + # naive: self.leases[resource] += 1 + # not naive: + self.leases[name].append(identifier) + self.leases_per_client[client][name].append(identifier) + else: + result = False + return result + + def _semaphore_exists(self, name): + if name not in self.max_leases: + return False + return True + + async def acquire( + self, comm=None, name=None, client=None, timeout=None, identifier=None + ): + with log_errors(): + if not self._semaphore_exists(name): + raise RuntimeError(f"Semaphore `{name}` not known or already closed.") + + if isinstance(name, list): + name = tuple(name) + w = _Watch(timeout) + w.start() + + while True: + # Reset the event and try to get a release. The event will be set if the state + # is changed and helps to identify when it is worth to retry an acquire + self.events[name].clear() + + # If we hit the timeout, this cancels the _get_lease + future = asyncio.wait_for( + self._get_lease(client, name, identifier), timeout=w.leftover() + ) + + try: + result = await future + except TimeoutError: + result = False + + # If acquiring fails, we wait for the event to be set, i.e. something has + # been released and we can try to acquire again (continue loop) + if not result: + future = asyncio.wait_for( + self.events[name].wait(), timeout=w.leftover() + ) + try: + await future + continue + except TimeoutError: + result = False + return result + + def release(self, comm=None, name=None, client=None, identifier=None): + with log_errors(): + if not self._semaphore_exists(name): + logger.warning( + f"Tried to release semaphore `{name}` but it is not known or already closed." + ) + return + if isinstance(name, list): + name = tuple(name) + if name in self.leases and identifier in self.leases[name]: + self._release_value(name, client, identifier) + else: + raise ValueError( + f"Tried to release semaphore but it was already released: " + f"client={client}, name={name}, identifier={identifier}" + ) + + def _release_value(self, name, client, identifier): + # Everything needs to be atomic here. + self.leases_per_client[client][name].remove(identifier) + self.leases[name].remove(identifier) + self.events[name].set() + + def _release_client(self, client): + semaphore_names = list(self.leases_per_client[client]) + for name in semaphore_names: + ids = list(self.leases_per_client[client][name]) + for _id in list(ids): + self._release_value(name=name, client=client, identifier=_id) + del self.leases_per_client[client] + + def _validate_leases(self): + if not self._validation_running: + self._validation_running = True + known_clients_with_leases = set(self.leases_per_client.keys()) + scheduler_clients = set(self.scheduler.clients.keys()) + for dead_client in known_clients_with_leases - scheduler_clients: + self._release_client(dead_client) + else: + self._validation_running = False + + def close(self, comm=None, name=None): + """Hard close the semaphore without warning clients which still hold a lease.""" + with log_errors(): + if not self._semaphore_exists(name): + return + + del self.max_leases[name] + if name in self.events: + del self.events[name] + if name in self.leases: + del self.leases[name] + + for client, client_leases in self.leases_per_client.items(): + if name in client_leases: + warnings.warn( + f"Closing semaphore `{name}` but client `{client}` still has a lease open.", + RuntimeWarning, + ) + del client_leases[name] + + +class Semaphore: + """ Semaphore + + Parameters + ---------- + max_leases: int (optional) + The maximum amount of leases that may be granted at the same time. This + effectively sets an upper limit to the amount of parallel access to a specific resource. + Defaults to 1. + name: string (optional) + Name of the semaphore to acquire. Choosing the same name allows two + disconnected processes to coordinate. If not given, a random + name will be generated. + client: Client (optional) + Client to use for communication with the scheduler. If not given, the + default global client will be used. + + Examples + -------- + >>> from distributed import Semaphore + >>> sem = Semaphore(max_leases=2, name='my_database') + >>> def access_resource(s, sem): + >>> # This automatically acquires a lease from the semaphore (if available) which will be + >>> # released when leaving the context manager. + >>> with sem: + >>> pass + >>> + >>> futures = client.map(access_resource, range(10), sem=sem) + >>> client.gather(futures) + >>> # Once done, close the semaphore to clean up the state on scheduler side. + >>> sem.close() + + Notes + ----- + If a client attempts to release the semaphore but doesn't have a lease acquired, this will raise an exception. + + + When a semaphore is closed, if, for that closed semaphore, a client attempts to: + + - Acquire a lease: an exception will be raised. + - Release: a warning will be logged. + - Close: nothing will happen. + + + dask executes functions by default assuming they are pure, when using semaphore acquire/releases inside + such a function, it must be noted that there *are* in fact side-effects, thus, the function can no longer be + considered pure. If this is not taken into account, this may lead to unexpected behavior. + + """ + + def __init__(self, max_leases=1, name=None, client=None): + # NOTE: the `id` of the `Semaphore` instance will always be unique, even among different + # instances for the same resource. The actual attribute that identifies a specific resource is `name`, + # which will be the same for all instances of this class which limit the same resource. + self.client = client or get_client() + self.id = uuid.uuid4().hex + self.name = name or "semaphore-" + uuid.uuid4().hex + self.max_leases = max_leases + + if self.client.asynchronous: + self._started = self.client.scheduler.semaphore_create( + name=self.name, max_leases=max_leases + ) + else: + self.client.sync( + self.client.scheduler.semaphore_create, + name=self.name, + max_leases=max_leases, + ) + self._started = asyncio.sleep(0) + + def __await__(self): + async def create_semaphore(): + await self._started + return self + + return create_semaphore().__await__() + + def acquire(self, timeout=None): + """ + Acquire a semaphore. + + If the internal counter is greater than zero, decrement it by one and return True immediately. + If it is zero, wait until a release() is called and return True. + """ + # TODO: This (may?) keep the HTTP request open until timeout runs out (forever if None). + # Can do this in batches of smaller timeouts. + # TODO: what if connection breaks up? + return self.client.sync( + self.client.scheduler.semaphore_acquire, + name=self.name, + timeout=timeout, + client=self.client.id, + identifier=self.id, + ) + + def release(self): + """ + Release a semaphore. + + Increment the internal counter by one. + """ + + """ Release the lock if already acquired """ + # TODO: what if connection breaks up? + return self.client.sync( + self.client.scheduler.semaphore_release, + name=self.name, + client=self.client.id, + identifier=self.id, + ) + + def __enter__(self): + self.acquire() + return self + + def __exit__(self, *args, **kwargs): + self.release() + + async def __aenter__(self): + await self.acquire() + return self + + async def __aexit__(self, *args, **kwargs): + await self.release() + + def __getstate__(self): + # Do not serialize the address since workers may have different + # addresses for the scheduler (e.g. if a proxy is between them) + return (self.name, self.max_leases) + + def __setstate__(self, state): + name, max_leases = state + client = get_client() + self.__init__(name=name, client=client, max_leases=max_leases) + + def close(self): + return self.client.sync(self.client.scheduler.semaphore_close, name=self.name) diff --git a/distributed/tests/test_semaphore.py b/distributed/tests/test_semaphore.py new file mode 100644 index 00000000000..9d94b83515a --- /dev/null +++ b/distributed/tests/test_semaphore.py @@ -0,0 +1,246 @@ +import pickle + +import dask +from dask.distributed import Client + +from distributed import Semaphore +from distributed.metrics import time +from distributed.utils_test import cluster, gen_cluster +from distributed.utils_test import client, loop, cluster_fixture # noqa: F401 +import pytest + + +@gen_cluster(client=True) +async def test_semaphore(c, s, a, b): + semaphore = await Semaphore(max_leases=2, name="resource_we_want_to_limit") + + result = await semaphore.acquire() # allowed_leases: 2 - 1 -> 1 + assert result is True + + second = await semaphore.acquire() # allowed_leases: 1 - 1 -> 0 + assert second is True + start = time() + result = await semaphore.acquire(timeout=0.025) # allowed_leases: 0 -> False + stop = time() + assert stop - start < 0.2 + assert result is False + + +@gen_cluster(client=True) +async def test_serializable(c, s, a, b): + sem = await Semaphore(max_leases=2, name="x") + res = await sem.acquire() + assert len(s.extensions["semaphores"].leases["x"]) == 1 + assert res + sem2 = pickle.loads(pickle.dumps(sem)) + assert sem2.name == sem.name + assert sem2.client.scheduler.address == sem.client.scheduler.address + + # actual leases didn't change + assert len(s.extensions["semaphores"].leases["x"]) == 1 + + res = await sem2.acquire() + assert res + assert len(s.extensions["semaphores"].leases["x"]) == 2 + + # Ensure that both objects access the same semaphore + res = await sem.acquire(timeout=0.025) + + assert not res + res = await sem2.acquire(timeout=0.025) + + assert not res + + +@gen_cluster(client=True) +async def test_release_simple(c, s, a, b): + def f(x, semaphore): + with semaphore: + assert semaphore.name == "x" + return x + 1 + + sem = await Semaphore(max_leases=2, name="x") + futures = c.map(f, list(range(10)), semaphore=sem) + await c.gather(futures) + + +@gen_cluster(client=True) +async def test_acquires_with_timeout(c, s, a, b): + sem = await Semaphore(1, "x") + assert await sem.acquire(timeout=0.025) + assert not await sem.acquire(timeout=0.025) + await sem.release() + assert await sem.acquire(timeout=0.025) + await sem.release() + + +def test_timeout_sync(client): + s = Semaphore(name="x") + # Using the context manager already acquires a lease, so the line below won't be able to acquire another one + with s: + assert s.acquire(timeout=0.025) is False + + +@pytest.mark.slow +@gen_cluster(client=True, timeout=20) +async def test_release_semaphore_after_timeout(c, s, a, b): + with dask.config.set( + {"distributed.scheduler.locks.lease-validation-interval": "50ms"} + ): + sem = await Semaphore(name="x", max_leases=2) + await sem.acquire() # leases: 2 - 1 = 1 + semY = await Semaphore(name="y") + + async with Client(s.address, asynchronous=True, name="ClientB") as clientB: + semB = await Semaphore(name="x", max_leases=2, client=clientB) + semYB = await Semaphore(name="y", client=clientB) + + assert await semB.acquire() # leases: 1 - 1 = 0 + assert await semYB.acquire() + + assert not (await sem.acquire(timeout=0.01)) + assert not (await semB.acquire(timeout=0.01)) + assert not (await semYB.acquire(timeout=0.01)) + + # `ClientB` goes out of scope, leases should be released + # At this point, we should be able to acquire x and y once + assert await sem.acquire() + assert await semY.acquire() + + assert not (await semY.acquire(timeout=0.01)) + assert not (await sem.acquire(timeout=0.01)) + + assert clientB.id not in s.extensions["semaphores"].leases_per_client + + +@gen_cluster() +async def test_async_ctx(s, a, b): + sem = await Semaphore(name="x") + async with sem: + assert not await sem.acquire(timeout=0.025) + assert await sem.acquire() + + +@pytest.mark.slow +def test_worker_dies(): + with cluster(disconnect_timeout=10) as (scheduler, workers): + with Client(scheduler["address"]) as client: + sem = Semaphore(name="x", max_leases=1) + + def f(x, sem, kill_address): + with sem: + from distributed.worker import get_worker + + worker = get_worker() + if worker.address == kill_address: + import os + + os.kill(os.getpid(), 15) + return x + + futures = client.map( + f, range(100), sem=sem, kill_address=workers[0]["address"] + ) + results = client.gather(futures) + + assert sorted(results) == list(range(100)) + + +@gen_cluster(client=True) +async def test_access_semaphore_by_name(c, s, a, b): + def f(x, release=True): + sem = Semaphore(name="x") + if not sem.acquire(timeout=0.1): + return False + if release: + sem.release() + + return True + + sem = await Semaphore(name="x") + futures = c.map(f, list(range(10))) + assert all(await c.gather(futures)) + + # Clean-up the state, otherwise we would get the same result when calling `f` with the same arguments + del futures + + assert len(s.extensions["semaphores"].leases["x"]) == 0 + assert await sem.acquire() + assert len(s.extensions["semaphores"].leases["x"]) == 1 + futures = c.map(f, list(range(10))) + assert not any(await c.gather(futures)) + await sem.release() + + del futures + + futures = c.map(f, list(range(10)), release=False) + result = await c.gather(futures) + assert result.count(True) == 1 + assert result.count(False) == 9 + + +@gen_cluster(client=True) +async def test_close_async(c, s, a, b): + sem = await Semaphore(name="test") + + assert await sem.acquire() + with pytest.warns( + RuntimeWarning, match="Closing semaphore .* but client .* still has a lease" + ): + await sem.close() + + with pytest.raises( + RuntimeError, match="Semaphore `test` not known or already closed." + ): + await sem.acquire() + + semaphore_object = s.extensions["semaphores"] + assert not semaphore_object.max_leases + assert not semaphore_object.leases + assert not semaphore_object.events + assert not any(semaphore_object.leases_per_client.values()) + + +def test_close_sync(client): + sem = Semaphore() + sem.close() + + with pytest.raises(RuntimeError, match="Semaphore .* not known or already closed."): + sem.acquire() + + +@gen_cluster(client=True) +async def test_release_once_too_many(c, s, a, b): + sem = await Semaphore(name="x") + assert await sem.acquire() + await sem.release() + + with pytest.raises( + ValueError, match="Tried to release semaphore but it was already released" + ): + await sem.release() + + assert await sem.acquire() + await sem.release() + + +@gen_cluster(client=True) +async def test_release_once_too_many_resilience(c, s, a, b): + def f(x, sem): + sem.acquire() + sem.release() + with pytest.raises( + ValueError, match="Tried to release semaphore but it was already released" + ): + sem.release() + return x + + sem = await Semaphore(max_leases=3, name="x") + + inpt = list(range(20)) + futures = c.map(f, inpt, sem=sem) + assert sorted(await c.gather(futures)) == inpt + + assert not s.extensions["semaphores"].leases["x"] + await sem.acquire() + assert len(s.extensions["semaphores"].leases["x"]) == 1 diff --git a/distributed/worker.py b/distributed/worker.py index ba25c91d979..c6ae63f0ef1 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1160,7 +1160,7 @@ async def close_gracefully(self): await self.scheduler.retire_workers(workers=[self.address], remove=False) await self.close(safe=True, nanny=not self.lifetime_restart) - async def terminate(self, comm, report=True, **kwargs): + async def terminate(self, comm=None, report=True, **kwargs): await self.close(report=report, **kwargs) return "OK" From 66fe0acf041add720d5cc9d9d1ed99d9e90e7e2e Mon Sep 17 00:00:00 2001 From: Prasun Anand Date: Wed, 1 Apr 2020 20:56:23 +0530 Subject: [PATCH 0765/1550] Add Resouces option to get_task_stream and call output_file (#3653) --- distributed/client.py | 29 +++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index f2b25b74f60..d42c29c2314 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -3367,9 +3367,10 @@ async def _profile( filename = "dask-profile.html" if filename: - from bokeh.plotting import save + from bokeh.plotting import output_file, save - save(figure, title="Dask Profile", filename=filename) + output_file(filename=filename, title="Dask Profile") + save(figure, filename=filename) return (state, figure) else: @@ -3852,7 +3853,13 @@ def collections_to_dsk(collections, *args, **kwargs): return collections_to_dsk(collections, *args, **kwargs) def get_task_stream( - self, start=None, stop=None, count=None, plot=False, filename="task-stream.html" + self, + start=None, + stop=None, + count=None, + plot=False, + filename="task-stream.html", + bokeh_resources=None, ): """ Get task stream data from scheduler @@ -3881,6 +3888,8 @@ def get_task_stream( If plot == 'save' then save the figure to a file filename: str (optional) The filename to save to if you set ``plot='save'`` + bokeh_resources: bokeh.resources.Resources (optional) + Specifies if the resource component is INLINE or CDN Examples -------- @@ -3920,10 +3929,17 @@ def get_task_stream( count=count, plot=plot, filename=filename, + bokeh_resources=bokeh_resources, ) async def _get_task_stream( - self, start=None, stop=None, count=None, plot=False, filename="task-stream.html" + self, + start=None, + stop=None, + count=None, + plot=False, + filename="task-stream.html", + bokeh_resources=None, ): msgs = await self.scheduler.get_task_stream(start=start, stop=stop, count=count) if plot: @@ -3935,9 +3951,10 @@ async def _get_task_stream( source, figure = task_stream_figure(sizing_mode="stretch_both") source.data.update(rects) if plot == "save": - from bokeh.plotting import save + from bokeh.plotting import save, output_file - save(figure, title="Dask Task Stream", filename=filename) + output_file(filename=filename, title="Dask Task Stream") + save(figure, filename=filename, resources=bokeh_resources) return (msgs, figure) else: return msgs From 09f86837e9b908b43991f6b4159f0083ffeb799b Mon Sep 17 00:00:00 2001 From: jakirkham Date: Thu, 2 Apr 2020 07:36:35 -0700 Subject: [PATCH 0766/1550] Get CUDA context to finalize Numba `DeviceNDArray` (#3666) * Ensure CUDA context on `DeviceNDArray` cleanup As the CUDA context does not seem to always be established when cleaning up Numba `DeviceNDArray`s, make sure to get the context right before cleanup to ensure the context is available. This is done when allocating new CUDA frames to when using UCX. Make sure this is handled both for Numba and old versions of RMM that also use Numba under-the-hood. Newer versions of RMM don't need this. Also do this when performing a host-to-device transfer where a Numba `DeviceNDArray` is the first object created on device. * Drop `PatchedCudaArrayInterface` and usage thereof This was used to ensure that Numba had acquired a CUDA context before finalizing any Numba `DeviceNDArray` objects that might be used to back a CuPy `ndarray`. As we now make sure of this when creating Numba `DeviceNDArray` by using `weakref.finalize`, this should be correctly handled for all CUDA objects backed by Numba `DeviceNDArray`s (not just CuPy `ndarray`s). So there should be no need to do this for CuPy `ndarray`s separately. Should simplify the CuPy serialization and cleanup paths a bit. --- distributed/comm/ucx.py | 17 +++++++++++++++-- distributed/protocol/cupy.py | 27 --------------------------- distributed/protocol/numba.py | 4 ++++ 3 files changed, 19 insertions(+), 29 deletions(-) diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index a29441ec4d5..4e6ca8116c8 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -6,6 +6,7 @@ .. _UCX: https://github.com/openucx/ucx """ import logging +import weakref import dask import numpy as np @@ -65,12 +66,24 @@ def init_once(): if hasattr(rmm, "DeviceBuffer"): cuda_array = lambda n: rmm.DeviceBuffer(size=n) else: # pre-0.11.0 - cuda_array = lambda n: rmm.device_array(n, dtype=np.uint8) + import numba.cuda + + def rmm_cuda_array(n): + a = rmm.device_array(n, dtype=np.uint8) + weakref.finalize(a, numba.cuda.current_context) + return a + + cuda_array = rmm_cuda_array except ImportError: try: import numba.cuda - cuda_array = lambda n: numba.cuda.device_array((n,), dtype=np.uint8) + def numba_cuda_array(n): + a = numba.cuda.device_array((n,), dtype=np.uint8) + weakref.finalize(a, numba.cuda.current_context) + return a + + cuda_array = numba_cuda_array except ImportError: def cuda_array(n): diff --git a/distributed/protocol/cupy.py b/distributed/protocol/cupy.py index b3465fee424..0a2c53be4a5 100644 --- a/distributed/protocol/cupy.py +++ b/distributed/protocol/cupy.py @@ -14,31 +14,6 @@ from .numba import dask_deserialize_numba_array as dask_deserialize_cuda_buffer -class PatchedCudaArrayInterface: - """This class does one thing: - 1) Makes sure that the cuda context is active - when deallocating the base cuda array. - Notice, this is only needed when the array to deserialize - isn't a native cupy array. - """ - - def __init__(self, ary): - self.__cuda_array_interface__ = ary.__cuda_array_interface__ - # Save a ref to ary so it won't go out of scope - self.base = ary - - def __del__(self): - # Making sure that the cuda context is active - # when deallocating the base cuda array - try: - import numba.cuda - - numba.cuda.current_context() - except ImportError: - pass - del self.base - - @cuda_serialize.register(cupy.ndarray) def cuda_serialize_cupy_ndarray(x): # Making sure `x` is behaving @@ -60,8 +35,6 @@ def cuda_serialize_cupy_ndarray(x): @cuda_deserialize.register(cupy.ndarray) def cuda_deserialize_cupy_ndarray(header, frames): (frame,) = frames - if not isinstance(frame, cupy.ndarray): - frame = PatchedCudaArrayInterface(frame) arr = cupy.ndarray( shape=header["shape"], dtype=header["typestr"], diff --git a/distributed/protocol/numba.py b/distributed/protocol/numba.py index 03bf4aa9f16..20eec8e11b6 100644 --- a/distributed/protocol/numba.py +++ b/distributed/protocol/numba.py @@ -1,3 +1,5 @@ +import weakref + import numba.cuda import numpy as np @@ -59,6 +61,8 @@ def dask_deserialize_numba_array(header, frames): frames = [dask_deserialize_rmm_device_buffer(header, frames)] else: frames = [numba.cuda.to_device(np.asarray(memoryview(f))) for f in frames] + for f in frames: + weakref.finalize(f, numba.cuda.current_context) arr = cuda_deserialize_numba_ndarray(header, frames) return arr From c05899e01189d4bbabeb53d46c37f5d7cdde5eb8 Mon Sep 17 00:00:00 2001 From: Florian Jetter Date: Thu, 2 Apr 2020 20:28:06 +0200 Subject: [PATCH 0767/1550] More documentation for Semaphore (#3664) --- distributed/distributed.yaml | 2 +- distributed/semaphore.py | 21 +++++++++++++++++++++ docs/source/api.rst | 2 ++ 3 files changed, 24 insertions(+), 1 deletion(-) diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index f4bfc7d76ab..ca31b17776e 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -39,7 +39,7 @@ distributed: key: null cert: null locks: - lease-validation-interval: 10s + lease-validation-interval: 10s # The time to wait until an acquired semaphore is released if the Client goes out of scope worker: blocked-handlers: [] diff --git a/distributed/semaphore.py b/distributed/semaphore.py index 23f253b7ede..6f5553af0e8 100644 --- a/distributed/semaphore.py +++ b/distributed/semaphore.py @@ -204,6 +204,26 @@ def close(self, comm=None, name=None): class Semaphore: """ Semaphore + This `semaphore `_ + will track leases on the scheduler which can be acquired and + released by an instance of this class. If the maximum amount of leases are + already acquired, it is not possible to acquire more and the caller waits + until another lease has been released. + + The lifetime of a lease is coupled to the ``Client`` it was acquired with. + Once the Client goes out of scope, the leases associated to it are freed. + This behavior can be controlled with the + ``distributed.scheduler.locks.lease-validation-interval`` configuration + option. + + A noticeable difference to the Semaphore of the python standard library is + that this implementation does not allow to release more often than it was + acquired. If this happens, a warning is emitted but the internal state is + not modified. + + This implementation is still in an experimental state and subtle changes in + behavior may occur without any change in the major version of this library. + Parameters ---------- max_leases: int (optional) @@ -222,6 +242,7 @@ class Semaphore: -------- >>> from distributed import Semaphore >>> sem = Semaphore(max_leases=2, name='my_database') + >>> >>> def access_resource(s, sem): >>> # This automatically acquires a lease from the semaphore (if available) which will be >>> # released when leaving the context manager. diff --git a/docs/source/api.rst b/docs/source/api.rst index 9d2f6c7f870..da9a76eed9b 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -194,6 +194,8 @@ Other .. autoclass:: Lock :members: +.. autoclass:: Semaphore + :members: .. autoclass:: Queue :members: .. autoclass:: Variable From 46314d88ebf7e1d2fe19bd97b6cda72c5d087419 Mon Sep 17 00:00:00 2001 From: "Jonathan J. Helmus" Date: Thu, 2 Apr 2020 17:30:44 -0500 Subject: [PATCH 0768/1550] Remove openssl 1.1.1d pin for Travis (#3668) OpenSSL 1.1.1f has been released with a fix for the bug introduced in 1.1.1e. Pinning to 1.1.1d is no longer necessary. --- continuous_integration/travis/install.sh | 3 --- 1 file changed, 3 deletions(-) diff --git a/continuous_integration/travis/install.sh b/continuous_integration/travis/install.sh index 09d13962bd4..4ee0790f6c5 100644 --- a/continuous_integration/travis/install.sh +++ b/continuous_integration/travis/install.sh @@ -79,9 +79,6 @@ if [[ $CRICK == true ]]; then python -m pip install -q git+https://github.com/jcrist/crick.git fi; -# Pin openssl==1.1.1d (see https://github.com/dask/distributed/issues/3588) -conda install -c conda-forge openssl==1.1.1d - # Install distributed python -m pip install --no-deps -e . From 4f11509b844c3569f164704983bb0affd009dd4c Mon Sep 17 00:00:00 2001 From: jakirkham Date: Thu, 2 Apr 2020 16:36:15 -0700 Subject: [PATCH 0769/1550] Enable more UCX tests (#3667) * Re-enable cuDF serialization tests * Skip cuDF strings test that segfaults test * Run CuPy test as this appears to work * Contain `check_deserialize` import Sometimes this doesn't work on some machines. We are still working out how to get this to work more reliably. For now just skip the test if this `import` doesn't work for any reason. Also contain the `import` to this test as other tests don't need it and we don't want the tests to fail to run entirely. * Rerurn `black` * Allow `test_ucx_deserialize` to error Instead of trying to catch and skip certain errors, go ahead and just add a comment to this test about how it can error on some systems. Since this no longer blocks the full test suite from running, it is less of an issue. Plus seeing it may give us motivation to fix either the system or the test in the future ;) * Rerun black --- distributed/comm/tests/test_ucx.py | 34 +++++++++++------------------- 1 file changed, 12 insertions(+), 22 deletions(-) diff --git a/distributed/comm/tests/test_ucx.py b/distributed/comm/tests/test_ucx.py index 84da6e4f1aa..f61ea22128a 100644 --- a/distributed/comm/tests/test_ucx.py +++ b/distributed/comm/tests/test_ucx.py @@ -12,8 +12,6 @@ from dask.dataframe.utils import assert_eq from distributed.utils_test import gen_test, loop, inc, cleanup, popen # noqa: 401 -from .test_comms import check_deserialize - try: HOST = ucp.get_address() @@ -156,6 +154,11 @@ async def test_ping_pong_data(): @gen_test() def test_ucx_deserialize(): + # Note we see this error on some systems with this test: + # `socket.gaierror: [Errno -5] No address associated with hostname` + # This may be due to a system configuration issue. + from .test_comms import check_deserialize + yield check_deserialize("tcp://") @@ -169,22 +172,15 @@ def test_ucx_deserialize(): lambda cudf: cudf.DataFrame([1]).head(0), lambda cudf: cudf.DataFrame([1.0]).head(0), lambda cudf: cudf.DataFrame({"a": []}), - pytest.param( - lambda cudf: cudf.DataFrame({"a": ["a"]}).head(0), - marks=pytest.mark.xfail(reason="0 length objects don't deseralize cleanly"), - ), - pytest.param( - lambda cudf: cudf.DataFrame({"a": [1.0]}).head(0), - marks=pytest.mark.xfail(reason="0 length objects don't deseralize cleanly"), - ), - pytest.param( - lambda cudf: cudf.DataFrame({"a": [1]}).head(0), - marks=pytest.mark.xfail(reason="0 length objects don't deseralize cleanly"), - ), + lambda cudf: cudf.DataFrame({"a": ["a"]}).head(0), + lambda cudf: cudf.DataFrame({"a": [1.0]}).head(0), + lambda cudf: cudf.DataFrame({"a": [1]}).head(0), lambda cudf: cudf.DataFrame({"a": [1, 2, None], "b": [1.0, 2.0, None]}), pytest.param( lambda cudf: cudf.DataFrame({"a": ["Check", "str"], "b": ["Sup", "port"]}), - marks=pytest.mark.xfail(reason="0 length objects don't deseralize cleanly"), + marks=pytest.mark.skip( + reason="This test segfaults for some reason. So skip running it entirely." + ), ), ], ) @@ -231,13 +227,7 @@ async def test_ping_pong_cupy(shape): @pytest.mark.slow @pytest.mark.asyncio @pytest.mark.parametrize( - "n", - [ - int(1e9), - pytest.param( - int(2.5e9), marks=[pytest.mark.xfail(reason="integer type in ucx-py")] - ), - ], + "n", [int(1e9), int(2.5e9),], ) async def test_large_cupy(n, cleanup): cupy = pytest.importorskip("cupy") From a5d1961a579ba934370fe166d84885f948851305 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Fri, 3 Apr 2020 15:57:13 -0500 Subject: [PATCH 0770/1550] bump version to 2.14.0 --- docs/source/changelog.rst | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index 12288fc4aba..c1bcab71eb5 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -1,6 +1,33 @@ Changelog ========= +2.14.0 - 2020-04-03 +------------------- + +- Enable more UCX tests (:pr:`3667`) `jakirkham`_ +- Remove openssl 1.1.1d pin for Travis (:pr:`3668`) `Jonathan J. Helmus`_ +- More documentation for ``Semaphore`` (:pr:`3664`) `Florian Jetter`_ +- Get CUDA context to finalize Numba ``DeviceNDArray`` (:pr:`3666`) `jakirkham`_ +- Add Resouces option to ``get_task_stream`` and call ``output_file`` (:pr:`3653`) `Prasun Anand`_ +- Add ``Semaphore`` extension (:pr:`3573`) `Lucas Rademaker`_ +- Replace ``ncores`` with ``nthreads`` in work stealing tests (:pr:`3615`) `James Bourbeau`_ +- Clean up some test warnings (:pr:`3662`) `Matthew Rocklin`_ +- Write "why killed" docs (:pr:`3596`) `Martin Durant`_ +- Update Python version checking (:pr:`3660`) `James Bourbeau`_ +- Add newlines to ensure code formatting for ``retire_workers`` (:pr:`3661`) `Rami Chowdhury`_ +- Clean up performance report test (:pr:`3655`) `Matthew Rocklin`_ +- Avoid diagnostics time in performance report (:pr:`3654`) `Matthew Rocklin`_ +- Introduce config for default task duration (:pr:`3642`) `Gabriel Sailer`_ +- UCX simplify receiving frames in ``comm`` (:pr:`3651`) `jakirkham`_ +- Bump checkout GitHub action to v2 (:pr:`3649`) `James Bourbeau`_ +- Handle exception in ``faulthandler`` (:pr:`3646`) `Jacob Tomlinson`_ +- Add prometheus metric for suspicious tasks (:pr:`3550`) `Gabriel Sailer`_ +- Remove ``local-directory`` keyword (:pr:`3620`) `Prasun Anand`_ +- Don't create output Futures in Client when there are mixed Client Futures (:pr:`3643`) `James Bourbeau`_ +- Add link to ``contributing.md`` (:pr:`3621`) `Prasun Anand`_ +- Update bokeh dependency in CI builds (:pr:`3637`) `James Bourbeau`_ + + 2.13.0 - 2020-03-25 ------------------- @@ -1641,3 +1668,7 @@ significantly without many new features. .. _`Matthias Urlichs`: https://github.com/smurfix .. _`Krishan Bhasin`: https://github.com/KrishanBhasin .. _`Abdulelah Bin Mahfoodh`: https://github.com/abduhbm +.. _`jakirkham`: https://github.com/jakirkham +.. _`Prasun Anand`: https://github.com/prasunanand +.. _`Jonathan J. Helmus`: https://github.com/jjhelmus +.. _`Rami Chowdhury`: https://github.com/necaris From 88d05134a97c8f5d26639c31bccd1af8e20a3abf Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Sat, 4 Apr 2020 12:35:51 -0500 Subject: [PATCH 0771/1550] Update Scheduler.rebalance return value when data is missing (#3670) --- distributed/client.py | 26 +++++++++++++++----------- distributed/scheduler.py | 3 ++- distributed/tests/test_client.py | 21 +++++++++++++++------ 3 files changed, 32 insertions(+), 18 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index d42c29c2314..fc6098d5ee1 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -587,7 +587,7 @@ def __init__( deserializers=None, extensions=DEFAULT_EXTENSIONS, direct_to_workers=None, - **kwargs + **kwargs, ): if timeout == no_default: timeout = dask.config.get("distributed.comm.timeouts.connect") @@ -960,7 +960,7 @@ async def _start(self, timeout=no_default, **kwargs): self.cluster = await LocalCluster( loop=self.loop, asynchronous=self._asynchronous, - **self._startup_kwargs + **self._startup_kwargs, ) except (OSError, socket.error) as e: if e.errno != errno.EADDRINUSE: @@ -970,7 +970,7 @@ async def _start(self, timeout=no_default, **kwargs): scheduler_port=0, loop=self.loop, asynchronous=True, - **self._startup_kwargs + **self._startup_kwargs, ) # Wait for all workers to be ready @@ -1422,7 +1422,7 @@ def submit( actor=False, actors=False, pure=None, - **kwargs + **kwargs, ): """ Submit a function application to the scheduler @@ -1542,7 +1542,7 @@ def map( actor=False, actors=False, pure=None, - **kwargs + **kwargs, ): """ Map a function on a sequence of arguments @@ -2538,7 +2538,7 @@ def get( priority=0, fifo_timeout="60s", actors=None, - **kwargs + **kwargs, ): """ Compute dask graph @@ -2669,7 +2669,7 @@ def compute( fifo_timeout="60s", actors=None, traverse=True, - **kwargs + **kwargs, ): """ Compute dask collections on cluster @@ -2817,7 +2817,7 @@ def persist( priority=0, fifo_timeout="60s", actors=None, - **kwargs + **kwargs, ): """ Persist dask collections on cluster @@ -3013,6 +3013,10 @@ async def _rebalance(self, futures=None, workers=None): await _wait(futures) keys = list({tokey(f.key) for f in self.futures_of(futures)}) result = await self.scheduler.rebalance(keys=keys, workers=workers) + if result["status"] == "missing-data": + raise ValueError( + f"During rebalance {len(result['keys'])} keys were found to be missing" + ) assert result["status"] == "OK" def rebalance(self, futures=None, workers=None, **kwargs): @@ -3023,7 +3027,7 @@ def rebalance(self, futures=None, workers=None, **kwargs): depending on keyword arguments. This operation is generally not well tested against normal operation of - the scheduler. It it not recommended to use it while waiting on + the scheduler. It is not recommended to use it while waiting on computations. Parameters @@ -3085,7 +3089,7 @@ def replicate(self, futures, n=None, workers=None, branching_factor=2, **kwargs) n=n, workers=workers, branching_factor=branching_factor, - **kwargs + **kwargs, ) def nthreads(self, workers=None, **kwargs): @@ -3505,7 +3509,7 @@ def retire_workers(self, workers=None, close_workers=True, **kwargs): self.scheduler.retire_workers, workers=workers, close_workers=close_workers, - **kwargs + **kwargs, ) def set_metadata(self, key, value): diff --git a/distributed/scheduler.py b/distributed/scheduler.py index bc6c0ea0fd5..235cb01931e 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -30,6 +30,7 @@ second, compose, groupby, + concat, ) from tornado.ioloop import IOLoop @@ -3103,7 +3104,7 @@ async def rebalance(self, comm=None, keys=None, workers=None): if not all(r["status"] == "OK" for r in result): return { "status": "missing-data", - "keys": sum([r["keys"] for r in result if "keys" in r], []), + "keys": tuple(concat(r["keys"].keys() for r in result)), } for sender, recipient, ts in msgs: diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 1ce1c1dfa6f..68507d889f0 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -2876,6 +2876,15 @@ def test_rebalance_unprepared(c, s, a, b): s.validate_state() +@gen_cluster(client=True) +async def test_rebalance_raises_missing_data(c, s, a, b): + with pytest.raises(ValueError, match=f"keys were found to be missing"): + futures = await c.scatter(range(100)) + keys = [f.key for f in futures] + del futures + await c.rebalance(keys) + + @gen_cluster(client=True) def test_receive_lost_key(c, s, a, b): x = c.submit(inc, 1, workers=[a.address]) @@ -4864,8 +4873,8 @@ def test_bytes_keys(c, s, a, b): @gen_cluster(client=True) def test_unicode_ascii_keys(c, s, a, b): - uni_type = type(u"") - key = u"inc-123" + uni_type = type("") + key = "inc-123" future = c.submit(inc, 1, key=key) result = yield future assert type(future.key) is uni_type @@ -4876,8 +4885,8 @@ def test_unicode_ascii_keys(c, s, a, b): @gen_cluster(client=True) def test_unicode_keys(c, s, a, b): - uni_type = type(u"") - key = u"inc-123\u03bc" + uni_type = type("") + key = "inc-123\u03bc" future = c.submit(inc, 1, key=key) result = yield future assert type(future.key) is uni_type @@ -4889,8 +4898,8 @@ def test_unicode_keys(c, s, a, b): result2 = yield future2 assert result2 == 3 - future3 = yield c.scatter({u"data-123": 123}) - result3 = yield future3[u"data-123"] + future3 = yield c.scatter({"data-123": 123}) + result3 = yield future3["data-123"] assert result3 == 123 From beee00a3fa96c26075a024241e55eb70fc1eed29 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Sun, 5 Apr 2020 12:09:20 -0500 Subject: [PATCH 0772/1550] Add zoom tools to profile plots (#3672) --- distributed/profile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/profile.py b/distributed/profile.py index 5bf071e20da..1bf81ad6ff0 100644 --- a/distributed/profile.py +++ b/distributed/profile.py @@ -383,7 +383,7 @@ def plot_figure(data, **kwargs): source = ColumnDataSource(data=data) - fig = figure(tools="tap", **kwargs) + fig = figure(tools="tap,box_zoom,xwheel_zoom,reset", **kwargs) r = fig.quad( "left", "right", From cc57f10e636f83d8bbbb2af83bf32a785dfb87d4 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 6 Apr 2020 10:17:12 -0700 Subject: [PATCH 0773/1550] Expose Security object as public API (#3675) --- distributed/__init__.py | 1 + distributed/cli/dask_scheduler.py | 3 +-- distributed/cli/dask_worker.py | 3 +-- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/distributed/__init__.py b/distributed/__init__.py index 608e23a58e3..2ad25d05093 100644 --- a/distributed/__init__.py +++ b/distributed/__init__.py @@ -22,6 +22,7 @@ from .nanny import Nanny from .pubsub import Pub, Sub from .queues import Queue +from .security import Security from .semaphore import Semaphore from .scheduler import Scheduler from .threadpoolexecutor import rejoin diff --git a/distributed/cli/dask_scheduler.py b/distributed/cli/dask_scheduler.py index 2394dd65dea..78d6623608f 100755 --- a/distributed/cli/dask_scheduler.py +++ b/distributed/cli/dask_scheduler.py @@ -11,9 +11,8 @@ from tornado.ioloop import IOLoop -from distributed import Scheduler +from distributed import Scheduler, Security from distributed.preloading import validate_preload_argv -from distributed.security import Security from distributed.cli.utils import check_python_3, install_signal_handlers from distributed.utils import deserialize_for_cli from distributed.proctitle import ( diff --git a/distributed/cli/dask_worker.py b/distributed/cli/dask_worker.py index 29261b52451..ff6d09b4c9a 100755 --- a/distributed/cli/dask_worker.py +++ b/distributed/cli/dask_worker.py @@ -11,8 +11,7 @@ import dask from dask.utils import ignoring from dask.system import CPU_COUNT -from distributed import Nanny -from distributed.security import Security +from distributed import Nanny, Security from distributed.cli.utils import check_python_3, install_signal_handlers from distributed.comm import get_address_host_port from distributed.preloading import validate_preload_argv From f2b13c935cd92af80d3d8f410c8f02053626b422 Mon Sep 17 00:00:00 2001 From: Nicholas Smith Date: Mon, 6 Apr 2020 18:08:57 -0500 Subject: [PATCH 0774/1550] Use relative URL in scheduler dashboard (#3676) As done for other examples of `OpenURL` in the same file. This is important for cases where the dashboard is exposed at a non-root prefix (e.g. via the `--dashboard-prefix` option to `dask-scheduler`) --- distributed/dashboard/components/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/dashboard/components/scheduler.py b/distributed/dashboard/components/scheduler.py index 9519d3629ff..c376e2098e4 100644 --- a/distributed/dashboard/components/scheduler.py +++ b/distributed/dashboard/components/scheduler.py @@ -1090,7 +1090,7 @@ def task_stream_figure(clear_interval="20s", **kwargs): """, ) - tap = TapTool(callback=OpenURL(url="/profile?key=@name")) + tap = TapTool(callback=OpenURL(url="./profile?key=@name")) root.add_tools( hover, From bd84d9d2ac5896e27743a1e0b07b1919cec8e0d3 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 6 Apr 2020 19:01:35 -0700 Subject: [PATCH 0775/1550] Run preload at import, start, and teardown (#3673) Previously we only ran preload scripts after the server had started. This made it difficult to modify the server before certain actions had taken place. For example, a test in this PR registers a new Comm backend for the server to use. This would not have been possible before. Now we run different preload functions when 1. we first instantiate the server 2. we first start the server 3. we close the server (we used to run teardown at `atexit` time --- distributed/preloading.py | 53 +++++++++++++++++++++++-------- distributed/scheduler.py | 11 ++++--- distributed/tests/test_preload.py | 19 +++++++++++ distributed/worker.py | 29 ++++++++++------- 4 files changed, 83 insertions(+), 29 deletions(-) diff --git a/distributed/preloading.py b/distributed/preloading.py index 9b276b4337f..2e0469419b9 100644 --- a/distributed/preloading.py +++ b/distributed/preloading.py @@ -1,4 +1,4 @@ -import atexit +import inspect import logging import os import shutil @@ -111,38 +111,65 @@ def _import_module(name, file_dir=None): } -def preload_modules(names, parameter=None, file_dir=None, argv=None): - """ Imports modules, handles `dask_setup` and `dask_teardown`. +def on_creation(names, file_dir: str = None) -> dict: + """ Imports each of the preload modules Parameters ---------- names: list of strings Module names or file paths - parameter: object - Parameter passed to `dask_setup` and `dask_teardown` - argv: [string] - List of string arguments passed to click-configurable `dask_setup`. file_dir: string Path of a directory where files should be copied """ if isinstance(names, str): names = [names] - for name in names: - interface = _import_module(name, file_dir=file_dir) + return {name: _import_module(name, file_dir=file_dir) for name in names} + +async def on_start(modules: dict, dask_server=None, argv=None): + """ Run when the server finishes its start method + + Parameters + ---------- + modules: Dict[str, module] + The imported modules, from on_creation + dask_server: dask.distributed.Server + The Worker or Scheduler + argv: [string] + List of string arguments passed to click-configurable `dask_setup`. + file_dir: string + Path of a directory where files should be copied + """ + for name, interface in modules.items(): dask_setup = interface.get("dask_setup", None) - dask_teardown = interface.get("dask_teardown", None) if dask_setup: if isinstance(dask_setup, click.Command): context = dask_setup.make_context( "dask_setup", list(argv), allow_extra_args=False ) - dask_setup.callback(parameter, *context.args, **context.params) + dask_setup.callback(dask_server, *context.args, **context.params) else: - dask_setup(parameter) + future = dask_setup(dask_server) + if inspect.isawaitable(future): + await future logger.info("Run preload setup function: %s", name) + +async def on_teardown(modules: dict, dask_server=None): + """ Run when the server starts its close method + + Parameters + ---------- + modules: Dict[str, module] + The imported modules, from on_creation + dask_server: dask.distributed.Server + The Worker or Scheduler + """ + for name, interface in modules.items(): + dask_teardown = interface.get("dask_teardown", None) if dask_teardown: - atexit.register(interface["dask_teardown"], parameter) + future = dask_teardown(dask_server) + if inspect.isawaitable(future): + await future diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 235cb01931e..525170c5b14 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -3,7 +3,7 @@ from collections.abc import Mapping, Set from datetime import timedelta from functools import partial -from inspect import isawaitable +import inspect import itertools import json import logging @@ -49,7 +49,7 @@ from . import profile from .metrics import time from .node import ServerNode -from .preloading import preload_modules +from . import preloading from .proctitle import setproctitle from .security import Security from .utils import ( @@ -1106,6 +1106,7 @@ def __init__( preload_argv = dask.config.get("distributed.scheduler.preload-argv") self.preload = preload self.preload_argv = preload_argv + self._preload_modules = preloading.on_creation(self.preload) self.security = security or Security() assert isinstance(self.security, Security) @@ -1463,7 +1464,7 @@ def del_scheduler_file(): weakref.finalize(self, del_scheduler_file) - preload_modules(self.preload, parameter=self, argv=self.preload_argv) + await preloading.on_start(self._preload_modules, self, argv=self.preload_argv) await asyncio.gather(*[plugin.start(self) for plugin in self.plugins]) @@ -1487,6 +1488,8 @@ async def close(self, comm=None, fast=False, close_workers=False): logger.info("Scheduler closing...") setproctitle("dask-scheduler [closing]") + await preloading.on_teardown(self._preload_modules, self) + if close_workers: await self.broadcast(msg={"op": "close_gracefully"}, nanny=True) for worker in self.workers: @@ -3584,7 +3587,7 @@ async def feed( if teardown: teardown = pickle.loads(teardown) state = setup(self) if setup else None - if isawaitable(state): + if inspect.isawaitable(state): state = await state try: while self.status == "running": diff --git a/distributed/tests/test_preload.py b/distributed/tests/test_preload.py index d3171ed6842..888e7c42ea2 100644 --- a/distributed/tests/test_preload.py +++ b/distributed/tests/test_preload.py @@ -95,3 +95,22 @@ def check_worker(): finally: sys.path.remove(tmpdir) shutil.rmtree(tmpdir) + + +@pytest.mark.asyncio +async def test_preload_import_time(cleanup): + text = """ +from distributed.comm.registry import backends +from distributed.comm.tcp import TCPBackend + +backends["foo"] = TCPBackend() +""".strip() + try: + async with Scheduler(port=0, preload=text, protocol="foo") as s: + async with Nanny(s.address, preload=text, protocol="foo") as n: + async with Client(s.address, asynchronous=True) as c: + await c.wait_for_workers(1) + finally: + from distributed.comm.registry import backends + + del backends["foo"] diff --git a/distributed/worker.py b/distributed/worker.py index c6ae63f0ef1..ff781202393 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -34,7 +34,7 @@ from .diskutils import WorkSpace from .metrics import time from .node import ServerNode -from .preloading import preload_modules +from . import preloading from .proctitle import setproctitle from .protocol import pickle, to_serialize, deserialize_bytes, serialize_bytelist from .pubsub import PubSubWorkerExtension @@ -470,12 +470,7 @@ def __init__( self.total_resources = resources or {} self.available_resources = (resources or {}).copy() self.death_timeout = parse_timedelta(death_timeout) - self.preload = preload - if self.preload is None: - self.preload = dask.config.get("distributed.worker.preload") - self.preload_argv = preload_argv - if self.preload_argv is None: - self.preload_argv = dask.config.get("distributed.worker.preload-argv") + self.memory_monitor_interval = parse_timedelta( memory_monitor_interval, default="ms" ) @@ -504,6 +499,16 @@ def __init__( self._workdir = self._workspace.new_work_dir(prefix="worker-") self.local_directory = self._workdir.dir_path + self.preload = preload + if self.preload is None: + self.preload = dask.config.get("distributed.worker.preload") + self.preload_argv = preload_argv + if self.preload_argv is None: + self.preload_argv = dask.config.get("distributed.worker.preload-argv") + self._preload_modules = preloading.on_creation( + self.preload, file_dir=self.local_directory + ) + self.security = security or Security() assert isinstance(self.security, Security) self.connection_args = self.security.get_connection_args("worker") @@ -1023,12 +1028,10 @@ async def start(self): if self.name is None: self.name = self.address - preload_modules( - self.preload, - parameter=self, - file_dir=self.local_directory, - argv=self.preload_argv, + await preloading.on_start( + self._preload_modules, self, argv=self.preload_argv, ) + # Services listen on all addresses # Note Nanny is not a "real" service, just some metadata # passed in service_ports... @@ -1085,6 +1088,8 @@ async def close( logger.info("Closed worker has not yet started: %s", self.status) self.status = "closing" + await preloading.on_teardown(self._preload_modules, self) + if nanny and self.nanny: with self.rpc(self.nanny) as r: await r.close_gracefully() From e395c86b7bab29c5388f888d989b3ba6d069484b Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 6 Apr 2020 19:31:51 -0700 Subject: [PATCH 0776/1550] Pass through connection/listen_args as splatted keywords (#3674) Previously we would accept a specific keyword connection_args or listen_args for security information like `ssl_context`. Now we pass in these keywords directly, and let the listener handle them without intermediary. --- distributed/cli/tests/test_dask_scheduler.py | 6 +- distributed/client.py | 4 +- distributed/comm/core.py | 12 ++- distributed/comm/tests/test_comms.py | 54 +++++-------- distributed/comm/tests/test_ucx.py | 8 +- distributed/core.py | 15 ++-- distributed/deploy/tests/test_local.py | 4 +- distributed/diagnostics/progressbar.py | 3 +- distributed/nanny.py | 5 +- distributed/scheduler.py | 3 +- distributed/tests/test_core.py | 12 ++- distributed/tests/test_security.py | 36 ++++----- distributed/tests/test_tls_functional.py | 5 ++ distributed/utils_test.py | 81 +++++++++----------- distributed/worker.py | 11 ++- 15 files changed, 109 insertions(+), 150 deletions(-) diff --git a/distributed/cli/tests/test_dask_scheduler.py b/distributed/cli/tests/test_dask_scheduler.py index cb6cc306b6c..2206c173925 100644 --- a/distributed/cli/tests/test_dask_scheduler.py +++ b/distributed/cli/tests/test_dask_scheduler.py @@ -32,7 +32,9 @@ def test_defaults(loop): @gen.coroutine def f(): # Default behaviour is to listen on all addresses - yield [assert_can_connect_from_everywhere_4_6(8786, 5.0)] # main port + yield [ + assert_can_connect_from_everywhere_4_6(8786, timeout=5.0) + ] # main port with Client("127.0.0.1:%d" % Scheduler.default_port, loop=loop) as c: c.sync(f) @@ -50,7 +52,7 @@ def test_hostport(loop): def f(): yield [ # The scheduler's main port can't be contacted from the outside - assert_can_connect_locally_4(8978, 5.0) + assert_can_connect_locally_4(8978, timeout=5.0) ] with Client("127.0.0.1:8978", loop=loop) as c: diff --git a/distributed/client.py b/distributed/client.py index fc6098d5ee1..73cccc8a18a 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -1044,9 +1044,7 @@ async def _ensure_connected(self, timeout=None): try: comm = await connect( - self.scheduler.address, - timeout=timeout, - connection_args=self.connection_args, + self.scheduler.address, timeout=timeout, **self.connection_args, ) comm.name = "Client->Scheduler" if timeout is not None: diff --git a/distributed/comm/core.py b/distributed/comm/core.py index 6ef26568853..bfae9e8dcc0 100644 --- a/distributed/comm/core.py +++ b/distributed/comm/core.py @@ -184,7 +184,7 @@ def connect(self, address, deserialize=True): """ -async def connect(addr, timeout=None, deserialize=True, connection_args=None): +async def connect(addr, timeout=None, deserialize=True, **connection_args): """ Connect to the given address (a URI such as ``tcp://127.0.0.1:1234``) and yield a ``Comm`` object. If the connection attempt fails, it is @@ -221,7 +221,7 @@ def _raise(error): try: while deadline - time() > 0: future = connector.connect( - loc, deserialize=deserialize, **(connection_args or {}) + loc, deserialize=deserialize, **connection_args ) with ignoring(TimeoutError): comm = await asyncio.wait_for( @@ -247,7 +247,7 @@ def _raise(error): return comm -def listen(addr, handle_comm, deserialize=True, connection_args=None): +def listen(addr, handle_comm, deserialize=True, **kwargs): """ Create a listener object with the given parameters. When its ``start()`` method is called, the listener will listen on the given address @@ -259,7 +259,7 @@ def listen(addr, handle_comm, deserialize=True, connection_args=None): try: scheme, loc = parse_address(addr, strict=True) except ValueError: - if connection_args and connection_args.get("ssl_context"): + if kwargs.get("ssl_context"): addr = "tls://" + addr else: addr = "tcp://" + addr @@ -267,6 +267,4 @@ def listen(addr, handle_comm, deserialize=True, connection_args=None): backend = registry.get_backend(scheme) - return backend.get_listener( - loc, handle_comm, deserialize, **(connection_args or {}) - ) + return backend.get_listener(loc, handle_comm, deserialize, **kwargs) diff --git a/distributed/comm/tests/test_comms.py b/distributed/comm/tests/test_comms.py index 2e5602a9ac5..035a95513fb 100644 --- a/distributed/comm/tests/test_comms.py +++ b/distributed/comm/tests/test_comms.py @@ -77,19 +77,15 @@ def check_tls_extra(info): @pytest.mark.asyncio -async def get_comm_pair(listen_addr, listen_args=None, connect_args=None, **kwargs): +async def get_comm_pair(listen_addr, listen_args={}, connect_args={}, **kwargs): q = asyncio.Queue() async def handle_comm(comm): await q.put(comm) - listener = await listen( - listen_addr, handle_comm, connection_args=listen_args, **kwargs - ) + listener = await listen(listen_addr, handle_comm, **listen_args, **kwargs) - comm = await connect( - listener.contact_address, connection_args=connect_args, **kwargs - ) + comm = await connect(listener.contact_address, **connect_args, **kwargs) serv_comm = await q.get() return (comm, serv_comm) @@ -332,9 +328,7 @@ async def sleep_for_60ms(): sleep_future = sleep_for_60ms() with pytest.raises(IOError): await connect( - "tls://localhost:28400", - 0.052, - connection_args={"ssl_context": get_client_ssl_context()}, + "tls://localhost:28400", 0.052, ssl_context=get_client_ssl_context(), ) max_thread_count = await sleep_future assert max_thread_count <= 2 + original_thread_count @@ -441,8 +435,8 @@ async def check_client_server( addr, check_listen_addr=None, check_contact_addr=None, - listen_args=None, - connect_args=None, + listen_args={}, + connect_args={}, ): """ Abstract client / server test. @@ -466,7 +460,7 @@ async def handle_comm(comm): listen_args = listen_args or {"xxx": "bar"} connect_args = connect_args or {"xxx": "foo"} - listener = await listen(addr, handle_comm, connection_args=listen_args) + listener = await listen(addr, handle_comm, **listen_args) # Check listener properties bound_addr = listener.listen_address @@ -490,7 +484,7 @@ async def handle_comm(comm): l = [] async def client_communicate(key, delay=0): - comm = await connect(listener.contact_address, connection_args=connect_args) + comm = await connect(listener.contact_address, **connect_args) assert comm.peer_address == listener.contact_address await comm.write({"op": "ping", "data": key}) @@ -644,15 +638,11 @@ async def handle_comm(comm): await comm.close() # Listener refuses a connector not signed by the CA - listener = await listen( - "tls://", handle_comm, connection_args={"ssl_context": serv_ctx} - ) + listener = await listen("tls://", handle_comm, ssl_context=serv_ctx) with pytest.raises(EnvironmentError) as excinfo: comm = await connect( - listener.contact_address, - timeout=0.5, - connection_args={"ssl_context": bad_cli_ctx}, + listener.contact_address, timeout=0.5, ssl_context=bad_cli_ctx, ) await comm.write({"x": "foo"}) # TODO: why is this necessary in Tornado 6 ? @@ -670,21 +660,15 @@ async def handle_comm(comm): raise # Sanity check - comm = await connect( - listener.contact_address, timeout=2, connection_args={"ssl_context": cli_ctx} - ) + comm = await connect(listener.contact_address, timeout=2, ssl_context=cli_ctx,) await comm.close() # Connector refuses a listener not signed by the CA - listener = await listen( - "tls://", handle_comm, connection_args={"ssl_context": bad_serv_ctx} - ) + listener = await listen("tls://", handle_comm, ssl_context=bad_serv_ctx) with pytest.raises(EnvironmentError) as excinfo: await connect( - listener.contact_address, - timeout=2, - connection_args={"ssl_context": cli_ctx}, + listener.contact_address, timeout=2, ssl_context=cli_ctx, ) # The wrong error is reported on Python 2, see https://github.com/tornadoweb/tornado/pull/2028 if sys.version_info >= (3,): @@ -696,20 +680,18 @@ async def handle_comm(comm): # -async def check_comm_closed_implicit( - addr, delay=None, listen_args=None, connect_args=None -): +async def check_comm_closed_implicit(addr, delay=None, listen_args={}, connect_args={}): async def handle_comm(comm): await comm.close() - listener = await listen(addr, handle_comm, connection_args=listen_args) + listener = await listen(addr, handle_comm, **listen_args) contact_addr = listener.contact_address - comm = await connect(contact_addr, connection_args=connect_args) + comm = await connect(contact_addr, **connect_args) with pytest.raises(CommClosedError): await comm.write({}) - comm = await connect(contact_addr, connection_args=connect_args) + comm = await connect(contact_addr, **connect_args) with pytest.raises(CommClosedError): await comm.read() @@ -729,7 +711,7 @@ async def test_inproc_comm_closed_implicit(): await check_comm_closed_implicit(inproc.new_address()) -async def check_comm_closed_explicit(addr, listen_args=None, connect_args=None): +async def check_comm_closed_explicit(addr, listen_args={}, connect_args={}): a, b = await get_comm_pair(addr, listen_args=listen_args, connect_args=connect_args) a_read = a.read() b_read = b.read() diff --git a/distributed/comm/tests/test_ucx.py b/distributed/comm/tests/test_ucx.py index f61ea22128a..9ac97deeb7e 100644 --- a/distributed/comm/tests/test_ucx.py +++ b/distributed/comm/tests/test_ucx.py @@ -26,18 +26,16 @@ def test_registered(): async def get_comm_pair( - listen_addr="ucx://" + HOST, listen_args=None, connect_args=None, **kwargs + listen_addr="ucx://" + HOST, listen_args={}, connect_args={}, **kwargs ): q = asyncio.queues.Queue() async def handle_comm(comm): await q.put(comm) - listener = listen(listen_addr, handle_comm, connection_args=listen_args, **kwargs) + listener = listen(listen_addr, handle_comm, **listen_args, **kwargs) async with listener: - comm = await connect( - listener.contact_address, connection_args=connect_args, **kwargs - ) + comm = await connect(listener.contact_address, **connect_args, **kwargs) serv_comm = await q.get() return (comm, serv_comm) diff --git a/distributed/core.py b/distributed/core.py index 1bf3b172b68..dd5e18d0007 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -302,7 +302,7 @@ def port(self): def identity(self, comm=None): return {"type": type(self).__name__, "id": self.id} - async def listen(self, port_or_addr=None, listen_args=None): + async def listen(self, port_or_addr=None, **kwargs): if port_or_addr is None: port_or_addr = self.default_port if isinstance(port_or_addr, int): @@ -313,10 +313,7 @@ async def listen(self, port_or_addr=None, listen_args=None): addr = port_or_addr assert isinstance(addr, str) listener = await listen( - addr, - self.handle_comm, - deserialize=self.deserialize, - connection_args=listen_args, + addr, self.handle_comm, deserialize=self.deserialize, **kwargs, ) self.listeners.append(listener) @@ -606,7 +603,7 @@ def __init__( self.deserialize = deserialize self.serializers = serializers self.deserializers = deserializers if deserializers is not None else serializers - self.connection_args = connection_args + self.connection_args = connection_args or {} self._created = weakref.WeakSet() rpc.active.add(self) @@ -644,7 +641,7 @@ async def live_comm(self): self.address, self.timeout, deserialize=self.deserialize, - connection_args=self.connection_args, + **self.connection_args, ) comm.name = "rpc" self.comms[comm] = False # mark as taken @@ -832,7 +829,7 @@ def __init__( self.deserialize = deserialize self.serializers = serializers self.deserializers = deserializers if deserializers is not None else serializers - self.connection_args = connection_args + self.connection_args = connection_args or {} self.timeout = timeout self._n_connecting = 0 self.server = weakref.ref(server) if server else None @@ -905,7 +902,7 @@ async def connect(self, addr, timeout=None): addr, timeout=timeout or self.timeout, deserialize=self.deserialize, - connection_args=self.connection_args, + **self.connection_args, ) comm.name = "ConnectionPool" comm._pool = weakref.ref(self) diff --git a/distributed/deploy/tests/test_local.py b/distributed/deploy/tests/test_local.py index 8ca780a4eb2..403beb3aa41 100644 --- a/distributed/deploy/tests/test_local.py +++ b/distributed/deploy/tests/test_local.py @@ -734,9 +734,9 @@ def test_local_tls(loop, temporary): loop, assert_can_connect_from_everywhere_4, c.scheduler.port, - connection_args=c.security.get_connection_args("client"), protocol="tls", timeout=3, + **c.security.get_connection_args("client"), ) # If we connect to a TLS localculster without ssl information we should fail @@ -744,8 +744,8 @@ def test_local_tls(loop, temporary): loop, assert_cannot_connect, addr="tcp://127.0.0.1:%d" % c.scheduler.port, - connection_args=c.security.get_connection_args("client"), exception_class=RuntimeError, + **c.security.get_connection_args("client"), ) diff --git a/distributed/diagnostics/progressbar.py b/distributed/diagnostics/progressbar.py index 11da7a30d3d..4ef7254f52e 100644 --- a/distributed/diagnostics/progressbar.py +++ b/distributed/diagnostics/progressbar.py @@ -63,8 +63,7 @@ def function(scheduler, p): return result self.comm = await connect( - self.scheduler, - connection_args=self.client().connection_args if self.client else None, + self.scheduler, **(self.client().connection_args if self.client else {}), ) logger.debug("Progressbar Connected to scheduler") diff --git a/distributed/nanny.py b/distributed/nanny.py index baa77e3ce10..d3a4d2dc82b 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -95,7 +95,6 @@ def __init__( self.security = security or Security() assert isinstance(self.security, Security) self.connection_args = self.security.get_connection_args("worker") - self.listen_args = self.security.get_listen_args("worker") if scheduler_file: cfg = json_load_robust(scheduler_file) @@ -244,7 +243,9 @@ async def start(self): await super().start() - await self.listen(self._start_address, listen_args=self.listen_args) + await self.listen( + self._start_address, **self.security.get_listen_args("worker") + ) self.ip = get_address_host(self.address) logger.info(" Start Nanny at: %r", self.address) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 525170c5b14..55692b875ab 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1111,7 +1111,6 @@ def __init__( self.security = security or Security() assert isinstance(self.security, Security) self.connection_args = self.security.get_connection_args("scheduler") - self.listen_args = self.security.get_listen_args("scheduler") if dashboard_address is not None: try: @@ -1431,7 +1430,7 @@ async def start(self): if self.status != "running": for addr in self._start_address: - await self.listen(addr, listen_args=self.listen_args) + await self.listen(addr, **self.security.get_listen_args("scheduler")) self.ip = get_address_host(self.listen_address) listen_ip = self.ip diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index b4993bb4ce8..49033a6a11e 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -230,12 +230,10 @@ async def listen_on(cls, *args, **kwargs): sec = tls_security() async with listen_on( - Server, "tls://", listen_args=sec.get_listen_args("scheduler") + Server, "tls://", **sec.get_listen_args("scheduler") ) as server: assert server.address.startswith("tls://") - await assert_can_connect( - server.address, connection_args=sec.get_connection_args("client") - ) + await assert_can_connect(server.address, **sec.get_connection_args("client")) # InProc @@ -253,9 +251,9 @@ async def listen_on(cls, *args, **kwargs): await assert_cannot_connect(inproc_addr2) -async def check_rpc(listen_addr, rpc_addr=None, listen_args=None, connection_args=None): +async def check_rpc(listen_addr, rpc_addr=None, listen_args={}, connection_args={}): server = Server({"ping": pingpong}) - await server.listen(listen_addr, listen_args=listen_args) + await server.listen(listen_addr, **listen_args) if rpc_addr is None: rpc_addr = server.address @@ -603,7 +601,7 @@ async def ping(comm, delay=0.01): servers = [Server({"ping": ping}) for i in range(10)] for server in servers: - await server.listen("tls://", listen_args=listen_args) + await server.listen("tls://", **listen_args) rpc = await ConnectionPool(limit=5, connection_args=connection_args) diff --git a/distributed/tests/test_security.py b/distributed/tests/test_security.py index 002e63d2855..8665ebead33 100644 --- a/distributed/tests/test_security.py +++ b/distributed/tests/test_security.py @@ -281,10 +281,10 @@ async def handle_comm(comm): forced_cipher_sec = Security() async with listen( - "tls://", handle_comm, connection_args=sec.get_listen_args("scheduler") + "tls://", handle_comm, **sec.get_listen_args("scheduler") ) as listener: comm = await connect( - listener.contact_address, connection_args=sec.get_connection_args("worker") + listener.contact_address, **sec.get_connection_args("worker") ) msg = await comm.read() assert msg == "hello" @@ -293,14 +293,12 @@ async def handle_comm(comm): # No SSL context for client with pytest.raises(TypeError): await connect( - listener.contact_address, - connection_args=sec.get_connection_args("client"), + listener.contact_address, **sec.get_connection_args("client"), ) # Check forced cipher comm = await connect( - listener.contact_address, - connection_args=forced_cipher_sec.get_connection_args("worker"), + listener.contact_address, **forced_cipher_sec.get_connection_args("worker"), ) cipher, _, _ = comm.extra_info["cipher"] assert cipher in [FORCED_CIPHER] + TLS_13_CIPHERS @@ -331,20 +329,18 @@ async def handle_comm(comm): for listen_addr in ["inproc://", "tls://"]: async with listen( - listen_addr, handle_comm, connection_args=sec.get_listen_args("scheduler") + listen_addr, handle_comm, **sec.get_listen_args("scheduler") ) as listener: comm = await connect( - listener.contact_address, - connection_args=sec2.get_connection_args("worker"), + listener.contact_address, **sec2.get_connection_args("worker"), ) comm.abort() async with listen( - listen_addr, handle_comm, connection_args=sec2.get_listen_args("scheduler") + listen_addr, handle_comm, **sec2.get_listen_args("scheduler") ) as listener: comm = await connect( - listener.contact_address, - connection_args=sec2.get_connection_args("worker"), + listener.contact_address, **sec2.get_connection_args("worker"), ) comm.abort() @@ -356,25 +352,21 @@ def check_encryption_error(): for listen_addr in ["tcp://"]: async with listen( - listen_addr, handle_comm, connection_args=sec.get_listen_args("scheduler") + listen_addr, handle_comm, **sec.get_listen_args("scheduler") ) as listener: comm = await connect( - listener.contact_address, - connection_args=sec.get_connection_args("worker"), + listener.contact_address, **sec.get_connection_args("worker"), ) comm.abort() with pytest.raises(RuntimeError): await connect( - listener.contact_address, - connection_args=sec2.get_connection_args("worker"), + listener.contact_address, **sec2.get_connection_args("worker"), ) with pytest.raises(RuntimeError): listen( - listen_addr, - handle_comm, - connection_args=sec2.get_listen_args("scheduler"), + listen_addr, handle_comm, **sec2.get_listen_args("scheduler"), ) @@ -408,10 +400,10 @@ async def handle_comm(comm): sec = Security.temporary() async with listen( - "tls://", handle_comm, connection_args=sec.get_listen_args("scheduler") + "tls://", handle_comm, **sec.get_listen_args("scheduler") ) as listener: comm = await connect( - listener.contact_address, connection_args=sec.get_connection_args("worker") + listener.contact_address, **sec.get_connection_args("worker") ) msg = await comm.read() assert msg == "hello" diff --git a/distributed/tests/test_tls_functional.py b/distributed/tests/test_tls_functional.py index 6d0e64b54e5..3a2bebf790d 100644 --- a/distributed/tests/test_tls_functional.py +++ b/distributed/tests/test_tls_functional.py @@ -11,6 +11,11 @@ from distributed.utils_test import gen_tls_cluster, inc, double, slowinc, slowadd +@gen_tls_cluster(client=True) +def test_basic(c, s, a, b): + pass + + @gen_tls_cluster(client=True) def test_Queue(c, s, a, b): assert s.address.startswith("tls://") diff --git a/distributed/utils_test.py b/distributed/utils_test.py index b521826647a..983eaac48f5 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -798,7 +798,7 @@ async def start_cluster( security=security, port=0, host=scheduler_addr, - **scheduler_kwargs + **scheduler_kwargs, ) workers = [ Worker( @@ -809,7 +809,7 @@ async def start_cluster( loop=loop, validate=True, host=ncore[0], - **(merge(worker_kwargs, ncore[2]) if len(ncore) > 2 else worker_kwargs) + **(merge(worker_kwargs, ncore[2]) if len(ncore) > 2 else worker_kwargs), ) for i, ncore in enumerate(nthreads) ] @@ -917,7 +917,7 @@ async def coro(): loop=loop, security=security, asynchronous=True, - **client_kwargs + **client_kwargs, ) args = [c] + args try: @@ -1108,115 +1108,106 @@ def requires_ipv6(test_func): requires_ipv6 = pytest.mark.skip("ipv6 required") -async def assert_can_connect(addr, timeout=None, connection_args=None): +async def assert_can_connect(addr, timeout=0.5, **kwargs): """ Check that it is possible to connect to the distributed *addr* within the given *timeout*. """ - if timeout is None: - timeout = 0.5 - comm = await connect(addr, timeout=timeout, connection_args=connection_args) + comm = await connect(addr, timeout=timeout, **kwargs) comm.abort() async def assert_cannot_connect( - addr, timeout=None, connection_args=None, exception_class=EnvironmentError + addr, timeout=0.5, exception_class=EnvironmentError, **kwargs ): """ Check that it is impossible to connect to the distributed *addr* within the given *timeout*. """ - if timeout is None: - timeout = 0.5 with pytest.raises(exception_class): - comm = await connect(addr, timeout=timeout, connection_args=connection_args) + comm = await connect(addr, timeout=timeout, **kwargs) comm.abort() -async def assert_can_connect_from_everywhere_4_6( - port, timeout=None, connection_args=None, protocol="tcp" -): +async def assert_can_connect_from_everywhere_4_6(port, protocol="tcp", **kwargs): """ Check that the local *port* is reachable from all IPv4 and IPv6 addresses. """ - args = (timeout, connection_args) futures = [ - assert_can_connect("%s://127.0.0.1:%d" % (protocol, port), *args), - assert_can_connect("%s://%s:%d" % (protocol, get_ip(), port), *args), + assert_can_connect("%s://127.0.0.1:%d" % (protocol, port), **kwargs), + assert_can_connect("%s://%s:%d" % (protocol, get_ip(), port), **kwargs), ] if has_ipv6(): futures += [ - assert_can_connect("%s://[::1]:%d" % (protocol, port), *args), - assert_can_connect("%s://[%s]:%d" % (protocol, get_ipv6(), port), *args), + assert_can_connect("%s://[::1]:%d" % (protocol, port), **kwargs), + assert_can_connect("%s://[%s]:%d" % (protocol, get_ipv6(), port), **kwargs), ] await asyncio.gather(*futures) async def assert_can_connect_from_everywhere_4( - port, timeout=None, connection_args=None, protocol="tcp" + port, protocol="tcp", **kwargs, ): """ Check that the local *port* is reachable from all IPv4 addresses. """ - args = (timeout, connection_args) futures = [ - assert_can_connect("%s://127.0.0.1:%d" % (protocol, port), *args), - assert_can_connect("%s://%s:%d" % (protocol, get_ip(), port), *args), + assert_can_connect("%s://127.0.0.1:%d" % (protocol, port), **kwargs), + assert_can_connect("%s://%s:%d" % (protocol, get_ip(), port), **kwargs), ] if has_ipv6(): futures += [ - assert_cannot_connect("%s://[::1]:%d" % (protocol, port), *args), - assert_cannot_connect("%s://[%s]:%d" % (protocol, get_ipv6(), port), *args), + assert_cannot_connect("%s://[::1]:%d" % (protocol, port), **kwargs), + assert_cannot_connect( + "%s://[%s]:%d" % (protocol, get_ipv6(), port), **kwargs + ), ] await asyncio.gather(*futures) -async def assert_can_connect_locally_4(port, timeout=None, connection_args=None): +async def assert_can_connect_locally_4(port, **kwargs): """ Check that the local *port* is only reachable from local IPv4 addresses. """ - args = (timeout, connection_args) - futures = [assert_can_connect("tcp://127.0.0.1:%d" % port, *args)] + futures = [assert_can_connect("tcp://127.0.0.1:%d" % port, **kwargs)] if get_ip() != "127.0.0.1": # No outside IPv4 connectivity? - futures += [assert_cannot_connect("tcp://%s:%d" % (get_ip(), port), *args)] + futures += [assert_cannot_connect("tcp://%s:%d" % (get_ip(), port), **kwargs)] if has_ipv6(): futures += [ - assert_cannot_connect("tcp://[::1]:%d" % port, *args), - assert_cannot_connect("tcp://[%s]:%d" % (get_ipv6(), port), *args), + assert_cannot_connect("tcp://[::1]:%d" % port, **kwargs), + assert_cannot_connect("tcp://[%s]:%d" % (get_ipv6(), port), **kwargs), ] await asyncio.gather(*futures) -async def assert_can_connect_from_everywhere_6( - port, timeout=None, connection_args=None -): +async def assert_can_connect_from_everywhere_6(port, **kwargs): """ Check that the local *port* is reachable from all IPv6 addresses. """ assert has_ipv6() - args = (timeout, connection_args) futures = [ - assert_cannot_connect("tcp://127.0.0.1:%d" % port, *args), - assert_cannot_connect("tcp://%s:%d" % (get_ip(), port), *args), - assert_can_connect("tcp://[::1]:%d" % port, *args), - assert_can_connect("tcp://[%s]:%d" % (get_ipv6(), port), *args), + assert_cannot_connect("tcp://127.0.0.1:%d" % port, **kwargs), + assert_cannot_connect("tcp://%s:%d" % (get_ip(), port), **kwargs), + assert_can_connect("tcp://[::1]:%d" % port, **kwargs), + assert_can_connect("tcp://[%s]:%d" % (get_ipv6(), port), **kwargs), ] await asyncio.gather(*futures) -async def assert_can_connect_locally_6(port, timeout=None, connection_args=None): +async def assert_can_connect_locally_6(port, **kwargs): """ Check that the local *port* is only reachable from local IPv6 addresses. """ assert has_ipv6() - args = (timeout, connection_args) futures = [ - assert_cannot_connect("tcp://127.0.0.1:%d" % port, *args), - assert_cannot_connect("tcp://%s:%d" % (get_ip(), port), *args), - assert_can_connect("tcp://[::1]:%d" % port, *args), + assert_cannot_connect("tcp://127.0.0.1:%d" % port, **kwargs), + assert_cannot_connect("tcp://%s:%d" % (get_ip(), port), **kwargs), + assert_can_connect("tcp://[::1]:%d" % port, **kwargs), ] if get_ipv6() != "::1": # No outside IPv6 connectivity? - futures += [assert_cannot_connect("tcp://[%s]:%d" % (get_ipv6(), port), *args)] + futures += [ + assert_cannot_connect("tcp://[%s]:%d" % (get_ipv6(), port), **kwargs) + ] await asyncio.gather(*futures) diff --git a/distributed/worker.py b/distributed/worker.py index ff781202393..0c6c8b46ba5 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -512,7 +512,6 @@ def __init__( self.security = security or Security() assert isinstance(self.security, Security) self.connection_args = self.security.get_connection_args("worker") - self.listen_args = self.security.get_listen_args("worker") self.memory_limit = parse_memory_limit(memory_limit, self.nthreads) @@ -816,9 +815,7 @@ async def _register_with_scheduler(self): while True: try: _start = time() - comm = await connect( - self.scheduler.address, connection_args=self.connection_args - ) + comm = await connect(self.scheduler.address, **self.connection_args) comm.name = "Worker->Scheduler" comm._server = weakref.ref(self) await comm.write( @@ -1022,7 +1019,9 @@ async def start(self): enable_gc_diagnosis() thread_state.on_event_loop_thread = True - await self.listen(self._start_address, listen_args=self.listen_args) + await self.listen( + self._start_address, **self.security.get_listen_args("worker") + ) self.ip = get_address_host(self.address) if self.name is None: @@ -1185,7 +1184,7 @@ def send_to_worker(self, address, msg): async def batched_send_connect(): comm = await connect( - address, connection_args=self.connection_args # TODO, serialization + address, **self.connection_args # TODO, serialization ) comm.name = "Worker->Worker" await comm.write({"op": "connection_stream"}) From 4055fd7f94c49e15583753cb76befbabbbc37f5c Mon Sep 17 00:00:00 2001 From: Martin Durant Date: Tue, 7 Apr 2020 11:15:17 -0400 Subject: [PATCH 0777/1550] don't make task graphs too big (#3671) Set a maximum size for which task graphs are displayed in the dashboard --- distributed/dashboard/components/scheduler.py | 6 ++++ .../dashboard/tests/test_scheduler_bokeh.py | 30 +++++++++++++++++++ distributed/distributed.yaml | 1 + 3 files changed, 37 insertions(+) diff --git a/distributed/dashboard/components/scheduler.py b/distributed/dashboard/components/scheduler.py index c376e2098e4..6b7a77d9479 100644 --- a/distributed/dashboard/components/scheduler.py +++ b/distributed/dashboard/components/scheduler.py @@ -34,6 +34,7 @@ from bokeh.transform import factor_cmap, linear_cmap from bokeh.io import curdoc import dask +from dask import config from dask.utils import format_bytes, key_split from tlz import pipe from tlz.curried import map, concat, groupby @@ -1171,6 +1172,7 @@ def __init__(self, scheduler, **kwargs): tap = TapTool(callback=OpenURL(url="info/task/@key.html"), renderers=[rect]) rect.nonselection_glyph = None self.root.add_tools(hover, tap) + self.max_items = config.get("distributed.dashboard.graph-max-items", 5000) @without_property_validation def update(self): @@ -1206,6 +1208,10 @@ def add_new_nodes_edges(self, new, new_edges, update=False): y = self.layout.y tasks = self.scheduler.tasks + if len(tasks) > self.max_items: + # graph to big - no update, reset for next time + self.invisible_count = len(tasks) + return for key in new: try: task = tasks[key] diff --git a/distributed/dashboard/tests/test_scheduler_bokeh.py b/distributed/dashboard/tests/test_scheduler_bokeh.py index f36bfd897e1..65b5fa25d50 100644 --- a/distributed/dashboard/tests/test_scheduler_bokeh.py +++ b/distributed/dashboard/tests/test_scheduler_bokeh.py @@ -36,6 +36,7 @@ ProfileServer, MemoryByKey, ) +from distributed.utils_test import async_wait_for from distributed.dashboard import scheduler @@ -501,6 +502,35 @@ def test_TaskGraph_clear(c, s, a, b): assert time() < start + 5 +@gen_cluster( + client=True, config={"distributed.dashboard.graph-max-items": 2,}, +) +def test_TaskGraph_limit(c, s, a, b): + gp = TaskGraph(s) + + def func(x): + return x + + f1 = c.submit(func, 1) + yield wait(f1) + gp.update() + assert len(gp.node_source.data["x"]) == 1 + f2 = c.submit(func, 2) + yield wait(f2) + gp.update() + assert len(gp.node_source.data["x"]) == 2 + f3 = c.submit(func, 3) + yield wait(f3) + gp.update() + assert len(gp.node_source.data["x"]) == 2 + del f1 + del f2 + del f3 + _ = c.submit(func, 1) + + async_wait_for(lambda: len(gp.node_source.data["x"]) == 1, timeout=1) + + @gen_cluster(client=True, timeout=30) def test_TaskGraph_complex(c, s, a, b): da = pytest.importorskip("dask.array") diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index ca31b17776e..97ca6be3945 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -129,6 +129,7 @@ distributed: dashboard: link: "{scheme}://{host}:{port}/status" export-tool: False + graph-max-items: 5000 # maximum number of tasks to try to plot in graph view ################## # Administrative # From ffeaa97500d1ac0ea0bd0d38c95ca3a6f8c7e1dd Mon Sep 17 00:00:00 2001 From: Lucas Rademaker <44430780+lr4d@users.noreply.github.com> Date: Wed, 8 Apr 2020 11:19:30 +0200 Subject: [PATCH 0778/1550] Refactor semaphore internals: make `_get_lease` synchronous (#3679) Co-authored-by: lr4d --- distributed/semaphore.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/distributed/semaphore.py b/distributed/semaphore.py index 6f5553af0e8..2d506c8ed0e 100644 --- a/distributed/semaphore.py +++ b/distributed/semaphore.py @@ -83,11 +83,9 @@ def create(self, comm=None, name=None, max_leases=None): % (max_leases, self.max_leases[name]) ) - async def _get_lease(self, client, name, identifier): + def _get_lease(self, client, name, identifier): result = True if len(self.leases[name]) < self.max_leases[name]: - # naive: self.leases[resource] += 1 - # not naive: self.leases[name].append(identifier) self.leases_per_client[client][name].append(identifier) else: @@ -116,15 +114,7 @@ async def acquire( # is changed and helps to identify when it is worth to retry an acquire self.events[name].clear() - # If we hit the timeout, this cancels the _get_lease - future = asyncio.wait_for( - self._get_lease(client, name, identifier), timeout=w.leftover() - ) - - try: - result = await future - except TimeoutError: - result = False + result = self._get_lease(client, name, identifier) # If acquiring fails, we wait for the event to be set, i.e. something has # been released and we can try to acquire again (continue loop) From 2aa3ee7be1b11e4326d679d9d3c7bf906ea0623d Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 8 Apr 2020 12:08:30 -0700 Subject: [PATCH 0779/1550] Support preload modules in Nanny (#3678) Adds support for running preload scripts on the Nanny --- distributed/cli/dask_worker.py | 10 ++++++++ distributed/cli/tests/test_dask_scheduler.py | 25 ++++++++++++++++++++ distributed/distributed.yaml | 12 ++++++---- distributed/nanny.py | 24 +++++++++++++++++++ distributed/tests/test_preload.py | 9 ++++++- distributed/worker.py | 2 +- 6 files changed, 76 insertions(+), 6 deletions(-) diff --git a/distributed/cli/dask_worker.py b/distributed/cli/dask_worker.py index ff6d09b4c9a..efc330a4a24 100755 --- a/distributed/cli/dask_worker.py +++ b/distributed/cli/dask_worker.py @@ -216,6 +216,14 @@ @click.argument( "preload_argv", nargs=-1, type=click.UNPROCESSED, callback=validate_preload_argv ) +@click.option( + "--preload-nanny", + type=str, + multiple=True, + is_eager=True, + help="Module that should be loaded by each nanny " + 'like "foo.bar" or "/path/to/foo.py"', +) @click.version_option() def main( scheduler, @@ -240,6 +248,7 @@ def main( tls_key, dashboard_address, worker_class, + preload_nanny, **kwargs ): g0, g1, g2 = gc.get_threshold() # https://github.com/dask/distributed/issues/1653 @@ -349,6 +358,7 @@ def del_pid_file(): worker_class = import_term(worker_class) if nanny: kwargs["worker_class"] = worker_class + kwargs["preload_nanny"] = preload_nanny if nanny: kwargs.update({"worker_port": worker_port, "listen_address": listen_address}) diff --git a/distributed/cli/tests/test_dask_scheduler.py b/distributed/cli/tests/test_dask_scheduler.py index 2206c173925..62c79f8c0b8 100644 --- a/distributed/cli/tests/test_dask_scheduler.py +++ b/distributed/cli/tests/test_dask_scheduler.py @@ -394,3 +394,28 @@ def test_idle_timeout(loop): ) stop = time() assert 1 < stop - start < 10 + + +def test_multiple_workers(loop): + text = """ +def dask_setup(worker): + worker.foo = 'setup' +""" + with popen(["dask-scheduler", "--no-dashboard"]) as s: + with popen( + [ + "dask-worker", + "localhost:8786", + "--no-dashboard", + "--preload", + text, + "--preload-nanny", + text, + ] + ) as a: + with Client("127.0.0.1:8786", loop=loop) as c: + c.wait_for_workers(1) + [foo] = c.run(lambda dask_worker: dask_worker.foo).values() + assert foo == "setup" + [foo] = c.run(lambda dask_worker: dask_worker.foo, nanny=True).values() + assert foo == "setup" diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index 97ca6be3945..bd1dd85c57e 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -22,8 +22,8 @@ distributed: work-stealing-interval: 100ms # Callback time for work stealing worker-ttl: null # like '60s'. Time to live for workers. They must heartbeat faster than this pickle: True # Is the scheduler allowed to deserialize arbitrary bytestrings - preload: [] - preload-argv: [] + preload: [] # Run custom modules with Scheduler + preload-argv: [] # See https://docs.dask.org/en/latest/setup/custom-startup.html unknown-task-duration: 500ms # Default duration for all tasks with unknown durations ("15m", "2h") default-task-durations: # How long we expect function names to run ("1h", "1s") (helps for long tasks) rechunk-split: 1us @@ -48,8 +48,8 @@ distributed: connections: # Maximum concurrent connections for data outgoing: 50 # This helps to control network saturation incoming: 10 - preload: [] - preload-argv: [] + preload: [] # Run custom modules with Worker + preload-argv: [] # See https://docs.dask.org/en/latest/setup/custom-startup.html daemon: True validate: False # Check worker state at every step for debugging lifetime: @@ -71,6 +71,10 @@ distributed: pause: 0.80 # fraction at which we pause worker threads terminate: 0.95 # fraction at which we terminate the worker + nanny: + preload: [] # Run custom modules with Nanny + preload-argv: [] # See https://docs.dask.org/en/latest/setup/custom-startup.html + client: heartbeat: 5s # time between client heartbeats diff --git a/distributed/nanny.py b/distributed/nanny.py index d3a4d2dc82b..3f7c20f98f9 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -19,6 +19,7 @@ from .core import RPCClosed, CommClosedError, coerce_to_address from .metrics import time from .node import ServerNode +from . import preloading from .process import AsyncProcess from .proctitle import enable_proctitle_on_children from .security import Security @@ -78,6 +79,8 @@ def __init__( death_timeout=None, preload=None, preload_argv=None, + preload_nanny=None, + preload_nanny_argv=None, security=None, contact_address=None, listen_address=None, @@ -121,12 +124,21 @@ def __init__( self.validate = validate self.resources = resources self.death_timeout = parse_timedelta(death_timeout) + self.preload = preload if self.preload is None: self.preload = dask.config.get("distributed.worker.preload") self.preload_argv = preload_argv if self.preload_argv is None: self.preload_argv = dask.config.get("distributed.worker.preload-argv") + + self.preload_nanny = preload_nanny + if self.preload_nanny is None: + self.preload_nanny = dask.config.get("distributed.nanny.preload") + self.preload_nanny_argv = preload_nanny_argv + if self.preload_nanny_argv is None: + self.preload_nanny_argv = dask.config.get("distributed.nanny.preload-argv") + self.Worker = Worker if worker_class is None else worker_class self.env = env or {} self.config = config or {} @@ -157,6 +169,10 @@ def __init__( self.local_directory = local_directory + self._preload_modules = preloading.on_creation( + self.preload_nanny, file_dir=self.local_directory + ) + self.services = services self.name = name self.quiet = quiet @@ -248,6 +264,10 @@ async def start(self): ) self.ip = get_address_host(self.address) + await preloading.on_start( + self._preload_modules, self, argv=self.preload_nanny_argv, + ) + logger.info(" Start Nanny at: %r", self.address) response = await self.instantiate() if response == "running": @@ -445,6 +465,9 @@ async def close(self, comm=None, timeout=5, report=None): self.status = "closing" logger.info("Closing Nanny at %r", self.address) + + await preloading.on_teardown(self._preload_modules, self) + self.stop() try: if self.process is not None: @@ -519,6 +542,7 @@ async def start(self): self.running = asyncio.Event() self.stopped = asyncio.Event() self.status = "starting" + try: await self.process.start() except OSError: diff --git a/distributed/tests/test_preload.py b/distributed/tests/test_preload.py index 888e7c42ea2..4f60ca586f9 100644 --- a/distributed/tests/test_preload.py +++ b/distributed/tests/test_preload.py @@ -62,13 +62,20 @@ async def test_worker_preload_config(cleanup): text = """ def dask_setup(worker): worker.foo = 'setup' + +def dask_teardown(worker): + worker.foo = 'teardown' """ - with dask.config.set({"distributed.worker.preload": text}): + with dask.config.set( + {"distributed.worker.preload": text, "distributed.nanny.preload": text,} + ): async with Scheduler(port=0) as s: async with Nanny(s.address) as w: + assert w.foo == "setup" async with Client(s.address, asynchronous=True) as c: d = await c.run(lambda dask_worker: dask_worker.foo) assert d == {w.worker_address: "setup"} + assert w.foo == "teardown" def test_worker_preload_module(loop): diff --git a/distributed/worker.py b/distributed/worker.py index 0c6c8b46ba5..f2f832f6d83 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -452,7 +452,7 @@ def __init__( # Target interface on which we contact the scheduler by default # TODO: it is unfortunate that we special-case inproc here if not host and not interface and not scheduler_addr.startswith("inproc://"): - host = get_ip(get_address_host(scheduler_addr)) + host = get_ip(get_address_host(scheduler_addr.split("://")[-1])) self._start_address = address_from_user_args( host=host, From 8116a266109cab5bcdb8d950f72235e9e050a7d5 Mon Sep 17 00:00:00 2001 From: Abdulelah Bin Mahfoodh Date: Thu, 9 Apr 2020 19:45:42 +0300 Subject: [PATCH 0780/1550] Fix dask-ssh after removing local-directory from dask_scheduler cli (#3684) * Fix dask-ssh after removing local-directory keyword from dask_scheduler * black changes * Add a test to dask-ssh with local directory parameter --- distributed/cli/dask_ssh.py | 4 +--- distributed/deploy/old_ssh.py | 15 +-------------- distributed/deploy/tests/test_ssh.py | 18 ++++++++++++++++++ 3 files changed, 20 insertions(+), 17 deletions(-) diff --git a/distributed/cli/dask_ssh.py b/distributed/cli/dask_ssh.py index eb09f49cfed..7674632807c 100755 --- a/distributed/cli/dask_ssh.py +++ b/distributed/cli/dask_ssh.py @@ -81,9 +81,7 @@ "--local-directory", default=None, type=click.Path(exists=True), - help=( - "Directory to use on all cluster nodes to place workers " "and scheduler files." - ), + help=("Directory to use on all cluster nodes to place workers files."), ) @click.option( "--remote-python", default=None, type=str, help="Path to Python on remote nodes." diff --git a/distributed/deploy/old_ssh.py b/distributed/deploy/old_ssh.py index 33e69772f9b..648d7b80905 100644 --- a/distributed/deploy/old_ssh.py +++ b/distributed/deploy/old_ssh.py @@ -209,24 +209,12 @@ def communicate(): def start_scheduler( - logdir, - addr, - port, - ssh_username, - ssh_port, - ssh_private_key, - remote_python=None, - local_directory=None, + logdir, addr, port, ssh_username, ssh_port, ssh_private_key, remote_python=None, ): cmd = "{python} -m distributed.cli.dask_scheduler --port {port}".format( python=remote_python or sys.executable, port=port, logdir=logdir ) - if local_directory is not None: - cmd += " --local-directory {local_directory}".format( - local_directory=local_directory - ) - # Optionally re-direct stdout and stderr to a logfile if logdir is not None: cmd = "mkdir -p {logdir} && ".format(logdir=logdir) + cmd @@ -422,7 +410,6 @@ def __init__( ssh_port, ssh_private_key, remote_python, - local_directory, ) # Start worker nodes diff --git a/distributed/deploy/tests/test_ssh.py b/distributed/deploy/tests/test_ssh.py index af6bf1566f2..11885dd8612 100644 --- a/distributed/deploy/tests/test_ssh.py +++ b/distributed/deploy/tests/test_ssh.py @@ -63,6 +63,24 @@ def test_defer_to_old(loop): assert isinstance(c, OldSSHCluster) +@pytest.mark.avoid_travis +def test_old_ssh_wih_local_dir(loop): + with pytest.warns(Warning): + from distributed.deploy.old_ssh import SSHCluster as OldSSHCluster + + with OldSSHCluster( + scheduler_addr="127.0.0.1", + scheduler_port=7437, + worker_addrs=["127.0.0.1", "127.0.0.1"], + local_directory="/tmp", + ) as c: + assert len(c.workers) == 2 + with Client(c) as client: + result = client.submit(lambda x: x + 1, 10) + result = result.result() + assert result == 11 + + @pytest.mark.asyncio async def test_config_inherited_by_subprocess(loop): def f(x): From c77618c0e62b2f7be70cca5da747d2fcf1ce4f98 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 9 Apr 2020 13:35:06 -0700 Subject: [PATCH 0781/1550] Replace Bokeh Server with Tornado HTTPServer (#3658) This creates a vanilla Tornado HTTPServer on the Worker and Schedulers, and then optionally attaches a Bokeh application to that server. This lets us always have an HTTP Server running, even if we aren't running the bokeh dashboard. This is helpful for example for prometheus and health checks, and to allow for other uses of the HTTP Server in the future. I hope that it also makes it easier to configure HTTPS globally across the project. The info, json, proxy, health, and prometheus routes have been moved out of the dashboard directory into a new distributed/http directory. --- distributed/cli/dask_scheduler.py | 7 +- distributed/cli/dask_worker.py | 4 +- distributed/cli/tests/test_dask_scheduler.py | 12 +- distributed/client.py | 2 +- distributed/comm/__init__.py | 1 + distributed/dashboard/__init__.py | 2 - distributed/dashboard/components/scheduler.py | 4 +- distributed/dashboard/components/worker.py | 7 +- distributed/dashboard/core.py | 114 ++--- distributed/dashboard/scheduler.py | 433 +----------------- .../dashboard/tests/test_scheduler_bokeh.py | 37 +- .../dashboard/tests/test_worker_bokeh.py | 38 +- distributed/dashboard/utils.py | 20 - distributed/dashboard/worker.py | 188 +------- distributed/deploy/cluster.py | 2 +- distributed/deploy/local.py | 27 +- distributed/deploy/spec.py | 8 +- distributed/deploy/tests/test_local.py | 37 +- distributed/deploy/tests/test_spec_cluster.py | 8 +- distributed/distributed.yaml | 19 + distributed/http/__init__.py | 1 + distributed/http/health.py | 12 + distributed/{dashboard => http}/proxy.py | 13 +- distributed/http/routing.py | 22 + distributed/http/scheduler/__init__.py | 0 distributed/http/scheduler/info.py | 203 ++++++++ distributed/http/scheduler/json.py | 72 +++ distributed/http/scheduler/prometheus.py | 99 ++++ .../scheduler/tests/test_scheduler_http.py} | 94 ++-- .../{dashboard => http}/static/css/base.css | 0 .../static/css/individual-cluster-map.css | 0 .../{dashboard => http}/static/css/status.css | 0 .../{dashboard => http}/static/css/system.css | 0 .../static/images/dask-logo.svg | 0 .../static/images/fa-bars.svg | 0 .../static/images/favicon.ico | Bin .../static/individual-cluster-map.html | 0 .../static/js/anime.min.js | 0 .../static/js/individual-cluster-map.js | 0 .../static/js/reconnecting-websocket.min.js | 0 distributed/http/statics.py | 10 + .../{dashboard => http}/templates/base.html | 0 .../templates/call-stack.html | 0 .../templates/json-index.html | 0 .../{dashboard => http}/templates/logs.html | 0 .../{dashboard => http}/templates/main.html | 0 .../{dashboard => http}/templates/simple.html | 0 .../{dashboard => http}/templates/status.html | 0 .../{dashboard => http}/templates/system.html | 0 .../{dashboard => http}/templates/task.html | 0 .../templates/worker-table.html | 0 .../{dashboard => http}/templates/worker.html | 0 .../templates/workers.html | 0 distributed/http/tests/__init__.py | 0 distributed/http/tests/test_core.py | 11 + distributed/http/tests/test_routing.py | 38 ++ distributed/http/utils.py | 51 +++ distributed/http/worker/__init__.py | 0 distributed/http/worker/prometheus.py | 98 ++++ .../worker/tests/test_worker_http.py} | 11 +- distributed/node.py | 57 ++- distributed/scheduler.py | 41 +- distributed/tests/test_client.py | 7 +- distributed/tests/test_scheduler.py | 15 +- distributed/tests/test_worker.py | 133 +++--- distributed/utils.py | 42 ++ distributed/utils_test.py | 4 +- distributed/worker.py | 22 +- setup.py | 2 +- 69 files changed, 1055 insertions(+), 973 deletions(-) create mode 100644 distributed/http/__init__.py create mode 100644 distributed/http/health.py rename distributed/{dashboard => http}/proxy.py (93%) create mode 100644 distributed/http/routing.py create mode 100644 distributed/http/scheduler/__init__.py create mode 100644 distributed/http/scheduler/info.py create mode 100644 distributed/http/scheduler/json.py create mode 100644 distributed/http/scheduler/prometheus.py rename distributed/{dashboard/tests/test_scheduler_bokeh_html.py => http/scheduler/tests/test_scheduler_http.py} (72%) rename distributed/{dashboard => http}/static/css/base.css (100%) rename distributed/{dashboard => http}/static/css/individual-cluster-map.css (100%) rename distributed/{dashboard => http}/static/css/status.css (100%) rename distributed/{dashboard => http}/static/css/system.css (100%) rename distributed/{dashboard => http}/static/images/dask-logo.svg (100%) rename distributed/{dashboard => http}/static/images/fa-bars.svg (100%) rename distributed/{dashboard => http}/static/images/favicon.ico (100%) rename distributed/{dashboard => http}/static/individual-cluster-map.html (100%) rename distributed/{dashboard => http}/static/js/anime.min.js (100%) rename distributed/{dashboard => http}/static/js/individual-cluster-map.js (100%) rename distributed/{dashboard => http}/static/js/reconnecting-websocket.min.js (100%) create mode 100644 distributed/http/statics.py rename distributed/{dashboard => http}/templates/base.html (100%) rename distributed/{dashboard => http}/templates/call-stack.html (100%) rename distributed/{dashboard => http}/templates/json-index.html (100%) rename distributed/{dashboard => http}/templates/logs.html (100%) rename distributed/{dashboard => http}/templates/main.html (100%) rename distributed/{dashboard => http}/templates/simple.html (100%) rename distributed/{dashboard => http}/templates/status.html (100%) rename distributed/{dashboard => http}/templates/system.html (100%) rename distributed/{dashboard => http}/templates/task.html (100%) rename distributed/{dashboard => http}/templates/worker-table.html (100%) rename distributed/{dashboard => http}/templates/worker.html (100%) rename distributed/{dashboard => http}/templates/workers.html (100%) create mode 100644 distributed/http/tests/__init__.py create mode 100644 distributed/http/tests/test_core.py create mode 100644 distributed/http/tests/test_routing.py create mode 100644 distributed/http/utils.py create mode 100644 distributed/http/worker/__init__.py create mode 100644 distributed/http/worker/prometheus.py rename distributed/{dashboard/tests/test_worker_bokeh_html.py => http/worker/tests/test_worker_http.py} (73%) diff --git a/distributed/cli/dask_scheduler.py b/distributed/cli/dask_scheduler.py index 78d6623608f..1eeb1e2715f 100755 --- a/distributed/cli/dask_scheduler.py +++ b/distributed/cli/dask_scheduler.py @@ -83,7 +83,7 @@ ) @click.option("--show/--no-show", default=False, help="Show web UI [default: --show]") @click.option( - "--dashboard-prefix", type=str, default=None, help="Prefix for the dashboard app" + "--dashboard-prefix", type=str, default="", help="Prefix for the dashboard app" ) @click.option( "--use-xheaders", @@ -202,8 +202,9 @@ def del_pid_file(): security=sec, host=host, port=port, - dashboard_address=dashboard_address if dashboard else None, - service_kwargs={"dashboard": {"prefix": dashboard_prefix}}, + dashboard=dashboard, + dashboard_address=dashboard_address, + http_prefix=dashboard_prefix, **kwargs ) logger.info("-" * 47) diff --git a/distributed/cli/dask_worker.py b/distributed/cli/dask_worker.py index efc330a4a24..9d73f7af5b2 100755 --- a/distributed/cli/dask_worker.py +++ b/distributed/cli/dask_worker.py @@ -397,8 +397,8 @@ def del_pid_file(): contact_address=contact_address, host=host, port=port, - dashboard_address=dashboard_address if dashboard else None, - service_kwargs={"dashboard": {"prefix": dashboard_prefix}}, + dashboard=dashboard, + dashboard_address=dashboard_address, name=name if nprocs == 1 or name is None or name == "" else str(name) + "-" + str(i), diff --git a/distributed/cli/tests/test_dask_scheduler.py b/distributed/cli/tests/test_dask_scheduler.py index 62c79f8c0b8..3e867b1f377 100644 --- a/distributed/cli/tests/test_dask_scheduler.py +++ b/distributed/cli/tests/test_dask_scheduler.py @@ -39,8 +39,9 @@ def f(): with Client("127.0.0.1:%d" % Scheduler.default_port, loop=loop) as c: c.sync(f) - with pytest.raises(Exception): - requests.get("http://127.0.0.1:8787/status/") + response = requests.get("http://127.0.0.1:8787/status/") + assert response.status_code == 404 + with pytest.raises(Exception): response = requests.get("http://127.0.0.1:9786/info.json") @@ -64,11 +65,8 @@ def test_no_dashboard(loop): pytest.importorskip("bokeh") with popen(["dask-scheduler", "--no-dashboard"]) as proc: with Client("127.0.0.1:%d" % Scheduler.default_port, loop=loop) as c: - for i in range(3): - line = proc.stderr.readline() - assert b"dashboard" not in line.lower() - with pytest.raises(Exception): - requests.get("http://127.0.0.1:8787/status/") + response = requests.get("http://127.0.0.1:8787/status/") + assert response.status_code == 404 def test_dashboard(loop): diff --git a/distributed/client.py b/distributed/client.py index 73cccc8a18a..6545e938511 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -843,7 +843,7 @@ def _repr_html_(self): if info and "dashboard" in info["services"]: text += ( - "
      • Dashboard: %(web)s\n" + "
      • Dashboard: %(web)s
      • \n" % {"web": self.dashboard_link} ) diff --git a/distributed/comm/__init__.py b/distributed/comm/__init__.py index 3537b301573..2ff679ada3d 100644 --- a/distributed/comm/__init__.py +++ b/distributed/comm/__init__.py @@ -10,6 +10,7 @@ get_local_address_for, ) from .core import connect, listen, Comm, CommClosedError +from .utils import get_tcp_server_address def _register_transports(): diff --git a/distributed/dashboard/__init__.py b/distributed/dashboard/__init__.py index 675963b1463..e69de29bb2d 100644 --- a/distributed/dashboard/__init__.py +++ b/distributed/dashboard/__init__.py @@ -1,2 +0,0 @@ -from .scheduler import BokehScheduler -from .worker import BokehWorker diff --git a/distributed/dashboard/components/scheduler.py b/distributed/dashboard/components/scheduler.py index 6b7a77d9479..26e60c55bce 100644 --- a/distributed/dashboard/components/scheduler.py +++ b/distributed/dashboard/components/scheduler.py @@ -75,7 +75,9 @@ from jinja2 import Environment, FileSystemLoader env = Environment( - loader=FileSystemLoader(os.path.join(os.path.dirname(__file__), "..", "templates")) + loader=FileSystemLoader( + os.path.join(os.path.dirname(__file__), "..", "..", "http", "templates") + ) ) BOKEH_THEME = Theme(os.path.join(os.path.dirname(__file__), "..", "theme.yaml")) diff --git a/distributed/dashboard/components/worker.py b/distributed/dashboard/components/worker.py index a11d3047838..a6feb3911e1 100644 --- a/distributed/dashboard/components/worker.py +++ b/distributed/dashboard/components/worker.py @@ -37,13 +37,12 @@ logger = logging.getLogger(__name__) -with open(os.path.join(os.path.dirname(__file__), "..", "templates", "base.html")) as f: - template_source = f.read() - from jinja2 import Environment, FileSystemLoader env = Environment( - loader=FileSystemLoader(os.path.join(os.path.dirname(__file__), "..", "templates")) + loader=FileSystemLoader( + os.path.join(os.path.dirname(__file__), "..", "..", "http", "templates") + ) ) BOKEH_THEME = Theme(os.path.join(os.path.dirname(__file__), "..", "theme.yaml")) diff --git a/distributed/dashboard/core.py b/distributed/dashboard/core.py index 9b919917a67..6843b0659b3 100644 --- a/distributed/dashboard/core.py +++ b/distributed/dashboard/core.py @@ -1,11 +1,14 @@ from distutils.version import LooseVersion -import os +import functools import warnings import bokeh -from bokeh.server.server import Server -from tornado import web -from urllib.parse import urljoin +from bokeh.server.server import BokehTornado +from bokeh.server.util import create_hosts_whitelist +from bokeh.application.handlers.function import FunctionHandler +from bokeh.application import Application +import dask +import toolz if LooseVersion(bokeh.__version__) < LooseVersion("0.13.0"): @@ -16,87 +19,28 @@ raise ImportError("Dask needs bokeh >= 0.13.0") -class BokehServer: - server_kwargs = {} +def BokehApplication(applications, server, prefix="/", template_variables={}): + prefix = prefix or "" + prefix = "/" + prefix.strip("/") + if not prefix.endswith("/"): + prefix = prefix + "/" - def listen(self, addr): - if self.server: - return - if isinstance(addr, tuple): - ip, port = addr - else: - port = addr - ip = None - for i in range(5): - try: - server_kwargs = dict( - port=port, - address=ip, - check_unused_sessions_milliseconds=500, - allow_websocket_origin=["*"], - use_index=False, - extra_patterns=[ - ( - r"/", - web.RedirectHandler, - {"url": urljoin(self.prefix.rstrip("/") + "/", r"status")}, - ) - ], - ) - server_kwargs.update(self.server_kwargs) - self.server = Server(self.apps, **server_kwargs) - self.server.start() + extra = toolz.merge({"prefix": prefix}, template_variables) - handlers = [ - ( - self.prefix + r"/statics/(.*)", - web.StaticFileHandler, - {"path": os.path.join(os.path.dirname(__file__), "static")}, - ) - ] - - self.server._tornado.add_handlers(r".*", handlers) - - return - except (SystemExit, EnvironmentError) as exc: - if port != 0: - if "already in use" in str( - exc - ) or "Only one usage of" in str( # Unix/Mac - exc - ): # Windows - msg = ( - "Port %d is already in use. " - "\nPerhaps you already have a cluster running?" - "\nHosting the diagnostics dashboard on a random port instead." - % port - ) - else: - msg = ( - "Failed to start diagnostics server on port %d. " % port - + str(exc) - ) - warnings.warn("\n" + msg) - port = 0 - if i == 4: - raise - - @property - def port(self): - return ( - self.server.port - or list(self.server._http._sockets.values())[0].getsockname()[1] - ) - - def stop(self): - for context in self.server._tornado._applications.values(): - context.run_unload_hook() - - self.server._tornado._stats_job.stop() - self.server._tornado._cleanup_job.stop() - if self.server._tornado._ping_job is not None: - self.server._tornado._ping_job.stop() + apps = { + prefix + k.lstrip("/"): functools.partial(v, server, extra) + for k, v in applications.items() + } + apps = {k: Application(FunctionHandler(v)) for k, v in apps.items()} + kwargs = dask.config.get("distributed.scheduler.dashboard.bokeh-application").copy() + extra_websocket_origins = create_hosts_whitelist( + kwargs.pop("allow_websocket_origin"), server.http_server.port + ) - # https://github.com/bokeh/bokeh/issues/5494 - if LooseVersion(bokeh.__version__) >= "0.12.4": - self.server.stop() + application = BokehTornado( + apps, + use_index=False, + extra_websocket_origins=extra_websocket_origins, + **kwargs, + ) + return application diff --git a/distributed/dashboard/scheduler.py b/distributed/dashboard/scheduler.py index 982bc424826..825195ecefa 100644 --- a/distributed/dashboard/scheduler.py +++ b/distributed/dashboard/scheduler.py @@ -1,23 +1,14 @@ -from datetime import datetime -from functools import partial -import os -import os.path -import json -import logging +from urllib.parse import urljoin -import dask -from dask.utils import format_bytes - -from tlz import merge, merge_with - -from tornado import escape -from tornado.websocket import WebSocketHandler +from tornado.ioloop import IOLoop +from tornado import web try: import numpy as np except ImportError: np = False +from .core import BokehApplication from .components.worker import counters_doc from .components.scheduler import ( systemmonitor_doc, @@ -42,419 +33,31 @@ individual_bandwidth_workers_doc, individual_memory_by_key_doc, ) -from .core import BokehServer from .worker import counters_doc -from .proxy import GlobalProxyHandler -from .utils import RequestHandler, redirect -from ..diagnostics.websocket import WebsocketPlugin -from ..metrics import time -from ..utils import log_errors, format_time -from ..scheduler import ALL_TASK_STATES - - -ns = { - func.__name__: func - for func in [format_bytes, format_time, datetime.fromtimestamp, time] -} - -rel_path_statics = {"rel_path_statics": "../../"} -logger = logging.getLogger(__name__) - template_variables = { "pages": ["status", "workers", "tasks", "system", "profile", "graph", "info"] } -class Workers(RequestHandler): - def get(self): - with log_errors(): - self.render( - "workers.html", - title="Workers", - scheduler=self.server, - **merge(self.server.__dict__, ns, self.extra, rel_path_statics), - ) - - -class Worker(RequestHandler): - def get(self, worker): - worker = escape.url_unescape(worker) - if worker not in self.server.workers: - self.send_error(404) - return - with log_errors(): - self.render( - "worker.html", - title="Worker: " + worker, - scheduler=self.server, - Worker=worker, - **merge(self.server.__dict__, ns, self.extra, rel_path_statics), - ) - - -class Task(RequestHandler): - def get(self, task): - task = escape.url_unescape(task) - if task not in self.server.tasks: - self.send_error(404) - return - with log_errors(): - self.render( - "task.html", - title="Task: " + task, - Task=task, - scheduler=self.server, - **merge(self.server.__dict__, ns, self.extra, rel_path_statics), - ) - - -class Logs(RequestHandler): - def get(self): - with log_errors(): - logs = self.server.get_logs() - self.render( - "logs.html", - title="Logs", - logs=logs, - **merge(self.extra, rel_path_statics), - ) - - -class WorkerLogs(RequestHandler): - async def get(self, worker): - with log_errors(): - worker = escape.url_unescape(worker) - logs = await self.server.get_worker_logs(workers=[worker]) - logs = logs[worker] - self.render( - "logs.html", - title="Logs: " + worker, - logs=logs, - **merge(self.extra, rel_path_statics), - ) - - -class WorkerCallStacks(RequestHandler): - async def get(self, worker): - with log_errors(): - worker = escape.url_unescape(worker) - keys = self.server.processing[worker] - call_stack = await self.server.get_call_stack(keys=keys) - self.render( - "call-stack.html", - title="Call Stacks: " + worker, - call_stack=call_stack, - **merge(self.extra, rel_path_statics), - ) - - -class TaskCallStack(RequestHandler): - async def get(self, key): - with log_errors(): - key = escape.url_unescape(key) - call_stack = await self.server.get_call_stack(keys=[key]) - if not call_stack: - self.write( - "

        Task not actively running. " - "It may be finished or not yet started

        " - ) - else: - self.render( - "call-stack.html", - title="Call Stack: " + key, - call_stack=call_stack, - **merge(self.extra, rel_path_statics), - ) - - -class CountsJSON(RequestHandler): - def get(self): - scheduler = self.server - erred = 0 - nbytes = 0 - nthreads = 0 - memory = 0 - processing = 0 - released = 0 - waiting = 0 - waiting_data = 0 - desired_workers = scheduler.adaptive_target() - - for ts in scheduler.tasks.values(): - if ts.exception_blame is not None: - erred += 1 - elif ts.state == "released": - released += 1 - if ts.waiting_on: - waiting += 1 - if ts.waiters: - waiting_data += 1 - for ws in scheduler.workers.values(): - nthreads += ws.nthreads - memory += len(ws.has_what) - nbytes += ws.nbytes - processing += len(ws.processing) - - response = { - "bytes": nbytes, - "clients": len(scheduler.clients), - "cores": nthreads, - "erred": erred, - "hosts": len(scheduler.host_info), - "idle": len(scheduler.idle), - "memory": memory, - "processing": processing, - "released": released, - "saturated": len(scheduler.saturated), - "tasks": len(scheduler.tasks), - "unrunnable": len(scheduler.unrunnable), - "waiting": waiting, - "waiting_data": waiting_data, - "workers": len(scheduler.workers), - "desired_workers": desired_workers, - } - self.write(response) - - -class IdentityJSON(RequestHandler): - def get(self): - self.write(self.server.identity()) - - -class IndexJSON(RequestHandler): - def get(self): - with log_errors(): - r = [url for url, _ in routes if url.endswith(".json")] - self.render( - "json-index.html", routes=r, title="Index of JSON routes", **self.extra - ) - - -class IndividualPlots(RequestHandler): - def get(self): - bokeh_server = self.server.services["dashboard"] - individual_bokeh = { - uri.strip("/").replace("-", " ").title(): uri - for uri in bokeh_server.apps - if uri.lstrip("/").startswith("individual-") and not uri.endswith(".json") - } - individual_static = { - uri.strip("/").replace(".html", "").replace("-", " ").title(): "/statics/" - + uri - for uri in os.listdir(os.path.join(os.path.dirname(__file__), "static")) - if uri.lstrip("/").startswith("individual-") and uri.endswith(".html") - } - result = {**individual_bokeh, **individual_static} - self.write(result) - - -class _PrometheusCollector: - def __init__(self, server): - self.server = server - - def collect(self): - from prometheus_client.core import GaugeMetricFamily, CounterMetricFamily - - yield GaugeMetricFamily( - "dask_scheduler_clients", - "Number of clients connected.", - value=len(self.server.clients), - ) - - yield GaugeMetricFamily( - "dask_scheduler_desired_workers", - "Number of workers scheduler needs for task graph.", - value=self.server.adaptive_target(), - ) +def connect(application, http_server, scheduler, prefix=""): + bokeh_app = BokehApplication( + applications, scheduler, prefix=prefix, template_variables=template_variables + ) + application.add_application(bokeh_app) + bokeh_app.initialize(IOLoop.current()) - worker_states = GaugeMetricFamily( - "dask_scheduler_workers", - "Number of workers known by scheduler.", - labels=["state"], - ) - worker_states.add_metric(["connected"], len(self.server.workers)) - worker_states.add_metric(["saturated"], len(self.server.saturated)) - worker_states.add_metric(["idle"], len(self.server.idle)) - yield worker_states - - tasks = GaugeMetricFamily( - "dask_scheduler_tasks", - "Number of tasks known by scheduler.", - labels=["state"], - ) - - task_counter = merge_with( - sum, (tp.states for tp in self.server.task_prefixes.values()) - ) - - suspicious_tasks = CounterMetricFamily( - "dask_scheduler_tasks_suspicious", - "Total number of times a task has been marked suspicious", - labels=["task_prefix_name"], - ) - - for tp in self.server.task_prefixes.values(): - suspicious_tasks.add_metric([tp.name], tp.suspicious) - yield suspicious_tasks - - yield CounterMetricFamily( - "dask_scheduler_tasks_forgotten", + bokeh_app.add_handlers( + r".*", + [ ( - "Total number of processed tasks no longer in memory and already " - "removed from the scheduler job queue. Note task groups on the " - "scheduler which have all tasks in the forgotten state are not included." - ), - value=task_counter.get("forgotten", 0.0), - ) - - for state in ALL_TASK_STATES: - tasks.add_metric([state], task_counter.get(state, 0.0)) - yield tasks - - -class PrometheusHandler(RequestHandler): - _collector = None - - def __init__(self, *args, **kwargs): - import prometheus_client - - super(PrometheusHandler, self).__init__(*args, **kwargs) - - if PrometheusHandler._collector: - # Especially during testing, multiple schedulers are started - # sequentially in the same python process - PrometheusHandler._collector.server = self.server - return - - PrometheusHandler._collector = _PrometheusCollector(self.server) - prometheus_client.REGISTRY.register(PrometheusHandler._collector) - - def get(self): - import prometheus_client - - self.write(prometheus_client.generate_latest()) - self.set_header("Content-Type", "text/plain; version=0.0.4") - - -class HealthHandler(RequestHandler): - def get(self): - self.write("ok") - self.set_header("Content-Type", "text/plain") - - -class EventstreamHandler(WebSocketHandler): - def initialize(self, server=None, extra=None): - self.server = server - self.extra = extra or {} - self.plugin = WebsocketPlugin(self, server) - self.server.add_plugin(self.plugin) - - def send(self, name, data): - data["name"] = name - for k in list(data): - # Drop bytes objects for now - if isinstance(data[k], bytes): - del data[k] - self.write_message(data) - - def open(self): - for worker in self.server.workers: - self.plugin.add_worker(self.server, worker) - - def on_message(self, message): - message = json.loads(message) - if message["name"] == "ping": - self.send("pong", {"timestamp": str(datetime.now())}) - - def on_close(self): - self.server.remove_plugin(self.plugin) - - -routes = [ - (r"info", redirect("info/main/workers.html")), - (r"info/main/workers.html", Workers), - (r"info/worker/(.*).html", Worker), - (r"info/task/(.*).html", Task), - (r"info/main/logs.html", Logs), - (r"info/call-stacks/(.*).html", WorkerCallStacks), - (r"info/call-stack/(.*).html", TaskCallStack), - (r"info/logs/(.*).html", WorkerLogs), - (r"json/counts.json", CountsJSON), - (r"json/identity.json", IdentityJSON), - (r"json/index.html", IndexJSON), - (r"individual-plots.json", IndividualPlots), - (r"metrics", PrometheusHandler), - (r"health", HealthHandler), - (r"eventstream", EventstreamHandler), - (r"proxy/(\d+)/(.*?)/(.*)", GlobalProxyHandler), -] - - -def get_handlers(server): - return [(url, cls, {"server": server}) for url, cls in routes] - - -class BokehScheduler(BokehServer): - def __init__(self, scheduler, io_loop=None, prefix="", **kwargs): - self.scheduler = scheduler - prefix = prefix or "" - prefix = prefix.rstrip("/") - if prefix and not prefix.startswith("/"): - prefix = "/" + prefix - self.prefix = prefix - - self.server_kwargs = kwargs - - # TLS configuration - http_server_kwargs = kwargs.setdefault("http_server_kwargs", {}) - tls_key = dask.config.get("distributed.scheduler.dashboard.tls.key") - tls_cert = dask.config.get("distributed.scheduler.dashboard.tls.cert") - tls_ca_file = dask.config.get("distributed.scheduler.dashboard.tls.ca-file") - if tls_cert and "ssl_options" not in http_server_kwargs: - import ssl - - ctx = ssl.create_default_context( - cafile=tls_ca_file, purpose=ssl.Purpose.SERVER_AUTH + r"/", + web.RedirectHandler, + {"url": urljoin((prefix or "").strip("/") + "/", r"status")}, ) - ctx.load_cert_chain(tls_cert, keyfile=tls_key) - # Unlike the client/scheduler/worker TLS handling, we don't care - # about authenticating the user's webclient, TLS here is just for - # encryption. Disable these checks. - ctx.check_hostname = False - ctx.verify_mode = ssl.CERT_NONE - http_server_kwargs["ssl_options"] = ctx - - self.server_kwargs["prefix"] = prefix or None - - self.apps = applications - self.apps = {k: partial(v, scheduler, self.extra) for k, v in self.apps.items()} - - self.loop = io_loop or scheduler.loop - self.server = None - - @property - def extra(self): - return merge({"prefix": self.prefix}, template_variables) - - @property - def my_server(self): - return self.scheduler - - def listen(self, *args, **kwargs): - super(BokehScheduler, self).listen(*args, **kwargs) - - handlers = [ - ( - self.prefix + "/" + url, - cls, - {"server": self.my_server, "extra": self.extra}, - ) - for url, cls in routes - ] - - self.server._tornado.add_handlers(r".*", handlers) + ], + ) applications = { diff --git a/distributed/dashboard/tests/test_scheduler_bokeh.py b/distributed/dashboard/tests/test_scheduler_bokeh.py index 65b5fa25d50..49bdfe448bf 100644 --- a/distributed/dashboard/tests/test_scheduler_bokeh.py +++ b/distributed/dashboard/tests/test_scheduler_bokeh.py @@ -17,9 +17,8 @@ from distributed.client import wait from distributed.metrics import time from distributed.utils_test import gen_cluster, inc, dec, slowinc, div, get_cert -from distributed.dashboard.worker import BokehWorker from distributed.dashboard.components.worker import Counters -from distributed.dashboard.scheduler import applications, BokehScheduler +from distributed.dashboard.scheduler import applications from distributed.dashboard.components.scheduler import ( SystemMonitor, Occupancy, @@ -46,12 +45,9 @@ @pytest.mark.skipif( sys.version_info[0] == 2, reason="https://github.com/bokeh/bokeh/issues/5494" ) -@gen_cluster( - client=True, scheduler_kwargs={"services": {("dashboard", 0): BokehScheduler}} -) +@gen_cluster(client=True, scheduler_kwargs={"dashboard": True}) def test_simple(c, s, a, b): - assert isinstance(s.services["dashboard"], BokehScheduler) - port = s.services["dashboard"].port + port = s.http_server.port future = c.submit(sleep, 1) yield gen.sleep(0.1) @@ -70,7 +66,7 @@ def test_simple(c, s, a, b): assert response -@gen_cluster(client=True, worker_kwargs=dict(services={"dashboard": BokehWorker})) +@gen_cluster(client=True, worker_kwargs={"dashboard": True}) def test_basic(c, s, a, b): for component in [TaskStream, SystemMonitor, Occupancy, StealingTimeSeries]: ss = component(s) @@ -592,21 +588,19 @@ def test_profile_server(c, s, a, b): @gen_cluster( - client=True, scheduler_kwargs={"services": {("dashboard", 0): BokehScheduler}} + client=True, scheduler_kwargs={"dashboard": True}, ) def test_root_redirect(c, s, a, b): http_client = AsyncHTTPClient() - response = yield http_client.fetch( - "http://localhost:%d/" % s.services["dashboard"].port - ) + response = yield http_client.fetch("http://localhost:%d/" % s.http_server.port) assert response.code == 200 assert "/status" in response.effective_url @gen_cluster( client=True, - scheduler_kwargs={"services": {("dashboard", 0): BokehScheduler}}, - worker_kwargs={"services": {"dashboard": BokehWorker}}, + scheduler_kwargs={"dashboard": True}, + worker_kwargs={"dashboard": True}, timeout=180, ) def test_proxy_to_workers(c, s, a, b): @@ -617,7 +611,7 @@ def test_proxy_to_workers(c, s, a, b): except ImportError: proxy_exists = False - dashboard_port = s.services["dashboard"].port + dashboard_port = s.http_server.port http_client = AsyncHTTPClient() response = yield http_client.fetch("http://localhost:%d/" % dashboard_port) assert response.code == 200 @@ -625,7 +619,7 @@ def test_proxy_to_workers(c, s, a, b): for w in [a, b]: host = w.ip - port = w.service_ports["dashboard"] + port = w.http_server.port proxy_url = "http://localhost:%d/proxy/%s/%s/status" % ( dashboard_port, port, @@ -647,7 +641,7 @@ def test_proxy_to_workers(c, s, a, b): @gen_cluster( client=True, - scheduler_kwargs={"services": {("dashboard", 0): BokehScheduler}}, + scheduler_kwargs={"dashboard": True}, config={ "distributed.scheduler.dashboard.tasks.task-stream-length": 10, "distributed.scheduler.dashboard.status.task-stream-length": 10, @@ -675,7 +669,7 @@ async def test_lots_of_tasks(c, s, a, b): @gen_cluster( client=True, - scheduler_kwargs={"services": {("dashboard", 0): BokehScheduler}}, + scheduler_kwargs={"dashboard": True}, config={ "distributed.scheduler.dashboard.tls.key": get_cert("tls-key.pem"), "distributed.scheduler.dashboard.tls.cert": get_cert("tls-cert.pem"), @@ -683,8 +677,7 @@ async def test_lots_of_tasks(c, s, a, b): }, ) def test_https_support(c, s, a, b): - assert isinstance(s.services["dashboard"], BokehScheduler) - port = s.services["dashboard"].port + port = s.http_server.port assert ( format_dashboard_link("localhost", port) == "https://localhost:%d/status" % port @@ -717,9 +710,7 @@ def test_https_support(c, s, a, b): assert not re.search("href=./", body) # no absolute links -@gen_cluster( - client=True, scheduler_kwargs={"services": {("dashboard", 0): BokehScheduler}} -) +@gen_cluster(client=True, scheduler_kwargs={"dashboard": True}) async def test_memory_by_key(c, s, a, b): mbk = MemoryByKey(s) diff --git a/distributed/dashboard/tests/test_worker_bokeh.py b/distributed/dashboard/tests/test_worker_bokeh.py index 97729fce14f..873cc1c1f3e 100644 --- a/distributed/dashboard/tests/test_worker_bokeh.py +++ b/distributed/dashboard/tests/test_worker_bokeh.py @@ -5,7 +5,6 @@ import pytest pytest.importorskip("bokeh") -import sys from tlz import first from tornado import gen from tornado.httpclient import AsyncHTTPClient @@ -13,8 +12,6 @@ from distributed.client import wait from distributed.metrics import time from distributed.utils_test import gen_cluster, inc, dec -from distributed.dashboard.scheduler import BokehScheduler -from distributed.dashboard.worker import BokehWorker from distributed.dashboard.components.worker import ( StateTable, CrossFilter, @@ -28,13 +25,11 @@ @gen_cluster( client=True, - worker_kwargs={"services": {("dashboard", 0): BokehWorker}}, - scheduler_kwargs={"services": {("dashboard", 0): BokehScheduler}}, + worker_kwargs={"dashboard": True}, + scheduler_kwargs={"dashboard": True}, ) def test_routes(c, s, a, b): - assert isinstance(a.services["dashboard"], BokehWorker) - assert isinstance(b.services["dashboard"], BokehWorker) - port = a.services["dashboard"].port + port = a.http_server.port future = c.submit(sleep, 1) yield gen.sleep(0.1) @@ -47,37 +42,33 @@ def test_routes(c, s, a, b): assert not re.search("href=./", body) # no absolute links response = yield http_client.fetch( - "http://localhost:%d/info/main/workers.html" % s.services["dashboard"].port + "http://localhost:%d/info/main/workers.html" % s.http_server.port ) assert str(port) in response.body.decode() -@pytest.mark.skipif( - sys.version_info[0] == 2, reason="https://github.com/bokeh/bokeh/issues/5494" -) -@gen_cluster(client=True, worker_kwargs={"services": {("dashboard", 0): BokehWorker}}) +@gen_cluster(client=True, worker_kwargs={"dashboard": True}) def test_simple(c, s, a, b): - assert s.workers[a.address].services == {"dashboard": a.services["dashboard"].port} - assert s.workers[b.address].services == {"dashboard": b.services["dashboard"].port} + assert s.workers[a.address].services == {"dashboard": a.http_server.port} + assert s.workers[b.address].services == {"dashboard": b.http_server.port} future = c.submit(sleep, 1) yield gen.sleep(0.1) http_client = AsyncHTTPClient() - for suffix in ["main", "crossfilter", "system"]: + for suffix in ["crossfilter", "system"]: response = yield http_client.fetch( - "http://localhost:%d/%s" % (a.services["dashboard"].port, suffix) + "http://localhost:%d/%s" % (a.http_server.port, suffix) ) assert "bokeh" in response.body.decode().lower() @gen_cluster( - client=True, worker_kwargs={"services": {("dashboard", 0): (BokehWorker, {})}} + client=True, worker_kwargs={"dashboard": True}, ) def test_services_kwargs(c, s, a, b): - assert s.workers[a.address].services == {"dashboard": a.services["dashboard"].port} - assert isinstance(a.services["dashboard"], BokehWorker) + assert s.workers[a.address].services == {"dashboard": a.http_server.port} @gen_cluster(client=True) @@ -166,17 +157,14 @@ def test_CommunicatingStream(c, s, a, b): @gen_cluster( - client=True, - clean_kwargs={"threads": False}, - worker_kwargs={"services": {("dashboard", 0): BokehWorker}}, + client=True, clean_kwargs={"threads": False}, worker_kwargs={"dashboard": True}, ) def test_prometheus(c, s, a, b): pytest.importorskip("prometheus_client") - assert s.workers[a.address].services == {"dashboard": a.services["dashboard"].port} http_client = AsyncHTTPClient() for suffix in ["metrics"]: response = yield http_client.fetch( - "http://localhost:%d/%s" % (a.services["dashboard"].port, suffix) + "http://localhost:%d/%s" % (a.http_server.port, suffix) ) assert response.code == 200 diff --git a/distributed/dashboard/utils.py b/distributed/dashboard/utils.py index 394e016a4da..0de536a6050 100644 --- a/distributed/dashboard/utils.py +++ b/distributed/dashboard/utils.py @@ -1,10 +1,8 @@ from distutils.version import LooseVersion -import os from numbers import Number import bokeh from bokeh.io import curdoc -from tornado import web from tlz import partition from tlz.curried import first @@ -15,7 +13,6 @@ BOKEH_VERSION = LooseVersion(bokeh.__version__) -dirname = os.path.dirname(__file__) PROFILING = False @@ -45,23 +42,6 @@ def transpose(lod): return {k: [d[k] for d in lod] for k in keys} -class RequestHandler(web.RequestHandler): - def initialize(self, server=None, extra=None): - self.server = server - self.extra = extra or {} - - def get_template_path(self): - return os.path.join(dirname, "templates") - - -def redirect(path): - class Redirect(RequestHandler): - def get(self): - self.redirect(path) - - return Redirect - - @without_property_validation def update(source, data): """ Update source with data diff --git a/distributed/dashboard/worker.py b/distributed/dashboard/worker.py index 54b3a0a4a51..ff9ae3b2f7d 100644 --- a/distributed/dashboard/worker.py +++ b/distributed/dashboard/worker.py @@ -1,10 +1,3 @@ -from functools import partial -import logging -import os - -from bokeh.themes import Theme -from tlz import merge - from .components.worker import ( status_doc, crossfilter_doc, @@ -13,177 +6,28 @@ profile_doc, profile_server_doc, ) -from .core import BokehServer -from .utils import RequestHandler, redirect - - -logger = logging.getLogger(__name__) - -with open(os.path.join(os.path.dirname(__file__), "templates", "base.html")) as f: - template_source = f.read() - -from jinja2 import Environment, FileSystemLoader - -env = Environment( - loader=FileSystemLoader(os.path.join(os.path.dirname(__file__), "templates")) -) +from .core import BokehApplication +from tornado.ioloop import IOLoop -BOKEH_THEME = Theme(os.path.join(os.path.dirname(__file__), "theme.yaml")) template_variables = { "pages": ["status", "system", "profile", "crossfilter", "profile-server"] } -class _PrometheusCollector: - def __init__(self, server): - self.worker = server - self.logger = logging.getLogger("distributed.dask_worker") - self.crick_available = True - try: - import crick # noqa: F401 - except ImportError: - self.crick_available = False - self.logger.info( - "Not all prometheus metrics available are exported. Digest-based metrics require crick to be installed" - ) - - def collect(self): - from prometheus_client.core import GaugeMetricFamily - - tasks = GaugeMetricFamily( - "dask_worker_tasks", "Number of tasks at worker.", labels=["state"] - ) - tasks.add_metric(["stored"], len(self.worker.data)) - tasks.add_metric(["executing"], len(self.worker.executing)) - tasks.add_metric(["ready"], len(self.worker.ready)) - tasks.add_metric(["waiting"], len(self.worker.waiting_for_data)) - tasks.add_metric(["serving"], len(self.worker._comms)) - yield tasks - - yield GaugeMetricFamily( - "dask_worker_connections", - "Number of task connections to other workers.", - value=len(self.worker.in_flight_workers), - ) - - yield GaugeMetricFamily( - "dask_worker_threads", - "Number of worker threads.", - value=self.worker.nthreads, - ) - - yield GaugeMetricFamily( - "dask_worker_latency_seconds", - "Latency of worker connection.", - value=self.worker.latency, - ) - - # all metrics using digests require crick to be installed - # the following metrics will export NaN, if the corresponding digests are None - if self.crick_available: - yield GaugeMetricFamily( - "dask_worker_tick_duration_median_seconds", - "Median tick duration at worker.", - value=self.worker.digests["tick-duration"].components[1].quantile(50), - ) - - yield GaugeMetricFamily( - "dask_worker_task_duration_median_seconds", - "Median task runtime at worker.", - value=self.worker.digests["task-duration"].components[1].quantile(50), - ) - - yield GaugeMetricFamily( - "dask_worker_transfer_bandwidth_median_bytes", - "Bandwidth for transfer at worker in Bytes.", - value=self.worker.digests["transfer-bandwidth"] - .components[1] - .quantile(50), - ) - - -class PrometheusHandler(RequestHandler): - _initialized = False - - def __init__(self, *args, **kwargs): - import prometheus_client - - super(PrometheusHandler, self).__init__(*args, **kwargs) +def connect(application, http_server, worker, prefix=""): + bokeh_app = BokehApplication( + applications, worker, prefix=prefix, template_variables=template_variables + ) + application.add_application(bokeh_app) + bokeh_app.initialize(IOLoop.current()) - if PrometheusHandler._initialized: - return - prometheus_client.REGISTRY.register(_PrometheusCollector(self.server)) - - PrometheusHandler._initialized = True - - def get(self): - import prometheus_client - - self.write(prometheus_client.generate_latest()) - self.set_header("Content-Type", "text/plain; version=0.0.4") - - -class HealthHandler(RequestHandler): - def get(self): - self.write("ok") - self.set_header("Content-Type", "text/plain") - - -routes = [ - (r"metrics", PrometheusHandler), - (r"health", HealthHandler), - (r"main", redirect("/status")), -] - - -def get_handlers(server): - return [(url, cls, {"server": server}) for url, cls in routes] - - -class BokehWorker(BokehServer): - def __init__(self, worker, io_loop=None, prefix="", **kwargs): - self.worker = worker - self.server_kwargs = kwargs - self.server_kwargs["prefix"] = prefix or None - prefix = prefix or "" - prefix = prefix.rstrip("/") - if prefix and not prefix.startswith("/"): - prefix = "/" + prefix - self.prefix = prefix - - self.apps = { - "/status": status_doc, - "/counters": counters_doc, - "/crossfilter": crossfilter_doc, - "/system": systemmonitor_doc, - "/profile": profile_doc, - "/profile-server": profile_server_doc, - } - self.apps = {k: partial(v, worker, self.extra) for k, v in self.apps.items()} - - self.loop = io_loop or worker.loop - self.server = None - - @property - def extra(self): - return merge({"prefix": self.prefix}, template_variables) - - @property - def my_server(self): - return self.worker - - def listen(self, *args, **kwargs): - super(BokehWorker, self).listen(*args, **kwargs) - - handlers = [ - ( - self.prefix + "/" + url, - cls, - {"server": self.my_server, "extra": self.extra}, - ) - for url, cls in routes - ] - - self.server._tornado.add_handlers(r".*", handlers) +applications = { + "/status": status_doc, + "/counters": counters_doc, + "/crossfilter": crossfilter_doc, + "/system": systemmonitor_doc, + "/profile": profile_doc, + "/profile-server": profile_server_doc, +} diff --git a/distributed/deploy/cluster.py b/distributed/deploy/cluster.py index 81f3d578fb2..8082d278483 100644 --- a/distributed/deploy/cluster.py +++ b/distributed/deploy/cluster.py @@ -48,7 +48,7 @@ class Cluster: _supports_scaling = True def __init__(self, asynchronous): - self.scheduler_info = {} + self.scheduler_info = {"workers": {}} self.periodic_callbacks = {} self._asynchronous = asynchronous diff --git a/distributed/deploy/local.py b/distributed/deploy/local.py index d1744ed32c0..cd33ad26a12 100644 --- a/distributed/deploy/local.py +++ b/distributed/deploy/local.py @@ -6,6 +6,7 @@ from dask.utils import factors from dask.system import CPU_COUNT +import toolz from .spec import SpecCluster from ..nanny import Nanny @@ -110,6 +111,7 @@ def __init__( blocked_handlers=None, interface=None, worker_class=None, + scheduler_kwargs=None, **worker_kwargs ): if ip is not None: @@ -172,6 +174,7 @@ def __init__( "nthreads": threads_per_worker, "services": worker_services, "dashboard_address": worker_dashboard_address, + "dashboard": worker_dashboard_address is not None, "interface": interface, "protocol": protocol, "security": security, @@ -181,16 +184,20 @@ def __init__( scheduler = { "cls": Scheduler, - "options": dict( - host=host, - services=services, - service_kwargs=service_kwargs, - security=security, - port=scheduler_port, - interface=interface, - protocol=protocol, - dashboard_address=dashboard_address, - blocked_handlers=blocked_handlers, + "options": toolz.merge( + dict( + host=host, + services=services, + service_kwargs=service_kwargs, + security=security, + port=scheduler_port, + interface=interface, + protocol=protocol, + dashboard=dashboard_address is not None, + dashboard_address=dashboard_address, + blocked_handlers=blocked_handlers, + ), + scheduler_kwargs or {}, ), } diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index 17b1af28148..99ab70d2de1 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -268,12 +268,12 @@ async def _start(self): if self.scheduler_spec is None: try: - from distributed.dashboard import BokehScheduler + import distributed.dashboard # noqa: F401 except ImportError: - services = {} + pass else: - services = {("dashboard", 8787): BokehScheduler} - self.scheduler_spec = {"cls": Scheduler, "options": {"services": services}} + options = {"dashboard": True} + self.scheduler_spec = {"cls": Scheduler, "options": options} cls = self.scheduler_spec["cls"] if isinstance(cls, str): diff --git a/distributed/deploy/tests/test_local.py b/distributed/deploy/tests/test_local.py index 403beb3aa41..31fbcebd3b8 100644 --- a/distributed/deploy/tests/test_local.py +++ b/distributed/deploy/tests/test_local.py @@ -12,6 +12,7 @@ from tornado.ioloop import IOLoop from tornado import gen import tornado +from tornado.httpclient import AsyncHTTPClient import pytest from dask.system import CPU_COUNT @@ -196,11 +197,18 @@ def test_Client_solo(loop): @gen_test() -def test_duplicate_clients(): +async def test_duplicate_clients(): pytest.importorskip("bokeh") - c1 = yield Client(processes=False, silence_logs=False, dashboard_address=9876) + c1 = await Client( + processes=False, silence_logs=False, dashboard_address=9876, asynchronous=True + ) with pytest.warns(Warning) as info: - c2 = yield Client(processes=False, silence_logs=False, dashboard_address=9876) + c2 = await Client( + processes=False, + silence_logs=False, + dashboard_address=9876, + asynchronous=True, + ) assert "dashboard" in c1.cluster.scheduler.services assert "dashboard" in c2.cluster.scheduler.services @@ -212,8 +220,8 @@ def test_duplicate_clients(): ) for msg in info.list ) - yield c1.close() - yield c2.close() + await c1.close() + await c2.close() def test_Client_kwargs(loop): @@ -405,7 +413,7 @@ def test_bokeh(loop, processes): processes=processes, dashboard_address=0, ) as c: - bokeh_port = c.scheduler.services["dashboard"].port + bokeh_port = c.scheduler.http_server.port url = "http://127.0.0.1:%d/status/" % bokeh_port start = time() while True: @@ -543,19 +551,22 @@ def test_death_timeout_raises(loop): @pytest.mark.skipif(sys.version_info < (3, 6), reason="Unknown") -def test_bokeh_kwargs(loop): +@pytest.mark.asyncio +async def test_bokeh_kwargs(cleanup): pytest.importorskip("bokeh") - with LocalCluster( + async with LocalCluster( n_workers=0, scheduler_port=0, silence_logs=False, - loop=loop, dashboard_address=0, - service_kwargs={"dashboard": {"prefix": "/foo"}}, + asynchronous=True, + scheduler_kwargs={"http_prefix": "/foo"}, ) as c: - - bs = c.scheduler.services["dashboard"] - assert bs.prefix == "/foo" + client = AsyncHTTPClient() + response = await client.fetch( + "http://localhost:{}/foo/status".format(c.scheduler.http_server.port) + ) + assert "bokeh" in response.body.decode() def test_io_loop_periodic_callbacks(loop): diff --git a/distributed/deploy/tests/test_spec_cluster.py b/distributed/deploy/tests/test_spec_cluster.py index ae24e7400e2..c9482f5da56 100644 --- a/distributed/deploy/tests/test_spec_cluster.py +++ b/distributed/deploy/tests/test_spec_cluster.py @@ -4,6 +4,7 @@ import dask from dask.distributed import SpecCluster, Worker, Client, Scheduler, Nanny +from distributed.compatibility import WINDOWS from distributed.deploy.spec import close_clusters, ProcessInterface, run_spec from distributed.metrics import time from distributed.utils_test import loop, cleanup # noqa: F401 @@ -84,9 +85,10 @@ def test_spec_sync(loop): def test_loop_started(): - cluster = SpecCluster( + with SpecCluster( worker_spec, scheduler={"cls": Scheduler, "options": {"port": 0}} - ) + ) as cluster: + pass @pytest.mark.asyncio @@ -212,6 +214,7 @@ async def test_restart(cleanup): assert len(cluster.workers) == 2 +@pytest.mark.skipif(WINDOWS, reason="HTTP Server doesn't close out") @pytest.mark.asyncio async def test_broken_worker(): with pytest.raises(Exception) as info: @@ -225,6 +228,7 @@ async def test_broken_worker(): assert "Broken" in str(info.value) +@pytest.mark.skipif(WINDOWS, reason="HTTP Server doesn't close out") @pytest.mark.slow def test_spec_close_clusters(loop): workers = {0: {"cls": Worker}} diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index bd1dd85c57e..b11270f4704 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -38,9 +38,22 @@ distributed: ca-file: null key: null cert: null + bokeh-application: # keywords to pass to BokehTornado application + allow_websocket_origin: ["*"] + keep_alive_milliseconds: 500 + check_unused_sessions_milliseconds: 500 locks: lease-validation-interval: 10s # The time to wait until an acquired semaphore is released if the Client goes out of scope + http: + routes: + - distributed.http.scheduler.prometheus + - distributed.http.scheduler.info + - distributed.http.scheduler.json + - distributed.http.health + - distributed.http.proxy + - distributed.http.statics + worker: blocked-handlers: [] multiprocessing-method: spawn @@ -71,6 +84,12 @@ distributed: pause: 0.80 # fraction at which we pause worker threads terminate: 0.95 # fraction at which we terminate the worker + http: + routes: + - distributed.http.worker.prometheus + - distributed.http.health + - distributed.http.statics + nanny: preload: [] # Run custom modules with Nanny preload-argv: [] # See https://docs.dask.org/en/latest/setup/custom-startup.html diff --git a/distributed/http/__init__.py b/distributed/http/__init__.py new file mode 100644 index 00000000000..b41a454ed2f --- /dev/null +++ b/distributed/http/__init__.py @@ -0,0 +1 @@ +from .utils import get_handlers diff --git a/distributed/http/health.py b/distributed/http/health.py new file mode 100644 index 00000000000..2a45c4abf77 --- /dev/null +++ b/distributed/http/health.py @@ -0,0 +1,12 @@ +from tornado import web + + +class HealthHandler(web.RequestHandler): + def get(self): + self.write("ok") + self.set_header("Content-Type", "text/plain") + + +routes = [ + ("/health", HealthHandler, {}), +] diff --git a/distributed/dashboard/proxy.py b/distributed/http/proxy.py similarity index 93% rename from distributed/dashboard/proxy.py rename to distributed/http/proxy.py index 3e76ba11c0e..c1f437d9b5f 100644 --- a/distributed/dashboard/proxy.py +++ b/distributed/http/proxy.py @@ -13,8 +13,8 @@ class GlobalProxyHandler(ProxyHandler): from a port to any valid endpoint'. """ - def initialize(self, server=None, extra=None): - self.scheduler = server + def initialize(self, dask_server=None, extra=None): + self.scheduler = dask_server self.extra = extra or {} async def http_get(self, port, host, proxied_path): @@ -77,8 +77,8 @@ class GlobalProxyHandler(web.RequestHandler): """Minimal Proxy handler when jupyter-server-proxy is not installed """ - def initialize(self, server=None, extra=None): - self.server = server + def initialize(self, dask_server=None, extra=None): + self.server = dask_server self.extra = extra or {} def get(self, port, host, proxied_path): @@ -128,3 +128,8 @@ def check_worker_dashboard_exits(scheduler, worker): if addr == w.host and port == str(bokeh_port): return True return False + + +routes = [ + (r"proxy/(\d+)/(.*?)/(.*)", GlobalProxyHandler, {}), +] diff --git a/distributed/http/routing.py b/distributed/http/routing.py new file mode 100644 index 00000000000..ac51086493d --- /dev/null +++ b/distributed/http/routing.py @@ -0,0 +1,22 @@ +from tornado import web +import tornado.httputil + + +class RoutingApplication(web.Application): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.applications = [] + + def find_handler(self, request: tornado.httputil.HTTPServerRequest, **kwargs): + handler = super().find_handler(request, **kwargs) + if handler and not issubclass(handler.handler_class, web.ErrorHandler): + return handler + else: + for app in self.applications: + handler = app.find_handler(request, **kwargs) or handler + if handler and not issubclass(handler.handler_class, web.ErrorHandler): + break + return handler + + def add_application(self, application: web.Application): + self.applications.append(application) diff --git a/distributed/http/scheduler/__init__.py b/distributed/http/scheduler/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/distributed/http/scheduler/info.py b/distributed/http/scheduler/info.py new file mode 100644 index 00000000000..6e5a222dd23 --- /dev/null +++ b/distributed/http/scheduler/info.py @@ -0,0 +1,203 @@ +from datetime import datetime +import json +import logging +import os +import os.path + +from dask.utils import format_bytes + +from tornado import escape +from tornado.websocket import WebSocketHandler +from tlz import first, merge + +from ..utils import RequestHandler, redirect +from ...diagnostics.websocket import WebsocketPlugin +from ...metrics import time +from ...utils import log_errors, format_time + +ns = { + func.__name__: func + for func in [format_bytes, format_time, datetime.fromtimestamp, time] +} + +rel_path_statics = {"rel_path_statics": "../../.."} + + +logger = logging.getLogger(__name__) + + +class Workers(RequestHandler): + def get(self): + with log_errors(): + self.render( + "workers.html", + title="Workers", + scheduler=self.server, + **merge(self.server.__dict__, ns, self.extra, rel_path_statics), + ) + + +class Worker(RequestHandler): + def get(self, worker): + worker = escape.url_unescape(worker) + if worker not in self.server.workers: + self.send_error(404) + return + with log_errors(): + self.render( + "worker.html", + title="Worker: " + worker, + scheduler=self.server, + Worker=worker, + **merge(self.server.__dict__, ns, self.extra, rel_path_statics), + ) + + +class Task(RequestHandler): + def get(self, task): + task = escape.url_unescape(task) + if task not in self.server.tasks: + self.send_error(404) + return + with log_errors(): + self.render( + "task.html", + title="Task: " + task, + Task=task, + scheduler=self.server, + **merge(self.server.__dict__, ns, self.extra, rel_path_statics), + ) + + +class Logs(RequestHandler): + def get(self): + with log_errors(): + logs = self.server.get_logs() + self.render( + "logs.html", + title="Logs", + logs=logs, + **merge(self.extra, rel_path_statics), + ) + + +class WorkerLogs(RequestHandler): + async def get(self, worker): + with log_errors(): + worker = escape.url_unescape(worker) + logs = await self.server.get_worker_logs(workers=[worker]) + logs = logs[worker] + self.render( + "logs.html", + title="Logs: " + worker, + logs=logs, + **merge(self.extra, rel_path_statics), + ) + + +class WorkerCallStacks(RequestHandler): + async def get(self, worker): + with log_errors(): + worker = escape.url_unescape(worker) + keys = self.server.processing[worker] + call_stack = await self.server.get_call_stack(keys=keys) + self.render( + "call-stack.html", + title="Call Stacks: " + worker, + call_stack=call_stack, + **merge(self.extra, rel_path_statics), + ) + + +class TaskCallStack(RequestHandler): + async def get(self, key): + with log_errors(): + key = escape.url_unescape(key) + call_stack = await self.server.get_call_stack(keys=[key]) + if not call_stack: + self.write( + "

        Task not actively running. " + "It may be finished or not yet started

        " + ) + else: + self.render( + "call-stack.html", + title="Call Stack: " + key, + call_stack=call_stack, + **merge(self.extra, rel_path_statics), + ) + + +class IndividualPlots(RequestHandler): + def get(self): + try: + from bokeh.server.tornado import BokehTornado + + bokeh_application = first( + app + for app in self.server.http_application.applications + if isinstance(app, BokehTornado) + ) + individual_bokeh = { + uri.strip("/").replace("-", " ").title(): uri + for uri in bokeh_application.app_paths + if uri.lstrip("/").startswith("individual-") + and not uri.endswith(".json") + } + individual_static = { + uri.strip("/") + .replace(".html", "") + .replace("-", " ") + .title(): "/statics/" + + uri + for uri in os.listdir( + os.path.join(os.path.dirname(__file__), "..", "static") + ) + if uri.lstrip("/").startswith("individual-") and uri.endswith(".html") + } + result = {**individual_bokeh, **individual_static} + self.write(result) + except (ImportError, StopIteration): + self.write({}) + + +class EventstreamHandler(WebSocketHandler): + def initialize(self, dask_server=None, extra=None): + self.server = dask_server + self.extra = extra or {} + self.plugin = WebsocketPlugin(self, self.server) + self.server.add_plugin(self.plugin) + + def send(self, name, data): + data["name"] = name + for k in list(data): + # Drop bytes objects for now + if isinstance(data[k], bytes): + del data[k] + self.write_message(data) + + def open(self): + for worker in self.server.workers: + self.plugin.add_worker(self.server, worker) + + def on_message(self, message): + message = json.loads(message) + if message["name"] == "ping": + self.send("pong", {"timestamp": str(datetime.now())}) + + def on_close(self): + self.server.remove_plugin(self.plugin) + + +routes = [ + (r"info", redirect("info/main/workers.html"), {}), + (r"info/main/workers.html", Workers, {}), + (r"info/worker/(.*).html", Worker, {}), + (r"info/task/(.*).html", Task, {}), + (r"info/main/logs.html", Logs, {}), + (r"info/call-stacks/(.*).html", WorkerCallStacks, {}), + (r"info/call-stack/(.*).html", TaskCallStack, {}), + (r"info/logs/(.*).html", WorkerLogs, {}), + (r"individual-plots.json", IndividualPlots, {}), + (r"eventstream", EventstreamHandler, {}), +] diff --git a/distributed/http/scheduler/json.py b/distributed/http/scheduler/json.py new file mode 100644 index 00000000000..5dc09b4b6fe --- /dev/null +++ b/distributed/http/scheduler/json.py @@ -0,0 +1,72 @@ +from ..utils import RequestHandler +from ...utils import log_errors + + +class CountsJSON(RequestHandler): + def get(self): + scheduler = self.server + erred = 0 + nbytes = 0 + nthreads = 0 + memory = 0 + processing = 0 + released = 0 + waiting = 0 + waiting_data = 0 + desired_workers = scheduler.adaptive_target() + + for ts in scheduler.tasks.values(): + if ts.exception_blame is not None: + erred += 1 + elif ts.state == "released": + released += 1 + if ts.waiting_on: + waiting += 1 + if ts.waiters: + waiting_data += 1 + for ws in scheduler.workers.values(): + nthreads += ws.nthreads + memory += len(ws.has_what) + nbytes += ws.nbytes + processing += len(ws.processing) + + response = { + "bytes": nbytes, + "clients": len(scheduler.clients), + "cores": nthreads, + "erred": erred, + "hosts": len(scheduler.host_info), + "idle": len(scheduler.idle), + "memory": memory, + "processing": processing, + "released": released, + "saturated": len(scheduler.saturated), + "tasks": len(scheduler.tasks), + "unrunnable": len(scheduler.unrunnable), + "waiting": waiting, + "waiting_data": waiting_data, + "workers": len(scheduler.workers), + "desired_workers": desired_workers, + } + self.write(response) + + +class IdentityJSON(RequestHandler): + def get(self): + self.write(self.server.identity()) + + +class IndexJSON(RequestHandler): + def get(self): + with log_errors(): + r = [url[5:] for url, _, _ in routes if url.endswith(".json")] + self.render( + "json-index.html", routes=r, title="Index of JSON routes", **self.extra + ) + + +routes = [ + (r"json/counts.json", CountsJSON, {}), + (r"json/identity.json", IdentityJSON, {}), + (r"json/index.html", IndexJSON, {}), +] diff --git a/distributed/http/scheduler/prometheus.py b/distributed/http/scheduler/prometheus.py new file mode 100644 index 00000000000..0f1f9c3c14f --- /dev/null +++ b/distributed/http/scheduler/prometheus.py @@ -0,0 +1,99 @@ +import toolz + +from ..utils import RequestHandler +from ...scheduler import ALL_TASK_STATES + + +class _PrometheusCollector: + def __init__(self, dask_server): + self.server = dask_server + + def collect(self): + from prometheus_client.core import GaugeMetricFamily, CounterMetricFamily + + yield GaugeMetricFamily( + "dask_scheduler_clients", + "Number of clients connected.", + value=len(self.server.clients), + ) + + yield GaugeMetricFamily( + "dask_scheduler_desired_workers", + "Number of workers scheduler needs for task graph.", + value=self.server.adaptive_target(), + ) + + worker_states = GaugeMetricFamily( + "dask_scheduler_workers", + "Number of workers known by scheduler.", + labels=["state"], + ) + worker_states.add_metric(["connected"], len(self.server.workers)) + worker_states.add_metric(["saturated"], len(self.server.saturated)) + worker_states.add_metric(["idle"], len(self.server.idle)) + yield worker_states + + tasks = GaugeMetricFamily( + "dask_scheduler_tasks", + "Number of tasks known by scheduler.", + labels=["state"], + ) + + task_counter = toolz.merge_with( + sum, (tp.states for tp in self.server.task_prefixes.values()) + ) + + suspicious_tasks = CounterMetricFamily( + "dask_scheduler_tasks_suspicious", + "Total number of times a task has been marked suspicious", + labels=["task_prefix_name"], + ) + + for tp in self.server.task_prefixes.values(): + suspicious_tasks.add_metric([tp.name], tp.suspicious) + yield suspicious_tasks + + yield CounterMetricFamily( + "dask_scheduler_tasks_forgotten", + ( + "Total number of processed tasks no longer in memory and already " + "removed from the scheduler job queue. Note task groups on the " + "scheduler which have all tasks in the forgotten state are not included." + ), + value=task_counter.get("forgotten", 0.0), + ) + + for state in ALL_TASK_STATES: + tasks.add_metric([state], task_counter.get(state, 0.0)) + yield tasks + + +class PrometheusHandler(RequestHandler): + _collector = None + + def __init__(self, *args, dask_server=None, **kwargs): + import prometheus_client + + super(PrometheusHandler, self).__init__( + *args, dask_server=dask_server, **kwargs + ) + + if PrometheusHandler._collector: + # Especially during testing, multiple schedulers are started + # sequentially in the same python process + PrometheusHandler._collector.server = self.server + return + + PrometheusHandler._collector = _PrometheusCollector(self.server) + prometheus_client.REGISTRY.register(PrometheusHandler._collector) + + def get(self): + import prometheus_client + + self.write(prometheus_client.generate_latest()) + self.set_header("Content-Type", "text/plain; version=0.0.4") + + +routes = [ + ("/metrics", PrometheusHandler, {}), +] diff --git a/distributed/dashboard/tests/test_scheduler_bokeh_html.py b/distributed/http/scheduler/tests/test_scheduler_http.py similarity index 72% rename from distributed/dashboard/tests/test_scheduler_bokeh_html.py rename to distributed/http/scheduler/tests/test_scheduler_http.py index de71b12a0d1..f1c3a8ed064 100644 --- a/distributed/dashboard/tests/test_scheduler_bokeh_html.py +++ b/distributed/http/scheduler/tests/test_scheduler_http.py @@ -6,20 +6,14 @@ pytest.importorskip("bokeh") from tornado.escape import url_escape -from tornado.httpclient import AsyncHTTPClient, HTTPClientError, HTTPRequest -from tornado.websocket import websocket_connect +from tornado.httpclient import AsyncHTTPClient, HTTPClientError from dask.sizeof import sizeof from distributed.utils import is_valid_xml from distributed.utils_test import gen_cluster, slowinc, inc -from distributed.dashboard import BokehScheduler, BokehWorker -@gen_cluster( - client=True, - scheduler_kwargs={"services": {("dashboard", 0): BokehScheduler}}, - worker_kwargs={"services": {"dashboard": BokehWorker}}, -) +@gen_cluster(client=True) async def test_connect(c, s, a, b): future = c.submit(lambda x: x + 1, 1) x = c.submit(slowinc, 1, delay=1, retries=5) @@ -39,7 +33,7 @@ async def test_connect(c, s, a, b): "individual-plots.json", ]: response = await http_client.fetch( - "http://localhost:%d/%s" % (s.services["dashboard"].port, suffix) + "http://localhost:%d/%s" % (s.http_server.port, suffix) ) assert response.code == 200 body = response.body.decode() @@ -50,36 +44,27 @@ async def test_connect(c, s, a, b): assert not re.search("href=./", body) # no absolute links -@gen_cluster( - client=True, - nthreads=[], - scheduler_kwargs={"services": {("dashboard", 0): BokehScheduler}}, -) +@gen_cluster(client=True, nthreads=[]) async def test_worker_404(c, s): http_client = AsyncHTTPClient() with pytest.raises(HTTPClientError) as err: await http_client.fetch( - "http://localhost:%d/info/worker/unknown" % s.services["dashboard"].port + "http://localhost:%d/info/worker/unknown" % s.http_server.port ) assert err.value.code == 404 with pytest.raises(HTTPClientError) as err: await http_client.fetch( - "http://localhost:%d/info/task/unknown" % s.services["dashboard"].port + "http://localhost:%d/info/task/unknown" % s.http_server.port ) assert err.value.code == 404 -@gen_cluster( - client=True, - scheduler_kwargs={ - "services": {("dashboard", 0): (BokehScheduler, {"prefix": "/foo"})} - }, -) +@gen_cluster(client=True, scheduler_kwargs={"http_prefix": "/foo", "dashboard": True}) async def test_prefix(c, s, a, b): http_client = AsyncHTTPClient() for suffix in ["foo/info/main/workers.html", "foo/json/index.html", "foo/system"]: response = await http_client.fetch( - "http://localhost:%d/%s" % (s.services["dashboard"].port, suffix) + "http://localhost:%d/%s" % (s.http_server.port, suffix) ) assert response.code == 200 body = response.body.decode() @@ -89,11 +74,7 @@ async def test_prefix(c, s, a, b): assert is_valid_xml(body) -@gen_cluster( - client=True, - clean_kwargs={"threads": False}, - scheduler_kwargs={"services": {("dashboard", 0): BokehScheduler}}, -) +@gen_cluster(client=True, clean_kwargs={"threads": False}) async def test_prometheus(c, s, a, b): pytest.importorskip("prometheus_client") from prometheus_client.parser import text_string_to_metric_families @@ -104,7 +85,7 @@ async def test_prometheus(c, s, a, b): # prometheus_client errors for _ in range(2): response = await http_client.fetch( - "http://localhost:%d/metrics" % s.services["dashboard"].port + "http://localhost:%d/metrics" % s.http_server.port ) assert response.code == 200 assert response.headers["Content-Type"] == "text/plain; version=0.0.4" @@ -114,11 +95,7 @@ async def test_prometheus(c, s, a, b): assert "dask_scheduler_workers" in families -@gen_cluster( - client=True, - clean_kwargs={"threads": False}, - scheduler_kwargs={"services": {("dashboard", 0): BokehScheduler}}, -) +@gen_cluster(client=True, clean_kwargs={"threads": False}) async def test_prometheus_collect_task_states(c, s, a, b): pytest.importorskip("prometheus_client") from prometheus_client.parser import text_string_to_metric_families @@ -126,11 +103,8 @@ async def test_prometheus_collect_task_states(c, s, a, b): http_client = AsyncHTTPClient() async def fetch_metrics(): - bokeh_scheduler = s.services["dashboard"] - assert s.services["dashboard"].scheduler is s - response = await http_client.fetch( - f"http://{bokeh_scheduler.server.address}:{bokeh_scheduler.port}/metrics" - ) + port = s.http_server.port + response = await http_client.fetch(f"http://localhost:{port}/metrics") txt = response.body.decode("utf8") families = { family.name: family for family in text_string_to_metric_families(txt) @@ -174,16 +148,12 @@ async def fetch_metrics(): assert sum(forgotten_tasks) == 0.0 -@gen_cluster( - client=True, - clean_kwargs={"threads": False}, - scheduler_kwargs={"services": {("dashboard", 0): BokehScheduler}}, -) +@gen_cluster(client=True, clean_kwargs={"threads": False}) async def test_health(c, s, a, b): http_client = AsyncHTTPClient() response = await http_client.fetch( - "http://localhost:%d/health" % s.services["dashboard"].port + "http://localhost:%d/health" % s.http_server.port ) assert response.code == 200 assert response.headers["Content-Type"] == "text/plain" @@ -192,9 +162,7 @@ async def test_health(c, s, a, b): assert txt == "ok" -@gen_cluster( - client=True, scheduler_kwargs={"services": {("dashboard", 0): BokehScheduler}} -) +@gen_cluster(client=True) async def test_task_page(c, s, a, b): future = c.submit(lambda x: x + 1, 1, workers=a.address) x = c.submit(inc, 1) @@ -203,7 +171,7 @@ async def test_task_page(c, s, a, b): "info/task/" + url_escape(future.key) + ".html", response = await http_client.fetch( - "http://localhost:%d/info/task/" % s.services["dashboard"].port + "http://localhost:%d/info/task/" % s.http_server.port + url_escape(future.key) + ".html" ) @@ -218,22 +186,34 @@ async def test_task_page(c, s, a, b): @gen_cluster( client=True, - scheduler_kwargs={ - "services": { - ("dashboard", 0): ( - BokehScheduler, - {"allow_websocket_origin": ["good.invalid"]}, - ) - } + scheduler_kwargs={"dashboard": True}, + config={ + "distributed.scheduler.dashboard.bokeh-application.allow_websocket_origin": [ + "good.invalid" + ] }, ) async def test_allow_websocket_origin(c, s, a, b): + from tornado.httpclient import HTTPRequest + from tornado.websocket import websocket_connect + url = ( "ws://localhost:%d/status/ws?bokeh-protocol-version=1.0&bokeh-session-id=1" - % s.services["dashboard"].port + % s.http_server.port ) with pytest.raises(HTTPClientError) as err: await websocket_connect( HTTPRequest(url, headers={"Origin": "http://evil.invalid"}) ) assert err.value.code == 403 + + +@gen_cluster(client=True) +async def test_eventstream(c, s, a, b): + from tornado.websocket import websocket_connect + + ws_client = await websocket_connect( + "ws://localhost:%d/%s" % (s.http_server.port, "eventstream") + ) + assert "websocket" in str(s.plugins).lower() + ws_client.close() diff --git a/distributed/dashboard/static/css/base.css b/distributed/http/static/css/base.css similarity index 100% rename from distributed/dashboard/static/css/base.css rename to distributed/http/static/css/base.css diff --git a/distributed/dashboard/static/css/individual-cluster-map.css b/distributed/http/static/css/individual-cluster-map.css similarity index 100% rename from distributed/dashboard/static/css/individual-cluster-map.css rename to distributed/http/static/css/individual-cluster-map.css diff --git a/distributed/dashboard/static/css/status.css b/distributed/http/static/css/status.css similarity index 100% rename from distributed/dashboard/static/css/status.css rename to distributed/http/static/css/status.css diff --git a/distributed/dashboard/static/css/system.css b/distributed/http/static/css/system.css similarity index 100% rename from distributed/dashboard/static/css/system.css rename to distributed/http/static/css/system.css diff --git a/distributed/dashboard/static/images/dask-logo.svg b/distributed/http/static/images/dask-logo.svg similarity index 100% rename from distributed/dashboard/static/images/dask-logo.svg rename to distributed/http/static/images/dask-logo.svg diff --git a/distributed/dashboard/static/images/fa-bars.svg b/distributed/http/static/images/fa-bars.svg similarity index 100% rename from distributed/dashboard/static/images/fa-bars.svg rename to distributed/http/static/images/fa-bars.svg diff --git a/distributed/dashboard/static/images/favicon.ico b/distributed/http/static/images/favicon.ico similarity index 100% rename from distributed/dashboard/static/images/favicon.ico rename to distributed/http/static/images/favicon.ico diff --git a/distributed/dashboard/static/individual-cluster-map.html b/distributed/http/static/individual-cluster-map.html similarity index 100% rename from distributed/dashboard/static/individual-cluster-map.html rename to distributed/http/static/individual-cluster-map.html diff --git a/distributed/dashboard/static/js/anime.min.js b/distributed/http/static/js/anime.min.js similarity index 100% rename from distributed/dashboard/static/js/anime.min.js rename to distributed/http/static/js/anime.min.js diff --git a/distributed/dashboard/static/js/individual-cluster-map.js b/distributed/http/static/js/individual-cluster-map.js similarity index 100% rename from distributed/dashboard/static/js/individual-cluster-map.js rename to distributed/http/static/js/individual-cluster-map.js diff --git a/distributed/dashboard/static/js/reconnecting-websocket.min.js b/distributed/http/static/js/reconnecting-websocket.min.js similarity index 100% rename from distributed/dashboard/static/js/reconnecting-websocket.min.js rename to distributed/http/static/js/reconnecting-websocket.min.js diff --git a/distributed/http/statics.py b/distributed/http/statics.py new file mode 100644 index 00000000000..4a8a60298fe --- /dev/null +++ b/distributed/http/statics.py @@ -0,0 +1,10 @@ +from tornado import web +import os + +routes = [ + ( + r"/statics/(.*)", + web.StaticFileHandler, + {"path": os.path.join(os.path.dirname(__file__), "static")}, + ), +] diff --git a/distributed/dashboard/templates/base.html b/distributed/http/templates/base.html similarity index 100% rename from distributed/dashboard/templates/base.html rename to distributed/http/templates/base.html diff --git a/distributed/dashboard/templates/call-stack.html b/distributed/http/templates/call-stack.html similarity index 100% rename from distributed/dashboard/templates/call-stack.html rename to distributed/http/templates/call-stack.html diff --git a/distributed/dashboard/templates/json-index.html b/distributed/http/templates/json-index.html similarity index 100% rename from distributed/dashboard/templates/json-index.html rename to distributed/http/templates/json-index.html diff --git a/distributed/dashboard/templates/logs.html b/distributed/http/templates/logs.html similarity index 100% rename from distributed/dashboard/templates/logs.html rename to distributed/http/templates/logs.html diff --git a/distributed/dashboard/templates/main.html b/distributed/http/templates/main.html similarity index 100% rename from distributed/dashboard/templates/main.html rename to distributed/http/templates/main.html diff --git a/distributed/dashboard/templates/simple.html b/distributed/http/templates/simple.html similarity index 100% rename from distributed/dashboard/templates/simple.html rename to distributed/http/templates/simple.html diff --git a/distributed/dashboard/templates/status.html b/distributed/http/templates/status.html similarity index 100% rename from distributed/dashboard/templates/status.html rename to distributed/http/templates/status.html diff --git a/distributed/dashboard/templates/system.html b/distributed/http/templates/system.html similarity index 100% rename from distributed/dashboard/templates/system.html rename to distributed/http/templates/system.html diff --git a/distributed/dashboard/templates/task.html b/distributed/http/templates/task.html similarity index 100% rename from distributed/dashboard/templates/task.html rename to distributed/http/templates/task.html diff --git a/distributed/dashboard/templates/worker-table.html b/distributed/http/templates/worker-table.html similarity index 100% rename from distributed/dashboard/templates/worker-table.html rename to distributed/http/templates/worker-table.html diff --git a/distributed/dashboard/templates/worker.html b/distributed/http/templates/worker.html similarity index 100% rename from distributed/dashboard/templates/worker.html rename to distributed/http/templates/worker.html diff --git a/distributed/dashboard/templates/workers.html b/distributed/http/templates/workers.html similarity index 100% rename from distributed/dashboard/templates/workers.html rename to distributed/http/templates/workers.html diff --git a/distributed/http/tests/__init__.py b/distributed/http/tests/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/distributed/http/tests/test_core.py b/distributed/http/tests/test_core.py new file mode 100644 index 00000000000..c1bffedb72e --- /dev/null +++ b/distributed/http/tests/test_core.py @@ -0,0 +1,11 @@ +from distributed.utils_test import gen_cluster +from tornado.httpclient import AsyncHTTPClient + + +@gen_cluster(client=True) +async def test_scheduler(c, s, a, b): + client = AsyncHTTPClient() + response = await client.fetch( + "http://localhost:{}/health".format(s.http_server.port) + ) + assert response.code == 200 diff --git a/distributed/http/tests/test_routing.py b/distributed/http/tests/test_routing.py new file mode 100644 index 00000000000..481cfb3a209 --- /dev/null +++ b/distributed/http/tests/test_routing.py @@ -0,0 +1,38 @@ +from tornado import web +from tornado.httpclient import AsyncHTTPClient, HTTPClientError +import pytest + +from distributed.http.routing import RoutingApplication + + +class OneHandler(web.RequestHandler): + def get(self): + self.write("one") + + +class TwoHandler(web.RequestHandler): + def get(self): + self.write("two") + + +@pytest.mark.asyncio +async def test_basic(): + application = RoutingApplication([(r"/one", OneHandler),]) + two = web.Application([(r"/two", TwoHandler),]) + server = application.listen(1234) + + client = AsyncHTTPClient("http://localhost:1234") + response = await client.fetch("http://localhost:1234/one") + assert response.body.decode() == "one" + + with pytest.raises(HTTPClientError): + response = await client.fetch("http://localhost:1234/two") + + application.applications.append(two) + + response = await client.fetch("http://localhost:1234/two") + assert response.body.decode() == "two" + + application.add_handlers(".*", [(r"/three", OneHandler, {})]) + response = await client.fetch("http://localhost:1234/three") + assert response.body.decode() == "one" diff --git a/distributed/http/utils.py b/distributed/http/utils.py new file mode 100644 index 00000000000..5977ccd5bad --- /dev/null +++ b/distributed/http/utils.py @@ -0,0 +1,51 @@ +import importlib +import os +from typing import List + +from tornado import web +import toolz + +from ..utils import has_keyword + + +dirname = os.path.dirname(__file__) + + +class RequestHandler(web.RequestHandler): + def initialize(self, dask_server=None, extra=None): + self.server = dask_server + self.extra = extra or {} + + def get_template_path(self): + return os.path.join(dirname, "templates") + + +def redirect(path): + class Redirect(RequestHandler): + def get(self): + self.redirect(path) + + return Redirect + + +def get_handlers(server, modules: List[str], prefix="/"): + prefix = prefix or "" + prefix = "/" + prefix.strip("/") + + if not prefix.endswith("/"): + prefix = prefix + "/" + + _routes = [] + for module_name in modules: + module = importlib.import_module(module_name) + _routes.extend(module.routes) + + routes = [] + + for url, cls, kwargs in _routes: + if has_keyword(cls.initialize, "dask_server"): + kwargs = toolz.assoc(kwargs, "dask_server", server) + + routes.append((prefix + url.lstrip("/"), cls, kwargs)) + + return routes diff --git a/distributed/http/worker/__init__.py b/distributed/http/worker/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/distributed/http/worker/prometheus.py b/distributed/http/worker/prometheus.py new file mode 100644 index 00000000000..a60de3a6b64 --- /dev/null +++ b/distributed/http/worker/prometheus.py @@ -0,0 +1,98 @@ +from ..utils import RequestHandler + +import logging + + +class _PrometheusCollector: + def __init__(self, server): + self.worker = server + self.logger = logging.getLogger("distributed.dask_worker") + self.crick_available = True + try: + import crick # noqa: F401 + except ImportError: + self.crick_available = False + self.logger.info( + "Not all prometheus metrics available are exported. Digest-based metrics require crick to be installed" + ) + + def collect(self): + from prometheus_client.core import GaugeMetricFamily + + tasks = GaugeMetricFamily( + "dask_worker_tasks", "Number of tasks at worker.", labels=["state"] + ) + tasks.add_metric(["stored"], len(self.worker.data)) + tasks.add_metric(["executing"], len(self.worker.executing)) + tasks.add_metric(["ready"], len(self.worker.ready)) + tasks.add_metric(["waiting"], len(self.worker.waiting_for_data)) + tasks.add_metric(["serving"], len(self.worker._comms)) + yield tasks + + yield GaugeMetricFamily( + "dask_worker_connections", + "Number of task connections to other workers.", + value=len(self.worker.in_flight_workers), + ) + + yield GaugeMetricFamily( + "dask_worker_threads", + "Number of worker threads.", + value=self.worker.nthreads, + ) + + yield GaugeMetricFamily( + "dask_worker_latency_seconds", + "Latency of worker connection.", + value=self.worker.latency, + ) + + # all metrics using digests require crick to be installed + # the following metrics will export NaN, if the corresponding digests are None + if self.crick_available: + yield GaugeMetricFamily( + "dask_worker_tick_duration_median_seconds", + "Median tick duration at worker.", + value=self.worker.digests["tick-duration"].components[1].quantile(50), + ) + + yield GaugeMetricFamily( + "dask_worker_task_duration_median_seconds", + "Median task runtime at worker.", + value=self.worker.digests["task-duration"].components[1].quantile(50), + ) + + yield GaugeMetricFamily( + "dask_worker_transfer_bandwidth_median_bytes", + "Bandwidth for transfer at worker in Bytes.", + value=self.worker.digests["transfer-bandwidth"] + .components[1] + .quantile(50), + ) + + +class PrometheusHandler(RequestHandler): + _initialized = False + + def __init__(self, *args, **kwargs): + import prometheus_client + + super(PrometheusHandler, self).__init__(*args, **kwargs) + + if PrometheusHandler._initialized: + return + + prometheus_client.REGISTRY.register(_PrometheusCollector(self.server)) + + PrometheusHandler._initialized = True + + def get(self): + import prometheus_client + + self.write(prometheus_client.generate_latest()) + self.set_header("Content-Type", "text/plain; version=0.0.4") + + +routes = [ + (r"metrics", PrometheusHandler, {}), +] diff --git a/distributed/dashboard/tests/test_worker_bokeh_html.py b/distributed/http/worker/tests/test_worker_http.py similarity index 73% rename from distributed/dashboard/tests/test_worker_bokeh_html.py rename to distributed/http/worker/tests/test_worker_http.py index 7a4d70a037c..0a4135fba7f 100644 --- a/distributed/dashboard/tests/test_worker_bokeh_html.py +++ b/distributed/http/worker/tests/test_worker_http.py @@ -1,13 +1,10 @@ import pytest -pytest.importorskip("bokeh") - from tornado.httpclient import AsyncHTTPClient from distributed.utils_test import gen_cluster -from distributed.dashboard import BokehWorker -@gen_cluster(client=True, worker_kwargs={"services": {("dashboard", 0): BokehWorker}}) +@gen_cluster(client=True) def test_prometheus(c, s, a, b): pytest.importorskip("prometheus_client") from prometheus_client.parser import text_string_to_metric_families @@ -18,7 +15,7 @@ def test_prometheus(c, s, a, b): # prometheus_client errors for _ in range(2): response = yield http_client.fetch( - "http://localhost:%d/metrics" % a.services["dashboard"].port + "http://localhost:%d/metrics" % a.http_server.port ) assert response.code == 200 assert response.headers["Content-Type"] == "text/plain; version=0.0.4" @@ -28,12 +25,12 @@ def test_prometheus(c, s, a, b): assert "dask_worker_latency_seconds" in families -@gen_cluster(client=True, worker_kwargs={"services": {("dashboard", 0): BokehWorker}}) +@gen_cluster(client=True) def test_health(c, s, a, b): http_client = AsyncHTTPClient() response = yield http_client.fetch( - "http://localhost:%d/health" % a.services["dashboard"].port + "http://localhost:%d/health" % a.http_server.port ) assert response.code == 200 assert response.headers["Content-Type"] == "text/plain" diff --git a/distributed/node.py b/distributed/node.py index af15b5a409f..11645e86317 100644 --- a/distributed/node.py +++ b/distributed/node.py @@ -4,12 +4,17 @@ import weakref from tornado.ioloop import IOLoop +from tornado.httpserver import HTTPServer from tornado import gen +import tlz import dask +from .comm import get_tcp_server_address +from .comm import get_address_host from .core import Server, ConnectionPool +from .http.routing import RoutingApplication from .versions import get_versions -from .utils import DequeHandler, TimeoutError +from .utils import DequeHandler, TimeoutError, clean_dashboard_address, ignoring class Node: @@ -189,3 +194,53 @@ async def start(self): # subclasses should implement their own start method whichs calls super().start() await Node.start(self) return self + + def start_http_server( + self, routes, dashboard_address, default_port=0, ssl_options=None, + ): + """ This creates an HTTP Server running on this node """ + + self.http_application = RoutingApplication(routes,) + + # TLS configuration + tls_key = dask.config.get("distributed.scheduler.dashboard.tls.key") + tls_cert = dask.config.get("distributed.scheduler.dashboard.tls.cert") + tls_ca_file = dask.config.get("distributed.scheduler.dashboard.tls.ca-file") + if tls_cert: + import ssl + + ssl_options = ssl.create_default_context( + cafile=tls_ca_file, purpose=ssl.Purpose.SERVER_AUTH + ) + ssl_options.load_cert_chain(tls_cert, keyfile=tls_key) + # We don't care about auth here, just encryption + ssl_options.check_hostname = False + ssl_options.verify_mode = ssl.CERT_NONE + + self.http_server = HTTPServer(self.http_application, ssl_options=ssl_options) + http_address = clean_dashboard_address(dashboard_address or default_port) + + if not http_address["address"]: + address = self._start_address + if isinstance(address, (list, tuple)): + address = address[0] + if address: + with ignoring(ValueError): + http_address["address"] = get_address_host(address) + changed_port = False + try: + self.http_server.listen(**http_address) + except Exception: + changed_port = True + self.http_server.listen(**tlz.merge(http_address, {"port": 0})) + self.http_server.port = get_tcp_server_address(self.http_server)[1] + self.services["dashboard"] = self.http_server + + if changed_port and dashboard_address: + warnings.warn( + "Port {} is already in use.\n" + "Perhaps you already have a cluster running?\n" + "Hosting the HTTP server on port {} instead".format( + http_address["port"], self.http_server.port + ) + ) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 55692b875ab..86dd6b9203e 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -36,6 +36,7 @@ import dask +from . import profile from .batched import BatchedSend from .comm import ( normalize_address, @@ -46,7 +47,8 @@ from .comm.addressing import addresses_from_user_args from .core import rpc, connect, send_recv, clean_exception, CommClosedError from .diagnostics.plugin import SchedulerPlugin -from . import profile + +from .http import get_handlers from .metrics import time from .node import ServerNode from . import preloading @@ -1060,6 +1062,8 @@ def __init__( port=0, protocol=None, dashboard_address=None, + dashboard=None, + http_prefix="/", preload=None, preload_argv=(), plugins=(), @@ -1112,15 +1116,30 @@ def __init__( assert isinstance(self.security, Security) self.connection_args = self.security.get_connection_args("scheduler") - if dashboard_address is not None: + self._start_address = addresses_from_user_args( + host=host, + port=port, + interface=interface, + protocol=protocol, + security=security, + default_port=self.default_port, + ) + + routes = get_handlers( + server=self, + modules=dask.config.get("distributed.scheduler.http.routes"), + prefix=http_prefix, + ) + self.start_http_server(routes, dashboard_address, default_port=8787) + + if dashboard: try: - from distributed.dashboard import BokehScheduler + import distributed.dashboard.scheduler except ImportError: logger.debug("To start diagnostics web server please install Bokeh") else: - self.service_specs[("dashboard", dashboard_address)] = ( - BokehScheduler, - (service_kwargs or {}).get("dashboard", {}), + distributed.dashboard.scheduler.connect( + self.http_application, self.http_server, self, prefix=http_prefix, ) # Communication state @@ -1327,15 +1346,6 @@ def __init__( connection_limit = get_fileno_limit() / 2 - self._start_address = addresses_from_user_args( - host=host, - port=port, - interface=interface, - protocol=protocol, - security=security, - default_port=self.default_port, - ) - super(Scheduler, self).__init__( handlers=self.handlers, stream_handlers=merge(worker_handlers, client_handlers), @@ -1506,6 +1516,7 @@ async def close(self, comm=None, fast=False, close_workers=False): self.periodic_callbacks.clear() self.stop_services() + for ext in self.extensions.values(): with ignoring(AttributeError): ext.teardown() diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 68507d889f0..297394631ae 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -5247,14 +5247,9 @@ def test_quiet_scheduler_loss(c, s): def test_dashboard_link(loop, monkeypatch): - pytest.importorskip("bokeh") - from distributed.dashboard import BokehScheduler - monkeypatch.setenv("USER", "myusername") - with cluster( - scheduler_kwargs={"services": {("dashboard", 12355): BokehScheduler}} - ) as (s, [a, b]): + with cluster(scheduler_kwargs={"dashboard_address": ":12355"}) as (s, [a, b]): with Client(s["address"], loop=loop) as c: with dask.config.set( {"distributed.dashboard.link": "{scheme}://foo-{USER}:{port}/status"} diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index a0e1b11de2b..2b48fa030e4 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -1190,19 +1190,14 @@ def test_correct_bad_time_estimate(c, s, *workers): @gen_test() async def test_service_hosts(): - pytest.importorskip("bokeh") - from distributed.dashboard import BokehScheduler - port = 0 for url, expected in [ ("tcp://0.0.0.0", ("::", "0.0.0.0")), ("tcp://127.0.0.1", "127.0.0.1"), ("tcp://127.0.0.1:38275", "127.0.0.1"), ]: - services = {("dashboard", port): BokehScheduler} - - async with Scheduler(host=url, services=services) as s: - sock = first(s.services["dashboard"].server._http._sockets.values()) + async with Scheduler(host=url) as s: + sock = first(s.http_server._sockets.values()) if isinstance(expected, tuple): assert sock.getsockname()[0] in expected else: @@ -1210,10 +1205,8 @@ async def test_service_hosts(): port = ("127.0.0.1", 0) for url in ["tcp://0.0.0.0", "tcp://127.0.0.1", "tcp://127.0.0.1:38275"]: - services = {("dashboard", port): BokehScheduler} - - async with Scheduler(services=services, host=url) as s: - sock = first(s.services["dashboard"].server._http._sockets.values()) + async with Scheduler(dashboard_address="127.0.0.1:0", host=url) as s: + sock = first(s.http_server._sockets.values()) assert sock.getsockname()[0] == "127.0.0.1" diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 0bda344fd96..a5e364ec0cd 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -5,7 +5,6 @@ from operator import add import os import psutil -import shutil import sys from time import sleep import traceback @@ -58,12 +57,11 @@ ) -def test_worker_nthreads(): - w = Worker("127.0.0.1", 8019) - try: - assert w.executor._max_workers == CPU_COUNT - finally: - shutil.rmtree(w.local_directory) +@pytest.mark.asyncio +async def test_worker_nthreads(cleanup): + async with Scheduler() as s: + async with Worker(s.address) as w: + assert w.executor._max_workers == CPU_COUNT @gen_cluster() @@ -75,13 +73,15 @@ def test_str(s, a, b): assert str(len(a.executing)) in repr(a) -def test_identity(): - w = Worker("127.0.0.1", 8019) - ident = w.identity(None) - assert "Worker" in ident["type"] - assert ident["scheduler"] == "tcp://127.0.0.1:8019" - assert isinstance(ident["nthreads"], int) - assert isinstance(ident["memory_limit"], Number) +@pytest.mark.asyncio +async def test_identity(cleanup): + async with Scheduler() as s: + async with Worker(s.address) as w: + ident = w.identity(None) + assert "Worker" in ident["type"] + assert ident["scheduler"] == s.address + assert isinstance(ident["nthreads"], int) + assert isinstance(ident["memory_limit"], Number) @gen_cluster(client=True) @@ -320,20 +320,17 @@ def test_worker_with_port_zero(): @pytest.mark.slow -def test_worker_waits_for_scheduler(loop): - @gen.coroutine - def f(): - w = Worker("127.0.0.1", 8007) - try: - yield asyncio.wait_for(w, 3) - except TimeoutError: - pass - else: - assert False - assert w.status not in ("closed", "running") - yield w.close(timeout=0.1) - - loop.run_sync(f) +@pytest.mark.asyncio +async def test_worker_waits_for_scheduler(cleanup): + w = Worker("127.0.0.1:8724") + try: + await asyncio.wait_for(w, 3) + except TimeoutError: + pass + else: + assert False + assert w.status not in ("closed", "running") + await w.close(timeout=0.1) @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)]) @@ -532,19 +529,21 @@ def test_close_on_disconnect(s, w): assert time() < start + 5 -def test_memory_limit_auto(): - a = Worker("127.0.0.1", 8099, nthreads=1) - b = Worker("127.0.0.1", 8099, nthreads=2) - c = Worker("127.0.0.1", 8099, nthreads=100) - d = Worker("127.0.0.1", 8099, nthreads=200) - - assert isinstance(a.memory_limit, Number) - assert isinstance(b.memory_limit, Number) +@pytest.mark.asyncio +async def test_memory_limit_auto(): + async with Scheduler() as s: + async with Worker(s.address, nthreads=1) as a, Worker( + s.address, nthreads=2 + ) as b, Worker(s.address, nthreads=100) as c, Worker( + s.address, nthreads=200 + ) as d: + assert isinstance(a.memory_limit, Number) + assert isinstance(b.memory_limit, Number) - if CPU_COUNT > 1: - assert a.memory_limit < b.memory_limit + if CPU_COUNT > 1: + assert a.memory_limit < b.memory_limit - assert c.memory_limit == d.memory_limit + assert c.memory_limit == d.memory_limit @gen_cluster(client=True) @@ -782,13 +781,13 @@ def test_hold_onto_dependents(c, s, a, b): @pytest.mark.slow @gen_cluster(client=False, nthreads=[]) -def test_worker_death_timeout(s): +async def test_worker_death_timeout(s): with dask.config.set({"distributed.comm.timeouts.connect": "1s"}): - yield s.close() + await s.close() w = Worker(s.address, death_timeout=1) with pytest.raises(TimeoutError) as info: - yield w + await w assert "Worker" in str(info.value) assert "timed out" in str(info.value) or "failed to start" in str(info.value) @@ -1024,39 +1023,25 @@ def test_worker_fds(s): @gen_cluster(nthreads=[]) async def test_service_hosts_match_worker(s): - pytest.importorskip("bokeh") - from distributed.dashboard import BokehWorker - - async with Worker( - s.address, services={("dashboard", ":0"): BokehWorker}, host="tcp://0.0.0.0" - ) as w: - sock = first(w.services["dashboard"].server._http._sockets.values()) + async with Worker(s.address, host="tcp://0.0.0.0") as w: + sock = first(w.http_server._sockets.values()) assert sock.getsockname()[0] in ("::", "0.0.0.0") async with Worker( - s.address, services={("dashboard", ":0"): BokehWorker}, host="tcp://127.0.0.1" + s.address, host="tcp://127.0.0.1", dashboard_address="0.0.0.0:0" ) as w: - sock = first(w.services["dashboard"].server._http._sockets.values()) + sock = first(w.http_server._sockets.values()) assert sock.getsockname()[0] in ("::", "0.0.0.0") - async with Worker( - s.address, services={("dashboard", 0): BokehWorker}, host="tcp://127.0.0.1" - ) as w: - sock = first(w.services["dashboard"].server._http._sockets.values()) + async with Worker(s.address, host="tcp://127.0.0.1") as w: + sock = first(w.http_server._sockets.values()) assert sock.getsockname()[0] == "127.0.0.1" @gen_cluster(nthreads=[]) -def test_start_services(s): - pytest.importorskip("bokeh") - from distributed.dashboard import BokehWorker - - services = {("dashboard", ":1234"): BokehWorker} - - w = yield Worker(s.address, services=services) - - assert w.services["dashboard"].server.port == 1234 - yield w.close() +async def test_start_services(s): + async with Worker(s.address, dashboard_address=1234) as w: + assert w.http_server.port == 1234 @gen_test() @@ -1234,16 +1219,18 @@ def f(x): assert all(f.key in b.data for f in futures) -def test_deque_handler(): +@pytest.mark.asyncio +async def test_deque_handler(cleanup): from distributed.worker import logger - w = Worker("127.0.0.1", 8019) - deque_handler = w._deque_handler - logger.info("foo456") - assert deque_handler.deque - msg = deque_handler.deque[-1] - assert "distributed.worker" in deque_handler.format(msg) - assert any(msg.msg == "foo456" for msg in deque_handler.deque) + async with Scheduler() as s: + async with Worker(s.address) as w: + deque_handler = w._deque_handler + logger.info("foo456") + assert deque_handler.deque + msg = deque_handler.deque[-1] + assert "distributed.worker" in deque_handler.format(msg) + assert any(msg.msg == "foo456" for msg in deque_handler.deque) @gen_cluster(nthreads=[], client=True) diff --git a/distributed/utils.py b/distributed/utils.py index eb622f7b837..adc20d4f368 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -1523,3 +1523,45 @@ def __setitem__(self, key, value): if len(self) >= self.maxsize: self.data.popitem(last=False) super().__setitem__(key, value) + + +def clean_dashboard_address(addr, default_listen_ip=""): + """ + + Examples + -------- + >>> clean_dashboard_address(8787) + {'address': '', 'port': 8787} + >>> clean_dashboard_address(":8787") + {'address': '', 'port': 8787} + >>> clean_dashboard_address("8787") + {'address': '', 'port': 8787} + >>> clean_dashboard_address("8787") + {'address': '', 'port': 8787} + >>> clean_dashboard_address("foo:8787") + {'address': 'foo', 'port': 8787} + """ + + if default_listen_ip == "0.0.0.0": + default_listen_ip = "" # for IPV6 + + try: + addr = int(addr) + except (TypeError, ValueError): + pass + + if isinstance(addr, str): + addr = addr.split(":") + + if isinstance(addr, (tuple, list)): + if len(addr) == 2: + host, port = (addr[0], int(addr[1])) + elif len(addr) == 1: + [host], port = addr, 0 + else: + raise ValueError(addr) + elif isinstance(addr, int): + host = default_listen_ip + port = addr + + return {"address": host, "port": port} diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 983eaac48f5..e1db066b732 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -1513,7 +1513,9 @@ def check_instances(): } # assert not list(SpecCluster._instances) # TODO - assert all(c.status == "closed" for c in SpecCluster._instances) + assert all(c.status == "closed" for c in SpecCluster._instances), list( + SpecCluster._instances + ) SpecCluster._instances.clear() Nanny._instances.clear() diff --git a/distributed/worker.py b/distributed/worker.py index f2f832f6d83..a50103bacab 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -32,6 +32,7 @@ from .comm.addressing import address_from_user_args from .core import error_message, CommClosedError, send_recv, pingpong, coerce_to_address from .diskutils import WorkSpace +from .http import get_handlers from .metrics import time from .node import ServerNode from . import preloading @@ -318,6 +319,8 @@ def __init__( port=None, protocol=None, dashboard_address=None, + dashboard=False, + http_prefix="/", nanny=None, plugins=(), low_level_profiler=dask.config.get("distributed.worker.profile.low-level"), @@ -587,15 +590,21 @@ def __init__( self.services = {} self.service_specs = services or {} - if dashboard_address is not None: + routes = get_handlers( + server=self, + modules=dask.config.get("distributed.worker.http.routes"), + prefix=http_prefix, + ) + self.start_http_server(routes, dashboard_address) + + if dashboard: try: - from distributed.dashboard import BokehWorker + import distributed.dashboard.worker except ImportError: logger.debug("To start diagnostics web server please install Bokeh") else: - self.service_specs[("dashboard", dashboard_address)] = ( - BokehWorker, - (service_kwargs or {}).get("dashboard", {}), + distributed.dashboard.worker.connect( + self.http_application, self.http_server, self, prefix=http_prefix, ) self.metrics = dict(metrics) if metrics else {} @@ -1116,8 +1125,7 @@ async def close( await self.scheduler.close_rpc() self._workdir.release() - for k, v in self.services.items(): - v.stop() + self.stop_services() if ( self.batched_stream diff --git a/setup.py b/setup.py index 155ae0c0274..6b70d0638ce 100755 --- a/setup.py +++ b/setup.py @@ -30,7 +30,7 @@ license="BSD", package_data={ "": ["templates/index.html", "template.html"], - "distributed": ["dashboard/templates/*.html"], + "distributed": ["http/templates/*.html"], }, include_package_data=True, install_requires=install_requires, From 74f10aaef2827194c58514451bea9dcd55167ae1 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Thu, 9 Apr 2020 19:04:16 -0500 Subject: [PATCH 0782/1550] Add Client.wait_to_workers to Client autosummary table (#3692) --- docs/source/api.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/api.rst b/docs/source/api.rst index da9a76eed9b..c036fdf96db 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -43,6 +43,7 @@ API Client.submit Client.unpublish_dataset Client.upload_file + Client.wait_for_workers Client.who_has .. currentmodule:: distributed From 22dbe7147fcb5d44dae4c37dc3f026b59369e5c1 Mon Sep 17 00:00:00 2001 From: Abdulelah Bin Mahfoodh Date: Mon, 13 Apr 2020 16:37:00 +0300 Subject: [PATCH 0783/1550] Fix propagating inherit config in SSHCluster for non-bash shells (#3688) * Fix propagating inherit config in SSHCluster for non-bash shells --- distributed/deploy/ssh.py | 71 ++++++++++++++++++++++++++------------- 1 file changed, 48 insertions(+), 23 deletions(-) diff --git a/distributed/deploy/ssh.py b/distributed/deploy/ssh.py index a7b3526bcba..fc629f8445c 100644 --- a/distributed/deploy/ssh.py +++ b/distributed/deploy/ssh.py @@ -86,22 +86,38 @@ async def start(self): import asyncssh # import now to avoid adding to module startup time self.connection = await asyncssh.connect(self.address, **self.connect_options) - self.proc = await self.connection.create_process( - " ".join( - [ - 'DASK_INTERNAL_INHERIT_CONFIG="%s"' - % serialize_for_cli(dask.config.global_config), - sys.executable, - "-m", - self.worker_module, - self.scheduler, - "--name", - str(self.name), - ] - + cli_keywords(self.kwargs, cls=_Worker, cmd=self.worker_module) + + result = await self.connection.run("uname") + if result.exit_status == 0: + set_env = 'env DASK_INTERNAL_INHERIT_CONFIG="{}"'.format( + serialize_for_cli(dask.config.global_config) ) + else: + result = await self.connection.run("cmd /c ver") + if result.exit_status == 0: + set_env = "set DASK_INTERNAL_INHERIT_CONFIG={} &&".format( + serialize_for_cli(dask.config.global_config) + ) + else: + raise Exception( + "Worker failed to set DASK_INTERNAL_INHERIT_CONFIG variable " + ) + + cmd = " ".join( + [ + set_env, + sys.executable, + "-m", + self.worker_module, + self.scheduler, + "--name", + str(self.name), + ] + + cli_keywords(self.kwargs, cls=_Worker, cmd=self.worker_module) ) + self.proc = await self.connection.create_process(cmd) + # We watch stderr in order to get the address, then we return while True: line = await self.proc.stderr.readline() @@ -144,18 +160,27 @@ async def start(self): self.connection = await asyncssh.connect(self.address, **self.connect_options) - self.proc = await self.connection.create_process( - " ".join( - [ - 'DASK_INTERNAL_INHERIT_CONFIG="%s"' - % serialize_for_cli(dask.config.global_config), - sys.executable, - "-m", - "distributed.cli.dask_scheduler", - ] - + cli_keywords(self.kwargs, cls=_Scheduler) + result = await self.connection.run("uname") + if result.exit_status == 0: + set_env = 'env DASK_INTERNAL_INHERIT_CONFIG="{}"'.format( + serialize_for_cli(dask.config.global_config) ) + else: + result = await self.connection.run("cmd /c ver") + if result.exit_status == 0: + set_env = "set DASK_INTERNAL_INHERIT_CONFIG={} &&".format( + serialize_for_cli(dask.config.global_config) + ) + else: + raise Exception( + "Scheduler failed to set DASK_INTERNAL_INHERIT_CONFIG variable " + ) + + cmd = " ".join( + [set_env, sys.executable, "-m", "distributed.cli.dask_scheduler",] + + cli_keywords(self.kwargs, cls=_Scheduler) ) + self.proc = await self.connection.create_process(cmd) # We watch stderr in order to get the address, then we return while True: From 8a0efe4ad8e05833e905f55f2e329a4aa02d2711 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 13 Apr 2020 07:18:36 -0700 Subject: [PATCH 0784/1550] Add Cluster __enter__ and __exit__ methods (#3699) These just call sync on the async versions --- distributed/deploy/cluster.py | 6 ++++++ distributed/deploy/spec.py | 7 ++++--- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/distributed/deploy/cluster.py b/distributed/deploy/cluster.py index 8082d278483..65199c48bd5 100644 --- a/distributed/deploy/cluster.py +++ b/distributed/deploy/cluster.py @@ -352,6 +352,12 @@ def _ipython_display_(self, **kwargs): data = {"text/plain": repr(self), "text/html": self._repr_html_()} display(data, raw=True) + def __enter__(self): + return self.sync(self.__aenter__) + + def __exit__(self, typ, value, traceback): + return self.sync(self.__aexit__, typ, value, traceback) + async def __aenter__(self): await self return self diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index 99ab70d2de1..c6338d3b93f 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -396,13 +396,14 @@ async def _close(self): await super()._close() - def __enter__(self): - self.sync(self._correct_state) + async def __aenter__(self): + await self + await self._correct_state() assert self.status == "running" return self def __exit__(self, typ, value, traceback): - self.close() + super().__exit__(typ, value, traceback) self._loop_runner.stop() def _threads_per_worker(self) -> int: From 20c3e29e49798d4d446e1991b1ea6042763fd3ab Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 13 Apr 2020 07:22:48 -0700 Subject: [PATCH 0785/1550] Replace Example with Examples in docstrings (#3697) This was causing warnings in sphinx, and doesn't get rendered properly --- distributed/batched.py | 4 ++-- distributed/deploy/cluster.py | 4 ++-- distributed/profile.py | 4 ++-- distributed/protocol/serialize.py | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/distributed/batched.py b/distributed/batched.py index 13c241d1e1b..07eb8e41014 100644 --- a/distributed/batched.py +++ b/distributed/batched.py @@ -22,8 +22,8 @@ class BatchedSend: Batching several messages at once helps performance when sending a myriad of tiny messages. - Example - ------- + Examples + -------- >>> stream = yield connect(address) >>> bstream = BatchedSend(interval='10 ms') >>> bstream.start(stream) diff --git a/distributed/deploy/cluster.py b/distributed/deploy/cluster.py index 65199c48bd5..7164b17b076 100644 --- a/distributed/deploy/cluster.py +++ b/distributed/deploy/cluster.py @@ -135,8 +135,8 @@ def scale(self, n: int) -> None: n: int Target number of workers - Example - ------- + Examples + -------- >>> cluster.scale(10) # scale cluster to ten workers """ raise NotImplementedError() diff --git a/distributed/profile.py b/distributed/profile.py index 1bf81ad6ff0..33eba502ef9 100644 --- a/distributed/profile.py +++ b/distributed/profile.py @@ -79,8 +79,8 @@ def process(frame, child, state, stop=None, omit=None): This recursively adds counts to the existing state dictionary and creates new entries for new functions. - Example - ------- + Examples + -------- >>> import sys, threading >>> ident = threading.get_ident() # replace with your thread of interest >>> frame = sys._current_frames()[ident] diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index 6db7ca70c13..f20bc490752 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -272,8 +272,8 @@ def deserialize(header, frames, deserializers=None): class Serialize: """ Mark an object that should be serialized - Example - ------- + Examples + -------- >>> msg = {'op': 'update', 'data': to_serialize(123)} >>> msg # doctest: +SKIP {'op': 'update', 'data': } From fee5c42ca5dab297d5e3829f1ccd94689e9d5b6c Mon Sep 17 00:00:00 2001 From: Abdulelah Bin Mahfoodh Date: Mon, 13 Apr 2020 17:51:32 +0300 Subject: [PATCH 0786/1550] Add remote_python keyword to the new SSHCluster (#3701) * Fix dask-ssh after removing local-directory keyword from dask_scheduler * black changes * Add a test to dask-ssh with local directory parameter * Add python_remote keyword to SSHCluster * Modify docstring for Scheduler and Worker * Rename to remote_python --- distributed/deploy/ssh.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/distributed/deploy/ssh.py b/distributed/deploy/ssh.py index fc629f8445c..4f0e713ffa9 100644 --- a/distributed/deploy/ssh.py +++ b/distributed/deploy/ssh.py @@ -58,6 +58,8 @@ class Worker(Process): The python module to run to start the worker. connect_options: dict kwargs to be passed to asyncssh connections + remote_python: str + Path to Python on remote node to run this worker. kwargs: dict These will be passed through the dask-worker CLI to the dask.distributed.Worker class @@ -70,6 +72,7 @@ def __init__( connect_options: dict, kwargs: dict, worker_module="distributed.cli.dask_worker", + remote_python=None, loop=None, name=None, ): @@ -81,6 +84,7 @@ def __init__( self.connect_options = connect_options self.kwargs = kwargs self.name = name + self.remote_python = remote_python async def start(self): import asyncssh # import now to avoid adding to module startup time @@ -141,17 +145,22 @@ class Scheduler(Process): The hostname where we should run this worker connect_options: dict kwargs to be passed to asyncssh connections + remote_python: str + Path to Python on remote node to run this scheduler. kwargs: dict These will be passed through the dask-scheduler CLI to the dask.distributed.Scheduler class """ - def __init__(self, address: str, connect_options: dict, kwargs: dict): + def __init__( + self, address: str, connect_options: dict, kwargs: dict, remote_python=None + ): super().__init__() self.address = address self.kwargs = kwargs self.connect_options = connect_options + self.remote_python = remote_python async def start(self): import asyncssh # import now to avoid adding to module startup time @@ -220,6 +229,7 @@ def SSHCluster( worker_options: dict = {}, scheduler_options: dict = {}, worker_module: str = "distributed.cli.dask_worker", + remote_python: str = None, **kwargs ): """ Deploy a Dask cluster using SSH @@ -254,6 +264,8 @@ def SSHCluster( Keywords to pass on to scheduler. worker_module: str, optional Python module to call to start the worker. + remote_python: str, optional + Path to Python on remote nodes. Examples -------- @@ -300,6 +312,7 @@ def SSHCluster( "address": hosts[0], "connect_options": connect_options, "kwargs": scheduler_options, + "remote_python": remote_python, }, } workers = { @@ -310,6 +323,7 @@ def SSHCluster( "connect_options": connect_options, "kwargs": worker_options, "worker_module": worker_module, + "remote_python": remote_python, }, } for i, host in enumerate(hosts[1:]) From 549528434dc4805d76cc2cfae74f7bff4ee6b645 Mon Sep 17 00:00:00 2001 From: "Jonathan J. Helmus" Date: Tue, 14 Apr 2020 10:00:17 -0500 Subject: [PATCH 0787/1550] do not log an error on unset variable delete (#3652) If a Variable is never set or accessed no entries are made in the tracking attributes of VariableExtension. Therefore there is no need to raise or log an error when the variable is deleted. --- distributed/tests/test_variable.py | 13 +++++++++++++ distributed/variable.py | 6 ++++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/distributed/tests/test_variable.py b/distributed/tests/test_variable.py index 64765d808c7..0e450aa7a02 100644 --- a/distributed/tests/test_variable.py +++ b/distributed/tests/test_variable.py @@ -2,6 +2,7 @@ import random from time import sleep import sys +import logging import pytest from tornado import gen @@ -12,6 +13,7 @@ from distributed.compatibility import WINDOWS from distributed.utils_test import gen_cluster, inc, div from distributed.utils_test import client, cluster_fixture, loop # noqa: F401 +from distributed.utils_test import captured_logger @gen_cluster(client=True) @@ -39,6 +41,17 @@ def test_variable(c, s, a, b): assert time() < start + 5 +@gen_cluster(client=True) +async def test_delete_unset_variable(c, s, a, b): + x = Variable() + assert x.client is c + with captured_logger(logging.getLogger("distributed.utils")) as logger: + x.delete() + await c.close() + text = logger.getvalue() + assert "KeyError" not in text + + @gen_cluster(client=True) def test_queue_with_data(c, s, a, b): x = Variable("x") diff --git a/distributed/variable.py b/distributed/variable.py index 3c6cc931166..a47064b1397 100644 --- a/distributed/variable.py +++ b/distributed/variable.py @@ -113,8 +113,10 @@ async def delete(self, stream=None, name=None, client=None): else: if old["type"] == "Future": await self.release(old["value"], name) - del self.waiting_conditions[name] - del self.variables[name] + with ignoring(KeyError): + del self.waiting_conditions[name] + with ignoring(KeyError): + del self.variables[name] class Variable: From 3a70aa6dc84cc29c2318b566d2d064ff84fed940 Mon Sep 17 00:00:00 2001 From: Florian Jetter Date: Tue, 14 Apr 2020 18:08:53 +0200 Subject: [PATCH 0788/1550] Allow modification of distributed.comm.retry at runtime (#3705) --- distributed/tests/test_scheduler.py | 28 +++++++++++++--------------- distributed/utils_comm.py | 17 ++++++++--------- 2 files changed, 21 insertions(+), 24 deletions(-) diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 2b48fa030e4..d64c88fceea 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -7,7 +7,6 @@ import re import sys from time import sleep -from unittest import mock import logging import dask @@ -1900,7 +1899,7 @@ async def test_gather_failing_cnn_recover(c, s, a, b): x = await c.scatter({"x": 1}, workers=a.address) s.rpc = await FlakyConnectionPool(failing_connections=1) - with mock.patch("distributed.utils_comm.retry_count", 1): + with dask.config.set({"distributed.comm.retry.count": 1}): res = await s.gather(keys=["x"]) assert res["status"] == "OK" @@ -1963,20 +1962,19 @@ def reducer(x, y): s.rpc = await FlakyConnectionPool(failing_connections=4) - with captured_logger( - logging.getLogger("distributed.scheduler") - ) as sched_logger, captured_logger( - logging.getLogger("distributed.client") - ) as client_logger, captured_logger( - logging.getLogger("distributed.utils_comm") - ) as utils_comm_logger, mock.patch( - "distributed.utils_comm.retry_count", 3 - ), mock.patch( - "distributed.utils_comm.retry_delay_min", 0.5 + with dask.config.set( + {"distributed.comm.retry.delay_min": 0.5, "distributed.comm.retry.count": 3,} ): - # Gather using the client (as an ordinary user would) - # Upon a missing key, the client will reschedule the computations - res = await c.gather(z) + with captured_logger( + logging.getLogger("distributed.scheduler") + ) as sched_logger, captured_logger( + logging.getLogger("distributed.client") + ) as client_logger, captured_logger( + logging.getLogger("distributed.utils_comm") + ) as utils_comm_logger: + # Gather using the client (as an ordinary user would) + # Upon a missing key, the client will reschedule the computations + res = await c.gather(z) assert res == 5 diff --git a/distributed/utils_comm.py b/distributed/utils_comm.py index 42404754527..b7e33656ab8 100644 --- a/distributed/utils_comm.py +++ b/distributed/utils_comm.py @@ -312,15 +312,6 @@ def subs_multiple(o, d): return o -retry_count = dask.config.get("distributed.comm.retry.count") -retry_delay_min = parse_timedelta( - dask.config.get("distributed.comm.retry.delay.min"), default="s" -) -retry_delay_max = parse_timedelta( - dask.config.get("distributed.comm.retry.delay.max"), default="s" -) - - async def retry( coro, count, @@ -383,6 +374,14 @@ async def retry_operation(coro, *args, operation=None, **kwargs): """ Retry an operation using the configuration values for the retry parameters """ + + retry_count = dask.config.get("distributed.comm.retry.count") + retry_delay_min = parse_timedelta( + dask.config.get("distributed.comm.retry.delay.min"), default="s" + ) + retry_delay_max = parse_timedelta( + dask.config.get("distributed.comm.retry.delay.max"), default="s" + ) return await retry( partial(coro, *args, **kwargs), count=retry_count, From 65c4a4989350ae79c46e98a7f6ea79708b289e94 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Wed, 15 Apr 2020 16:00:28 -0500 Subject: [PATCH 0789/1550] Avoid DeprecationWarning from pandas (#3712) --- distributed/protocol/tests/test_pandas.py | 51 +++++++++++++++++------ distributed/tests/test_collections.py | 19 ++++++--- 2 files changed, 52 insertions(+), 18 deletions(-) diff --git a/distributed/protocol/tests/test_pandas.py b/distributed/protocol/tests/test_pandas.py index 104151fb55a..b1f96bfd486 100644 --- a/distributed/protocol/tests/test_pandas.py +++ b/distributed/protocol/tests/test_pandas.py @@ -1,5 +1,5 @@ +import numpy as np import pandas as pd -import pandas.util.testing as tm import pytest from dask.dataframe.utils import assert_eq @@ -22,18 +22,43 @@ pd.DataFrame({"x": [b"a", b"b", b"c"]}), pd.DataFrame({"x": pd.Categorical(["a", "b", "a"], ordered=True)}), pd.DataFrame({"x": pd.Categorical(["a", "b", "a"], ordered=False)}), - tm.makeCategoricalIndex(), - tm.makeCustomDataframe(5, 3), - tm.makeDataFrame(), - tm.makeDateIndex(), - tm.makeMissingDataframe(), - tm.makeMixedDataFrame(), - tm.makeObjectSeries(), - tm.makePeriodFrame(), - tm.makeRangeIndex(), - tm.makeTimeDataFrame(), - tm.makeTimeSeries(), - tm.makeUnicodeIndex(), + pd.Index(pd.Categorical(["a"], categories=["a", "b"], ordered=True)), + pd.date_range("2000", periods=12, freq="B"), + pd.RangeIndex(10), + pd.DataFrame( + "a", + index=pd.Index(["a", "b", "c", "d"], name="a"), + columns=pd.Index(["A", "B", "C", "D"], name="columns"), + ), + pd.DataFrame( + np.random.randn(10, 5), columns=list("ABCDE"), index=list("abcdefghij") + ), + pd.DataFrame( + np.random.randn(10, 5), columns=list("ABCDE"), index=list("abcdefghij") + ).where(lambda x: x > 0), + pd.DataFrame( + { + "a": [0.0, 0.1], + "B": [0.0, 1.0], + "C": ["a", "b"], + "D": pd.to_datetime(["2000", "2001"]), + } + ), + pd.Series(["a", "b", "c"], index=["a", "b", "c"]), + pd.DataFrame( + np.random.randn(10, 5), + columns=list("ABCDE"), + index=pd.period_range("2000", periods=10, freq="B"), + ), + pd.DataFrame( + np.random.randn(10, 5), + columns=list("ABCDE"), + index=pd.date_range("2000", periods=10, freq="B"), + ), + pd.Series( + np.random.randn(10), name="a", index=pd.date_range("2000", periods=10, freq="B") + ), + pd.Index(["סשםקה7ךשץא", "8טלכז6לרפל"]), ] diff --git a/distributed/tests/test_collections.py b/distributed/tests/test_collections.py index 0843d711761..61424c68f38 100644 --- a/distributed/tests/test_collections.py +++ b/distributed/tests/test_collections.py @@ -11,7 +11,7 @@ from distributed.utils_test import client, cluster_fixture, loop # noqa F401 import numpy as np import pandas as pd -import pandas.util.testing as tm +import pandas.testing as tm dfs = [ @@ -126,28 +126,37 @@ def test_dataframe_set_index_sync(wait, client): assert len(df2) +def make_time_dataframe(): + return pd.DataFrame( + np.random.randn(30, 4), + columns=list("ABCD"), + index=pd.date_range("2000", periods=30, freq="B"), + ) + + def test_loc_sync(client): - df = pd.util.testing.makeTimeDataFrame() + df = make_time_dataframe() ddf = dd.from_pandas(df, npartitions=10) ddf.loc["2000-01-17":"2000-01-24"].compute() def test_rolling_sync(client): - df = pd.util.testing.makeTimeDataFrame() + df = make_time_dataframe() ddf = dd.from_pandas(df, npartitions=10) ddf.A.rolling(2).mean().compute() @gen_cluster(client=True) def test_loc(c, s, a, b): - df = pd.util.testing.makeTimeDataFrame() + df = make_time_dataframe() ddf = dd.from_pandas(df, npartitions=10) future = c.compute(ddf.loc["2000-01-17":"2000-01-24"]) yield future def test_dataframe_groupby_tasks(client): - df = pd.util.testing.makeTimeDataFrame() + df = make_time_dataframe() + df["A"] = df.A // 0.1 df["B"] = df.B // 0.1 ddf = dd.from_pandas(df, npartitions=10) From 5243d23df9efb58d4f0d643c29a40062b365b65e Mon Sep 17 00:00:00 2001 From: jakirkham Date: Wed, 15 Apr 2020 14:05:48 -0700 Subject: [PATCH 0790/1550] Always use `readinto` in TCP (#3711) As of Tornado 5.0.0+, `read_into` is always available. Given this is our minimum requirement for Tornado, there is no need to handle earlier Tornado versions that don't have this feature. So drop that code path to simplify maintenance. --- distributed/comm/tcp.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/distributed/comm/tcp.py b/distributed/comm/tcp.py index 7003053ce06..769e9132abe 100644 --- a/distributed/comm/tcp.py +++ b/distributed/comm/tcp.py @@ -13,7 +13,7 @@ import dask from tornado import netutil -from tornado.iostream import StreamClosedError, IOStream +from tornado.iostream import StreamClosedError from tornado.tcpclient import TCPClient from tornado.tcpserver import TCPServer @@ -132,10 +132,6 @@ class TCP(Comm): An established communication based on an underlying Tornado IOStream. """ - # IOStream.read_into() currently proposed in - # https://github.com/tornadoweb/tornado/pull/2193 - _iostream_has_read_into = hasattr(IOStream, "read_into") - def __init__(self, stream, local_addr, peer_addr, deserialize=True): Comm.__init__(self) self._local_addr = local_addr @@ -192,15 +188,10 @@ async def read(self, deserializers=None): frames = [] for length in lengths: + frame = bytearray(length) if length: - if self._iostream_has_read_into: - frame = bytearray(length) - n = await stream.read_into(frame) - assert n == length, (n, length) - else: - frame = await stream.read_bytes(length) - else: - frame = b"" + n = await stream.read_into(frame) + assert n == length, (n, length) frames.append(frame) except StreamClosedError as e: self.stream = None From ee8cff496da9d26c2e140df549067d1683cab8ea Mon Sep 17 00:00:00 2001 From: Florian Jetter Date: Thu, 16 Apr 2020 18:05:17 +0200 Subject: [PATCH 0791/1550] Idempotent semaphore acquire with retries (#3690) Semaphore.acquire now performs idempotent acquire requests and retries in case of connection failures. Each lease is now unique and is assigned a unique timeout which is controlled using the configuration option ``distributed.scheduler.locks.lease-timeout`` --- distributed/distributed.yaml | 1 + distributed/semaphore.py | 280 +++++++++++++++++++--------- distributed/tests/test_semaphore.py | 235 +++++++++++++++++++---- 3 files changed, 388 insertions(+), 128 deletions(-) diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index b11270f4704..71ecd840a10 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -44,6 +44,7 @@ distributed: check_unused_sessions_milliseconds: 500 locks: lease-validation-interval: 10s # The time to wait until an acquired semaphore is released if the Client goes out of scope + lease-timeout: 30s # The timeout after which a lease will be released if not refreshed http: routes: diff --git a/distributed/semaphore.py b/distributed/semaphore.py index 2d506c8ed0e..0727c279a3a 100644 --- a/distributed/semaphore.py +++ b/distributed/semaphore.py @@ -1,6 +1,5 @@ import uuid from collections import defaultdict, deque -from functools import partial import asyncio import dask from asyncio import TimeoutError @@ -9,6 +8,7 @@ from .metrics import time import warnings import logging +from distributed.utils_comm import retry_operation logger = logging.getLogger(__name__) @@ -21,12 +21,14 @@ def __init__(self, duration=None): def start(self): self.started_at = time() + def elapsed(self): + return time() - self.started_at + def leftover(self): if self.duration is None: return None else: - elapsed = time() - self.started_at - return max(0, self.duration - elapsed) + return max(0, self.duration - self.elapsed()) class SemaphoreExtension: @@ -37,41 +39,49 @@ class SemaphoreExtension: * semaphore_acquire * semaphore_release * semaphore_create + * semaphore_close + * semaphore_refresh_leases """ def __init__(self, scheduler): self.scheduler = scheduler - self.leases = defaultdict(deque) + + # {semaphore_name: asyncio.Event} self.events = defaultdict(asyncio.Event) + # {semaphore_name: max_leases} self.max_leases = dict() - self.leases_per_client = defaultdict(partial(defaultdict, deque)) + # {semaphore_name: {lease_id: lease_last_seen_timestamp}} + self.leases = defaultdict(dict) + self.scheduler.handlers.update( { "semaphore_create": self.create, "semaphore_acquire": self.acquire, "semaphore_release": self.release, "semaphore_close": self.close, + "semaphore_refresh_leases": self.refresh_leases, } ) self.scheduler.extensions["semaphores"] = self - self.pc_validate_leases = PeriodicCallback( - self._validate_leases, - 1000 - * parse_timedelta( - dask.config.get( - "distributed.scheduler.locks.lease-validation-interval" - ), - default="s", - ), + + validation_callback_time = 1000 * parse_timedelta( + dask.config.get("distributed.scheduler.locks.lease-validation-interval"), + default="s", + ) + self._pc_lease_timeout = PeriodicCallback( + self._check_lease_timeout, + validation_callback_time, io_loop=self.scheduler.loop, ) - self.pc_validate_leases.start() - self._validation_running = False + self._pc_lease_timeout.start() + self.lease_timeout = parse_timedelta( + dask.config.get("distributed.scheduler.locks.lease-timeout"), default="s", + ) # `comm` here is required by the handler interface def create(self, comm=None, name=None, max_leases=None): - # We use `self.max_leases.keys()` as the point of truth to find out if a semaphore with a specific + # We use `self.max_leases` as the point of truth to find out if a semaphore with a specific # `name` has been created. if name not in self.max_leases: assert isinstance(max_leases, int), max_leases @@ -83,11 +93,32 @@ def create(self, comm=None, name=None, max_leases=None): % (max_leases, self.max_leases[name]) ) - def _get_lease(self, client, name, identifier): + def refresh_leases(self, comm=None, name=None, lease_ids=None): + with log_errors(): + now = time() + logger.debug( + "Refresh leases for %s with ids %s at %s", name, lease_ids, now + ) + for id_ in lease_ids: + if id_ not in self.leases[name]: + logger.critical( + f"Trying to refresh an unknown lease ID {id_} for {name}. This might be due to leases " + f"timing out and may cause overbooking of the semaphore!" + f"This is often caused by long-running GIL-holding in the task which acquired the lease." + ) + self.leases[name][id_] = now + + def _get_lease(self, name, lease_id): result = True - if len(self.leases[name]) < self.max_leases[name]: - self.leases[name].append(identifier) - self.leases_per_client[client][name].append(identifier) + + if ( + # This allows request idempotency + lease_id in self.leases[name] + or len(self.leases[name]) < self.max_leases[name] + ): + now = time() + logger.info("Acquire lease %s for %s at %s", lease_id, name, now) + self.leases[name][lease_id] = now else: result = False return result @@ -97,9 +128,7 @@ def _semaphore_exists(self, name): return False return True - async def acquire( - self, comm=None, name=None, client=None, timeout=None, identifier=None - ): + async def acquire(self, comm=None, name=None, timeout=None, lease_id=None): with log_errors(): if not self._semaphore_exists(name): raise RuntimeError(f"Semaphore `{name}` not known or already closed.") @@ -110,11 +139,17 @@ async def acquire( w.start() while True: + logger.info( + "Trying to acquire %s for %s with %ss left.", + lease_id, + name, + w.leftover(), + ) # Reset the event and try to get a release. The event will be set if the state # is changed and helps to identify when it is worth to retry an acquire self.events[name].clear() - result = self._get_lease(client, name, identifier) + result = self._get_lease(name, lease_id) # If acquiring fails, we wait for the event to be set, i.e. something has # been released and we can try to acquire again (continue loop) @@ -127,9 +162,16 @@ async def acquire( continue except TimeoutError: result = False + logger.info( + "Acquisition of lease %s for %s is %s after waiting for %ss.", + lease_id, + name, + result, + w.elapsed(), + ) return result - def release(self, comm=None, name=None, client=None, identifier=None): + def release(self, comm=None, name=None, lease_id=None): with log_errors(): if not self._semaphore_exists(name): logger.warning( @@ -138,37 +180,41 @@ def release(self, comm=None, name=None, client=None, identifier=None): return if isinstance(name, list): name = tuple(name) - if name in self.leases and identifier in self.leases[name]: - self._release_value(name, client, identifier) + if name in self.leases and lease_id in self.leases[name]: + self._release_value(name, lease_id) else: - raise ValueError( + logger.warning( f"Tried to release semaphore but it was already released: " - f"client={client}, name={name}, identifier={identifier}" + f"name={name}, lease_id={lease_id}. This can happen if the semaphore timed out before." ) - def _release_value(self, name, client, identifier): + def _release_value(self, name, lease_id): + logger.info("Releasing %s for %s", lease_id, name) # Everything needs to be atomic here. - self.leases_per_client[client][name].remove(identifier) - self.leases[name].remove(identifier) + del self.leases[name][lease_id] self.events[name].set() - def _release_client(self, client): - semaphore_names = list(self.leases_per_client[client]) + def _check_lease_timeout(self): + now = time() + semaphore_names = list(self.leases.keys()) for name in semaphore_names: - ids = list(self.leases_per_client[client][name]) - for _id in list(ids): - self._release_value(name=name, client=client, identifier=_id) - del self.leases_per_client[client] - - def _validate_leases(self): - if not self._validation_running: - self._validation_running = True - known_clients_with_leases = set(self.leases_per_client.keys()) - scheduler_clients = set(self.scheduler.clients.keys()) - for dead_client in known_clients_with_leases - scheduler_clients: - self._release_client(dead_client) - else: - self._validation_running = False + ids = list(self.leases[name]) + logger.debug( + "Validating leases for %s at time %s. Currently known %s", + name, + now, + self.leases[name], + ) + for _id in ids: + time_since_refresh = now - self.leases[name][_id] + if time_since_refresh > self.lease_timeout: + logger.info( + "Lease %s for %s timed out after %ss.", + _id, + name, + time_since_refresh, + ) + self._release_value(name=name, lease_id=_id) def close(self, comm=None, name=None): """Hard close the semaphore without warning clients which still hold a lease.""" @@ -180,15 +226,12 @@ def close(self, comm=None, name=None): if name in self.events: del self.events[name] if name in self.leases: - del self.leases[name] - - for client, client_leases in self.leases_per_client.items(): - if name in client_leases: + if self.leases[name]: warnings.warn( - f"Closing semaphore `{name}` but client `{client}` still has a lease open.", + f"Closing semaphore {name} but there remain unreleased leases {sorted(self.leases[name])}", RuntimeWarning, ) - del client_leases[name] + del self.leases[name] class Semaphore: @@ -200,19 +243,31 @@ class Semaphore: already acquired, it is not possible to acquire more and the caller waits until another lease has been released. - The lifetime of a lease is coupled to the ``Client`` it was acquired with. - Once the Client goes out of scope, the leases associated to it are freed. - This behavior can be controlled with the - ``distributed.scheduler.locks.lease-validation-interval`` configuration - option. + The lifetime or leases are controlled using a timeout. This timeout is + refreshed in regular intervals by the ``Client`` of this instance and + provides protection from deadlocks or resource starvation in case of worker + failure. + The timeout can be controlled using the configuration option + ``distributed.scheduler.locks.lease-timeout`` and the interval in which the + scheduler verifies the timeout is set using the option + ``distributed.scheduler.locks.lease-validation-interval``. A noticeable difference to the Semaphore of the python standard library is that this implementation does not allow to release more often than it was acquired. If this happens, a warning is emitted but the internal state is not modified. - This implementation is still in an experimental state and subtle changes in - behavior may occur without any change in the major version of this library. + .. warning:: + + This implementation is still in an experimental state and subtle + changes in behavior may occur without any change in the major version + of this library. + + .. warning:: + + This implementation is susceptible to lease overbooking in case of + lease timeouts. It is advised to monitor log information and adjust + above configuration options to suitable values for the user application. Parameters ---------- @@ -263,25 +318,37 @@ class Semaphore: """ def __init__(self, max_leases=1, name=None, client=None): - # NOTE: the `id` of the `Semaphore` instance will always be unique, even among different - # instances for the same resource. The actual attribute that identifies a specific resource is `name`, - # which will be the same for all instances of this class which limit the same resource. self.client = client or get_client() - self.id = uuid.uuid4().hex self.name = name or "semaphore-" + uuid.uuid4().hex self.max_leases = max_leases + self.id = uuid.uuid4().hex + self._leases = deque() - if self.client.asynchronous: - self._started = self.client.scheduler.semaphore_create( - name=self.name, max_leases=max_leases - ) - else: - self.client.sync( - self.client.scheduler.semaphore_create, - name=self.name, - max_leases=max_leases, + self._started = self.client.sync( + self.client.scheduler.semaphore_create, + name=self.name, + max_leases=max_leases, + ) + # this should give ample time to refresh without introducing another + # config parameter since this *must* be smaller than the timeout anyhow + refresh_leases_interval = ( + parse_timedelta( + dask.config.get("distributed.scheduler.locks.lease-timeout"), + default="s", ) - self._started = asyncio.sleep(0) + / 5 + ) + self._refreshing_leases = False + pc = PeriodicCallback( + self._refresh_leases, + callback_time=1000 * refresh_leases_interval, + io_loop=self.client.io_loop, + ) + self.refresh_callback = pc + # Registering the pc to the client here is important for proper cleanup + self._periodic_callback_name = f"refresh_semaphores_{self.id}" + self.client._periodic_callbacks[self._periodic_callback_name] = pc + pc.start() def __await__(self): async def create_semaphore(): @@ -290,6 +357,40 @@ async def create_semaphore(): return create_semaphore().__await__() + async def _refresh_leases(self): + if self.client.scheduler is not None and not self._refreshing_leases: + self._refreshing_leases = True + if self._leases: + logger.debug( + "%s refreshing leases for %s with IDs %s", + self.client.id, + self.name, + self._leases, + ) + await self.client.scheduler.semaphore_refresh_leases( + lease_ids=list(self._leases), name=self.name + ) + self._refreshing_leases = False + + async def _acquire(self, timeout=None): + lease_id = uuid.uuid4().hex + logger.info( + "%s requests lease for %s with ID %s", self.client.id, self.name, lease_id, + ) + + # Using a unique lease id generated here allows us to retry since the + # server handle is idempotent + + result = await retry_operation( + self.client.scheduler.semaphore_acquire, + name=self.name, + timeout=timeout, + lease_id=lease_id, + ) + if result: + self._leases.append(lease_id) + return result + def acquire(self, timeout=None): """ Acquire a semaphore. @@ -297,16 +398,7 @@ def acquire(self, timeout=None): If the internal counter is greater than zero, decrement it by one and return True immediately. If it is zero, wait until a release() is called and return True. """ - # TODO: This (may?) keep the HTTP request open until timeout runs out (forever if None). - # Can do this in batches of smaller timeouts. - # TODO: what if connection breaks up? - return self.client.sync( - self.client.scheduler.semaphore_acquire, - name=self.name, - timeout=timeout, - client=self.client.id, - identifier=self.id, - ) + return self.client.sync(self._acquire, timeout=timeout) def release(self): """ @@ -316,12 +408,13 @@ def release(self): """ """ Release the lock if already acquired """ - # TODO: what if connection breaks up? + if not self._leases: + raise RuntimeError("Released too often") + # popleft to release the oldest lease first + lease_id = self._leases.popleft() + logger.info("%s releases %s for %s", self.client.id, lease_id, self.name) return self.client.sync( - self.client.scheduler.semaphore_release, - name=self.name, - client=self.client.id, - identifier=self.id, + self.client.scheduler.semaphore_release, name=self.name, lease_id=lease_id, ) def __enter__(self): @@ -350,3 +443,8 @@ def __setstate__(self, state): def close(self): return self.client.sync(self.client.scheduler.semaphore_close, name=self.name) + + def __del__(self): + if self._periodic_callback_name in self.client._periodic_callbacks: + self.client._periodic_callbacks[self._periodic_callback_name].stop() + del self.client._periodic_callbacks[self._periodic_callback_name] diff --git a/distributed/tests/test_semaphore.py b/distributed/tests/test_semaphore.py index 9d94b83515a..3c68f685eff 100644 --- a/distributed/tests/test_semaphore.py +++ b/distributed/tests/test_semaphore.py @@ -1,17 +1,24 @@ import pickle - import dask +import pytest from dask.distributed import Client from distributed import Semaphore +from distributed.comm import Comm +from distributed.core import ConnectionPool from distributed.metrics import time -from distributed.utils_test import cluster, gen_cluster -from distributed.utils_test import client, loop, cluster_fixture # noqa: F401 -import pytest +from distributed.utils_test import ( # noqa: F401 + client, + cluster, + cluster_fixture, + gen_cluster, + slowidentity, + loop, +) @gen_cluster(client=True) -async def test_semaphore(c, s, a, b): +async def test_semaphore_trivial(c, s, a, b): semaphore = await Semaphore(max_leases=2, name="resource_we_want_to_limit") result = await semaphore.acquire() # allowed_leases: 2 - 1 -> 1 @@ -81,36 +88,37 @@ def test_timeout_sync(client): assert s.acquire(timeout=0.025) is False -@pytest.mark.slow -@gen_cluster(client=True, timeout=20) +@gen_cluster( + client=True, + timeout=20, + config={ + "distributed.scheduler.locks.lease-validation-interval": "500ms", + "distributed.scheduler.locks.lease-timeout": "500ms", + }, +) async def test_release_semaphore_after_timeout(c, s, a, b): - with dask.config.set( - {"distributed.scheduler.locks.lease-validation-interval": "50ms"} - ): - sem = await Semaphore(name="x", max_leases=2) - await sem.acquire() # leases: 2 - 1 = 1 - semY = await Semaphore(name="y") - - async with Client(s.address, asynchronous=True, name="ClientB") as clientB: - semB = await Semaphore(name="x", max_leases=2, client=clientB) - semYB = await Semaphore(name="y", client=clientB) - - assert await semB.acquire() # leases: 1 - 1 = 0 - assert await semYB.acquire() + sem = await Semaphore(name="x", max_leases=2) + await sem.acquire() # leases: 2 - 1 = 1 + semY = await Semaphore(name="y") - assert not (await sem.acquire(timeout=0.01)) - assert not (await semB.acquire(timeout=0.01)) - assert not (await semYB.acquire(timeout=0.01)) + async with Client(s.address, asynchronous=True, name="ClientB") as clientB: + semB = await Semaphore(name="x", max_leases=2, client=clientB) + semYB = await Semaphore(name="y", client=clientB) - # `ClientB` goes out of scope, leases should be released - # At this point, we should be able to acquire x and y once - assert await sem.acquire() - assert await semY.acquire() + assert await semB.acquire() # leases: 1 - 1 = 0 + assert await semYB.acquire() - assert not (await semY.acquire(timeout=0.01)) assert not (await sem.acquire(timeout=0.01)) + assert not (await semB.acquire(timeout=0.01)) + assert not (await semYB.acquire(timeout=0.01)) - assert clientB.id not in s.extensions["semaphores"].leases_per_client + # `ClientB` goes out of scope, leases should be released + # At this point, we should be able to acquire x and y once + assert await sem.acquire() + assert await semY.acquire() + + assert not (await semY.acquire(timeout=0.5)) + assert not (await sem.acquire(timeout=0.5)) @gen_cluster() @@ -185,7 +193,8 @@ async def test_close_async(c, s, a, b): assert await sem.acquire() with pytest.warns( - RuntimeWarning, match="Closing semaphore .* but client .* still has a lease" + RuntimeWarning, + match="Closing semaphore .* but there remain unreleased leases .*", ): await sem.close() @@ -198,7 +207,6 @@ async def test_close_async(c, s, a, b): assert not semaphore_object.max_leases assert not semaphore_object.leases assert not semaphore_object.events - assert not any(semaphore_object.leases_per_client.values()) def test_close_sync(client): @@ -215,9 +223,7 @@ async def test_release_once_too_many(c, s, a, b): assert await sem.acquire() await sem.release() - with pytest.raises( - ValueError, match="Tried to release semaphore but it was already released" - ): + with pytest.raises(RuntimeError, match="Released too often"): await sem.release() assert await sem.acquire() @@ -229,9 +235,7 @@ async def test_release_once_too_many_resilience(c, s, a, b): def f(x, sem): sem.acquire() sem.release() - with pytest.raises( - ValueError, match="Tried to release semaphore but it was already released" - ): + with pytest.raises(RuntimeError, match="Released too often"): sem.release() return x @@ -244,3 +248,160 @@ def f(x, sem): assert not s.extensions["semaphores"].leases["x"] await sem.acquire() assert len(s.extensions["semaphores"].leases["x"]) == 1 + + +class BrokenComm(Comm): + peer_address = None + local_address = None + + def close(self): + pass + + def closed(self): + return True + + def abort(self): + pass + + def read(self, deserializers=None): + raise EnvironmentError + + def write(self, msg, serializers=None, on_error=None): + raise EnvironmentError + + +class FlakyConnectionPool(ConnectionPool): + def __init__(self, *args, failing_connections=0, **kwargs): + self.cnn_count = 0 + self.failing_connections = failing_connections + self._flaky_active = False + super().__init__(*args, **kwargs) + + def activate(self): + self._flaky_active = True + + async def connect(self, *args, **kwargs): + if self.cnn_count >= self.failing_connections or not self._flaky_active: + return await super().connect(*args, **kwargs) + else: + self.cnn_count += 1 + return BrokenComm() + + +@gen_cluster(client=True) +async def test_retry_acquire(c, s, a, b): + with dask.config.set({"distributed.comm.retry.count": 1}): + + pool = await FlakyConnectionPool(failing_connections=1) + rpc = pool(s.address) + c.scheduler = rpc + semaphore = await Semaphore( + max_leases=2, name="resource_we_want_to_limit", client=c + ) + pool.activate() + + result = await semaphore.acquire() + assert result is True + + second = await semaphore.acquire() + assert second is True + start = time() + result = await semaphore.acquire(timeout=0.025) + stop = time() + assert stop - start < 0.2 + assert result is False + + +@gen_cluster( + client=True, + config={ + "distributed.scheduler.locks.lease-timeout": "100ms", + "distributed.scheduler.locks.lease-validation-interval": "10ms", + }, +) +async def test_oversubscribing_leases(c, s, a, b): + """ + This test ensures that we detect oversubscription scenarios and will not + accept new leases as long as the semaphore is oversubscribed. + + Oversubscription may occur if tasks hold the GIL for a longer time than the + lease-timeout is configured causing the lease refreshs to go stale and + timeout. + + We cannot protect ourselves entirely from this but we can ensure that while + a task with a timed out lease is still running, we block further + acquisitions until we return to normal. + + An example would be a task which continuously locks the GIL for a longer + time than the lease timeout but this continous lock only makes up a + fraction of the tasks runtime. + + """ + # GH3705 + + from distributed.worker import Worker, get_client + + # Using the metadata as a crude "asyncio.Event" since the proper event + # implementation cannot be serialized. For the purpose of this test a + # metadata check with a sleep loop is not elegant but practical. + await c.set_metadata("release", False) + sem = await Semaphore() + + def guaranteed_lease_timeout(x, sem): + """ + This function simulates a payload computation with some GIL + locking in the beginning. + + To simulate this we will manually disable the refresh callback, i.e. + all leases will eventually timeout. The function will only + release/return once the "Event" is set, i.e. our observer is done. + """ + sem.refresh_callback.stop() + client = get_client() + + with sem: + # This simulates a task which holds the GIL for longer than the + # lease-timeout. + slowidentity(delay=0.2) + old_value = client.set_metadata(x, "locked") + + # Now the GIL is free again, i.e. we enable the callback again + sem.refresh_callback.start() + + # This is the poormans Event.wait() + while not client.get_metadata("release"): + slowidentity(delay=0.02) + return x + + def observe_state(sem): + """ + This function is 100% artificial and acts as an observer to verify + our assumptions. The function will wait until both payload tasks are + executing, i.e. we're in an oversubscription scenario. It will then + try to acquire and hopefully fail showing that the semaphore is + protected if the oversubscription is recognized. + """ + client = get_client() + x_locked = False + y_locked = False + # We wait until we're in an oversubscribed state, i.e. both tasks + # are executed although there should only be one allowed + while not x_locked and y_locked: + slowidentity(delay=0.005) + x_locked = client.get_metadata(0) == "locked" + y_locked = client.get_metadata(1) == "locked" + + # Once we're in an oversubscribed state, we must not be able to + # acquire a lease. + assert not sem.acquire(timeout=0.05) + client.set_metadata("release", True) + + observer = await Worker(s.address) + + futures = c.map( + guaranteed_lease_timeout, range(2), sem=sem, workers=[a.address, b.address] + ) + fut_observe = c.submit(observe_state, sem=sem, workers=[observer.address]) + + payload, observer = await c.gather([futures, fut_observe]) + assert sorted(payload) == [0, 1] From 6a3dc40891b9df0e2b94073a945641ebb39bfb5a Mon Sep 17 00:00:00 2001 From: crusaderky Date: Fri, 17 Apr 2020 00:36:27 +0100 Subject: [PATCH 0792/1550] Force threads_per_worker (#3715) --- distributed/deploy/tests/test_adaptive.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/distributed/deploy/tests/test_adaptive.py b/distributed/deploy/tests/test_adaptive.py index 9c68e6ddf53..7fb91292540 100644 --- a/distributed/deploy/tests/test_adaptive.py +++ b/distributed/deploy/tests/test_adaptive.py @@ -1,3 +1,4 @@ +import gc import math from time import sleep @@ -151,6 +152,7 @@ def test_min_max(): processes=False, dashboard_address=None, asynchronous=True, + threads_per_worker=1, ) try: adapt = cluster.adapt(minimum=1, maximum=2, interval="20 ms", wait_count=10) @@ -179,6 +181,7 @@ def test_min_max(): assert len(adapt.log) == 2 and all(d["status"] == "up" for _, d in adapt.log) del futures + gc.collect() start = time() while len(cluster.scheduler.workers) != 1: From d5cb312496708d84b01a423ce5cd4867240cc4d3 Mon Sep 17 00:00:00 2001 From: "Richard (Rick) Zamora" Date: Fri, 17 Apr 2020 10:16:02 -0500 Subject: [PATCH 0793/1550] Dask-serialize dicts longer than five elements (#3689) --- distributed/protocol/serialize.py | 34 +++++++++++++++-- distributed/protocol/tests/test_serialize.py | 40 ++++++++++++++++++++ 2 files changed, 71 insertions(+), 3 deletions(-) diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index f20bc490752..a1b35ec4463 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -90,6 +90,20 @@ def register_serialization_family(name, dumps, loads): register_serialization_family("error", None, serialization_error_loads) +def check_dask_serializable(x): + if type(x) in (list, set, tuple) and len(x): + return check_dask_serializable(next(iter(x))) + elif type(x) is dict and len(x): + return check_dask_serializable(next(iter(x.items()))[1]) + else: + try: + dask_serialize.dispatch(type(x)) + return True + except TypeError: + pass + return False + + def serialize(x, serializers=None, on_error="message", context=None): r""" Convert object to a header and list of bytestrings @@ -132,8 +146,22 @@ def serialize(x, serializers=None, on_error="message", context=None): if isinstance(x, Serialized): return x.header, x.frames + if type(x) in (list, set, tuple, dict): + iterate_collection = False + if type(x) is list and "msgpack" in serializers: + # Note: "msgpack" will always convert lists to tuples + # (see GitHub #3716), so we should iterate + # through the list if "msgpack" comes before "pickle" + # in the list of serializers. + iterate_collection = ("pickle" not in serializers) or ( + serializers.index("pickle") > serializers.index("msgpack") + ) + if not iterate_collection: + # Check for "dask"-serializable data in dict/list/set + iterate_collection = check_dask_serializable(x) + # Determine whether keys are safe to be serialized with msgpack - if type(x) is dict and len(x) <= 5: + if type(x) is dict and iterate_collection: try: msgpack.dumps(list(x.keys())) except Exception: @@ -143,9 +171,9 @@ def serialize(x, serializers=None, on_error="message", context=None): if ( type(x) in (list, set, tuple) - and len(x) <= 5 + and iterate_collection or type(x) is dict - and len(x) <= 5 + and iterate_collection and dict_safe ): if isinstance(x, dict): diff --git a/distributed/protocol/tests/test_serialize.py b/distributed/protocol/tests/test_serialize.py index 41e2af51b70..dd23e5e635d 100644 --- a/distributed/protocol/tests/test_serialize.py +++ b/distributed/protocol/tests/test_serialize.py @@ -21,6 +21,7 @@ register_serialization_family, dask_serialize, ) +from distributed.protocol.serialize import check_dask_serializable from distributed.utils import nbytes from distributed.utils_test import inc, gen_test from distributed.comm.utils import to_frames, from_frames @@ -388,3 +389,42 @@ def _(x): header, frames = serialize([MyObj(), MyObj()]) assert header["compression"] == [False, False] + + +@pytest.mark.parametrize( + "data,is_serializable", + [ + ([], False), + ({}, False), + ({i: i for i in range(10)}, False), + (set(range(10)), False), + (tuple(range(100)), False), + ({"x": MyObj(5)}, True), + ({"x": {"y": MyObj(5)}}, True), + pytest.param( + [1, MyObj(5)], + True, + marks=pytest.mark.xfail(reason="Only checks 0th element for now."), + ), + ([MyObj([0, 1, 2]), 1], True), + (tuple([MyObj(None)]), True), + ({("x", i): MyObj(5) for i in range(100)}, True), + ], +) +def test_check_dask_serializable(data, is_serializable): + result = check_dask_serializable(data) + expected = is_serializable + + assert result == expected + + +@pytest.mark.parametrize( + "serializers", + [["msgpack"], ["pickle"], ["msgpack", "pickle"], ["pickle", "msgpack"]], +) +def test_serialize_lists(serializers): + data_in = ["a", 2, "c", None, "e", 6] + header, frames = serialize(data_in, serializers=serializers) + data_out = deserialize(header, frames) + + assert data_in == data_out From 07b0cfeef4d2515361c1ee89222c18d30ad26a67 Mon Sep 17 00:00:00 2001 From: Florian Jetter Date: Fri, 17 Apr 2020 18:49:28 +0200 Subject: [PATCH 0794/1550] Adjust semaphore test timeouts (#3720) --- distributed/tests/test_semaphore.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/distributed/tests/test_semaphore.py b/distributed/tests/test_semaphore.py index 3c68f685eff..2f08fe45751 100644 --- a/distributed/tests/test_semaphore.py +++ b/distributed/tests/test_semaphore.py @@ -361,7 +361,8 @@ def guaranteed_lease_timeout(x, sem): with sem: # This simulates a task which holds the GIL for longer than the - # lease-timeout. + # lease-timeout. This is twice the lease timeout to ensurre that the + # leases are actually timed out slowidentity(delay=0.2) old_value = client.set_metadata(x, "locked") @@ -391,9 +392,12 @@ def observe_state(sem): x_locked = client.get_metadata(0) == "locked" y_locked = client.get_metadata(1) == "locked" + # Once both are locked we should give the refresh time to notify the scheduler + # This parameter should be larger than ``lease-validation-interval`` + slowidentity(delay=0.15) # Once we're in an oversubscribed state, we must not be able to # acquire a lease. - assert not sem.acquire(timeout=0.05) + assert not sem.acquire(timeout=0) client.set_metadata("release", True) observer = await Worker(s.address) @@ -405,3 +409,16 @@ def observe_state(sem): payload, observer = await c.gather([futures, fut_observe]) assert sorted(payload) == [0, 1] + + +@gen_cluster(client=True,) +async def test_timeout_zero(c, s, a, b): + # Depending on the internals a timeout zero cannot work, e.g. when the + # initial try already includes a wait. Since some test cases use this, it is + # worth testing against. + + sem = await Semaphore() + + assert await sem.acquire(timeout=0) + assert not await sem.acquire(timeout=0) + await sem.release() From 41f1568975d51815e87df006f70589a8a4f6b84d Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Fri, 17 Apr 2020 15:35:32 -0500 Subject: [PATCH 0795/1550] Add batch_size to Client.map (#3650) An informal benchmark ```python In [9]: %time _ = wait(c.map(inc, range(100_000), pure=False, batch_size=1_000)) CPU times: user 31.8 s, sys: 1.07 s, total: 32.9 s Wall time: 33.3 s In [10]: %time _ = wait(c.map(inc, range(100_000), pure=False)) CPU times: user 45.3 s, sys: 2.13 s, total: 47.5 s Wall time: 48.4 s ``` The difference likely increases in the size of the iterable. Closes https://github.com/dask/distributed/issues/2181 --- distributed/client.py | 47 +++++++++++++++++++++++++++----- distributed/tests/test_client.py | 16 +++++++++++ 2 files changed, 56 insertions(+), 7 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 6545e938511..607ca9f5b48 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -30,7 +30,7 @@ from dask.compatibility import apply from dask.utils import ensure_dict, format_bytes, funcname -from tlz import first, groupby, merge, valmap, keymap +from tlz import first, groupby, merge, valmap, keymap, partition_all try: from dask.delayed import single_key @@ -1044,7 +1044,7 @@ async def _ensure_connected(self, timeout=None): try: comm = await connect( - self.scheduler.address, timeout=timeout, **self.connection_args, + self.scheduler.address, timeout=timeout, **self.connection_args ) comm.name = "Client->Scheduler" if timeout is not None: @@ -1540,6 +1540,7 @@ def map( actor=False, actors=False, pure=None, + batch_size=None, **kwargs, ): """ Map a function on a sequence of arguments @@ -1579,6 +1580,11 @@ def map( See :doc:`actors` for additional details. actors: bool (default False) Alias for `actor` + batch_size : int, optional + Submit tasks to the scheduler in batches of (at most) ``batch_size``. + Larger batch sizes can be useful for very large ``iterables``, + as the cluster can start processing tasks while later ones are + submitted asynchronously. **kwargs: dict Extra keywords to send to the function. Large values will be included explicitly in the task graph. @@ -1596,11 +1602,6 @@ def map( -------- Client.submit: Submit a single function """ - key = key or funcname(func) - actor = actor or actors - if pure is None: - pure = not actor - if not callable(func): raise TypeError("First input to map must be a callable function") @@ -1611,6 +1612,38 @@ def map( "Dask no longer supports mapping over Iterators or Queues." "Consider using a normal for loop and Client.submit" ) + total_length = sum(len(x) for x in iterables) + + if batch_size and batch_size > 1 and total_length > batch_size: + batches = list( + zip(*[partition_all(batch_size, iterable) for iterable in iterables]) + ) + return sum( + [ + self.map( + func, + *batch, + key=key, + workers=workers, + retries=retries, + priority=priority, + allow_other_workers=allow_other_workers, + fifo_timeout=fifo_timeout, + resources=resources, + actor=actor, + actors=actors, + pure=pure, + **kwargs, + ) + for batch in batches + ], + [], + ) + + key = key or funcname(func) + actor = actor or actors + if pure is None: + pure = not actor if allow_other_workers and workers is None: raise ValueError("Only use allow_other_workers= if using workers=") diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 297394631ae..17aea54f78a 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -218,6 +218,22 @@ def test_map_retries(c, s, a, b): yield z +@gen_cluster(client=True) +async def test_map_batch_size(c, s, a, b): + result = c.map(inc, range(100), batch_size=10) + result = await c.gather(result) + assert result == list(range(1, 101)) + + result = c.map(add, range(100), range(100), batch_size=10) + result = await c.gather(result) + assert result == list(range(0, 200, 2)) + + # mismatch shape + result = c.map(add, range(100, 200), range(10), batch_size=2) + result = await c.gather(result) + assert result == list(range(100, 120, 2)) + + @gen_cluster(client=True) def test_compute_retries(c, s, a, b): args = [ZeroDivisionError("one"), ZeroDivisionError("two"), 3] From 5c027a6fe08e387362ab4e59264b80b652d22377 Mon Sep 17 00:00:00 2001 From: Florian Jetter Date: Sat, 18 Apr 2020 17:09:10 +0200 Subject: [PATCH 0796/1550] Fix flaky test_oversubscribing_leases (#3726) * Add get_value method to semaphore * Introduce refresh_leases to control lease refreshing --- distributed/semaphore.py | 36 ++++++++++------- distributed/tests/test_semaphore.py | 61 ++++++++++++++++++++--------- 2 files changed, 64 insertions(+), 33 deletions(-) diff --git a/distributed/semaphore.py b/distributed/semaphore.py index 0727c279a3a..976f54704c4 100644 --- a/distributed/semaphore.py +++ b/distributed/semaphore.py @@ -60,6 +60,7 @@ def __init__(self, scheduler): "semaphore_release": self.release, "semaphore_close": self.close, "semaphore_refresh_leases": self.refresh_leases, + "semaphore_value": self.get_value, } ) @@ -79,6 +80,9 @@ def __init__(self, scheduler): dask.config.get("distributed.scheduler.locks.lease-timeout"), default="s", ) + async def get_value(self, comm=None, name=None): + return len(self.leases[name]) + # `comm` here is required by the handler interface def create(self, comm=None, name=None, max_leases=None): # We use `self.max_leases` as the point of truth to find out if a semaphore with a specific @@ -102,7 +106,7 @@ def refresh_leases(self, comm=None, name=None, lease_ids=None): for id_ in lease_ids: if id_ not in self.leases[name]: logger.critical( - f"Trying to refresh an unknown lease ID {id_} for {name}. This might be due to leases " + f"Refreshing an unknown lease ID {id_} for {name}. This might be due to leases " f"timing out and may cause overbooking of the semaphore!" f"This is often caused by long-running GIL-holding in the task which acquired the lease." ) @@ -349,6 +353,7 @@ def __init__(self, max_leases=1, name=None, client=None): self._periodic_callback_name = f"refresh_semaphores_{self.id}" self.client._periodic_callbacks[self._periodic_callback_name] = pc pc.start() + self.refresh_leases = True def __await__(self): async def create_semaphore(): @@ -358,19 +363,16 @@ async def create_semaphore(): return create_semaphore().__await__() async def _refresh_leases(self): - if self.client.scheduler is not None and not self._refreshing_leases: - self._refreshing_leases = True - if self._leases: - logger.debug( - "%s refreshing leases for %s with IDs %s", - self.client.id, - self.name, - self._leases, - ) - await self.client.scheduler.semaphore_refresh_leases( - lease_ids=list(self._leases), name=self.name - ) - self._refreshing_leases = False + if self.refresh_leases and self._leases: + logger.debug( + "%s refreshing leases for %s with IDs %s", + self.client.id, + self.name, + self._leases, + ) + await self.client.scheduler.semaphore_refresh_leases( + lease_ids=list(self._leases), name=self.name + ) async def _acquire(self, timeout=None): lease_id = uuid.uuid4().hex @@ -417,6 +419,12 @@ def release(self): self.client.scheduler.semaphore_release, name=self.name, lease_id=lease_id, ) + def get_value(self): + """ + Return the number of currently registered leases. + """ + return self.client.sync(self.client.scheduler.semaphore_value, name=self.name) + def __enter__(self): self.acquire() return self diff --git a/distributed/tests/test_semaphore.py b/distributed/tests/test_semaphore.py index 2f08fe45751..7a36431042f 100644 --- a/distributed/tests/test_semaphore.py +++ b/distributed/tests/test_semaphore.py @@ -2,7 +2,7 @@ import dask import pytest from dask.distributed import Client - +from time import sleep from distributed import Semaphore from distributed.comm import Comm from distributed.core import ConnectionPool @@ -10,6 +10,8 @@ from distributed.utils_test import ( # noqa: F401 client, cluster, + async_wait_for, + captured_logger, cluster_fixture, gen_cluster, slowidentity, @@ -316,7 +318,7 @@ async def test_retry_acquire(c, s, a, b): client=True, config={ "distributed.scheduler.locks.lease-timeout": "100ms", - "distributed.scheduler.locks.lease-validation-interval": "10ms", + "distributed.scheduler.locks.lease-validation-interval": "100ms", }, ) async def test_oversubscribing_leases(c, s, a, b): @@ -346,6 +348,7 @@ async def test_oversubscribing_leases(c, s, a, b): # metadata check with a sleep loop is not elegant but practical. await c.set_metadata("release", False) sem = await Semaphore() + sem.refresh_callback.stop() def guaranteed_lease_timeout(x, sem): """ @@ -356,7 +359,7 @@ def guaranteed_lease_timeout(x, sem): all leases will eventually timeout. The function will only release/return once the "Event" is set, i.e. our observer is done. """ - sem.refresh_callback.stop() + sem.refresh_leases = False client = get_client() with sem: @@ -364,14 +367,17 @@ def guaranteed_lease_timeout(x, sem): # lease-timeout. This is twice the lease timeout to ensurre that the # leases are actually timed out slowidentity(delay=0.2) - old_value = client.set_metadata(x, "locked") + assert sem._leases # Now the GIL is free again, i.e. we enable the callback again - sem.refresh_callback.start() + sem.refresh_leases = True + sleep(0.1) # This is the poormans Event.wait() - while not client.get_metadata("release"): - slowidentity(delay=0.02) + while client.get_metadata("release") is not True: + sleep(0.05) + + assert sem.get_value() >= 1 return x def observe_state(sem): @@ -382,22 +388,17 @@ def observe_state(sem): try to acquire and hopefully fail showing that the semaphore is protected if the oversubscription is recognized. """ - client = get_client() - x_locked = False - y_locked = False + sem.refresh_callback.stop() # We wait until we're in an oversubscribed state, i.e. both tasks # are executed although there should only be one allowed - while not x_locked and y_locked: - slowidentity(delay=0.005) - x_locked = client.get_metadata(0) == "locked" - y_locked = client.get_metadata(1) == "locked" - - # Once both are locked we should give the refresh time to notify the scheduler - # This parameter should be larger than ``lease-validation-interval`` - slowidentity(delay=0.15) + while not sem.get_value() > 1: + sleep(0.2) + # Once we're in an oversubscribed state, we must not be able to # acquire a lease. assert not sem.acquire(timeout=0) + + client = get_client() client.set_metadata("release", True) observer = await Worker(s.address) @@ -407,8 +408,18 @@ def observe_state(sem): ) fut_observe = c.submit(observe_state, sem=sem, workers=[observer.address]) - payload, observer = await c.gather([futures, fut_observe]) + with captured_logger("distributed.semaphore") as caplog: + payload, observer = await c.gather([futures, fut_observe]) + + logs = caplog.getvalue().split("\n") + timeouts = [log for log in logs if "timed out" in log] + refresh_unknown = [log for log in logs if "Refreshing an unknown lease ID" in log] + assert len(timeouts) == 2 + assert len(refresh_unknown) == 2 + assert sorted(payload) == [0, 1] + # Back to normal + assert await sem.get_value() == 0 @gen_cluster(client=True,) @@ -422,3 +433,15 @@ async def test_timeout_zero(c, s, a, b): assert await sem.acquire(timeout=0) assert not await sem.acquire(timeout=0) await sem.release() + + +@gen_cluster(client=True) +async def test_getvalue(c, s, a, b): + + sem = await Semaphore() + + assert await sem.get_value() == 0 + await sem.acquire() + assert await sem.get_value() == 1 + await sem.release() + assert await sem.get_value() == 0 From 9622b8f9bef1855412e9b23265378e2da1f47f2f Mon Sep 17 00:00:00 2001 From: crusaderky Date: Sat, 18 Apr 2020 18:06:06 +0100 Subject: [PATCH 0797/1550] Replace gen.coroutine with async-await in tests (#3706) --- .gitignore | 1 + distributed/cli/tests/test_dask_scheduler.py | 17 +- distributed/cli/utils.py | 6 +- distributed/client.py | 5 +- distributed/comm/tests/test_ucx.py | 4 +- .../dashboard/tests/test_components.py | 13 +- .../dashboard/tests/test_scheduler_bokeh.py | 181 +- .../dashboard/tests/test_worker_bokeh.py | 40 +- distributed/deploy/tests/test_adaptive.py | 76 +- distributed/deploy/tests/test_local.py | 22 +- .../diagnostics/tests/test_eventstream.py | 20 +- .../diagnostics/tests/test_graph_layout.py | 30 +- .../diagnostics/tests/test_progress.py | 64 +- .../diagnostics/tests/test_progress_stream.py | 10 +- .../diagnostics/tests/test_progressbar.py | 6 +- .../tests/test_scheduler_plugin.py | 18 +- .../diagnostics/tests/test_task_stream.py | 28 +- distributed/diagnostics/tests/test_widgets.py | 34 +- .../diagnostics/tests/test_worker_plugin.py | 22 +- .../http/worker/tests/test_worker_http.py | 12 +- distributed/node.py | 2 +- distributed/protocol/tests/test_arrow.py | 6 +- distributed/protocol/tests/test_h5py.py | 8 +- distributed/protocol/tests/test_netcdf4.py | 4 +- distributed/protocol/tests/test_numpy.py | 4 +- distributed/protocol/tests/test_serialize.py | 44 +- distributed/tests/test_actor.py | 196 +- distributed/tests/test_as_completed.py | 49 +- distributed/tests/test_asyncprocess.py | 106 +- distributed/tests/test_client.py | 1875 ++++++++--------- distributed/tests/test_collections.py | 30 +- distributed/tests/test_core.py | 2 +- distributed/tests/test_failed_workers.py | 160 +- distributed/tests/test_locks.py | 48 +- distributed/tests/test_nanny.py | 181 +- distributed/tests/test_priorities.py | 19 +- distributed/tests/test_publish.py | 116 +- distributed/tests/test_pubsub.py | 33 +- distributed/tests/test_queues.py | 177 +- distributed/tests/test_resources.py | 107 +- distributed/tests/test_scheduler.py | 603 +++--- distributed/tests/test_security.py | 47 +- distributed/tests/test_semaphore.py | 6 +- distributed/tests/test_steal.py | 231 +- distributed/tests/test_stress.py | 64 +- distributed/tests/test_tls_functional.py | 72 +- distributed/tests/test_utils.py | 42 +- distributed/tests/test_utils_test.py | 38 +- distributed/tests/test_variable.py | 99 +- distributed/tests/test_worker.py | 499 +++-- distributed/tests/test_worker_client.py | 74 +- distributed/utils.py | 1 + distributed/utils_test.py | 54 +- distributed/worker.py | 1 + docs/source/asynchronous.rst | 29 +- 55 files changed, 2764 insertions(+), 2872 deletions(-) diff --git a/.gitignore b/.gitignore index 86ee425adff..cf6732eaa70 100644 --- a/.gitignore +++ b/.gitignore @@ -17,3 +17,4 @@ dask-worker-space/ *.swp .ycm_extra_conf.py tags +.ipynb_checkpoints diff --git a/distributed/cli/tests/test_dask_scheduler.py b/distributed/cli/tests/test_dask_scheduler.py index 3e867b1f377..6f4129514b9 100644 --- a/distributed/cli/tests/test_dask_scheduler.py +++ b/distributed/cli/tests/test_dask_scheduler.py @@ -10,7 +10,6 @@ import tempfile from time import sleep -from tornado import gen from click.testing import CliRunner import distributed @@ -29,12 +28,9 @@ def test_defaults(loop): with popen(["dask-scheduler", "--no-dashboard"]) as proc: - @gen.coroutine - def f(): + async def f(): # Default behaviour is to listen on all addresses - yield [ - assert_can_connect_from_everywhere_4_6(8786, timeout=5.0) - ] # main port + await assert_can_connect_from_everywhere_4_6(8786, timeout=5.0) with Client("127.0.0.1:%d" % Scheduler.default_port, loop=loop) as c: c.sync(f) @@ -49,12 +45,9 @@ def f(): def test_hostport(loop): with popen(["dask-scheduler", "--no-dashboard", "--host", "127.0.0.1:8978"]): - @gen.coroutine - def f(): - yield [ - # The scheduler's main port can't be contacted from the outside - assert_can_connect_locally_4(8978, timeout=5.0) - ] + async def f(): + # The scheduler's main port can't be contacted from the outside + await assert_can_connect_locally_4(8978, timeout=5.0) with Client("127.0.0.1:8978", loop=loop) as c: assert len(c.nthreads()) == 0 diff --git a/distributed/cli/utils.py b/distributed/cli/utils.py index 4cfb41abe0f..c1bff051534 100644 --- a/distributed/cli/utils.py +++ b/distributed/cli/utils.py @@ -1,4 +1,3 @@ -from tornado import gen from tornado.ioloop import IOLoop @@ -51,11 +50,10 @@ def install_signal_handlers(loop=None, cleanup=None): old_handlers = {} def handle_signal(sig, frame): - @gen.coroutine - def cleanup_and_stop(): + async def cleanup_and_stop(): try: if cleanup is not None: - yield cleanup(sig) + await cleanup(sig) finally: loop.stop() diff --git a/distributed/client.py b/distributed/client.py index 607ca9f5b48..5ba05a84a3b 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -102,7 +102,6 @@ def _get_global_client(): return c else: del _global_clients[k] - del L return None @@ -1339,6 +1338,10 @@ def close(self, timeout=no_default): timeout = self._timeout * 2 # XXX handling of self.status here is not thread-safe if self.status == "closed": + if self.asynchronous: + future = asyncio.Future() + future.set_result(None) + return future return self.status = "closing" diff --git a/distributed/comm/tests/test_ucx.py b/distributed/comm/tests/test_ucx.py index 9ac97deeb7e..7e3cb61e375 100644 --- a/distributed/comm/tests/test_ucx.py +++ b/distributed/comm/tests/test_ucx.py @@ -151,13 +151,13 @@ async def test_ping_pong_data(): @gen_test() -def test_ucx_deserialize(): +async def test_ucx_deserialize(): # Note we see this error on some systems with this test: # `socket.gaierror: [Errno -5] No address associated with hostname` # This may be due to a system configuration issue. from .test_comms import check_deserialize - yield check_deserialize("tcp://") + await check_deserialize("tcp://") @pytest.mark.asyncio diff --git a/distributed/dashboard/tests/test_components.py b/distributed/dashboard/tests/test_components.py index 3e6a696cc6b..a3e444e17e6 100644 --- a/distributed/dashboard/tests/test_components.py +++ b/distributed/dashboard/tests/test_components.py @@ -1,9 +1,10 @@ +import asyncio + import pytest pytest.importorskip("bokeh") from bokeh.models import ColumnDataSource, Model -from tornado import gen from distributed.utils_test import slowinc, gen_cluster from distributed.dashboard.components.shared import ( @@ -21,16 +22,16 @@ def test_basic(Component): @gen_cluster(client=True, clean_kwargs={"threads": False}) -def test_profile_plot(c, s, a, b): +async def test_profile_plot(c, s, a, b): p = ProfilePlot() assert not p.source.data["left"] - yield c.map(slowinc, range(10), delay=0.05) + await c.gather(c.map(slowinc, range(10), delay=0.05)) p.update(a.profile_recent) assert len(p.source.data["left"]) >= 1 @gen_cluster(client=True, clean_kwargs={"threads": False}) -def test_profile_time_plot(c, s, a, b): +async def test_profile_time_plot(c, s, a, b): from bokeh.io import curdoc sp = ProfileTimePlot(s, doc=curdoc()) @@ -42,7 +43,7 @@ def test_profile_time_plot(c, s, a, b): assert len(sp.source.data["left"]) <= 1 assert len(ap.source.data["left"]) <= 1 - yield c.map(slowinc, range(10), delay=0.05) + await c.gather(c.map(slowinc, range(10), delay=0.05)) ap.trigger_update() sp.trigger_update() - yield gen.sleep(0.05) + await asyncio.sleep(0.05) diff --git a/distributed/dashboard/tests/test_scheduler_bokeh.py b/distributed/dashboard/tests/test_scheduler_bokeh.py index 49bdfe448bf..8ed1bb0f8a1 100644 --- a/distributed/dashboard/tests/test_scheduler_bokeh.py +++ b/distributed/dashboard/tests/test_scheduler_bokeh.py @@ -1,3 +1,4 @@ +import asyncio import json import re import ssl @@ -8,7 +9,6 @@ pytest.importorskip("bokeh") from tlz import first -from tornado import gen from tornado.httpclient import AsyncHTTPClient, HTTPRequest import dask @@ -35,31 +35,26 @@ ProfileServer, MemoryByKey, ) -from distributed.utils_test import async_wait_for - from distributed.dashboard import scheduler scheduler.PROFILING = False -@pytest.mark.skipif( - sys.version_info[0] == 2, reason="https://github.com/bokeh/bokeh/issues/5494" -) @gen_cluster(client=True, scheduler_kwargs={"dashboard": True}) -def test_simple(c, s, a, b): +async def test_simple(c, s, a, b): port = s.http_server.port future = c.submit(sleep, 1) - yield gen.sleep(0.1) + await asyncio.sleep(0.1) http_client = AsyncHTTPClient() for suffix in applications: - response = yield http_client.fetch("http://localhost:%d%s" % (port, suffix)) + response = await http_client.fetch("http://localhost:%d%s" % (port, suffix)) body = response.body.decode() assert "bokeh" in body.lower() assert not re.search("href=./", body) # no absolute links - response = yield http_client.fetch( + response = await http_client.fetch( "http://localhost:%d/individual-plots.json" % port ) response = json.loads(response.body.decode()) @@ -67,7 +62,7 @@ def test_simple(c, s, a, b): @gen_cluster(client=True, worker_kwargs={"dashboard": True}) -def test_basic(c, s, a, b): +async def test_basic(c, s, a, b): for component in [TaskStream, SystemMonitor, Occupancy, StealingTimeSeries]: ss = component(s) @@ -79,24 +74,24 @@ def test_basic(c, s, a, b): @gen_cluster(client=True) -def test_counters(c, s, a, b): +async def test_counters(c, s, a, b): pytest.importorskip("crick") while "tick-duration" not in s.digests: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) ss = Counters(s) ss.update() - yield gen.sleep(0.1) + await asyncio.sleep(0.1) ss.update() start = time() while not len(ss.digest_sources["tick-duration"][0].data["x"]): - yield gen.sleep(1) + await asyncio.sleep(1) assert time() < start + 5 @gen_cluster(client=True) -def test_stealing_events(c, s, a, b): +async def test_stealing_events(c, s, a, b): se = StealingEvents(s) futures = c.map( @@ -104,7 +99,7 @@ def test_stealing_events(c, s, a, b): ) while not b.task_state: # will steal soon - yield gen.sleep(0.01) + await asyncio.sleep(0.01) se.update() @@ -112,7 +107,7 @@ def test_stealing_events(c, s, a, b): @gen_cluster(client=True) -def test_events(c, s, a, b): +async def test_events(c, s, a, b): e = Events(s, "all") futures = c.map( @@ -120,7 +115,7 @@ def test_events(c, s, a, b): ) while not b.task_state: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) e.update() d = dict(e.source.data) @@ -128,12 +123,12 @@ def test_events(c, s, a, b): @gen_cluster(client=True) -def test_task_stream(c, s, a, b): +async def test_task_stream(c, s, a, b): ts = TaskStream(s) futures = c.map(slowinc, range(10), delay=0.001) - yield wait(futures) + await wait(futures) ts.update() d = dict(ts.source.data) @@ -146,7 +141,7 @@ def test_task_stream(c, s, a, b): assert all(len(L) == 10 for L in d.values()) total = c.submit(sum, futures) - yield wait(total) + await wait(total) ts.update() d = dict(ts.source.data) @@ -154,21 +149,21 @@ def test_task_stream(c, s, a, b): @gen_cluster(client=True) -def test_task_stream_n_rectangles(c, s, a, b): +async def test_task_stream_n_rectangles(c, s, a, b): ts = TaskStream(s, n_rectangles=10) futures = c.map(slowinc, range(10), delay=0.001) - yield wait(futures) + await wait(futures) ts.update() assert len(ts.source.data["start"]) == 10 @gen_cluster(client=True) -def test_task_stream_second_plugin(c, s, a, b): +async def test_task_stream_second_plugin(c, s, a, b): ts = TaskStream(s, n_rectangles=10, clear_interval=10) ts.update() futures = c.map(inc, range(10)) - yield wait(futures) + await wait(futures) ts.update() ts2 = TaskStream(s, n_rectangles=5, clear_interval=10) @@ -176,21 +171,21 @@ def test_task_stream_second_plugin(c, s, a, b): @gen_cluster(client=True) -def test_task_stream_clear_interval(c, s, a, b): +async def test_task_stream_clear_interval(c, s, a, b): ts = TaskStream(s, clear_interval=200) - yield wait(c.map(inc, range(10))) + await wait(c.map(inc, range(10))) ts.update() - yield gen.sleep(0.010) - yield wait(c.map(dec, range(10))) + await asyncio.sleep(0.010) + await wait(c.map(dec, range(10))) ts.update() assert len(set(map(len, ts.source.data.values()))) == 1 assert ts.source.data["name"].count("inc") == 10 assert ts.source.data["name"].count("dec") == 10 - yield gen.sleep(0.300) - yield wait(c.map(inc, range(10, 20))) + await asyncio.sleep(0.300) + await wait(c.map(inc, range(10, 20))) ts.update() assert len(set(map(len, ts.source.data.values()))) == 1 @@ -199,11 +194,11 @@ def test_task_stream_clear_interval(c, s, a, b): @gen_cluster(client=True) -def test_TaskProgress(c, s, a, b): +async def test_TaskProgress(c, s, a, b): tp = TaskProgress(s) futures = c.map(slowinc, range(10), delay=0.001) - yield wait(futures) + await wait(futures) tp.update() d = dict(tp.source.data) @@ -211,7 +206,7 @@ def test_TaskProgress(c, s, a, b): assert d["name"] == ["slowinc"] futures2 = c.map(dec, range(5)) - yield wait(futures2) + await wait(futures2) tp.update() d = dict(tp.source.data) @@ -221,35 +216,35 @@ def test_TaskProgress(c, s, a, b): del futures, futures2 while s.tasks: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) tp.update() assert not tp.source.data["all"] @gen_cluster(client=True) -def test_TaskProgress_empty(c, s, a, b): +async def test_TaskProgress_empty(c, s, a, b): tp = TaskProgress(s) tp.update() futures = [c.submit(inc, i, key="f-" + "a" * i) for i in range(20)] - yield wait(futures) + await wait(futures) tp.update() del futures while s.tasks: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) tp.update() assert not any(len(v) for v in tp.source.data.values()) @gen_cluster(client=True) -def test_CurrentLoad(c, s, a, b): +async def test_CurrentLoad(c, s, a, b): cl = CurrentLoad(s) futures = c.map(slowinc, range(10), delay=0.001) - yield wait(futures) + await wait(futures) cl.update() d = dict(cl.source.data) @@ -261,34 +256,34 @@ def test_CurrentLoad(c, s, a, b): @gen_cluster(client=True) -def test_ProcessingHistogram(c, s, a, b): +async def test_ProcessingHistogram(c, s, a, b): ph = ProcessingHistogram(s) ph.update() assert (ph.source.data["top"] != 0).sum() == 1 futures = c.map(slowinc, range(10), delay=0.050) while not s.tasks: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) ph.update() assert ph.source.data["right"][-1] > 2 @gen_cluster(client=True) -def test_NBytesHistogram(c, s, a, b): +async def test_NBytesHistogram(c, s, a, b): nh = NBytesHistogram(s) nh.update() assert (nh.source.data["top"] != 0).sum() == 1 futures = c.map(inc, range(10)) - yield wait(futures) + await wait(futures) nh.update() assert nh.source.data["right"][-1] > 5 * 20 @gen_cluster(client=True) -def test_WorkerTable(c, s, a, b): +async def test_WorkerTable(c, s, a, b): wt = WorkerTable(s) wt.update() assert all(wt.source.data.values()) @@ -307,7 +302,7 @@ def test_WorkerTable(c, s, a, b): @gen_cluster(client=True) -def test_WorkerTable_custom_metrics(c, s, a, b): +async def test_WorkerTable_custom_metrics(c, s, a, b): def metric_port(worker): return worker.port @@ -320,7 +315,7 @@ def metric_address(worker): for name, func in metrics.items(): w.metrics[name] = func - yield [a.heartbeat(), b.heartbeat()] + await asyncio.gather(a.heartbeat(), b.heartbeat()) for w in [a, b]: assert s.workers[w.address].metrics["metric_port"] == w.port @@ -341,13 +336,13 @@ def metric_address(worker): @gen_cluster(client=True) -def test_WorkerTable_different_metrics(c, s, a, b): +async def test_WorkerTable_different_metrics(c, s, a, b): def metric_port(worker): return worker.port a.metrics["metric_a"] = metric_port b.metrics["metric_b"] = metric_port - yield [a.heartbeat(), b.heartbeat()] + await asyncio.gather(a.heartbeat(), b.heartbeat()) assert s.workers[a.address].metrics["metric_a"] == a.port assert s.workers[b.address].metrics["metric_b"] == b.port @@ -366,12 +361,12 @@ def metric_port(worker): @gen_cluster(client=True) -def test_WorkerTable_metrics_with_different_metric_2(c, s, a, b): +async def test_WorkerTable_metrics_with_different_metric_2(c, s, a, b): def metric_port(worker): return worker.port a.metrics["metric_a"] = metric_port - yield [a.heartbeat(), b.heartbeat()] + await asyncio.gather(a.heartbeat(), b.heartbeat()) wt = WorkerTable(s) wt.update() @@ -385,13 +380,13 @@ def metric_port(worker): @gen_cluster(client=True, worker_kwargs={"metrics": {"my_port": lambda w: w.port}}) -def test_WorkerTable_add_and_remove_metrics(c, s, a, b): +async def test_WorkerTable_add_and_remove_metrics(c, s, a, b): def metric_port(worker): return worker.port a.metrics["metric_a"] = metric_port b.metrics["metric_b"] = metric_port - yield [a.heartbeat(), b.heartbeat()] + await asyncio.gather(a.heartbeat(), b.heartbeat()) assert s.workers[a.address].metrics["metric_a"] == a.port assert s.workers[b.address].metrics["metric_b"] == b.port @@ -403,14 +398,14 @@ def metric_port(worker): # Remove 'metric_b' from worker b del b.metrics["metric_b"] - yield [a.heartbeat(), b.heartbeat()] + await asyncio.gather(a.heartbeat(), b.heartbeat()) wt = WorkerTable(s) wt.update() assert "metric_a" in wt.source.data del a.metrics["metric_a"] - yield [a.heartbeat(), b.heartbeat()] + await asyncio.gather(a.heartbeat(), b.heartbeat()) wt = WorkerTable(s) wt.update() @@ -418,14 +413,14 @@ def metric_port(worker): @gen_cluster(client=True) -def test_WorkerTable_custom_metric_overlap_with_core_metric(c, s, a, b): +async def test_WorkerTable_custom_metric_overlap_with_core_metric(c, s, a, b): def metric(worker): return -999 a.metrics["executing"] = metric a.metrics["cpu"] = metric a.metrics["metric"] = metric - yield [a.heartbeat(), b.heartbeat()] + await asyncio.gather(a.heartbeat(), b.heartbeat()) assert s.workers[a.address].metrics["executing"] != -999 assert s.workers[a.address].metrics["cpu"] != -999 @@ -433,11 +428,11 @@ def metric(worker): @gen_cluster(client=True) -def test_TaskGraph(c, s, a, b): +async def test_TaskGraph(c, s, a, b): gp = TaskGraph(s) futures = c.map(inc, range(5)) total = c.submit(sum, futures) - yield total + await total gp.update() assert set(map(len, gp.node_source.data.values())) == {6} @@ -449,22 +444,22 @@ def test_TaskGraph(c, s, a, b): x = da.random.random((20, 20), chunks=(10, 10)).persist() y = (x + x.T) - x.mean(axis=0) y = y.persist() - yield wait(y) + await wait(y) gp.update() gp.update() - yield c.compute((x + y).sum()) + await c.compute((x + y).sum()) gp.update() future = c.submit(inc, 10) future2 = c.submit(inc, future) - yield wait(future2) + await wait(future2) key = future.key del future, future2 while key in s.tasks: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert "memory" in gp.node_source.data["state"] @@ -475,25 +470,25 @@ def test_TaskGraph(c, s, a, b): @gen_cluster(client=True) -def test_TaskGraph_clear(c, s, a, b): +async def test_TaskGraph_clear(c, s, a, b): gp = TaskGraph(s) futures = c.map(inc, range(5)) total = c.submit(sum, futures) - yield total + await total gp.update() del total, futures while s.tasks: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) gp.update() gp.update() start = time() while any(gp.node_source.data.values()) or any(gp.edge_source.data.values()): - yield gen.sleep(0.1) + await asyncio.sleep(0.1) gp.update() assert time() < start + 5 @@ -501,49 +496,43 @@ def test_TaskGraph_clear(c, s, a, b): @gen_cluster( client=True, config={"distributed.dashboard.graph-max-items": 2,}, ) -def test_TaskGraph_limit(c, s, a, b): +async def test_TaskGraph_limit(c, s, a, b): gp = TaskGraph(s) def func(x): return x f1 = c.submit(func, 1) - yield wait(f1) + await wait(f1) gp.update() assert len(gp.node_source.data["x"]) == 1 f2 = c.submit(func, 2) - yield wait(f2) + await wait(f2) gp.update() assert len(gp.node_source.data["x"]) == 2 f3 = c.submit(func, 3) - yield wait(f3) + await wait(f3) gp.update() assert len(gp.node_source.data["x"]) == 2 - del f1 - del f2 - del f3 - _ = c.submit(func, 1) - - async_wait_for(lambda: len(gp.node_source.data["x"]) == 1, timeout=1) @gen_cluster(client=True, timeout=30) -def test_TaskGraph_complex(c, s, a, b): +async def test_TaskGraph_complex(c, s, a, b): da = pytest.importorskip("dask.array") gp = TaskGraph(s) x = da.random.random((2000, 2000), chunks=(1000, 1000)) y = ((x + x.T) - x.mean(axis=0)).persist() - yield wait(y) + await wait(y) gp.update() assert len(gp.layout.index) == len(gp.node_source.data["x"]) assert len(gp.layout.index) == len(s.tasks) z = (x - y).sum().persist() - yield wait(z) + await wait(z) gp.update() assert len(gp.layout.index) == len(gp.node_source.data["x"]) assert len(gp.layout.index) == len(s.tasks) del z - yield gen.sleep(0.2) + await asyncio.sleep(0.2) gp.update() assert len(gp.layout.index) == sum( v == "True" for v in gp.node_source.data["visible"] @@ -559,10 +548,10 @@ def test_TaskGraph_complex(c, s, a, b): @gen_cluster(client=True) -def test_TaskGraph_order(c, s, a, b): +async def test_TaskGraph_order(c, s, a, b): x = c.submit(inc, 1) y = c.submit(div, 1, 0) - yield wait(y) + await wait(y) gp = TaskGraph(s) gp.update() @@ -577,12 +566,12 @@ def test_TaskGraph_order(c, s, a, b): "distributed.worker.profile.cycle": "50ms", }, ) -def test_profile_server(c, s, a, b): +async def test_profile_server(c, s, a, b): ptp = ProfileServer(s) start = time() - yield gen.sleep(0.100) + await asyncio.sleep(0.100) while len(ptp.ts_source.data["time"]) < 2: - yield gen.sleep(0.100) + await asyncio.sleep(0.100) ptp.trigger_update() assert time() < start + 2 @@ -590,9 +579,9 @@ def test_profile_server(c, s, a, b): @gen_cluster( client=True, scheduler_kwargs={"dashboard": True}, ) -def test_root_redirect(c, s, a, b): +async def test_root_redirect(c, s, a, b): http_client = AsyncHTTPClient() - response = yield http_client.fetch("http://localhost:%d/" % s.http_server.port) + response = await http_client.fetch("http://localhost:%d/" % s.http_server.port) assert response.code == 200 assert "/status" in response.effective_url @@ -603,7 +592,7 @@ def test_root_redirect(c, s, a, b): worker_kwargs={"dashboard": True}, timeout=180, ) -def test_proxy_to_workers(c, s, a, b): +async def test_proxy_to_workers(c, s, a, b): try: import jupyter_server_proxy # noqa: F401 @@ -613,7 +602,7 @@ def test_proxy_to_workers(c, s, a, b): dashboard_port = s.http_server.port http_client = AsyncHTTPClient() - response = yield http_client.fetch("http://localhost:%d/" % dashboard_port) + response = await http_client.fetch("http://localhost:%d/" % dashboard_port) assert response.code == 200 assert "/status" in response.effective_url @@ -627,8 +616,8 @@ def test_proxy_to_workers(c, s, a, b): ) direct_url = "http://localhost:%s/status" % port http_client = AsyncHTTPClient() - response_proxy = yield http_client.fetch(proxy_url) - response_direct = yield http_client.fetch(direct_url) + response_proxy = await http_client.fetch(proxy_url) + response_direct = await http_client.fetch(direct_url) assert response_proxy.code == 200 if proxy_exists: @@ -676,7 +665,7 @@ async def test_lots_of_tasks(c, s, a, b): "distributed.scheduler.dashboard.tls.ca-file": get_cert("tls-ca-cert.pem"), }, ) -def test_https_support(c, s, a, b): +async def test_https_support(c, s, a, b): port = s.http_server.port assert ( @@ -687,7 +676,7 @@ def test_https_support(c, s, a, b): ctx.load_verify_locations(get_cert("tls-ca-cert.pem")) http_client = AsyncHTTPClient() - response = yield http_client.fetch( + response = await http_client.fetch( "https://localhost:%d/individual-plots.json" % port, ssl_options=ctx ) response = json.loads(response.body.decode()) @@ -704,7 +693,7 @@ def test_https_support(c, s, a, b): req = HTTPRequest( url="https://localhost:%d/%s" % (port, suffix), ssl_options=ctx ) - response = yield http_client.fetch(req) + response = await http_client.fetch(req) assert response.code < 300 body = response.body.decode() assert not re.search("href=./", body) # no absolute links diff --git a/distributed/dashboard/tests/test_worker_bokeh.py b/distributed/dashboard/tests/test_worker_bokeh.py index 873cc1c1f3e..47ac89c6b0a 100644 --- a/distributed/dashboard/tests/test_worker_bokeh.py +++ b/distributed/dashboard/tests/test_worker_bokeh.py @@ -1,12 +1,12 @@ -from operator import add, sub +import asyncio import re +from operator import add, sub from time import sleep import pytest pytest.importorskip("bokeh") from tlz import first -from tornado import gen from tornado.httpclient import AsyncHTTPClient from distributed.client import wait @@ -28,20 +28,20 @@ worker_kwargs={"dashboard": True}, scheduler_kwargs={"dashboard": True}, ) -def test_routes(c, s, a, b): +async def test_routes(c, s, a, b): port = a.http_server.port future = c.submit(sleep, 1) - yield gen.sleep(0.1) + await asyncio.sleep(0.1) http_client = AsyncHTTPClient() for suffix in ["status", "counters", "system", "profile", "profile-server"]: - response = yield http_client.fetch("http://localhost:%d/%s" % (port, suffix)) + response = await http_client.fetch("http://localhost:%d/%s" % (port, suffix)) body = response.body.decode() assert "bokeh" in body.lower() assert not re.search("href=./", body) # no absolute links - response = yield http_client.fetch( + response = await http_client.fetch( "http://localhost:%d/info/main/workers.html" % s.http_server.port ) @@ -49,16 +49,16 @@ def test_routes(c, s, a, b): @gen_cluster(client=True, worker_kwargs={"dashboard": True}) -def test_simple(c, s, a, b): +async def test_simple(c, s, a, b): assert s.workers[a.address].services == {"dashboard": a.http_server.port} assert s.workers[b.address].services == {"dashboard": b.http_server.port} future = c.submit(sleep, 1) - yield gen.sleep(0.1) + await asyncio.sleep(0.1) http_client = AsyncHTTPClient() for suffix in ["crossfilter", "system"]: - response = yield http_client.fetch( + response = await http_client.fetch( "http://localhost:%d/%s" % (a.http_server.port, suffix) ) assert "bokeh" in response.body.decode().lower() @@ -67,12 +67,12 @@ def test_simple(c, s, a, b): @gen_cluster( client=True, worker_kwargs={"dashboard": True}, ) -def test_services_kwargs(c, s, a, b): +async def test_services_kwargs(c, s, a, b): assert s.workers[a.address].services == {"dashboard": a.http_server.port} @gen_cluster(client=True) -def test_basic(c, s, a, b): +async def test_basic(c, s, a, b): for component in [ StateTable, ExecutingTimeSeries, @@ -92,7 +92,7 @@ def slowall(*args): x = c.submit(slowall, xs, ys, 1, workers=a.address) y = c.submit(slowall, xs, ys, 2, workers=b.address) - yield gen.sleep(0.1) + await asyncio.sleep(0.1) aa.update() bb.update() @@ -103,19 +103,19 @@ def slowall(*args): @gen_cluster(client=True) -def test_counters(c, s, a, b): +async def test_counters(c, s, a, b): pytest.importorskip("crick") while "tick-duration" not in a.digests: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) aa = Counters(a) aa.update() - yield gen.sleep(0.1) + await asyncio.sleep(0.1) aa.update() start = time() while not len(aa.digest_sources["tick-duration"][0].data["x"]): - yield gen.sleep(1) + await asyncio.sleep(1) assert time() < start + 5 a.digests["foo"].add(1) @@ -134,7 +134,7 @@ def test_counters(c, s, a, b): @gen_cluster(client=True) -def test_CommunicatingStream(c, s, a, b): +async def test_CommunicatingStream(c, s, a, b): aa = CommunicatingStream(a) bb = CommunicatingStream(b) @@ -143,7 +143,7 @@ def test_CommunicatingStream(c, s, a, b): adds = c.map(add, xs, ys, workers=a.address) subs = c.map(sub, xs, ys, workers=b.address) - yield wait([adds, subs]) + await wait([adds, subs]) aa.update() bb.update() @@ -159,12 +159,12 @@ def test_CommunicatingStream(c, s, a, b): @gen_cluster( client=True, clean_kwargs={"threads": False}, worker_kwargs={"dashboard": True}, ) -def test_prometheus(c, s, a, b): +async def test_prometheus(c, s, a, b): pytest.importorskip("prometheus_client") http_client = AsyncHTTPClient() for suffix in ["metrics"]: - response = yield http_client.fetch( + response = await http_client.fetch( "http://localhost:%d/%s" % (a.http_server.port, suffix) ) assert response.code == 200 diff --git a/distributed/deploy/tests/test_adaptive.py b/distributed/deploy/tests/test_adaptive.py index 7fb91292540..651717aeee4 100644 --- a/distributed/deploy/tests/test_adaptive.py +++ b/distributed/deploy/tests/test_adaptive.py @@ -1,10 +1,10 @@ +import asyncio import gc import math from time import sleep import dask import pytest -from tornado import gen from distributed import Client, wait, Adaptive, LocalCluster, SpecCluster, Worker from distributed.utils_test import gen_test, slowinc, clean @@ -40,13 +40,13 @@ def scale_down(self, workers): future = c.map(slowinc, [1, 1, 1], key=["a-4", "b-4", "c-1"]) while len(s.rprocessing) < 3: - await gen.sleep(0.001) + await asyncio.sleep(0.001) ta = cluster.adapt( interval="100 ms", scale_factor=2, Adaptive=TestAdaptive ) - await gen.sleep(0.3) + await asyncio.sleep(0.3) def test_adaptive_local_cluster(loop): @@ -91,7 +91,7 @@ async def test_adaptive_local_cluster_multi_workers(cleanup): start = time() while not cluster.scheduler.workers: - await gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 15, adapt.log await c.gather(futures) @@ -100,13 +100,13 @@ async def test_adaptive_local_cluster_multi_workers(cleanup): start = time() # while cluster.workers: while cluster.scheduler.workers: - await gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 15, adapt.log # no workers for a while for i in range(10): assert not cluster.scheduler.workers - await gen.sleep(0.05) + await asyncio.sleep(0.05) futures = c.map(slowinc, range(100), delay=0.01) await c.gather(futures) @@ -136,7 +136,7 @@ def scale_up(self, n, **kwargs): ta = cluster.adapt( min_size=2, interval=0.1, scale_factor=2, Adaptive=TestAdaptive ) - await gen.sleep(0.3) + await asyncio.sleep(0.3) # Assert that adaptive cycle does not reduce cluster below minimum size # as determined via override. @@ -144,8 +144,8 @@ def scale_up(self, n, **kwargs): @gen_test() -def test_min_max(): - cluster = yield LocalCluster( +async def test_min_max(): + cluster = await LocalCluster( 0, scheduler_port=0, silence_logs=False, @@ -156,14 +156,14 @@ def test_min_max(): ) try: adapt = cluster.adapt(minimum=1, maximum=2, interval="20 ms", wait_count=10) - c = yield Client(cluster, asynchronous=True) + c = await Client(cluster, asynchronous=True) start = time() while not cluster.scheduler.workers: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 1 - yield gen.sleep(0.2) + await asyncio.sleep(0.2) assert len(cluster.scheduler.workers) == 1 assert len(adapt.log) == 1 and adapt.log[-1][1] == {"status": "up", "n": 1} @@ -171,11 +171,11 @@ def test_min_max(): start = time() while len(cluster.scheduler.workers) < 2: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 1 assert len(cluster.scheduler.workers) == 2 - yield gen.sleep(0.5) + await asyncio.sleep(0.5) assert len(cluster.scheduler.workers) == 2 assert len(cluster.workers) == 2 assert len(adapt.log) == 2 and all(d["status"] == "up" for _, d in adapt.log) @@ -185,12 +185,12 @@ def test_min_max(): start = time() while len(cluster.scheduler.workers) != 1: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 2 assert adapt.log[-1][1]["status"] == "down" finally: - yield c.close() - yield cluster.close() + await c.close() + await cluster.close() @pytest.mark.asyncio @@ -213,19 +213,19 @@ async def test_avoid_churn(cleanup): for i in range(10): await client.submit(slowinc, i, delay=0.040) - await gen.sleep(0.040) + await asyncio.sleep(0.040) assert len(adapt.log) == 1 -@gen_test(timeout=None) -def test_adapt_quickly(): +@pytest.mark.asyncio +async def test_adapt_quickly(): """ We want to avoid creating and deleting workers frequently Instead we want to wait a few beats before removing a worker in case the user is taking a brief pause between work """ - cluster = yield LocalCluster( + cluster = await LocalCluster( 0, asynchronous=True, processes=False, @@ -233,46 +233,46 @@ def test_adapt_quickly(): silence_logs=False, dashboard_address=None, ) - client = yield Client(cluster, asynchronous=True) + client = await Client(cluster, asynchronous=True) adapt = cluster.adapt(interval="20 ms", wait_count=5, maximum=10) try: future = client.submit(slowinc, 1, delay=0.100) - yield wait(future) + await wait(future) assert len(adapt.log) == 1 # Scale up when there is plenty of available work futures = client.map(slowinc, range(1000), delay=0.100) while len(adapt.log) == 1: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert len(adapt.log) == 2 assert adapt.log[-1][1]["status"] == "up" d = [x for x in adapt.log[-1] if isinstance(x, dict)][0] assert 2 < d["n"] <= adapt.maximum while len(cluster.workers) < adapt.maximum: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) del futures while len(cluster.scheduler.tasks) > 1: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) - yield cluster + await cluster while len(cluster.scheduler.workers) > 1 or len(cluster.worker_spec) > 1: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) # Don't scale up for large sequential computations - x = yield client.scatter(1) + x = await client.scatter(1) log = list(cluster._adaptive.log) for i in range(100): x = client.submit(slowinc, x) - yield gen.sleep(0.1) + await asyncio.sleep(0.1) assert len(cluster.workers) == 1 finally: - yield client.close() - yield cluster.close() + await client.close() + await cluster.close() @gen_test(timeout=None) @@ -291,13 +291,13 @@ async def test_adapt_down(): futures = client.map(slowinc, range(1000), delay=0.1) while len(cluster.scheduler.workers) < 5: - await gen.sleep(0.1) + await asyncio.sleep(0.1) cluster.adapt(maximum=2) start = time() while len(cluster.scheduler.workers) != 2: - await gen.sleep(0.1) + await asyncio.sleep(0.1) assert time() < start + 1 @@ -335,7 +335,7 @@ def test_basic_no_loop(loop): loop.add_callback(loop.stop) -@gen_test(timeout=None) +@pytest.mark.asyncio async def test_target_duration(): """ Ensure that redefining adapt with a lower maximum removes workers """ with dask.config.set( @@ -352,12 +352,12 @@ async def test_target_duration(): adapt = cluster.adapt(interval="20ms", minimum=2, target_duration="5s") async with Client(cluster, asynchronous=True) as client: while len(cluster.scheduler.workers) < 2: - await gen.sleep(0.01) + await asyncio.sleep(0.01) futures = client.map(slowinc, range(100), delay=0.3) while len(adapt.log) < 2: - await gen.sleep(0.01) + await asyncio.sleep(0.01) assert adapt.log[0][1] == {"status": "up", "n": 2} assert adapt.log[1][1] == {"status": "up", "n": 20} @@ -385,7 +385,7 @@ def key(ws): await adaptive.adapt() while len(cluster.scheduler.workers) == 4: - await gen.sleep(0.01) + await asyncio.sleep(0.01) names = {ws.name for ws in cluster.scheduler.workers.values()} assert names == {"a-1", "a-2"} or names == {"b-1", "b-2"} diff --git a/distributed/deploy/tests/test_local.py b/distributed/deploy/tests/test_local.py index 31fbcebd3b8..94a6016dd2a 100644 --- a/distributed/deploy/tests/test_local.py +++ b/distributed/deploy/tests/test_local.py @@ -10,7 +10,6 @@ from distutils.version import LooseVersion from tornado.ioloop import IOLoop -from tornado import gen import tornado from tornado.httpclient import AsyncHTTPClient import pytest @@ -761,13 +760,13 @@ def test_local_tls(loop, temporary): @gen_test() -def test_scale_retires_workers(): +async def test_scale_retires_workers(): class MyCluster(LocalCluster): def scale_down(self, *args, **kwargs): pass loop = IOLoop.current() - cluster = yield MyCluster( + cluster = await MyCluster( 0, scheduler_port=0, processes=False, @@ -776,26 +775,26 @@ def scale_down(self, *args, **kwargs): loop=loop, asynchronous=True, ) - c = yield Client(cluster, asynchronous=True) + c = await Client(cluster, asynchronous=True) assert not cluster.workers - yield cluster.scale(2) + await cluster.scale(2) start = time() while len(cluster.scheduler.workers) != 2: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 3 - yield cluster.scale(1) + await cluster.scale(1) start = time() while len(cluster.scheduler.workers) != 1: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 3 - yield c.close() - yield cluster.close() + await c.close() + await cluster.close() def test_local_tls_restart(loop): @@ -844,8 +843,7 @@ def test_asynchronous_property(loop): loop=loop, ) as cluster: - @gen.coroutine - def _(): + async def _(): assert cluster.asynchronous cluster.sync(_) diff --git a/distributed/diagnostics/tests/test_eventstream.py b/distributed/diagnostics/tests/test_eventstream.py index a111220b39e..4af97799893 100644 --- a/distributed/diagnostics/tests/test_eventstream.py +++ b/distributed/diagnostics/tests/test_eventstream.py @@ -1,7 +1,7 @@ +import asyncio import collections import pytest -from tornado import gen from distributed.client import wait from distributed.diagnostics.eventstream import EventStream, eventstream @@ -10,7 +10,7 @@ @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3) -def test_eventstream(c, s, *workers): +async def test_eventstream(c, s, *workers): pytest.importorskip("bokeh") es = EventStream() @@ -19,8 +19,8 @@ def test_eventstream(c, s, *workers): futures = c.map(div, [1] * 10, range(10)) total = c.submit(sum, futures[1:]) - yield wait(total) - yield wait(futures) + await wait(total) + await wait(futures) assert len(es.buffer) == 11 @@ -43,13 +43,13 @@ def test_eventstream(c, s, *workers): @gen_cluster(client=True) -def test_eventstream_remote(c, s, a, b): +async def test_eventstream_remote(c, s, a, b): base_plugins = len(s.plugins) - comm = yield eventstream(s.address, interval=0.010) + comm = await eventstream(s.address, interval=0.010) start = time() while len(s.plugins) == base_plugins: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 5 futures = c.map(div, [1] * 10, range(10)) @@ -57,13 +57,13 @@ def test_eventstream_remote(c, s, a, b): start = time() total = [] while len(total) < 10: - msgs = yield comm.read() + msgs = await comm.read() assert isinstance(msgs, tuple) total.extend(msgs) assert time() < start + 5 - yield comm.close() + await comm.close() start = time() while len(s.plugins) > base_plugins: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 5 diff --git a/distributed/diagnostics/tests/test_graph_layout.py b/distributed/diagnostics/tests/test_graph_layout.py index fc8fba8d028..b63311f8432 100644 --- a/distributed/diagnostics/tests/test_graph_layout.py +++ b/distributed/diagnostics/tests/test_graph_layout.py @@ -1,18 +1,18 @@ +import asyncio import operator from distributed.utils_test import gen_cluster, inc from distributed.diagnostics import GraphLayout from distributed import wait -from tornado import gen @gen_cluster(client=True) -def test_basic(c, s, a, b): +async def test_basic(c, s, a, b): gl = GraphLayout(s) futures = c.map(inc, range(5)) total = c.submit(sum, futures) - yield total + await total assert len(gl.x) == len(gl.y) == 6 assert all(gl.x[f.key] == 0 for f in futures) @@ -21,11 +21,11 @@ def test_basic(c, s, a, b): @gen_cluster(client=True) -def test_construct_after_call(c, s, a, b): +async def test_construct_after_call(c, s, a, b): futures = c.map(inc, range(5)) total = c.submit(sum, futures) - yield total + await total gl = GraphLayout(s) @@ -36,13 +36,13 @@ def test_construct_after_call(c, s, a, b): @gen_cluster(client=True) -def test_states(c, s, a, b): +async def test_states(c, s, a, b): gl = GraphLayout(s) futures = c.map(inc, range(5)) total = c.submit(sum, futures) del futures - yield total + await total updates = {state for idx, state in gl.state_updates} assert "memory" in updates @@ -51,31 +51,31 @@ def test_states(c, s, a, b): @gen_cluster(client=True) -def test_release_tasks(c, s, a, b): +async def test_release_tasks(c, s, a, b): gl = GraphLayout(s) futures = c.map(inc, range(5)) total = c.submit(sum, futures) - yield total + await total key = total.key del total while key in s.tasks: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert len(gl.visible_updates) == 1 assert len(gl.visible_edge_updates) == 5 @gen_cluster(client=True) -def test_forget(c, s, a, b): +async def test_forget(c, s, a, b): gl = GraphLayout(s) futures = c.map(inc, range(10)) futures = c.map(inc, futures) - yield wait(futures) + await wait(futures) del futures while s.tasks: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert not gl.x assert not gl.y @@ -85,12 +85,12 @@ def test_forget(c, s, a, b): @gen_cluster(client=True) -def test_unique_positions(c, s, a, b): +async def test_unique_positions(c, s, a, b): gl = GraphLayout(s) x = c.submit(inc, 1) ys = [c.submit(operator.add, x, i) for i in range(5)] - yield wait(ys) + await wait(ys) y_positions = [(gl.x[k], gl.y[k]) for k in gl.x] assert len(y_positions) == len(set(y_positions)) diff --git a/distributed/diagnostics/tests/test_progress.py b/distributed/diagnostics/tests/test_progress.py index 8e3ba1688cc..871dcb0c5a5 100644 --- a/distributed/diagnostics/tests/test_progress.py +++ b/distributed/diagnostics/tests/test_progress.py @@ -1,6 +1,6 @@ -import pytest +import asyncio -from tornado import gen +import pytest from distributed import Nanny from distributed.client import wait @@ -29,24 +29,24 @@ def h(*args): @nodebug @gen_cluster(client=True) -def test_many_Progress(c, s, a, b): +async def test_many_Progress(c, s, a, b): x = c.submit(f, 1) y = c.submit(g, x) z = c.submit(h, y) - bars = [Progress(keys=[z], scheduler=s) for i in range(10)] - yield [bar.setup() for bar in bars] + bars = [Progress(keys=[z], scheduler=s) for _ in range(10)] + await asyncio.gather(*(bar.setup() for bar in bars)) - yield z + await z start = time() while not all(b.status == "finished" for b in bars): - yield gen.sleep(0.1) + await asyncio.sleep(0.1) assert time() < start + 5 @gen_cluster(client=True) -def test_multiprogress(c, s, a, b): +async def test_multiprogress(c, s, a, b): x1 = c.submit(f, 1) x2 = c.submit(f, x1) x3 = c.submit(f, x2) @@ -54,18 +54,18 @@ def test_multiprogress(c, s, a, b): y2 = c.submit(g, y1) p = MultiProgress([y2], scheduler=s, complete=True) - yield p.setup() + await p.setup() assert p.all_keys == { "f": {f.key for f in [x1, x2, x3]}, "g": {f.key for f in [y1, y2]}, } - yield x3 + await x3 assert p.keys["f"] == set() - yield y2 + await y2 assert p.keys == {"f": set(), "g": set()} @@ -73,7 +73,7 @@ def test_multiprogress(c, s, a, b): @gen_cluster(client=True) -def test_robust_to_bad_plugin(c, s, a, b): +async def test_robust_to_bad_plugin(c, s, a, b): class Bad(SchedulerPlugin): def transition(self, key, start, finish, **kwargs): raise Exception() @@ -83,7 +83,7 @@ def transition(self, key, start, finish, **kwargs): x = c.submit(inc, 1) y = c.submit(inc, x) - result = yield y + result = await y assert result == 3 @@ -95,11 +95,11 @@ def check_bar_completed(capsys, width=40): @gen_cluster(client=True, Worker=Nanny, timeout=None) -def test_AllProgress(c, s, a, b): +async def test_AllProgress(c, s, a, b): x, y, z = c.map(inc, [1, 2, 3]) xx, yy, zz = c.map(dec, [x, y, z]) - yield wait([x, y, z]) + await wait([x, y, z]) p = AllProgress(s) assert p.all["inc"] == {x.key, y.key, z.key} assert p.state["memory"]["inc"] == {x.key, y.key, z.key} @@ -109,7 +109,7 @@ def test_AllProgress(c, s, a, b): assert isinstance(p.nbytes["inc"], int) assert p.nbytes["inc"] > 0 - yield wait([xx, yy, zz]) + await wait([xx, yy, zz]) assert p.all["dec"] == {xx.key, yy.key, zz.key} assert p.state["memory"]["dec"] == {xx.key, yy.key, zz.key} assert p.state["released"] == {} @@ -117,7 +117,7 @@ def test_AllProgress(c, s, a, b): assert p.nbytes["inc"] == p.nbytes["dec"] t = c.submit(sum, [x, y, z]) - yield t + await t keys = {x.key, y.key, z.key} del x, y, z @@ -126,7 +126,7 @@ def test_AllProgress(c, s, a, b): gc.collect() while any(k in s.who_has for k in keys): - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert p.state["released"]["inc"] == keys assert p.all["inc"] == keys @@ -135,7 +135,7 @@ def test_AllProgress(c, s, a, b): assert p.nbytes["inc"] == 0 xxx = c.submit(div, 1, 0) - yield wait([xxx]) + await wait([xxx]) assert p.state["erred"] == {"div": {xxx.key}} tkey = t.key @@ -145,7 +145,7 @@ def test_AllProgress(c, s, a, b): gc.collect() while tkey in s.tasks: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) for coll in [p.all, p.nbytes] + list(p.state.values()): assert "inc" not in coll @@ -160,47 +160,47 @@ def f(x): gc.collect() - yield gen.sleep(1) + await asyncio.sleep(1) - yield wait([future]) + await wait([future]) assert p.state["memory"] == {"f": {future.key}} - yield c._restart() + await c._restart() for coll in [p.all] + list(p.state.values()): assert not coll x = c.submit(div, 1, 2) - yield wait([x]) + await wait([x]) assert set(p.all) == {"div"} assert all(set(d) == {"div"} for d in p.state.values()) @gen_cluster(client=True, Worker=Nanny) -def test_AllProgress_lost_key(c, s, a, b, timeout=None): +async def test_AllProgress_lost_key(c, s, a, b, timeout=None): p = AllProgress(s) futures = c.map(inc, range(5)) - yield wait(futures) + await wait(futures) assert len(p.state["memory"]["inc"]) == 5 - yield a.close() - yield b.close() + await a.close() + await b.close() start = time() while len(p.state["memory"]["inc"]) > 0: - yield gen.sleep(0.1) + await asyncio.sleep(0.1) assert time() < start + 5 @gen_cluster(client=True) -def test_GroupProgress(c, s, a, b): +async def test_GroupProgress(c, s, a, b): da = pytest.importorskip("dask.array") fp = GroupProgress(s) x = da.ones(100, chunks=10) y = x + 1 z = (x * y).sum().persist(optimize_graph=False) - yield wait(z) + await wait(z) assert 3 < len(fp.groups) < 10 for k, g in fp.groups.items(): assert fp.keys[k] @@ -212,6 +212,6 @@ def test_GroupProgress(c, s, a, b): del x, y, z while s.tasks: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert not fp.groups diff --git a/distributed/diagnostics/tests/test_progress_stream.py b/distributed/diagnostics/tests/test_progress_stream.py index 77b3922a42e..8f506b7a7bb 100644 --- a/distributed/diagnostics/tests/test_progress_stream.py +++ b/distributed/diagnostics/tests/test_progress_stream.py @@ -56,7 +56,7 @@ def test_progress_quads_too_many(): @gen_cluster(client=True) -def test_progress_stream(c, s, a, b): +async def test_progress_stream(c, s, a, b): futures = c.map(div, [1] * 10, range(10)) x = 1 @@ -64,10 +64,10 @@ def test_progress_stream(c, s, a, b): x = delayed(inc)(x) future = c.compute(x) - yield wait(futures + [future]) + await wait(futures + [future]) - comm = yield progress_stream(s.address, interval=0.010) - msg = yield comm.read() + comm = await progress_stream(s.address, interval=0.010) + msg = await comm.read() nbytes = msg.pop("nbytes") assert msg == { "all": {"div": 10, "inc": 5}, @@ -81,7 +81,7 @@ def test_progress_stream(c, s, a, b): assert progress_quads(msg) - yield comm.close() + await comm.close() def test_progress_quads_many_functions(): diff --git a/distributed/diagnostics/tests/test_progressbar.py b/distributed/diagnostics/tests/test_progressbar.py index 535efd0e9e2..f19dbd2df26 100644 --- a/distributed/diagnostics/tests/test_progressbar.py +++ b/distributed/diagnostics/tests/test_progressbar.py @@ -25,17 +25,17 @@ def test_text_progressbar(capsys, client): @gen_cluster(client=True) -def test_TextProgressBar_error(c, s, a, b): +async def test_TextProgressBar_error(c, s, a, b): x = c.submit(div, 1, 0) progress = TextProgressBar([x.key], scheduler=s.address, start=False, interval=0.01) - yield progress.listen() + await progress.listen() assert progress.status == "error" assert progress.comm.closed() progress = TextProgressBar([x.key], scheduler=s.address, start=False, interval=0.01) - yield progress.listen() + await progress.listen() assert progress.status == "error" assert progress.comm.closed() diff --git a/distributed/diagnostics/tests/test_scheduler_plugin.py b/distributed/diagnostics/tests/test_scheduler_plugin.py index 6fc9e22f3df..31ada3f9e12 100644 --- a/distributed/diagnostics/tests/test_scheduler_plugin.py +++ b/distributed/diagnostics/tests/test_scheduler_plugin.py @@ -4,7 +4,7 @@ @gen_cluster(client=True) -def test_simple(c, s, a, b): +async def test_simple(c, s, a, b): class Counter(SchedulerPlugin): def start(self, scheduler): self.scheduler = scheduler @@ -25,7 +25,7 @@ def transition(self, key, start, finish, *args, **kwargs): y = c.submit(inc, x) z = c.submit(inc, y) - yield z + await z assert counter.count == 3 s.remove_plugin(counter) @@ -33,7 +33,7 @@ def transition(self, key, start, finish, *args, **kwargs): @gen_cluster(nthreads=[], client=False) -def test_add_remove_worker(s): +async def test_add_remove_worker(s): events = [] class MyPlugin(SchedulerPlugin): @@ -51,10 +51,10 @@ def remove_worker(self, worker, scheduler): a = Worker(s.address) b = Worker(s.address) - yield a - yield b - yield a.close() - yield b.close() + await a + await b + await a.close() + await b.close() assert events == [ ("add_worker", a.address), @@ -65,8 +65,8 @@ def remove_worker(self, worker, scheduler): events[:] = [] s.remove_plugin(plugin) - a = yield Worker(s.address) - yield a.close() + a = await Worker(s.address) + await a.close() assert events == [] diff --git a/distributed/diagnostics/tests/test_task_stream.py b/distributed/diagnostics/tests/test_task_stream.py index 4639c7a7a0b..4b57d18ee7a 100644 --- a/distributed/diagnostics/tests/test_task_stream.py +++ b/distributed/diagnostics/tests/test_task_stream.py @@ -13,13 +13,13 @@ @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3) -def test_TaskStreamPlugin(c, s, *workers): +async def test_TaskStreamPlugin(c, s, *workers): es = TaskStreamPlugin(s) assert not es.buffer futures = c.map(div, [1] * 10, range(10)) total = c.submit(sum, futures[1:]) - yield wait(total) + await wait(total) assert len(es.buffer) == 11 @@ -45,19 +45,19 @@ def test_TaskStreamPlugin(c, s, *workers): @gen_cluster(client=True) -def test_maxlen(c, s, a, b): +async def test_maxlen(c, s, a, b): tasks = TaskStreamPlugin(s, maxlen=5) futures = c.map(inc, range(10)) - yield wait(futures) + await wait(futures) assert len(tasks.buffer) == 5 @gen_cluster(client=True) -def test_collect(c, s, a, b): +async def test_collect(c, s, a, b): tasks = TaskStreamPlugin(s) start = time() futures = c.map(slowinc, range(10), delay=0.1) - yield wait(futures) + await wait(futures) L = tasks.collect() assert len(L) == len(futures) @@ -82,15 +82,15 @@ def test_collect(c, s, a, b): @gen_cluster(client=True) -def test_client(c, s, a, b): - L = yield c.get_task_stream() +async def test_client(c, s, a, b): + L = await c.get_task_stream() assert L == () futures = c.map(slowinc, range(10), delay=0.1) - yield wait(futures) + await wait(futures) tasks = [p for p in s.plugins if isinstance(p, TaskStreamPlugin)][0] - L = yield c.get_task_stream() + L = await c.get_task_stream() assert L == tuple(tasks.buffer) @@ -105,14 +105,14 @@ def test_client_sync(client): @gen_cluster(client=True) -def test_get_task_stream_plot(c, s, a, b): +async def test_get_task_stream_plot(c, s, a, b): bokeh = pytest.importorskip("bokeh") - yield c.get_task_stream() + await c.get_task_stream() futures = c.map(slowinc, range(10), delay=0.1) - yield wait(futures) + await wait(futures) - data, figure = yield c.get_task_stream(plot=True) + data, figure = await c.get_task_stream(plot=True) assert isinstance(figure, bokeh.plotting.Figure) diff --git a/distributed/diagnostics/tests/test_widgets.py b/distributed/diagnostics/tests/test_widgets.py index c217d17e293..6064462d893 100644 --- a/distributed/diagnostics/tests/test_widgets.py +++ b/distributed/diagnostics/tests/test_widgets.py @@ -88,24 +88,24 @@ def record_display(*args): @gen_cluster(client=True) -def test_progressbar_widget(c, s, a, b): +async def test_progressbar_widget(c, s, a, b): x = c.submit(inc, 1) y = c.submit(inc, x) z = c.submit(inc, y) - yield wait(z) + await wait(z) progress = ProgressWidget([z.key], scheduler=s.address, complete=True) - yield progress.listen() + await progress.listen() assert progress.bar.value == 1.0 assert "3 / 3" in progress.bar_text.value progress = ProgressWidget([z.key], scheduler=s.address) - yield progress.listen() + await progress.listen() @gen_cluster(client=True) -def test_multi_progressbar_widget(c, s, a, b): +async def test_multi_progressbar_widget(c, s, a, b): x1 = c.submit(inc, 1) x2 = c.submit(inc, x1) x3 = c.submit(inc, x2) @@ -113,10 +113,10 @@ def test_multi_progressbar_widget(c, s, a, b): y2 = c.submit(dec, y1) e = c.submit(throws, y2) other = c.submit(inc, 123) - yield wait([other, e]) + await wait([other, e]) p = MultiProgressWidget([e.key], scheduler=s.address, complete=True) - yield p.listen() + await p.listen() assert p.bars["inc"].value == 1.0 assert p.bars["dec"].value == 1.0 @@ -145,7 +145,7 @@ def test_multi_progressbar_widget(c, s, a, b): @gen_cluster() -def test_multi_progressbar_widget_after_close(s, a, b): +async def test_multi_progressbar_widget_after_close(s, a, b): s.update_graph( tasks=valmap( dumps_task, @@ -170,7 +170,7 @@ def test_multi_progressbar_widget_after_close(s, a, b): ) p = MultiProgressWidget(["x-1", "x-2", "x-3"], scheduler=s.address) - yield p.listen() + await p.listen() assert "x" in p.bars @@ -231,7 +231,7 @@ def test_progressbar_cancel(client): @gen_cluster() -def test_multibar_complete(s, a, b): +async def test_multibar_complete(s, a, b): s.update_graph( tasks=valmap( dumps_task, @@ -256,7 +256,7 @@ def test_multibar_complete(s, a, b): ) p = MultiProgressWidget(["e"], scheduler=s.address, complete=True) - yield p.listen() + await p.listen() assert p._last_response["all"] == {"x": 3, "y": 2, "e": 1} assert all(b.value == 1.0 for k, b in p.bars.items() if k != "e") @@ -274,28 +274,28 @@ def test_fast(client): @gen_cluster(client=True, client_kwargs={"serializers": ["msgpack"]}) -def test_serializers(c, s, a, b): +async def test_serializers(c, s, a, b): x = c.submit(inc, 1) y = c.submit(inc, x) z = c.submit(inc, y) - yield wait(z) + await wait(z) progress = ProgressWidget([z], scheduler=s.address, complete=True) - yield progress.listen() + await progress.listen() assert progress.bar.value == 1.0 assert "3 / 3" in progress.bar_text.value @gen_tls_cluster(client=True) -def test_tls(c, s, a, b): +async def test_tls(c, s, a, b): x = c.submit(inc, 1) y = c.submit(inc, x) z = c.submit(inc, y) - yield wait(z) + await wait(z) progress = ProgressWidget([z], scheduler=s.address, complete=True) - yield progress.listen() + await progress.listen() assert progress.bar.value == 1.0 assert "3 / 3" in progress.bar_text.value diff --git a/distributed/diagnostics/tests/test_worker_plugin.py b/distributed/diagnostics/tests/test_worker_plugin.py index b3b919d7fe2..2ee5a28c780 100644 --- a/distributed/diagnostics/tests/test_worker_plugin.py +++ b/distributed/diagnostics/tests/test_worker_plugin.py @@ -34,26 +34,26 @@ def transition(self, key, start, finish, **kwargs): @gen_cluster(client=True, nthreads=[]) -def test_create_with_client(c, s): - yield c.register_worker_plugin(MyPlugin(123)) +async def test_create_with_client(c, s): + await c.register_worker_plugin(MyPlugin(123)) - worker = yield Worker(s.address, loop=s.loop) + worker = await Worker(s.address, loop=s.loop) assert worker._my_plugin_status == "setup" assert worker._my_plugin_data == 123 - yield worker.close() + await worker.close() assert worker._my_plugin_status == "teardown" @gen_cluster(client=True, worker_kwargs={"plugins": [MyPlugin(5)]}) -def test_create_on_construction(c, s, a, b): +async def test_create_on_construction(c, s, a, b): assert len(a.plugins) == len(b.plugins) == 1 assert a._my_plugin_status == "setup" assert a._my_plugin_data == 5 @gen_cluster(nthreads=[("127.0.0.1", 1)], client=True) -def test_normal_task_transitions_called(c, s, w): +async def test_normal_task_transitions_called(c, s, w): expected_transitions = [ ("task", "waiting", "ready"), ("task", "ready", "executing"), @@ -62,12 +62,12 @@ def test_normal_task_transitions_called(c, s, w): plugin = MyPlugin(1, expected_transitions=expected_transitions) - yield c.register_worker_plugin(plugin) - yield c.submit(lambda x: x, 1, key="task") + await c.register_worker_plugin(plugin) + await c.submit(lambda x: x, 1, key="task") @gen_cluster(nthreads=[("127.0.0.1", 1)], client=True) -def test_failing_task_transitions_called(c, s, w): +async def test_failing_task_transitions_called(c, s, w): def failing(x): raise Exception() @@ -79,10 +79,10 @@ def failing(x): plugin = MyPlugin(1, expected_transitions=expected_transitions) - yield c.register_worker_plugin(plugin) + await c.register_worker_plugin(plugin) with pytest.raises(Exception): - yield c.submit(failing, 1, key="task") + await c.submit(failing, 1, key="task") @gen_cluster(nthreads=[("127.0.0.1", 1)], client=True) diff --git a/distributed/http/worker/tests/test_worker_http.py b/distributed/http/worker/tests/test_worker_http.py index 0a4135fba7f..2282a4daa66 100644 --- a/distributed/http/worker/tests/test_worker_http.py +++ b/distributed/http/worker/tests/test_worker_http.py @@ -5,16 +5,16 @@ @gen_cluster(client=True) -def test_prometheus(c, s, a, b): +async def test_prometheus(c, s, a, b): pytest.importorskip("prometheus_client") from prometheus_client.parser import text_string_to_metric_families http_client = AsyncHTTPClient() - # request data twice since there once was a case where metrics got registered multiple times resulting in - # prometheus_client errors + # request data twice since there once was a case where metrics got registered + # multiple times resulting in prometheus_client errors for _ in range(2): - response = yield http_client.fetch( + response = await http_client.fetch( "http://localhost:%d/metrics" % a.http_server.port ) assert response.code == 200 @@ -26,10 +26,10 @@ def test_prometheus(c, s, a, b): @gen_cluster(client=True) -def test_health(c, s, a, b): +async def test_health(c, s, a, b): http_client = AsyncHTTPClient() - response = yield http_client.fetch( + response = await http_client.fetch( "http://localhost:%d/health" % a.http_server.port ) assert response.code == 200 diff --git a/distributed/node.py b/distributed/node.py index 11645e86317..740776bed68 100644 --- a/distributed/node.py +++ b/distributed/node.py @@ -3,9 +3,9 @@ import warnings import weakref +from tornado import gen from tornado.ioloop import IOLoop from tornado.httpserver import HTTPServer -from tornado import gen import tlz import dask diff --git a/distributed/protocol/tests/test_arrow.py b/distributed/protocol/tests/test_arrow.py index 37aff3a2644..e86bfa6f827 100644 --- a/distributed/protocol/tests/test_arrow.py +++ b/distributed/protocol/tests/test_arrow.py @@ -28,10 +28,10 @@ def echo(arg): @pytest.mark.parametrize("obj", [batch, tbl], ids=["RecordBatch", "Table"]) def test_scatter(obj): @gen_cluster(client=True) - def run_test(client, scheduler, worker1, worker2): - obj_fut = yield client.scatter(obj) + async def run_test(client, scheduler, worker1, worker2): + obj_fut = await client.scatter(obj) fut = client.submit(echo, obj_fut) - result = yield fut + result = await fut assert obj.equals(result) run_test() diff --git a/distributed/protocol/tests/test_h5py.py b/distributed/protocol/tests/test_h5py.py index 6bae5b3b8d5..80eeb2c05f5 100644 --- a/distributed/protocol/tests/test_h5py.py +++ b/distributed/protocol/tests/test_h5py.py @@ -90,7 +90,7 @@ def test_raise_error_on_serialize_write_permissions(): @silence_h5py_issue775 @gen_cluster(client=True) -def test_h5py_serialize(c, s, a, b): +async def test_h5py_serialize(c, s, a, b): from dask.utils import SerializableLock lock = SerializableLock("hdf5") @@ -102,12 +102,12 @@ def test_h5py_serialize(c, s, a, b): dset = f["/group/x"] x = da.from_array(dset, chunks=dset.chunks, lock=lock) y = c.compute(x) - y = yield y + y = await y assert (y[:] == dset[:]).all() @gen_cluster(client=True) -def test_h5py_serialize_2(c, s, a, b): +async def test_h5py_serialize_2(c, s, a, b): with tmpfile() as fn: with h5py.File(fn, mode="a") as f: x = f.create_dataset("/group/x", shape=(12,), dtype="i4", chunks=(4,)) @@ -116,5 +116,5 @@ def test_h5py_serialize_2(c, s, a, b): dset = f["/group/x"] x = da.from_array(dset, chunks=(3,)) y = c.compute(x.sum()) - y = yield y + y = await y assert y == (1 + 2 + 3 + 4) * 3 diff --git a/distributed/protocol/tests/test_netcdf4.py b/distributed/protocol/tests/test_netcdf4.py index f1ddcead3ef..1ed78508156 100644 --- a/distributed/protocol/tests/test_netcdf4.py +++ b/distributed/protocol/tests/test_netcdf4.py @@ -82,12 +82,12 @@ def test_serialize_deserialize_group(): @gen_cluster(client=True) -def test_netcdf4_serialize(c, s, a, b): +async def test_netcdf4_serialize(c, s, a, b): with tmpfile() as fn: create_test_dataset(fn) with netCDF4.Dataset(fn, mode="r") as f: dset = f.variables["x"] x = da.from_array(dset, chunks=2) y = c.compute(x) - y = yield y + y = await y assert (y[:] == dset[:]).all() diff --git a/distributed/protocol/tests/test_numpy.py b/distributed/protocol/tests/test_numpy.py index 99a298d9694..08a7c2df244 100644 --- a/distributed/protocol/tests/test_numpy.py +++ b/distributed/protocol/tests/test_numpy.py @@ -233,9 +233,9 @@ def test_dont_compress_uncompressable_data(): @gen_cluster(client=True, timeout=60) -def test_dumps_large_blosc(c, s, a, b): +async def test_dumps_large_blosc(c, s, a, b): x = c.submit(np.ones, BIG_BYTES_SHARD_SIZE * 2, dtype="u1") - result = yield x + await x @pytest.mark.skipif(sys.version_info[0] < 3, reason="numpy doesnt use memoryviews") diff --git a/distributed/protocol/tests/test_serialize.py b/distributed/protocol/tests/test_serialize.py index dd23e5e635d..6a5af842ddd 100644 --- a/distributed/protocol/tests/test_serialize.py +++ b/distributed/protocol/tests/test_serialize.py @@ -120,34 +120,34 @@ def test_nested_deserialize(): @gen_cluster(client=True) -def test_object_in_graph(c, s, a, b): +async def test_object_in_graph(c, s, a, b): o = MyObj(123) v = delayed(o) v2 = delayed(identity)(v) future = c.compute(v2) - result = yield future + result = await future assert isinstance(result, MyObj) assert result.data == 123 @gen_cluster(client=True) -def test_scatter(c, s, a, b): +async def test_scatter(c, s, a, b): o = MyObj(123) - [future] = yield c._scatter([o]) - yield c._replicate(o) - o2 = yield c._gather(future) + [future] = await c._scatter([o]) + await c._replicate(o) + o2 = await c._gather(future) assert isinstance(o2, MyObj) assert o2.data == 123 @gen_cluster(client=True) -def test_inter_worker_comms(c, s, a, b): +async def test_inter_worker_comms(c, s, a, b): o = MyObj(123) - [future] = yield c._scatter([o], workers=a.address) + [future] = await c._scatter([o], workers=a.address) future2 = c.submit(identity, future, workers=b.address) - o2 = yield c._gather(future2) + o2 = await c._gather(future2) assert isinstance(o2, MyObj) assert o2.data == 123 @@ -249,14 +249,14 @@ def test_errors(): @gen_test() -def test_err_on_bad_deserializer(): - frames = yield to_frames({"x": to_serialize(1234)}, serializers=["pickle"]) +async def test_err_on_bad_deserializer(): + frames = await to_frames({"x": to_serialize(1234)}, serializers=["pickle"]) - result = yield from_frames(frames, deserializers=["pickle", "foo"]) + result = await from_frames(frames, deserializers=["pickle", "foo"]) assert result == {"x": 1234} - with pytest.raises(TypeError) as info: - yield from_frames(frames, deserializers=["msgpack"]) + with pytest.raises(TypeError): + await from_frames(frames, deserializers=["msgpack"]) class MyObject: @@ -290,7 +290,7 @@ def my_loads(header, frames): client_kwargs={"serializers": ["my-ser", "pickle"]}, worker_kwargs={"serializers": ["my-ser", "pickle"]}, ) -def test_context_specific_serialization(c, s, a, b): +async def test_context_specific_serialization(c, s, a, b): register_serialization_family("my-ser", my_dumps, my_loads) try: @@ -298,7 +298,7 @@ def test_context_specific_serialization(c, s, a, b): x = c.submit(MyObject, x=1, y=2, workers=a.address) y = c.submit(lambda x: x, x, workers=b.address) - yield wait(y) + await wait(y) key = y.key @@ -307,11 +307,11 @@ def check(dask_worker): my_obj = dask_worker.data[key] return my_obj.context - result = yield c.run(check, workers=[b.address]) + result = await c.run(check, workers=[b.address]) expected = {"sender": a.address, "recipient": b.address} assert result[b.address]["sender"] == a.address # see origin worker - z = yield y # bring object to local process + z = await y # bring object to local process assert z.x == 1 and z.y == 2 assert z.context["sender"] == b.address @@ -322,14 +322,14 @@ def check(dask_worker): @gen_cluster(client=True) -def test_context_specific_serialization_class(c, s, a, b): +async def test_context_specific_serialization_class(c, s, a, b): register_serialization(MyObject, my_dumps, my_loads) # Create the object on A, force communication to B x = c.submit(MyObject, x=1, y=2, workers=a.address) y = c.submit(lambda x: x, x, workers=b.address) - yield wait(y) + await wait(y) key = y.key @@ -338,11 +338,11 @@ def check(dask_worker): my_obj = dask_worker.data[key] return my_obj.context - result = yield c.run(check, workers=[b.address]) + result = await c.run(check, workers=[b.address]) expected = {"sender": a.address, "recipient": b.address} assert result[b.address]["sender"] == a.address # see origin worker - z = yield y # bring object to local process + z = await y # bring object to local process assert z.x == 1 and z.y == 2 assert z.context["sender"] == b.address diff --git a/distributed/tests/test_actor.py b/distributed/tests/test_actor.py index de69db5685a..89233eaca24 100644 --- a/distributed/tests/test_actor.py +++ b/distributed/tests/test_actor.py @@ -1,6 +1,6 @@ +import asyncio import operator from time import sleep -from tornado import gen import pytest @@ -50,20 +50,20 @@ def get(self, key): @pytest.mark.parametrize("direct_to_workers", [True, False]) def test_client_actions(direct_to_workers): @gen_cluster(client=True) - def test(c, s, a, b): - c = yield Client( + async def test(c, s, a, b): + c = await Client( s.address, asynchronous=True, direct_to_workers=direct_to_workers ) counter = c.submit(Counter, workers=[a.address], actor=True) assert isinstance(counter, Future) - counter = yield counter + counter = await counter assert counter._address assert hasattr(counter, "increment") assert hasattr(counter, "add") assert hasattr(counter, "n") - n = yield counter.n + n = await counter.n assert n == 0 assert counter._address == a.address @@ -71,17 +71,17 @@ def test(c, s, a, b): assert isinstance(a.actors[counter.key], Counter) assert s.tasks[counter.key].actor - yield [counter.increment(), counter.increment()] + await asyncio.gather(counter.increment(), counter.increment()) - n = yield counter.n + n = await counter.n assert n == 2 counter.add(10) - while (yield counter.n) != 10 + 2: - n = yield counter.n - yield gen.sleep(0.01) + while (await counter.n) != 10 + 2: + n = await counter.n + await asyncio.sleep(0.01) - yield c.close() + await c.close() test() @@ -89,7 +89,7 @@ def test(c, s, a, b): @pytest.mark.parametrize("separate_thread", [False, True]) def test_worker_actions(separate_thread): @gen_cluster(client=True) - def test(c, s, a, b): + async def test(c, s, a, b): counter = c.submit(Counter, workers=[a.address], actor=True) a_address = a.address @@ -106,17 +106,17 @@ def f(counter): assert end > start futures = [c.submit(f, counter, pure=False) for _ in range(10)] - yield futures + await c.gather(futures) - counter = yield counter - assert (yield counter.n) == 10 + counter = await counter + assert await counter.n == 10 test() @gen_cluster(client=True) -def test_Actor(c, s, a, b): - counter = yield c.submit(Counter, actor=True) +async def test_Actor(c, s, a, b): + counter = await c.submit(Counter, actor=True) assert counter._cls == Counter @@ -132,22 +132,22 @@ def test_Actor(c, s, a, b): + "Should rely on sending small messages rather than rpc" ) @gen_cluster(client=True) -def test_linear_access(c, s, a, b): +async def test_linear_access(c, s, a, b): start = time() future = c.submit(sleep, 0.2) actor = c.submit(List, actor=True, dummy=future) - actor = yield actor + actor = await actor for i in range(100): actor.append(i) while True: - yield gen.sleep(0.1) - L = yield actor.L + await asyncio.sleep(0.1) + L = await actor.L if len(L) == 100: break - L = yield actor.L + L = await actor.L stop = time() assert L == tuple(range(100)) @@ -155,7 +155,7 @@ def test_linear_access(c, s, a, b): @gen_cluster(client=True) -def test_exceptions_create(c, s, a, b): +async def test_exceptions_create(c, s, a, b): class Foo: x = 0 @@ -163,62 +163,62 @@ def __init__(self): raise ValueError("bar") with pytest.raises(ValueError) as info: - future = yield c.submit(Foo, actor=True) + await c.submit(Foo, actor=True) assert "bar" in str(info.value) @gen_cluster(client=True) -def test_exceptions_method(c, s, a, b): +async def test_exceptions_method(c, s, a, b): class Foo: def throw(self): 1 / 0 - foo = yield c.submit(Foo, actor=True) + foo = await c.submit(Foo, actor=True) with pytest.raises(ZeroDivisionError): - yield foo.throw() + await foo.throw() @gen_cluster(client=True) -def test_gc(c, s, a, b): +async def test_gc(c, s, a, b): actor = c.submit(Counter, actor=True) - yield wait(actor) + await wait(actor) del actor while a.actors or b.actors: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) @gen_cluster(client=True) -def test_track_dependencies(c, s, a, b): +async def test_track_dependencies(c, s, a, b): actor = c.submit(Counter, actor=True) - yield wait(actor) + await wait(actor) x = c.submit(sleep, 0.5) y = c.submit(lambda x, y: x, x, actor) del actor - yield gen.sleep(0.3) + await asyncio.sleep(0.3) assert a.actors or b.actors @gen_cluster(client=True) -def test_future(c, s, a, b): +async def test_future(c, s, a, b): counter = c.submit(Counter, actor=True, workers=[a.address]) assert isinstance(counter, Future) - yield wait(counter) + await wait(counter) assert isinstance(a.actors[counter.key], Counter) - counter = yield counter + counter = await counter assert isinstance(counter, Actor) assert counter._address - yield gen.sleep(0.1) + await asyncio.sleep(0.1) assert counter.key in c.futures # don't lose future @gen_cluster(client=True) -def test_future_dependencies(c, s, a, b): +async def test_future_dependencies(c, s, a, b): counter = c.submit(Counter, actor=True, workers=[a.address]) def f(a): @@ -226,13 +226,13 @@ def f(a): assert a._cls == Counter x = c.submit(f, counter, workers=[b.address]) - yield x + await x assert {ts.key for ts in s.tasks[x.key].dependencies} == {counter.key} assert {ts.key for ts in s.tasks[counter.key].dependents} == {x.key} y = c.submit(f, counter, workers=[a.address], pure=False) - yield y + await y assert {ts.key for ts in s.tasks[y.key].dependencies} == {counter.key} assert {ts.key for ts in s.tasks[counter.key].dependents} == {x.key, y.key} @@ -256,15 +256,15 @@ def test_sync(client): @gen_cluster(client=True, config={"distributed.comm.timeouts.connect": "1s"}) -def test_failed_worker(c, s, a, b): +async def test_failed_worker(c, s, a, b): future = c.submit(Counter, actor=True, workers=[a.address]) - yield wait(future) - counter = yield future + await wait(future) + counter = await future - yield a.close() + await a.close() with pytest.raises(Exception) as info: - yield counter.increment() + await counter.increment() assert "actor" in str(info.value).lower() assert "worker" in str(info.value).lower() @@ -272,45 +272,45 @@ def test_failed_worker(c, s, a, b): @gen_cluster(client=True) -def bench(c, s, a, b): - counter = yield c.submit(Counter, actor=True) +async def bench(c, s, a, b): + counter = await c.submit(Counter, actor=True) for i in range(1000): - yield counter.increment() + await counter.increment() @gen_cluster(client=True) -def test_numpy_roundtrip(c, s, a, b): +async def test_numpy_roundtrip(c, s, a, b): np = pytest.importorskip("numpy") - server = yield c.submit(ParameterServer, actor=True) + server = await c.submit(ParameterServer, actor=True) x = np.random.random(1000) - yield server.put("x", x) + await server.put("x", x) - y = yield server.get("x") + y = await server.get("x") assert (x == y).all() @gen_cluster(client=True) -def test_numpy_roundtrip_getattr(c, s, a, b): +async def test_numpy_roundtrip_getattr(c, s, a, b): np = pytest.importorskip("numpy") - counter = yield c.submit(Counter, actor=True) + counter = await c.submit(Counter, actor=True) x = np.random.random(1000) - yield counter.add(x) + await counter.add(x) - y = yield counter.n + y = await counter.n assert (x == y).all() @gen_cluster(client=True) -def test_repr(c, s, a, b): - counter = yield c.submit(Counter, actor=True) +async def test_repr(c, s, a, b): + counter = await c.submit(Counter, actor=True) assert "Counter" in repr(counter) assert "Actor" in repr(counter) @@ -319,8 +319,8 @@ def test_repr(c, s, a, b): @gen_cluster(client=True) -def test_dir(c, s, a, b): - counter = yield c.submit(Counter, actor=True) +async def test_dir(c, s, a, b): + counter = await c.submit(Counter, actor=True) d = set(dir(counter)) @@ -330,8 +330,8 @@ def test_dir(c, s, a, b): @gen_cluster(client=True) -def test_many_computations(c, s, a, b): - counter = yield c.submit(Counter, actor=True) +async def test_many_computations(c, s, a, b): + counter = await c.submit(Counter, actor=True) def add(n, counter): for i in range(n): @@ -342,13 +342,13 @@ def add(n, counter): while not done.done(): assert len(s.processing) <= a.nthreads + b.nthreads - yield gen.sleep(0.01) + await asyncio.sleep(0.01) - yield done + await done @gen_cluster(client=True, nthreads=[("127.0.0.1", 5)] * 2) -def test_thread_safety(c, s, a, b): +async def test_thread_safety(c, s, a, b): class Unsafe: def __init__(self): self.n = 0 @@ -362,32 +362,32 @@ def f(self): assert self.n == 1 self.n = 0 - unsafe = yield c.submit(Unsafe, actor=True) + unsafe = await c.submit(Unsafe, actor=True) futures = [unsafe.f() for i in range(10)] - yield futures + await c.gather(futures) @gen_cluster(client=True) -def test_Actors_create_dependencies(c, s, a, b): - counter = yield c.submit(Counter, actor=True) +async def test_Actors_create_dependencies(c, s, a, b): + counter = await c.submit(Counter, actor=True) future = c.submit(lambda x: None, counter) - yield wait(future) + await wait(future) assert s.tasks[future.key].dependencies == {s.tasks[counter.key]} @gen_cluster(client=True) -def test_load_balance(c, s, a, b): +async def test_load_balance(c, s, a, b): class Foo: def __init__(self, x): pass b = c.submit(operator.mul, "b", 1000000) - yield wait(b) + await wait(b) [ws] = s.tasks[b.key].who_has - x = yield c.submit(Foo, b, actor=True) - y = yield c.submit(Foo, b, actor=True) + x = await c.submit(Foo, b, actor=True) + y = await c.submit(Foo, b, actor=True) assert x.key != y.key # actors assumed not pure assert s.tasks[x.key].who_has == {ws} # first went to best match @@ -395,28 +395,28 @@ def __init__(self, x): @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 5) -def test_load_balance_map(c, s, *workers): +async def test_load_balance_map(c, s, *workers): class Foo: def __init__(self, x, y=None): pass b = c.submit(operator.mul, "b", 1000000) - yield wait(b) + await wait(b) actors = c.map(Foo, range(10), y=b, actor=True) - yield wait(actors) + await wait(actors) assert all(len(w.actors) == 2 for w in workers) @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 4, Worker=Nanny) -def bench_param_server(c, s, *workers): +async def bench_param_server(c, s, *workers): import dask.array as da import numpy as np x = da.random.random((500000, 1000), chunks=(1000, 1000)) x = x.persist() - yield wait(x) + await wait(x) class ParameterServer: data = None @@ -443,17 +443,17 @@ def f(block, ps=None): from distributed.utils import format_time start = time() - ps = yield c.submit(ParameterServer, x.shape[1], actor=True) + ps = await c.submit(ParameterServer, x.shape[1], actor=True) y = x.map_blocks(f, ps=ps, dtype=x.dtype) - # result = yield c.compute(y.mean()) - yield wait(y.persist()) + # result = await c.compute(y.mean()) + await wait(y.persist()) end = time() print(format_time(end - start)) @pytest.mark.xfail(reason="unknown") @gen_cluster(client=True) -def test_compute(c, s, a, b): +async def test_compute(c, s, a, b): @dask.delayed def f(n, counter): assert isinstance(counter, Actor) @@ -468,12 +468,12 @@ def check(counter, blanks): values = [f(i, counter) for i in range(5)] final = check(counter, values) - result = yield c.compute(final, actors=counter) + result = await c.compute(final, actors=counter) assert result == 0 + 1 + 2 + 3 + 4 start = time() while a.data or b.data: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 5 @@ -509,15 +509,15 @@ def check(dask_worker): nthreads=[("127.0.0.1", 1)], config={"distributed.worker.profile.interval": "1ms"}, ) -def test_actors_in_profile(c, s, a): +async def test_actors_in_profile(c, s, a): class Sleeper: def sleep(self, time): sleep(time) - sleeper = yield c.submit(Sleeper, actor=True) + sleeper = await c.submit(Sleeper, actor=True) for i in range(5): - yield sleeper.sleep(0.200) + await sleeper.sleep(0.200) if ( list(a.profile_recent["children"])[0].startswith("sleep") or "Sleeper.sleep" in a.profile_keys @@ -527,28 +527,26 @@ def sleep(self, time): @gen_cluster(client=True) -def test_waiter(c, s, a, b): +async def test_waiter(c, s, a, b): from tornado.locks import Event class Waiter: def __init__(self): self.event = Event() - @gen.coroutine - def set(self): + async def set(self): self.event.set() - @gen.coroutine - def wait(self): - yield self.event.wait() + async def wait(self): + await self.event.wait() - waiter = yield c.submit(Waiter, actor=True) + waiter = await c.submit(Waiter, actor=True) - futures = [waiter.wait() for i in range(5)] # way more than we have actor threads + futures = [waiter.wait() for _ in range(5)] # way more than we have actor threads - yield gen.sleep(0.1) + await asyncio.sleep(0.1) assert not any(future.done() for future in futures) - yield waiter.set() + await waiter.set() - yield futures + await c.gather(futures) diff --git a/distributed/tests/test_as_completed.py b/distributed/tests/test_as_completed.py index f71c6f7492e..ae257f9bb8e 100644 --- a/distributed/tests/test_as_completed.py +++ b/distributed/tests/test_as_completed.py @@ -6,7 +6,6 @@ from time import sleep import pytest -from tornado import gen from distributed.client import _as_completed, as_completed, _first_completed, wait from distributed.metrics import time @@ -16,18 +15,18 @@ @gen_cluster(client=True) -def test__as_completed(c, s, a, b): +async def test__as_completed(c, s, a, b): x = c.submit(inc, 1) y = c.submit(inc, 1) z = c.submit(inc, 2) q = queue.Queue() - yield _as_completed([x, y, z], q) + await _as_completed([x, y, z], q) assert q.qsize() == 3 assert {q.get(), q.get(), q.get()} == {x, y, z} - result = yield _first_completed([x, y, z]) + result = await _first_completed([x, y, z]) assert result in [x, y, z] @@ -129,11 +128,10 @@ def test_as_completed_cancel_last(client): x = client.submit(inc, 1) y = client.submit(inc, 0.3) - @gen.coroutine - def _(): - yield gen.sleep(0.1) - yield w.cancel(asynchronous=True) - yield y.cancel(asynchronous=True) + async def _(): + await asyncio.sleep(0.1) + await w.cancel(asynchronous=True) + await y.cancel(asynchronous=True) client.loop.add_callback(_) @@ -144,32 +142,23 @@ def _(): @gen_cluster(client=True) -def test_async_for_py2_equivalent(c, s, a, b): +async def test_async_for_py2_equivalent(c, s, a, b): futures = c.map(sleep, [0.01] * 3, pure=False) seq = as_completed(futures) - x = yield seq.__anext__() - y = yield seq.__anext__() - z = yield seq.__anext__() - + x, y, z = [el async for el in seq] assert x.done() assert y.done() assert z.done() assert x.key != y.key - with pytest.raises(StopAsyncIteration): - yield seq.__anext__() - @gen_cluster(client=True) -def test_as_completed_error_async(c, s, a, b): +async def test_as_completed_error_async(c, s, a, b): x = c.submit(throws, 1) y = c.submit(inc, 1) ac = as_completed([x, y]) - first = yield ac.__anext__() - second = yield ac.__anext__() - result = {first, second} - + result = {el async for el in ac} assert result == {x, y} assert x.status == "error" assert y.status == "finished" @@ -200,17 +189,16 @@ def test_as_completed_with_results(client): @gen_cluster(client=True) -def test_as_completed_with_results_async(c, s, a, b): +async def test_as_completed_with_results_async(c, s, a, b): x = c.submit(throws, 1) y = c.submit(inc, 5) z = c.submit(inc, 1) ac = as_completed([x, y, z], with_results=True) - yield y.cancel() + await y.cancel() with pytest.raises(RuntimeError) as exc: - first = yield ac.__anext__() - second = yield ac.__anext__() - third = yield ac.__anext__() + async for _ in ac: + pass assert str(exc.value) == "hello!" @@ -252,17 +240,14 @@ async def test_str(c, s, a, b): @gen_cluster(client=True) -def test_as_completed_with_results_no_raise_async(c, s, a, b): +async def test_as_completed_with_results_no_raise_async(c, s, a, b): x = c.submit(throws, 1) y = c.submit(inc, 5) z = c.submit(inc, 1) ac = as_completed([x, y, z], with_results=True, raise_errors=False) c.loop.add_callback(y.cancel) - first = yield ac.__anext__() - second = yield ac.__anext__() - third = yield ac.__anext__() - res = [first, second, third] + res = [el async for el in ac] dd = {r[0]: r[1:] for r in res} assert set(dd.keys()) == {y, x, z} diff --git a/distributed/tests/test_asyncprocess.py b/distributed/tests/test_asyncprocess.py index e496b35cb90..3923d81cf2c 100644 --- a/distributed/tests/test_asyncprocess.py +++ b/distributed/tests/test_asyncprocess.py @@ -1,11 +1,12 @@ -from datetime import timedelta +import asyncio import gc import os import signal import sys import threading -from time import sleep import weakref +from datetime import timedelta +from time import sleep import pytest from tornado import gen @@ -50,7 +51,7 @@ def threads_info(q): @pytest.mark.xfail(reason="Intermittent failure") @nodebug @gen_test() -def test_simple(): +async def test_simple(): to_child = mp_context.Queue() from_child = mp_context.Queue() @@ -67,15 +68,15 @@ def test_simple(): # join() before start() with pytest.raises(AssertionError): - yield proc.join() + await proc.join() - yield proc.start() + await proc.start() assert proc.is_alive() assert proc.pid is not None assert proc.exitcode is None t1 = time() - yield proc.join(timeout=0.02) + await proc.join(timeout=0.02) dt = time() - t1 assert 0.2 >= dt >= 0.01 assert proc.is_alive() @@ -91,7 +92,7 @@ def test_simple(): # child should be stopping now t1 = time() - yield proc.join(timeout=10) + await proc.join(timeout=10) dt = time() - t1 assert dt <= 1.0 assert not proc.is_alive() @@ -100,7 +101,7 @@ def test_simple(): # join() again t1 = time() - yield proc.join() + await proc.join() dt = time() - t1 assert dt <= 0.6 @@ -133,14 +134,14 @@ def test_simple(): pytest.fail("AsyncProcess should have been destroyed") t1 = time() while wr2() is not None: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) gc.collect() dt = time() - t1 assert dt < 2.0 @gen_test() -def test_exitcode(): +async def test_exitcode(): q = mp_context.Queue() proc = AsyncProcess(target=exit, kwargs={"q": q}) @@ -148,80 +149,81 @@ def test_exitcode(): assert not proc.is_alive() assert proc.exitcode is None - yield proc.start() + await proc.start() assert proc.is_alive() assert proc.exitcode is None q.put(5) - yield proc.join(timeout=3.0) + await proc.join(timeout=3.0) assert not proc.is_alive() assert proc.exitcode == 5 @pytest.mark.skipif(os.name == "nt", reason="POSIX only") @gen_test() -def test_signal(): +async def test_signal(): proc = AsyncProcess(target=exit_with_signal, args=(signal.SIGINT,)) proc.daemon = True assert not proc.is_alive() assert proc.exitcode is None - yield proc.start() - yield proc.join(timeout=3.0) + await proc.start() + await proc.join(timeout=3.0) assert not proc.is_alive() # Can be 255 with forkserver, see https://bugs.python.org/issue30589 assert proc.exitcode in (-signal.SIGINT, 255) proc = AsyncProcess(target=wait) - yield proc.start() + await proc.start() os.kill(proc.pid, signal.SIGTERM) - yield proc.join(timeout=3.0) + await proc.join(timeout=3.0) assert not proc.is_alive() assert proc.exitcode in (-signal.SIGTERM, 255) @gen_test() -def test_terminate(): +async def test_terminate(): proc = AsyncProcess(target=wait) proc.daemon = True - yield proc.start() - yield proc.terminate() + await proc.start() + await proc.terminate() - yield proc.join(timeout=3.0) + await proc.join(timeout=3.0) assert not proc.is_alive() assert proc.exitcode in (-signal.SIGTERM, 255) @gen_test() -def test_close(): +async def test_close(): proc = AsyncProcess(target=exit_now) proc.close() with pytest.raises(ValueError): - yield proc.start() + await proc.start() proc = AsyncProcess(target=exit_now) - yield proc.start() + await proc.start() proc.close() with pytest.raises(ValueError): - yield proc.terminate() + await proc.terminate() proc = AsyncProcess(target=exit_now) - yield proc.start() - yield proc.join() + await proc.start() + await proc.join() proc.close() with pytest.raises(ValueError): - yield proc.join() + await proc.join() proc.close() @gen_test() -def test_exit_callback(): +async def test_exit_callback(): to_child = mp_context.Queue() from_child = mp_context.Queue() evt = Event() + # FIXME: this breaks if changed to async def... @gen.coroutine def on_stop(_proc): assert _proc is proc @@ -234,13 +236,13 @@ def on_stop(_proc): proc.set_exit_callback(on_stop) proc.daemon = True - yield proc.start() - yield gen.sleep(0.05) + await proc.start() + await asyncio.sleep(0.05) assert proc.is_alive() assert not evt.is_set() to_child.put(None) - yield evt.wait(timedelta(seconds=3)) + await evt.wait(timedelta(seconds=3)) assert evt.is_set() assert not proc.is_alive() @@ -250,25 +252,25 @@ def on_stop(_proc): proc.set_exit_callback(on_stop) proc.daemon = True - yield proc.start() - yield gen.sleep(0.05) + await proc.start() + await asyncio.sleep(0.05) assert proc.is_alive() assert not evt.is_set() - yield proc.terminate() - yield evt.wait(timedelta(seconds=3)) + await proc.terminate() + await evt.wait(timedelta(seconds=3)) assert evt.is_set() @gen_test() -def test_child_main_thread(): +async def test_child_main_thread(): """ The main thread in the child should be called "MainThread". """ q = mp_context.Queue() proc = AsyncProcess(target=threads_info, args=(q,)) - yield proc.start() - yield proc.join() + await proc.start() + await proc.join() n_threads = q.get() main_name = q.get() assert n_threads <= 3 @@ -282,38 +284,38 @@ def test_child_main_thread(): sys.platform.startswith("win"), reason="num_fds not supported on windows" ) @gen_test() -def test_num_fds(): +async def test_num_fds(): psutil = pytest.importorskip("psutil") # Warm up proc = AsyncProcess(target=exit_now) proc.daemon = True - yield proc.start() - yield proc.join() + await proc.start() + await proc.join() p = psutil.Process() before = p.num_fds() proc = AsyncProcess(target=exit_now) proc.daemon = True - yield proc.start() - yield proc.join() + await proc.start() + await proc.join() assert not proc.is_alive() assert proc.exitcode == 0 start = time() while p.num_fds() > before: - yield gen.sleep(0.1) + await asyncio.sleep(0.1) print("fds:", before, p.num_fds()) assert time() < start + 10 @gen_test() -def test_terminate_after_stop(): +async def test_terminate_after_stop(): proc = AsyncProcess(target=sleep, args=(0,)) - yield proc.start() - yield gen.sleep(0.1) - yield proc.terminate() + await proc.start() + await asyncio.sleep(0.1) + await proc.terminate() def _worker_process(worker_ready, child_pipe): @@ -342,12 +344,12 @@ def _parent_process(child_pipe): The child_alive pipe is held open for as long as the child is alive, and can be used to determine if it exited correctly. """ - def parent_process_coroutine(): + async def parent_process_coroutine(): worker_ready = mp_context.Event() worker = AsyncProcess(target=_worker_process, args=(worker_ready, child_pipe)) - yield worker.start() + await worker.start() # Wait for the child process to have started. worker_ready.wait() @@ -359,7 +361,7 @@ def parent_process_coroutine(): with pristine_loop() as loop: try: - loop.run_sync(gen.coroutine(parent_process_coroutine), timeout=10) + loop.run_sync(parent_process_coroutine(), timeout=10) finally: loop.stop() diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 17aea54f78a..fd95895c84e 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -20,7 +20,6 @@ import pytest from tlz import identity, isdistinct, concat, pluck, valmap, first, merge -from tornado import gen import dask from dask import delayed @@ -105,80 +104,79 @@ @gen_cluster(client=True, timeout=None) -def test_submit(c, s, a, b): +async def test_submit(c, s, a, b): x = c.submit(inc, 10) assert not x.done() assert isinstance(x, Future) assert x.client is c - result = yield x + result = await x assert result == 11 assert x.done() y = c.submit(inc, 20) z = c.submit(add, x, y) - result = yield z + result = await z assert result == 11 + 21 s.validate_state() @gen_cluster(client=True) -def test_map(c, s, a, b): +async def test_map(c, s, a, b): L1 = c.map(inc, range(5)) assert len(L1) == 5 assert isdistinct(x.key for x in L1) assert all(isinstance(x, Future) for x in L1) - result = yield L1[0] + result = await L1[0] assert result == inc(0) assert len(s.tasks) == 5 L2 = c.map(inc, L1) - result = yield L2[1] + result = await L2[1] assert result == inc(inc(1)) assert len(s.tasks) == 10 # assert L1[0].key in s.tasks[L2[0].key] total = c.submit(sum, L2) - result = yield total + result = await total assert result == sum(map(inc, map(inc, range(5)))) L3 = c.map(add, L1, L2) - result = yield L3[1] + result = await L3[1] assert result == inc(1) + inc(inc(1)) L4 = c.map(add, range(3), range(4)) - results = yield c.gather(L4) - if sys.version_info[0] >= 3: - assert results == list(map(add, range(3), range(4))) + results = await c.gather(L4) + assert results == list(map(add, range(3), range(4))) def f(x, y=10): return x + y L5 = c.map(f, range(5), y=5) - results = yield c.gather(L5) + results = await c.gather(L5) assert results == list(range(5, 10)) y = c.submit(f, 10) L6 = c.map(f, range(5), y=y) - results = yield c.gather(L6) + results = await c.gather(L6) assert results == list(range(20, 25)) s.validate_state() @gen_cluster(client=True) -def test_map_empty(c, s, a, b): +async def test_map_empty(c, s, a, b): L1 = c.map(inc, [], pure=False) assert len(L1) == 0 - results = yield c.gather(L1) + results = await c.gather(L1) assert results == [] @gen_cluster(client=True) -def test_map_keynames(c, s, a, b): +async def test_map_keynames(c, s, a, b): futures = c.map(inc, range(4), key="INC") assert all(f.key.startswith("INC") for f in futures) assert isdistinct(f.key for f in futures) @@ -192,7 +190,7 @@ def test_map_keynames(c, s, a, b): @gen_cluster(client=True) -def test_map_retries(c, s, a, b): +async def test_map_retries(c, s, a, b): args = [ [ZeroDivisionError("one"), 2, 3], [4, 5, 6], @@ -200,22 +198,22 @@ def test_map_retries(c, s, a, b): ] x, y, z = c.map(*map_varying(args), retries=2) - assert (yield x) == 2 - assert (yield y) == 4 - assert (yield z) == 9 + assert await x == 2 + assert await y == 4 + assert await z == 9 x, y, z = c.map(*map_varying(args), retries=1, pure=False) - assert (yield x) == 2 - assert (yield y) == 4 + assert await x == 2 + assert await y == 4 with pytest.raises(ZeroDivisionError, match="eight"): - yield z + await z x, y, z = c.map(*map_varying(args), retries=0, pure=False) with pytest.raises(ZeroDivisionError, match="one"): - yield x - assert (yield y) == 4 + await x + assert await y == 4 with pytest.raises(ZeroDivisionError, match="seven"): - yield z + await z @gen_cluster(client=True) @@ -235,25 +233,25 @@ async def test_map_batch_size(c, s, a, b): @gen_cluster(client=True) -def test_compute_retries(c, s, a, b): +async def test_compute_retries(c, s, a, b): args = [ZeroDivisionError("one"), ZeroDivisionError("two"), 3] # Sanity check for varying() use x = c.compute(delayed(varying(args))()) with pytest.raises(ZeroDivisionError, match="one"): - yield x + await x # Same retries for all x = c.compute(delayed(varying(args))(), retries=1) with pytest.raises(ZeroDivisionError, match="two"): - yield x + await x x = c.compute(delayed(varying(args))(), retries=2) - assert (yield x) == 3 + assert await x == 3 args.append(4) x = c.compute(delayed(varying(args))(), retries=2) - assert (yield x) == 3 + assert await x == 3 # Per-future retries xargs = [ZeroDivisionError("one"), ZeroDivisionError("two"), 30, 40] @@ -264,17 +262,17 @@ def test_compute_retries(c, s, a, b): x, y = c.compute([x, y], retries={x: 2}) gc.collect() - assert (yield x) == 30 + assert await x == 30 with pytest.raises(ZeroDivisionError, match="five"): - yield y + await y x, y, z = [delayed(varying(args))() for args in (xargs, yargs, zargs)] x, y, z = c.compute([x, y, z], retries={(y, z): 2}) with pytest.raises(ZeroDivisionError, match="one"): - yield x - assert (yield y) == 70 - assert (yield z) == 80 + await x + assert await y == 70 + assert await z == 80 def test_retries_get(c): @@ -289,43 +287,43 @@ def test_retries_get(c): @gen_cluster(client=True) -def test_compute_persisted_retries(c, s, a, b): +async def test_compute_persisted_retries(c, s, a, b): args = [ZeroDivisionError("one"), ZeroDivisionError("two"), 3] # Sanity check x = c.persist(delayed(varying(args))()) fut = c.compute(x) with pytest.raises(ZeroDivisionError, match="one"): - yield fut + await fut x = c.persist(delayed(varying(args))()) fut = c.compute(x, retries=1) with pytest.raises(ZeroDivisionError, match="two"): - yield fut + await fut x = c.persist(delayed(varying(args))()) fut = c.compute(x, retries=2) - assert (yield fut) == 3 + assert await fut == 3 args.append(4) x = c.persist(delayed(varying(args))()) fut = c.compute(x, retries=3) - assert (yield fut) == 3 + assert await fut == 3 @gen_cluster(client=True) -def test_persist_retries(c, s, a, b): +async def test_persist_retries(c, s, a, b): # Same retries for all args = [ZeroDivisionError("one"), ZeroDivisionError("two"), 3] x = c.persist(delayed(varying(args))(), retries=1) x = c.compute(x) with pytest.raises(ZeroDivisionError, match="two"): - yield x + await x x = c.persist(delayed(varying(args))(), retries=2) x = c.compute(x) - assert (yield x) == 3 + assert await x == 3 # Per-key retries xargs = [ZeroDivisionError("one"), ZeroDivisionError("two"), 30, 40] @@ -337,17 +335,17 @@ def test_persist_retries(c, s, a, b): x, y, z = c.compute([x, y, z]) with pytest.raises(ZeroDivisionError, match="one"): - yield x - assert (yield y) == 70 - assert (yield z) == 80 + await x + assert await y == 70 + assert await z == 80 @gen_cluster(client=True) -def test_retries_dask_array(c, s, a, b): +async def test_retries_dask_array(c, s, a, b): da = pytest.importorskip("dask.array") x = da.ones((10, 10), chunks=(3, 3)) future = c.compute(x.sum(), retries=2) - y = yield future + y = await future assert y == 100 @@ -370,7 +368,7 @@ async def test_future_repr(c, s, a, b): @gen_cluster(client=True) -def test_future_tuple_repr(c, s, a, b): +async def test_future_tuple_repr(c, s, a, b): da = pytest.importorskip("dask.array") y = da.arange(10, chunks=(5,)).persist() f = futures_of(y)[0] @@ -380,13 +378,13 @@ def test_future_tuple_repr(c, s, a, b): @gen_cluster(client=True) -def test_Future_exception(c, s, a, b): +async def test_Future_exception(c, s, a, b): x = c.submit(div, 1, 0) - result = yield x.exception() + result = await x.exception() assert isinstance(result, ZeroDivisionError) x = c.submit(div, 1, 1) - result = yield x.exception() + result = await x.exception() assert result is None @@ -399,23 +397,23 @@ def test_Future_exception_sync(c): @gen_cluster(client=True) -def test_Future_release(c, s, a, b): +async def test_Future_release(c, s, a, b): # Released Futures should be removed timely from the Client x = c.submit(div, 1, 1) - yield x + await x x.release() - yield gen.moment + await asyncio.sleep(0) assert not c.futures x = c.submit(slowinc, 1, delay=0.5) x.release() - yield gen.moment + await asyncio.sleep(0) assert not c.futures x = c.submit(div, 1, 0) - yield x.exception() + await x.exception() x.release() - yield gen.moment + await asyncio.sleep(0) assert not c.futures @@ -454,7 +452,7 @@ def test_short_tracebacks(loop, c): @gen_cluster(client=True) -def test_map_naming(c, s, a, b): +async def test_map_naming(c, s, a, b): L1 = c.map(inc, range(5)) L2 = c.map(inc, range(5)) @@ -468,7 +466,7 @@ def test_map_naming(c, s, a, b): @gen_cluster(client=True) -def test_submit_naming(c, s, a, b): +async def test_submit_naming(c, s, a, b): a = c.submit(inc, 1) b = c.submit(inc, 1) @@ -479,33 +477,33 @@ def test_submit_naming(c, s, a, b): @gen_cluster(client=True) -def test_exceptions(c, s, a, b): +async def test_exceptions(c, s, a, b): x = c.submit(div, 1, 2) - result = yield x + result = await x assert result == 1 / 2 x = c.submit(div, 1, 0) with pytest.raises(ZeroDivisionError): - result = yield x + await x x = c.submit(div, 10, 2) # continues to operate - result = yield x + result = await x assert result == 10 / 2 @gen_cluster() -def test_gc(s, a, b): - c = yield Client(s.address, asynchronous=True) +async def test_gc(s, a, b): + c = await Client(s.address, asynchronous=True) x = c.submit(inc, 10) - yield x + await x assert s.tasks[x.key].who_has x.__del__() - yield async_wait_for( + await async_wait_for( lambda: x.key not in s.tasks or not s.tasks[x.key].who_has, timeout=0.3 ) - yield c.close() + await c.close() def test_thread(c): @@ -534,27 +532,27 @@ def test_sync_exceptions(c): @gen_cluster(client=True) -def test_gather(c, s, a, b): +async def test_gather(c, s, a, b): x = c.submit(inc, 10) y = c.submit(inc, x) - result = yield c.gather(x) + result = await c.gather(x) assert result == 11 - result = yield c.gather([x]) + result = await c.gather([x]) assert result == [11] - result = yield c.gather({"x": x, "y": [y]}) + result = await c.gather({"x": x, "y": [y]}) assert result == {"x": 11, "y": [12]} @gen_cluster(client=True) -def test_gather_lost(c, s, a, b): - [x] = yield c.scatter([1], workers=a.address) +async def test_gather_lost(c, s, a, b): + [x] = await c.scatter([1], workers=a.address) y = c.submit(inc, 1, workers=b.address) - yield a.close() + await a.close() with pytest.raises(Exception): - res = yield c.gather([x, y]) + await c.gather([x, y]) def test_gather_sync(c): @@ -571,25 +569,25 @@ def test_gather_sync(c): @gen_cluster(client=True) -def test_gather_strict(c, s, a, b): +async def test_gather_strict(c, s, a, b): x = c.submit(div, 2, 1) y = c.submit(div, 1, 0) with pytest.raises(ZeroDivisionError): - yield c.gather([x, y]) + await c.gather([x, y]) - [xx] = yield c.gather([x, y], errors="skip") + [xx] = await c.gather([x, y], errors="skip") assert xx == 2 @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)]) -def test_gather_skip(c, s, a): +async def test_gather_skip(c, s, a): x = c.submit(div, 1, 0, priority=10) y = c.submit(slowinc, 1, delay=0.5) with captured_logger(logging.getLogger("distributed.scheduler")) as sched: with captured_logger(logging.getLogger("distributed.client")) as client: - L = yield c.gather([x, y], errors="skip") + L = await c.gather([x, y], errors="skip") assert L == [2] assert not client.getvalue() @@ -597,28 +595,29 @@ def test_gather_skip(c, s, a): @gen_cluster(client=True) -def test_limit_concurrent_gathering(c, s, a, b): +async def test_limit_concurrent_gathering(c, s, a, b): futures = c.map(inc, range(100)) - results = yield futures + await c.gather(futures) assert len(a.outgoing_transfer_log) + len(b.outgoing_transfer_log) < 100 @gen_cluster(client=True, timeout=None) -def test_get(c, s, a, b): +async def test_get(c, s, a, b): future = c.get({"x": (inc, 1)}, "x", sync=False) assert isinstance(future, Future) - result = yield future + result = await future assert result == 2 futures = c.get({"x": (inc, 1)}, ["x"], sync=False) assert isinstance(futures[0], Future) - result = yield futures + result = await c.gather(futures) assert result == [2] - result = yield c.get({}, [], sync=False) + futures = c.get({}, [], sync=False) + result = await c.gather(futures) assert result == [] - result = yield c.get( + result = await c.get( {("x", 1): (inc, 1), ("x", 2): (inc, ("x", 1))}, ("x", 2), sync=False ) assert result == 3 @@ -650,7 +649,7 @@ def test_get_sync_optimize_graph_passes_through(c): @gen_cluster(client=True) -def test_gather_errors(c, s, a, b): +async def test_gather_errors(c, s, a, b): def f(a, b): raise TypeError @@ -660,20 +659,20 @@ def g(a, b): future_f = c.submit(f, 1, 2) future_g = c.submit(g, 1, 2) with pytest.raises(TypeError): - yield c.gather(future_f) + await c.gather(future_f) with pytest.raises(AttributeError): - yield c.gather(future_g) + await c.gather(future_g) - yield a.close() + await a.close() @gen_cluster(client=True) -def test_wait(c, s, a, b): +async def test_wait(c, s, a, b): x = c.submit(inc, 1) y = c.submit(inc, 1) z = c.submit(inc, 2) - done, not_done = yield wait([x, y, z]) + done, not_done = await wait([x, y, z]) assert done == {x, y, z} assert not_done == set() @@ -681,12 +680,12 @@ def test_wait(c, s, a, b): @gen_cluster(client=True) -def test_wait_first_completed(c, s, a, b): +async def test_wait_first_completed(c, s, a, b): x = c.submit(slowinc, 1) y = c.submit(slowinc, 1) z = c.submit(inc, 2) - done, not_done = yield wait([x, y, z], return_when="FIRST_COMPLETED") + done, not_done = await wait([x, y, z], return_when="FIRST_COMPLETED") assert done == {z} assert not_done == {x, y} @@ -696,10 +695,10 @@ def test_wait_first_completed(c, s, a, b): @gen_cluster(client=True, timeout=2) -def test_wait_timeout(c, s, a, b): +async def test_wait_timeout(c, s, a, b): future = c.submit(sleep, 0.3) with pytest.raises(TimeoutError): - yield wait(future, timeout=0.01) + await wait(future, timeout=0.01) def test_wait_sync(c): @@ -728,31 +727,31 @@ def test_wait_informative_error_for_timeouts(c): @gen_cluster(client=True) -def test_garbage_collection(c, s, a, b): +async def test_garbage_collection(c, s, a, b): x = c.submit(inc, 1) y = c.submit(inc, 1) assert c.refcount[x.key] == 2 x.__del__() - yield gen.moment + await asyncio.sleep(0) assert c.refcount[x.key] == 1 z = c.submit(inc, y) y.__del__() - yield gen.moment + await asyncio.sleep(0) - result = yield z + result = await z assert result == 3 ykey = y.key y.__del__() - yield gen.moment + await asyncio.sleep(0) assert ykey not in c.futures @gen_cluster(client=True) -def test_garbage_collection_with_scatter(c, s, a, b): - [future] = yield c.scatter([1]) +async def test_garbage_collection_with_scatter(c, s, a, b): + [future] = await c.scatter([1]) assert future.key in c.futures assert future.status == "finished" assert s.who_wants[future.key] == {c.id} @@ -760,7 +759,7 @@ def test_garbage_collection_with_scatter(c, s, a, b): key = future.key assert c.refcount[key] == 1 future.__del__() - yield gen.moment + await asyncio.sleep(0) assert c.refcount[key] == 0 start = time() @@ -769,50 +768,50 @@ def test_garbage_collection_with_scatter(c, s, a, b): break else: assert time() < start + 3 - yield gen.sleep(0.1) + await asyncio.sleep(0.1) @gen_cluster(timeout=1000, client=True) -def test_recompute_released_key(c, s, a, b): +async def test_recompute_released_key(c, s, a, b): x = c.submit(inc, 100) - result1 = yield x + result1 = await x xkey = x.key del x import gc gc.collect() - yield gen.moment + await asyncio.sleep(0) assert c.refcount[xkey] == 0 # 1 second batching needs a second action to trigger while xkey in s.tasks and s.tasks[xkey].who_has or xkey in a.data or xkey in b.data: - yield gen.sleep(0.1) + await asyncio.sleep(0.1) x = c.submit(inc, 100) assert x.key in c.futures - result2 = yield x + result2 = await x assert result1 == result2 @pytest.mark.slow @gen_cluster(client=True) -def test_long_tasks_dont_trigger_timeout(c, s, a, b): +async def test_long_tasks_dont_trigger_timeout(c, s, a, b): from time import sleep x = c.submit(sleep, 3) - yield x + await x @pytest.mark.skip @gen_cluster(client=True) -def test_missing_data_heals(c, s, a, b): +async def test_missing_data_heals(c, s, a, b): a.validate = False b.validate = False x = c.submit(inc, 1) y = c.submit(inc, x) z = c.submit(inc, y) - yield wait([x, y, z]) + await wait([x, y, z]) # Secretly delete y's key if y.key in a.data: @@ -821,36 +820,36 @@ def test_missing_data_heals(c, s, a, b): if y.key in b.data: del b.data[y.key] b.release_key(y.key) - yield gen.moment + await asyncio.sleep(0) w = c.submit(add, y, z) - result = yield w + result = await w assert result == 3 + 4 @pytest.mark.skip @gen_cluster(client=True) -def test_gather_robust_to_missing_data(c, s, a, b): +async def test_gather_robust_to_missing_data(c, s, a, b): a.validate = False b.validate = False x, y, z = c.map(inc, range(3)) - yield wait([x, y, z]) # everything computed + await wait([x, y, z]) # everything computed for f in [x, y]: for w in [a, b]: if f.key in w.data: del w.data[f.key] - yield gen.moment + await asyncio.sleep(0) w.release_key(f.key) - xx, yy, zz = yield c.gather([x, y, z]) + xx, yy, zz = await c.gather([x, y, z]) assert (xx, yy, zz) == (1, 2, 3) @pytest.mark.skip @gen_cluster(client=True) -def test_gather_robust_to_nested_missing_data(c, s, a, b): +async def test_gather_robust_to_nested_missing_data(c, s, a, b): a.validate = False b.validate = False w = c.submit(inc, 1) @@ -858,22 +857,22 @@ def test_gather_robust_to_nested_missing_data(c, s, a, b): y = c.submit(inc, x) z = c.submit(inc, y) - yield wait([z]) + await wait([z]) for worker in [a, b]: for datum in [y, z]: if datum.key in worker.data: del worker.data[datum.key] - yield gen.moment + await asyncio.sleep(0) worker.release_key(datum.key) - result = yield c.gather([z]) + result = await c.gather([z]) assert result == [inc(inc(inc(inc(1))))] @gen_cluster(client=True) -def test_tokenize_on_futures(c, s, a, b): +async def test_tokenize_on_futures(c, s, a, b): x = c.submit(inc, 1) y = c.submit(inc, 1) tok = tokenize(x) @@ -889,10 +888,10 @@ def test_tokenize_on_futures(c, s, a, b): not sys.platform.startswith("linux"), reason="Need 127.0.0.2 to mean localhost" ) @gen_cluster([("127.0.0.1", 1), ("127.0.0.2", 2)], client=True) -def test_restrictions_submit(c, s, a, b): +async def test_restrictions_submit(c, s, a, b): x = c.submit(inc, 1, workers={a.ip}) y = c.submit(inc, x, workers={b.ip}) - yield wait([x, y]) + await wait([x, y]) assert s.host_restrictions[x.key] == {a.ip} assert x.key in a.data @@ -902,10 +901,10 @@ def test_restrictions_submit(c, s, a, b): @gen_cluster(client=True) -def test_restrictions_ip_port(c, s, a, b): +async def test_restrictions_ip_port(c, s, a, b): x = c.submit(inc, 1, workers={a.address}) y = c.submit(inc, x, workers={b.address}) - yield wait([x, y]) + await wait([x, y]) assert s.worker_restrictions[x.key] == {a.address} assert x.key in a.data @@ -918,9 +917,9 @@ def test_restrictions_ip_port(c, s, a, b): not sys.platform.startswith("linux"), reason="Need 127.0.0.2 to mean localhost" ) @gen_cluster([("127.0.0.1", 1), ("127.0.0.2", 2)], client=True) -def test_restrictions_map(c, s, a, b): +async def test_restrictions_map(c, s, a, b): L = c.map(inc, range(5), workers={a.ip}) - yield wait(L) + await wait(L) assert set(a.data) == {x.key for x in L} assert not b.data @@ -928,7 +927,7 @@ def test_restrictions_map(c, s, a, b): assert s.host_restrictions[x.key] == {a.ip} L = c.map(inc, [10, 11, 12], workers=[{a.ip}, {a.ip, b.ip}, {b.ip}]) - yield wait(L) + await wait(L) assert s.host_restrictions[L[0].key] == {a.ip} assert s.host_restrictions[L[1].key] == {a.ip, b.ip} @@ -942,22 +941,22 @@ def test_restrictions_map(c, s, a, b): not sys.platform.startswith("linux"), reason="Need 127.0.0.2 to mean localhost" ) @gen_cluster([("127.0.0.1", 1), ("127.0.0.2", 2)], client=True) -def test_restrictions_get(c, s, a, b): +async def test_restrictions_get(c, s, a, b): dsk = {"x": 1, "y": (inc, "x"), "z": (inc, "y")} restrictions = {"y": {a.ip}, "z": {b.ip}} futures = c.get(dsk, ["y", "z"], restrictions, sync=False) - result = yield futures + result = await c.gather(futures) assert result == [2, 3] assert "y" in a.data assert "z" in b.data @gen_cluster(client=True) -def dont_test_bad_restrictions_raise_exception(c, s, a, b): +async def dont_test_bad_restrictions_raise_exception(c, s, a, b): z = c.submit(inc, 2, workers={"bad-address"}) try: - yield z + await z assert False except ValueError as e: assert "bad-address" in str(e) @@ -965,133 +964,133 @@ def dont_test_bad_restrictions_raise_exception(c, s, a, b): @gen_cluster(client=True, timeout=None) -def test_remove_worker(c, s, a, b): +async def test_remove_worker(c, s, a, b): L = c.map(inc, range(20)) - yield wait(L) + await wait(L) - yield b.close() + await b.close() assert b.address not in s.workers - result = yield c.gather(L) + result = await c.gather(L) assert result == list(map(inc, range(20))) @gen_cluster(nthreads=[("127.0.0.1", 1)], client=True) -def test_errors_dont_block(c, s, w): +async def test_errors_dont_block(c, s, w): L = [c.submit(inc, 1), c.submit(throws, 1), c.submit(inc, 2), c.submit(throws, 2)] start = time() while not (L[0].status == L[2].status == "finished"): assert time() < start + 5 - yield gen.sleep(0.01) + await asyncio.sleep(0.01) - result = yield c.gather([L[0], L[2]]) + result = await c.gather([L[0], L[2]]) assert result == [2, 3] @gen_cluster(client=True) -def test_submit_quotes(c, s, a, b): +async def test_submit_quotes(c, s, a, b): def assert_list(x, z=[]): return isinstance(x, list) and isinstance(z, list) x = c.submit(assert_list, [1, 2, 3]) - result = yield x + result = await x assert result x = c.submit(assert_list, [1, 2, 3], z=[4, 5, 6]) - result = yield x + result = await x assert result x = c.submit(inc, 1) y = c.submit(inc, 2) z = c.submit(assert_list, [x, y]) - result = yield z + result = await z assert result @gen_cluster(client=True) -def test_map_quotes(c, s, a, b): +async def test_map_quotes(c, s, a, b): def assert_list(x, z=[]): return isinstance(x, list) and isinstance(z, list) L = c.map(assert_list, [[1, 2, 3], [4]]) - result = yield c.gather(L) + result = await c.gather(L) assert all(result) L = c.map(assert_list, [[1, 2, 3], [4]], z=[10]) - result = yield c.gather(L) + result = await c.gather(L) assert all(result) L = c.map(assert_list, [[1, 2, 3], [4]], [[]] * 3) - result = yield c.gather(L) + result = await c.gather(L) assert all(result) @gen_cluster() -def test_two_consecutive_clients_share_results(s, a, b): - c = yield Client(s.address, asynchronous=True) +async def test_two_consecutive_clients_share_results(s, a, b): + c = await Client(s.address, asynchronous=True) x = c.submit(random.randint, 0, 1000, pure=True) - xx = yield x + xx = await x - f = yield Client(s.address, asynchronous=True) + f = await Client(s.address, asynchronous=True) y = f.submit(random.randint, 0, 1000, pure=True) - yy = yield y + yy = await y assert xx == yy - yield c.close() - yield f.close() + await c.close() + await f.close() @gen_cluster(client=True) -def test_submit_then_get_with_Future(c, s, a, b): +async def test_submit_then_get_with_Future(c, s, a, b): x = c.submit(slowinc, 1) dsk = {"y": (inc, x)} - result = yield c.get(dsk, "y", sync=False) + result = await c.get(dsk, "y", sync=False) assert result == 3 @gen_cluster(client=True) -def test_aliases(c, s, a, b): +async def test_aliases(c, s, a, b): x = c.submit(inc, 1) dsk = {"y": x} - result = yield c.get(dsk, "y", sync=False) + result = await c.get(dsk, "y", sync=False) assert result == 2 @gen_cluster(client=True) -def test_aliases_2(c, s, a, b): +async def test_aliases_2(c, s, a, b): dsk_keys = [ ({"x": (inc, 1), "y": "x", "z": "x", "w": (add, "y", "z")}, ["y", "w"]), ({"x": "y", "y": 1}, ["x"]), ({"x": 1, "y": "x", "z": "y", "w": (inc, "z")}, ["w"]), ] for dsk, keys in dsk_keys: - result = yield c.get(dsk, keys, sync=False) + result = await c.gather(c.get(dsk, keys, sync=False)) assert list(result) == list(dask.get(dsk, keys)) - yield gen.moment + await asyncio.sleep(0) @gen_cluster(client=True) -def test__scatter(c, s, a, b): - d = yield c.scatter({"y": 20}) +async def test_scatter(c, s, a, b): + d = await c.scatter({"y": 20}) assert isinstance(d["y"], Future) assert a.data.get("y") == 20 or b.data.get("y") == 20 y_who_has = s.get_who_has(keys=["y"])["y"] assert a.address in y_who_has or b.address in y_who_has assert s.get_nbytes(summary=False) == {"y": sizeof(20)} - yy = yield c.gather([d["y"]]) + yy = await c.gather([d["y"]]) assert yy == [20] - [x] = yield c.scatter([10]) + [x] = await c.scatter([10]) assert isinstance(x, Future) assert a.data.get(x.key) == 10 or b.data.get(x.key) == 10 - xx = yield c.gather([x]) + xx = await c.gather([x]) x_who_has = s.get_who_has(keys=[x.key])[x.key] assert s.tasks[x.key].who_has assert ( @@ -1102,49 +1101,49 @@ def test__scatter(c, s, a, b): assert xx == [10] z = c.submit(add, x, d["y"]) # submit works on Future - result = yield z + result = await z assert result == 10 + 20 - result = yield c.gather([z, x]) + result = await c.gather([z, x]) assert result == [30, 10] @gen_cluster(client=True) -def test__scatter_types(c, s, a, b): - d = yield c.scatter({"x": 1}) +async def test_scatter_types(c, s, a, b): + d = await c.scatter({"x": 1}) assert isinstance(d, dict) assert list(d) == ["x"] for seq in [[1], (1,), {1}, frozenset([1])]: - L = yield c.scatter(seq) + L = await c.scatter(seq) assert isinstance(L, type(seq)) assert len(L) == 1 s.validate_state() - seq = yield c.scatter(range(5)) + seq = await c.scatter(range(5)) assert isinstance(seq, list) assert len(seq) == 5 s.validate_state() @gen_cluster(client=True) -def test__scatter_non_list(c, s, a, b): - x = yield c.scatter(1) +async def test_scatter_non_list(c, s, a, b): + x = await c.scatter(1) assert isinstance(x, Future) - result = yield x + result = await x assert result == 1 @gen_cluster(client=True) -def test_scatter_hash(c, s, a, b): - [a] = yield c.scatter([1]) - [b] = yield c.scatter([1]) +async def test_scatter_hash(c, s, a, b): + [a] = await c.scatter([1]) + [b] = await c.scatter([1]) assert a.key == b.key s.validate_state() @gen_cluster(client=True) -def test_scatter_tokenize_local(c, s, a, b): +async def test_scatter_tokenize_local(c, s, a, b): from dask.base import normalize_token class MyObj: @@ -1159,46 +1158,46 @@ def f(x): obj = MyObj() - future = yield c.scatter(obj) + future = await c.scatter(obj) assert L and L[0] is obj @gen_cluster(client=True) -def test_scatter_singletons(c, s, a, b): +async def test_scatter_singletons(c, s, a, b): np = pytest.importorskip("numpy") pd = pytest.importorskip("pandas") for x in [1, np.ones(5), pd.DataFrame({"x": [1, 2, 3]})]: - future = yield c.scatter(x) - result = yield future + future = await c.scatter(x) + result = await future assert str(result) == str(x) @gen_cluster(client=True) -def test_scatter_typename(c, s, a, b): - future = yield c.scatter(123) +async def test_scatter_typename(c, s, a, b): + future = await c.scatter(123) assert future.key.startswith("int") @gen_cluster(client=True) -def test_scatter_hash(c, s, a, b): - x = yield c.scatter(123) - y = yield c.scatter(123) +async def test_scatter_hash(c, s, a, b): + x = await c.scatter(123) + y = await c.scatter(123) assert x.key == y.key - z = yield c.scatter(123, hash=False) + z = await c.scatter(123, hash=False) assert z.key != y.key @gen_cluster(client=True) -def test_get_releases_data(c, s, a, b): - [x] = yield c.get({"x": (inc, 1)}, ["x"], sync=False) +async def test_get_releases_data(c, s, a, b): + await c.gather(c.get({"x": (inc, 1)}, ["x"], sync=False)) import gc gc.collect() start = time() while c.refcount["x"]: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 2 @@ -1229,26 +1228,26 @@ def test_global_clients(loop): @gen_cluster(client=True) -def test_exception_on_exception(c, s, a, b): +async def test_exception_on_exception(c, s, a, b): x = c.submit(lambda: 1 / 0) y = c.submit(inc, x) with pytest.raises(ZeroDivisionError): - yield y + await y z = c.submit(inc, y) with pytest.raises(ZeroDivisionError): - yield z + await z @gen_cluster(client=True) -def test_get_nbytes(c, s, a, b): - [x] = yield c.scatter([1]) +async def test_get_nbytes(c, s, a, b): + [x] = await c.scatter([1]) assert s.get_nbytes(summary=False) == {x.key: sizeof(1)} y = c.submit(inc, x) - yield y + await y assert s.get_nbytes(summary=False) == {x.key: sizeof(1), y.key: sizeof(2)} @@ -1257,24 +1256,24 @@ def test_get_nbytes(c, s, a, b): not sys.platform.startswith("linux"), reason="Need 127.0.0.2 to mean localhost" ) @gen_cluster([("127.0.0.1", 1), ("127.0.0.2", 2)], client=True) -def test_nbytes_determines_worker(c, s, a, b): +async def test_nbytes_determines_worker(c, s, a, b): x = c.submit(identity, 1, workers=[a.ip]) y = c.submit(identity, tuple(range(100)), workers=[b.ip]) - yield c.gather([x, y]) + await c.gather([x, y]) z = c.submit(lambda x, y: None, x, y) - yield z + await z assert s.tasks[z.key].who_has == {s.workers[b.address]} @gen_cluster(client=True) -def test_if_intermediates_clear_on_error(c, s, a, b): +async def test_if_intermediates_clear_on_error(c, s, a, b): x = delayed(div, pure=True)(1, 0) y = delayed(div, pure=True)(1, 2) z = delayed(add, pure=True)(x, y) f = c.compute(z) with pytest.raises(ZeroDivisionError): - yield f + await f s.validate_state() assert not any(ts.who_has for ts in s.tasks.values()) @@ -1282,7 +1281,7 @@ def test_if_intermediates_clear_on_error(c, s, a, b): @gen_cluster( client=True, config={"distributed.scheduler.default-task-durations": {"f": "1ms"}} ) -def test_pragmatic_move_small_data_to_large_data(c, s, a, b): +async def test_pragmatic_move_small_data_to_large_data(c, s, a, b): np = pytest.importorskip("numpy") lists = c.map(np.ones, [10000] * 10, pure=False) sums = c.map(np.sum, lists) @@ -1293,9 +1292,8 @@ def f(x, y): results = c.map(f, lists, [total] * 10) - yield wait([total]) - - yield wait(results) + await wait([total]) + await wait(results) assert ( sum( @@ -1307,20 +1305,20 @@ def f(x, y): @gen_cluster(client=True) -def test_get_with_non_list_key(c, s, a, b): +async def test_get_with_non_list_key(c, s, a, b): dsk = {("x", 0): (inc, 1), 5: (inc, 2)} - x = yield c.get(dsk, ("x", 0), sync=False) - y = yield c.get(dsk, 5, sync=False) + x = await c.get(dsk, ("x", 0), sync=False) + y = await c.get(dsk, 5, sync=False) assert x == 2 assert y == 3 @gen_cluster(client=True) -def test_get_with_error(c, s, a, b): +async def test_get_with_error(c, s, a, b): dsk = {"x": (div, 1, 0), "y": (inc, "x")} with pytest.raises(ZeroDivisionError): - yield c.get(dsk, "y", sync=False) + await c.get(dsk, "y", sync=False) def test_get_with_error_sync(c): @@ -1330,12 +1328,12 @@ def test_get_with_error_sync(c): @gen_cluster(client=True) -def test_directed_scatter(c, s, a, b): - yield c.scatter([1, 2, 3], workers=[a.address]) +async def test_directed_scatter(c, s, a, b): + await c.scatter([1, 2, 3], workers=[a.address]) assert len(a.data) == 3 assert not b.data - yield c.scatter([4, 5], workers=[b.name]) + await c.scatter([4, 5], workers=[b.name]) assert len(b.data) == 2 @@ -1347,56 +1345,56 @@ def test_directed_scatter_sync(c, s, a, b, loop): @gen_cluster(client=True) -def test_scatter_direct(c, s, a, b): - future = yield c.scatter(123, direct=True) +async def test_scatter_direct(c, s, a, b): + future = await c.scatter(123, direct=True) assert future.key in a.data or future.key in b.data assert s.tasks[future.key].who_has assert future.status == "finished" - result = yield future + result = await future assert result == 123 assert not s.counters["op"].components[0]["scatter"] - result = yield future + result = await future assert not s.counters["op"].components[0]["gather"] - result = yield c.gather(future) + result = await c.gather(future) assert not s.counters["op"].components[0]["gather"] @gen_cluster(client=True) -def test_scatter_direct_numpy(c, s, a, b): +async def test_scatter_direct_numpy(c, s, a, b): np = pytest.importorskip("numpy") x = np.ones(5) - future = yield c.scatter(x, direct=True) - result = yield future + future = await c.scatter(x, direct=True) + result = await future assert np.allclose(x, result) assert not s.counters["op"].components[0]["scatter"] @gen_cluster(client=True) -def test_scatter_direct_broadcast(c, s, a, b): - future2 = yield c.scatter(456, direct=True, broadcast=True) +async def test_scatter_direct_broadcast(c, s, a, b): + future2 = await c.scatter(456, direct=True, broadcast=True) assert future2.key in a.data assert future2.key in b.data assert s.tasks[future2.key].who_has == {s.workers[a.address], s.workers[b.address]} - result = yield future2 + result = await future2 assert result == 456 assert not s.counters["op"].components[0]["scatter"] @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 4) -def test_scatter_direct_balanced(c, s, *workers): - futures = yield c.scatter([1, 2, 3], direct=True) +async def test_scatter_direct_balanced(c, s, *workers): + futures = await c.scatter([1, 2, 3], direct=True) assert sorted([len(w.data) for w in workers]) == [0, 1, 1, 1] @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 4) -def test_scatter_direct_broadcast_target(c, s, *workers): - futures = yield c.scatter([123, 456], direct=True, workers=workers[0].address) +async def test_scatter_direct_broadcast_target(c, s, *workers): + futures = await c.scatter([123, 456], direct=True, workers=workers[0].address) assert futures[0].key in workers[0].data assert futures[1].key in workers[0].data - futures = yield c.scatter( + futures = await c.scatter( [123, 456], direct=True, broadcast=True, @@ -1410,16 +1408,16 @@ def test_scatter_direct_broadcast_target(c, s, *workers): @gen_cluster(client=True, nthreads=[]) -def test_scatter_direct_empty(c, s): +async def test_scatter_direct_empty(c, s): with pytest.raises((ValueError, TimeoutError)): - yield c.scatter(123, direct=True, timeout=0.1) + await c.scatter(123, direct=True, timeout=0.1) @gen_cluster(client=True, timeout=None, nthreads=[("127.0.0.1", 1)] * 5) -def test_scatter_direct_spread_evenly(c, s, *workers): +async def test_scatter_direct_spread_evenly(c, s, *workers): futures = [] for i in range(10): - future = yield c.scatter(i, direct=True) + future = await c.scatter(i, direct=True) futures.append(future) assert all(w.data for w in workers) @@ -1436,34 +1434,32 @@ def test_scatter_gather_sync(c, direct, broadcast): @gen_cluster(client=True) -def test_gather_direct(c, s, a, b): - futures = yield c.scatter([1, 2, 3]) +async def test_gather_direct(c, s, a, b): + futures = await c.scatter([1, 2, 3]) - data = yield c.gather(futures, direct=True) + data = await c.gather(futures, direct=True) assert data == [1, 2, 3] @gen_cluster(client=True) -def test_many_submits_spread_evenly(c, s, a, b): +async def test_many_submits_spread_evenly(c, s, a, b): L = [c.submit(inc, i) for i in range(10)] - yield wait(L) + await wait(L) assert a.data and b.data @gen_cluster(client=True) -def test_traceback(c, s, a, b): +async def test_traceback(c, s, a, b): x = c.submit(div, 1, 0) - tb = yield x.traceback() - - if sys.version_info[0] >= 3: - assert any("x / y" in line for line in pluck(3, traceback.extract_tb(tb))) + tb = await x.traceback() + assert any("x / y" in line for line in pluck(3, traceback.extract_tb(tb))) @gen_cluster(client=True) -def test_get_traceback(c, s, a, b): +async def test_get_traceback(c, s, a, b): try: - yield c.get({"x": (div, 1, 0)}, "x", sync=False) + await c.get({"x": (div, 1, 0)}, "x", sync=False) except ZeroDivisionError: exc_type, exc_value, exc_traceback = sys.exc_info() L = traceback.format_tb(exc_traceback) @@ -1471,10 +1467,10 @@ def test_get_traceback(c, s, a, b): @gen_cluster(client=True) -def test_gather_traceback(c, s, a, b): +async def test_gather_traceback(c, s, a, b): x = c.submit(div, 1, 0) try: - yield c.gather(x) + await c.gather(x) except ZeroDivisionError: exc_type, exc_value, exc_traceback = sys.exc_info() L = traceback.format_tb(exc_traceback) @@ -1484,12 +1480,11 @@ def test_gather_traceback(c, s, a, b): def test_traceback_sync(c): x = c.submit(div, 1, 0) tb = x.traceback() - if sys.version_info[0] >= 3: - assert any( - "x / y" in line - for line in concat(traceback.extract_tb(tb)) - if isinstance(line, str) - ) + assert any( + "x / y" in line + for line in concat(traceback.extract_tb(tb)) + if isinstance(line, str) + ) y = c.submit(inc, x) tb2 = y.traceback() @@ -1504,7 +1499,7 @@ def test_traceback_sync(c): @gen_cluster(client=True) -def test_upload_file(c, s, a, b): +async def test_upload_file(c, s, a, b): def g(): import myfile @@ -1513,21 +1508,21 @@ def g(): with save_sys_modules(): for value in [123, 456]: with tmp_text("myfile.py", "def f():\n return {}".format(value)) as fn: - yield c.upload_file(fn) + await c.upload_file(fn) x = c.submit(g, pure=False) - result = yield x + result = await x assert result == value @gen_cluster(client=True) -def test_upload_file_no_extension(c, s, a, b): +async def test_upload_file_no_extension(c, s, a, b): with tmp_text("myfile", "") as fn: - yield c.upload_file(fn) + await c.upload_file(fn) @gen_cluster(client=True) -def test_upload_file_zip(c, s, a, b): +async def test_upload_file_zip(c, s, a, b): def g(): import myfile @@ -1541,10 +1536,10 @@ def g(): ) as fn_my_file: with zipfile.ZipFile("myfile.zip", "w") as z: z.write(fn_my_file, arcname=os.path.basename(fn_my_file)) - yield c.upload_file("myfile.zip") + await c.upload_file("myfile.zip") x = c.submit(g, pure=False) - result = yield x + result = await x assert result == value finally: if os.path.exists("myfile.zip"): @@ -1552,7 +1547,7 @@ def g(): @gen_cluster(client=True) -def test_upload_file_egg(c, s, a, b): +async def test_upload_file_egg(c, s, a, b): def g(): import package_1, package_2 @@ -1601,22 +1596,22 @@ def g(): ][0] egg_path = os.path.join(egg_root, egg_name) - yield c.upload_file(egg_path) + await c.upload_file(egg_path) os.remove(egg_path) x = c.submit(g, pure=False) - result = yield x + result = await x assert result == (value, value) @gen_cluster(client=True) -def test_upload_large_file(c, s, a, b): +async def test_upload_large_file(c, s, a, b): assert a.local_directory assert b.local_directory with tmp_text("myfile", "abc") as fn: with tmp_text("myfile2", "def") as fn2: - yield c._upload_large_file(fn, remote_filename="x") - yield c._upload_large_file(fn2) + await c._upload_large_file(fn, remote_filename="x") + await c._upload_large_file(fn2) for w in [a, b]: assert os.path.exists(os.path.join(w.local_directory, "x")) @@ -1640,10 +1635,10 @@ def g(): @gen_cluster(client=True) -def test_upload_file_exception(c, s, a, b): +async def test_upload_file_exception(c, s, a, b): with tmp_text("myfile.py", "syntax-error!") as fn: with pytest.raises(SyntaxError): - yield c.upload_file(fn) + await c.upload_file(fn) def test_upload_file_exception_sync(c): @@ -1654,29 +1649,29 @@ def test_upload_file_exception_sync(c): @pytest.mark.skip @gen_cluster() -def test_multiple_clients(s, a, b): - a = yield Client(s.address, asynchronous=True) - b = yield Client(s.address, asynchronous=True) +async def test_multiple_clients(s, a, b): + a = await Client(s.address, asynchronous=True) + b = await Client(s.address, asynchronous=True) x = a.submit(inc, 1) y = b.submit(inc, 2) assert x.client is a assert y.client is b - xx = yield x - yy = yield y + xx = await x + yy = await y assert xx == 2 assert yy == 3 z = a.submit(add, x, y) assert z.client is a - zz = yield z + zz = await z assert zz == 5 - yield a.close() - yield b.close() + await a.close() + await b.close() @gen_cluster(client=True) -def test_async_compute(c, s, a, b): +async def test_async_compute(c, s, a, b): from dask.delayed import delayed x = delayed(1) @@ -1688,7 +1683,7 @@ def test_async_compute(c, s, a, b): assert isinstance(zz, Future) assert aa == 3 - result = yield c.gather([yy, zz]) + result = await c.gather([yy, zz]) assert result == [2, 0] assert isinstance(c.compute(y), Future) @@ -1696,8 +1691,8 @@ def test_async_compute(c, s, a, b): @gen_cluster(client=True) -def test_async_compute_with_scatter(c, s, a, b): - d = yield c.scatter({("x", 1): 1, ("y", 1): 2}) +async def test_async_compute_with_scatter(c, s, a, b): + d = await c.scatter({("x", 1): 1, ("y", 1): 2}) x, y = d[("x", 1)], d[("y", 1)] from dask.delayed import delayed @@ -1705,7 +1700,7 @@ def test_async_compute_with_scatter(c, s, a, b): z = delayed(add)(delayed(inc)(x), delayed(inc)(y)) zz = c.compute(z) - [result] = yield c.gather([zz]) + [result] = await c.gather([zz]) assert result == 2 + 3 @@ -1719,22 +1714,22 @@ def test_sync_compute(c): @gen_cluster(client=True) -def test_remote_scatter_gather(c, s, a, b): - x, y, z = yield c.scatter([1, 2, 3]) +async def test_remote_scatter_gather(c, s, a, b): + x, y, z = await c.scatter([1, 2, 3]) assert x.key in a.data or x.key in b.data assert y.key in a.data or y.key in b.data assert z.key in a.data or z.key in b.data - xx, yy, zz = yield c.gather([x, y, z]) + xx, yy, zz = await c.gather([x, y, z]) assert (xx, yy, zz) == (1, 2, 3) @gen_cluster(timeout=1000, client=True) -def test_remote_submit_on_Future(c, s, a, b): +async def test_remote_submit_on_Future(c, s, a, b): x = c.submit(lambda x: x + 1, 1) y = c.submit(lambda x: x + 1, x) - result = yield y + result = await y assert result == 3 @@ -1748,22 +1743,22 @@ def test_start_is_idempotent(c): @gen_cluster(client=True) -def test_client_with_scheduler(c, s, a, b): +async def test_client_with_scheduler(c, s, a, b): assert s.nthreads == {a.address: a.nthreads, b.address: b.nthreads} x = c.submit(inc, 1) y = c.submit(inc, 2) z = c.submit(add, x, y) - result = yield x + result = await x assert result == 1 + 1 - result = yield z + result = await z assert result == 1 + 1 + 1 + 2 - A, B, C = yield c.scatter([1, 2, 3]) - AA, BB, xx = yield c.gather([A, B, x]) + A, B, C = await c.scatter([1, 2, 3]) + AA, BB, xx = await c.gather([A, B, x]) assert (AA, BB, xx) == (1, 2, 2) - result = yield c.get({"x": (inc, 1), "y": (add, "x", 10)}, "y", sync=False) + result = await c.get({"x": (inc, 1), "y": (add, "x", 10)}, "y", sync=False) assert result == 12 @@ -1771,33 +1766,33 @@ def test_client_with_scheduler(c, s, a, b): not sys.platform.startswith("linux"), reason="Need 127.0.0.2 to mean localhost" ) @gen_cluster([("127.0.0.1", 1), ("127.0.0.2", 2)], client=True) -def test_allow_restrictions(c, s, a, b): +async def test_allow_restrictions(c, s, a, b): aws = s.workers[a.address] bws = s.workers[a.address] x = c.submit(inc, 1, workers=a.ip) - yield x + await x assert s.tasks[x.key].who_has == {aws} assert not s.loose_restrictions x = c.submit(inc, 2, workers=a.ip, allow_other_workers=True) - yield x + await x assert s.tasks[x.key].who_has == {aws} assert x.key in s.loose_restrictions L = c.map(inc, range(3, 13), workers=a.ip, allow_other_workers=True) - yield wait(L) + await wait(L) assert all(s.tasks[f.key].who_has == {aws} for f in L) assert {f.key for f in L}.issubset(s.loose_restrictions) x = c.submit(inc, 15, workers="127.0.0.3", allow_other_workers=True) - yield x + await x assert s.tasks[x.key].who_has assert x.key in s.loose_restrictions L = c.map(inc, range(15, 25), workers="127.0.0.3", allow_other_workers=True) - yield wait(L) + await wait(L) assert all(s.tasks[f.key].who_has for f in L) assert {f.key for f in L}.issubset(s.loose_restrictions) @@ -1828,18 +1823,18 @@ def test_bad_address(): @gen_cluster(client=True) -def test_long_error(c, s, a, b): +async def test_long_error(c, s, a, b): def bad(x): raise ValueError("a" * 100000) x = c.submit(bad, 10) try: - yield x + await x except ValueError as e: assert len(str(e)) < 100000 - tb = yield x.traceback() + tb = await x.traceback() assert all( len(line) < 100000 for line in concat(traceback.extract_tb(tb)) @@ -1848,18 +1843,18 @@ def bad(x): @gen_cluster(client=True) -def test_map_on_futures_with_kwargs(c, s, a, b): +async def test_map_on_futures_with_kwargs(c, s, a, b): def f(x, y=10): return x + y futures = c.map(inc, range(10)) futures2 = c.map(f, futures, y=20) - results = yield c.gather(futures2) + results = await c.gather(futures2) assert results == [i + 1 + 20 for i in range(10)] future = c.submit(inc, 100) future2 = c.submit(f, future, y=200) - result = yield future2 + result = await future2 assert result == 100 + 1 + 200 @@ -1883,19 +1878,19 @@ def __setstate__(self, state): @gen_cluster(client=True) -def test_badly_serialized_input(c, s, a, b): +async def test_badly_serialized_input(c, s, a, b): o = BadlySerializedObject() future = c.submit(inc, o) futures = c.map(inc, range(10)) - L = yield c.gather(futures) + L = await c.gather(futures) assert list(L) == list(map(inc, range(10))) assert future.status == "error" @pytest.mark.skipif("True", reason="") -def test_badly_serialized_input_stderr(capsys, c): +async def test_badly_serialized_input_stderr(capsys, c): o = BadlySerializedObject() future = c.submit(inc, o) @@ -1928,37 +1923,37 @@ def test_repr(loop): @gen_cluster(client=True) -def test_repr_async(c, s, a, b): +async def test_repr_async(c, s, a, b): c._repr_html_() @gen_cluster(client=True, worker_kwargs={"memory_limit": None}) -def test_repr_no_memory_limit(c, s, a, b): +async def test_repr_no_memory_limit(c, s, a, b): c._repr_html_() @gen_test() -def test_repr_localcluster(): - cluster = yield LocalCluster( +async def test_repr_localcluster(): + cluster = await LocalCluster( processes=False, dashboard_address=None, asynchronous=True ) - client = yield Client(cluster, asynchronous=True) + client = await Client(cluster, asynchronous=True) try: text = client._repr_html_() assert cluster.scheduler.address in text assert is_valid_xml(client._repr_html_()) finally: - yield client.close() - yield cluster.close() + await client.close() + await cluster.close() @gen_cluster(client=True) -def test_forget_simple(c, s, a, b): +async def test_forget_simple(c, s, a, b): x = c.submit(inc, 1, retries=2) y = c.submit(inc, 2) z = c.submit(add, x, y, workers=[a.ip], allow_other_workers=True) - yield wait([x, y, z]) + await wait([x, y, z]) assert not s.waiting_data.get(x.key) assert not s.waiting_data.get(y.key) @@ -1977,14 +1972,14 @@ def test_forget_simple(c, s, a, b): @gen_cluster(client=True) -def test_forget_complex(e, s, A, B): - a, b, c, d = yield e.scatter(list(range(4))) +async def test_forget_complex(e, s, A, B): + a, b, c, d = await e.scatter(list(range(4))) ab = e.submit(add, a, b) cd = e.submit(add, c, d) ac = e.submit(add, a, c) acab = e.submit(add, ac, ab) - yield wait([a, b, c, d, ab, ac, cd, acab]) + await wait([a, b, c, d, ab, ac, cd, acab]) assert set(s.tasks) == {f.key for f in [ab, ac, cd, acab, a, b, c, d]} @@ -2000,7 +1995,7 @@ def test_forget_complex(e, s, A, B): start = time() while b.key in A.data or b.key in B.data: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 10 s.client_releases_keys(keys=[ac.key], client=e.id) @@ -2008,7 +2003,7 @@ def test_forget_complex(e, s, A, B): @gen_cluster(client=True) -def test_forget_in_flight(e, s, A, B): +async def test_forget_in_flight(e, s, A, B): delayed2 = partial(delayed, pure=True) a, b, c, d = [delayed2(slowinc)(i) for i in range(4)] ab = delayed2(slowadd)(a, b, dask_key_name="ab") @@ -2020,7 +2015,7 @@ def test_forget_in_flight(e, s, A, B): s.validate_state() for i in range(5): - yield gen.sleep(0.01) + await asyncio.sleep(0.01) s.validate_state() s.client_releases_keys(keys=[y.key], client=e.id) @@ -2031,11 +2026,11 @@ def test_forget_in_flight(e, s, A, B): @gen_cluster(client=True) -def test_forget_errors(c, s, a, b): +async def test_forget_errors(c, s, a, b): x = c.submit(div, 1, 0) y = c.submit(inc, x) z = c.submit(inc, y) - yield wait([y]) + await wait([y]) assert x.key in s.exceptions assert x.key in s.exceptions_blame @@ -2074,21 +2069,21 @@ def test_repr_sync(c): @gen_cluster(client=True) -def test_waiting_data(c, s, a, b): +async def test_waiting_data(c, s, a, b): x = c.submit(inc, 1) y = c.submit(inc, 2) z = c.submit(add, x, y, workers=[a.ip], allow_other_workers=True) - yield wait([x, y, z]) + await wait([x, y, z]) assert not s.waiting_data.get(x.key) assert not s.waiting_data.get(y.key) @gen_cluster() -def test_multi_client(s, a, b): - c = yield Client(s.address, asynchronous=True) - f = yield Client(s.address, asynchronous=True) +async def test_multi_client(s, a, b): + c = await Client(s.address, asynchronous=True) + f = await Client(s.address, asynchronous=True) assert set(s.client_comms) == {c.id, f.id} @@ -2098,7 +2093,7 @@ def test_multi_client(s, a, b): assert y.key == y2.key - yield wait([x, y]) + await wait([x, y]) assert s.wants_what == { c.id: {x.key, y.key}, @@ -2107,22 +2102,22 @@ def test_multi_client(s, a, b): } assert s.who_wants == {x.key: {c.id}, y.key: {c.id, f.id}} - yield c.close() + await c.close() start = time() while c.id in s.wants_what: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 5 assert c.id not in s.wants_what assert c.id not in s.who_wants[y.key] assert x.key not in s.who_wants - yield f.close() + await f.close() start = time() while s.tasks: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 2, s.tasks @@ -2135,29 +2130,29 @@ def long_running_client_connection(address): @gen_cluster() -def test_cleanup_after_broken_client_connection(s, a, b): +async def test_cleanup_after_broken_client_connection(s, a, b): proc = mp_context.Process(target=long_running_client_connection, args=(s.address,)) proc.daemon = True proc.start() start = time() while not s.tasks: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 5 proc.terminate() start = time() while s.tasks: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 5 @gen_cluster() -def test_multi_garbage_collection(s, a, b): - c = yield Client(s.address, asynchronous=True) +async def test_multi_garbage_collection(s, a, b): + c = await Client(s.address, asynchronous=True) - f = yield Client(s.address, asynchronous=True) + f = await Client(s.address, asynchronous=True) x = c.submit(inc, 1) y = f.submit(inc, 2) @@ -2165,12 +2160,12 @@ def test_multi_garbage_collection(s, a, b): assert y.key == y2.key - yield wait([x, y]) + await wait([x, y]) x.__del__() start = time() while x.key in a.data or x.key in b.data: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 5 assert s.wants_what == {c.id: {y.key}, f.id: {y.key}, "fire-and-forget": set()} @@ -2179,10 +2174,10 @@ def test_multi_garbage_collection(s, a, b): y.__del__() start = time() while x.key in s.wants_what[f.id]: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 5 - yield gen.sleep(0.1) + await asyncio.sleep(0.1) assert y.key in a.data or y.key in b.data assert s.wants_what == {c.id: {y.key}, f.id: set(), "fire-and-forget": set()} assert s.who_wants == {y.key: {c.id}} @@ -2190,32 +2185,32 @@ def test_multi_garbage_collection(s, a, b): y2.__del__() start = time() while y.key in a.data or y.key in b.data: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 5 assert not any(v for v in s.wants_what.values()) assert not s.who_wants - yield c.close() - yield f.close() + await c.close() + await f.close() @gen_cluster(client=True) -def test__broadcast(c, s, a, b): - x, y = yield c.scatter([1, 2], broadcast=True) +async def test__broadcast(c, s, a, b): + x, y = await c.scatter([1, 2], broadcast=True) assert a.data == b.data == {x.key: 1, y.key: 2} @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 4) -def test__broadcast_integer(c, s, *workers): - x, y = yield c.scatter([1, 2], broadcast=2) +async def test__broadcast_integer(c, s, *workers): + x, y = await c.scatter([1, 2], broadcast=2) assert len(s.tasks[x.key].who_has) == 2 assert len(s.tasks[y.key].who_has) == 2 @gen_cluster(client=True) -def test__broadcast_dict(c, s, a, b): - d = yield c.scatter({"x": 1}, broadcast=True) +async def test__broadcast_dict(c, s, a, b): + d = await c.scatter({"x": 1}, broadcast=True) assert a.data == b.data == {"x": 1} @@ -2239,20 +2234,20 @@ def test_broadcast(c, s, a, b): @gen_cluster(client=True) -def test_proxy(c, s, a, b): - msg = yield c.scheduler.proxy(msg={"op": "identity"}, worker=a.address) +async def test_proxy(c, s, a, b): + msg = await c.scheduler.proxy(msg={"op": "identity"}, worker=a.address) assert msg["id"] == a.identity()["id"] @gen_cluster(client=True) -def test__cancel(c, s, a, b): +async def test__cancel(c, s, a, b): x = c.submit(slowinc, 1) y = c.submit(slowinc, x) while y.key not in s.tasks: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) - yield c.cancel([x]) + await c.cancel([x]) assert x.cancelled() assert "cancel" in str(x) @@ -2260,7 +2255,7 @@ def test__cancel(c, s, a, b): start = time() while not y.cancelled(): - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 5 assert not s.tasks @@ -2268,54 +2263,56 @@ def test__cancel(c, s, a, b): @gen_cluster(client=True) -def test__cancel_tuple_key(c, s, a, b): +async def test_cancel_tuple_key(c, s, a, b): x = c.submit(inc, 1, key=("x", 0, 1)) - - result = yield x - yield c.cancel(x) + await x + await c.cancel(x) with pytest.raises(CancelledError): - yield x + await x @gen_cluster() -def test__cancel_multi_client(s, a, b): - c = yield Client(s.address, asynchronous=True) - f = yield Client(s.address, asynchronous=True) +async def test_cancel_multi_client(s, a, b): + c = await Client(s.address, asynchronous=True) + f = await Client(s.address, asynchronous=True) x = c.submit(slowinc, 1) y = f.submit(slowinc, 1) assert x.key == y.key - yield c.cancel([x]) + await c.cancel([x]) assert x.cancelled() assert not y.cancelled() start = time() while y.key not in s.tasks: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 5 - out = yield y + out = await y assert out == 2 with pytest.raises(CancelledError): - yield x + await x - yield c.close() - yield f.close() + await c.close() + await f.close() @gen_cluster(client=True) -def test__cancel_collection(c, s, a, b): +async def test_cancel_collection(c, s, a, b): L = c.map(double, [[1], [2], [3]]) x = db.Bag({("b", i): f for i, f in enumerate(L)}, "b", 3) - yield c.cancel(x) - yield c.cancel([x]) + await c.cancel(x) + await c.cancel([x]) assert all(f.cancelled() for f in L) - assert not s.tasks + start = time() + while s.tasks: + assert time() < start + 1 + await asyncio.sleep(0.01) def test_cancel(c): @@ -2337,18 +2334,18 @@ def test_cancel(c): @gen_cluster(client=True) -def test_future_type(c, s, a, b): +async def test_future_type(c, s, a, b): x = c.submit(inc, 1) - yield wait([x]) + await wait([x]) assert x.type == int assert "int" in str(x) @gen_cluster(client=True) -def test_traceback_clean(c, s, a, b): +async def test_traceback_clean(c, s, a, b): x = c.submit(div, 1, 0) try: - yield x + await x except Exception as e: f = e exc_type, exc_value, tb = sys.exc_info() @@ -2359,7 +2356,7 @@ def test_traceback_clean(c, s, a, b): @gen_cluster(client=True) -def test_map_differnet_lengths(c, s, a, b): +async def test_map_differnet_lengths(c, s, a, b): assert len(c.map(add, [1, 2], [1, 2, 3])) == 2 @@ -2375,7 +2372,7 @@ def test_Future_exception_sync_2(loop, capsys): @gen_cluster(timeout=60, client=True) -def test_async_persist(c, s, a, b): +async def test_async_persist(c, s, a, b): from dask.delayed import delayed, Delayed x = delayed(1) @@ -2393,13 +2390,13 @@ def test_async_persist(c, s, a, b): assert w.__dask_keys__() == ww.__dask_keys__() while y.key not in s.tasks and w.key not in s.tasks: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert s.who_wants[y.key] == {c.id} assert s.who_wants[w.key] == {c.id} yyf, wwf = c.compute([yy, ww]) - yyy, www = yield c.gather([yyf, wwf]) + yyy, www = await c.gather([yyf, wwf]) assert yyy == inc(1) assert www == add(inc(1), dec(1)) @@ -2408,7 +2405,7 @@ def test_async_persist(c, s, a, b): @gen_cluster(client=True) -def test__persist(c, s, a, b): +async def test__persist(c, s, a, b): pytest.importorskip("dask.array") import dask.array as da @@ -2424,7 +2421,7 @@ def test__persist(c, s, a, b): g, h = c.compute([y, yy]) - gg, hh = yield c.gather([g, h]) + gg, hh = await c.gather([g, h]) assert (gg == hh).all() @@ -2447,7 +2444,7 @@ def test_persist(c): @gen_cluster(timeout=60, client=True) -def test_long_traceback(c, s, a, b): +async def test_long_traceback(c, s, a, b): from distributed.protocol.pickle import dumps def deep(n): @@ -2457,22 +2454,22 @@ def deep(n): return deep(n - 1) x = c.submit(deep, 200) - yield wait([x]) + await wait([x]) assert len(dumps(c.futures[x.key].traceback)) < 10000 assert isinstance(c.futures[x.key].exception, ZeroDivisionError) @gen_cluster(client=True) -def test_wait_on_collections(c, s, a, b): +async def test_wait_on_collections(c, s, a, b): L = c.map(double, [[1], [2], [3]]) x = db.Bag({("b", i): f for i, f in enumerate(L)}, "b", 3) - yield wait(x) + await wait(x) assert all(f.key in a.data or f.key in b.data for f in L) @gen_cluster(client=True) -def test_futures_of_get(c, s, a, b): +async def test_futures_of_get(c, s, a, b): x, y, z = c.map(inc, [1, 2, 3]) assert set(futures_of(0)) == set() @@ -2498,15 +2495,15 @@ def test_futures_of_class(): @gen_cluster(client=True) -def test_futures_of_cancelled_raises(c, s, a, b): +async def test_futures_of_cancelled_raises(c, s, a, b): x = c.submit(inc, 1) - yield c.cancel([x]) + await c.cancel([x]) with pytest.raises(CancelledError): - yield x + await x with pytest.raises(CancelledError): - yield c.get({"x": (inc, x), "y": (inc, 2)}, ["x", "y"], sync=False) + await c.get({"x": (inc, x), "y": (inc, 2)}, ["x", "y"], sync=False) with pytest.raises(CancelledError): c.submit(inc, x) @@ -2522,69 +2519,69 @@ def test_futures_of_cancelled_raises(c, s, a, b): @pytest.mark.skip @gen_cluster(nthreads=[("127.0.0.1", 1)], client=True) -def test_dont_delete_recomputed_results(c, s, w): +async def test_dont_delete_recomputed_results(c, s, w): x = c.submit(inc, 1) # compute first time - yield wait([x]) + await wait([x]) x.__del__() # trigger garbage collection - yield gen.moment + await asyncio.sleep(0) xx = c.submit(inc, 1) # compute second time start = time() while xx.key not in w.data: # data shows up - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 1 while time() < start + (s.delete_interval + 100) / 1000: # and stays assert xx.key in w.data - yield gen.sleep(0.01) + await asyncio.sleep(0.01) @gen_cluster(nthreads=[], client=True) -def test_fatally_serialized_input(c, s): +async def test_fatally_serialized_input(c, s): o = FatallySerializedObject() future = c.submit(inc, o) while not s.tasks: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) @pytest.mark.skip(reason="Use fast random selection now") @gen_cluster(client=True) -def test_balance_tasks_by_stacks(c, s, a, b): +async def test_balance_tasks_by_stacks(c, s, a, b): x = c.submit(inc, 1) - yield wait(x) + await wait(x) y = c.submit(inc, 2) - yield wait(y) + await wait(y) assert len(a.data) == len(b.data) == 1 @gen_cluster(client=True) -def test_run(c, s, a, b): - results = yield c.run(inc, 1) +async def test_run(c, s, a, b): + results = await c.run(inc, 1) assert results == {a.address: 2, b.address: 2} - results = yield c.run(inc, 1, workers=[a.address]) + results = await c.run(inc, 1, workers=[a.address]) assert results == {a.address: 2} - results = yield c.run(inc, 1, workers=[]) + results = await c.run(inc, 1, workers=[]) assert results == {} @gen_cluster(client=True) -def test_run_handles_picklable_data(c, s, a, b): +async def test_run_handles_picklable_data(c, s, a, b): futures = c.map(inc, range(10)) - yield wait(futures) + await wait(futures) def func(): return {}, set(), [], (), 1, "hello", b"100" - results = yield c.run_on_scheduler(func) + results = await c.run_on_scheduler(func) assert results == func() - results = yield c.run(func) + results = await c.run(func) assert results == {w.address: func() for w in [a, b]} @@ -2600,22 +2597,21 @@ def func(x, y=10): @gen_cluster(client=True) -def test_run_coroutine(c, s, a, b): - results = yield c.run(geninc, 1, delay=0.05) +async def test_run_coroutine(c, s, a, b): + results = await c.run(geninc, 1, delay=0.05) assert results == {a.address: 2, b.address: 2} - results = yield c.run(geninc, 1, delay=0.05, workers=[a.address]) + results = await c.run(geninc, 1, delay=0.05, workers=[a.address]) assert results == {a.address: 2} - results = yield c.run(geninc, 1, workers=[]) + results = await c.run(geninc, 1, workers=[]) assert results == {} with pytest.raises(RuntimeError, match="hello"): - yield c.run(throws, 1) + await c.run(throws, 1) - if sys.version_info >= (3, 5): - results = yield c.run(asyncinc, 2, delay=0.01) - assert results == {a.address: 3, b.address: 3} + results = await c.run(asyncinc, 2, delay=0.01) + assert results == {a.address: 3, b.address: 3} def test_run_coroutine_sync(c, s, a, b): @@ -2692,39 +2688,38 @@ def test_diagnostic_nbytes_sync(c): @gen_cluster(client=True) -def test_diagnostic_nbytes(c, s, a, b): +async def test_diagnostic_nbytes(c, s, a, b): incs = c.map(inc, [1, 2, 3]) doubles = c.map(double, [1, 2, 3]) - yield wait(incs + doubles) + await wait(incs + doubles) assert s.get_nbytes(summary=False) == {k.key: sizeof(1) for k in incs + doubles} assert s.get_nbytes(summary=True) == {"inc": sizeof(1) * 3, "double": sizeof(1) * 3} @gen_test() -def test_worker_aliases(): - s = yield Scheduler(validate=True, port=0) +async def test_worker_aliases(): + s = await Scheduler(validate=True, port=0) a = Worker(s.address, name="alice") b = Worker(s.address, name="bob") w = Worker(s.address, name=3) - yield [a, b, w] - - c = yield Client(s.address, asynchronous=True) + await asyncio.gather(a, b, w) + c = await Client(s.address, asynchronous=True) L = c.map(inc, range(10), workers="alice") - future = yield c.scatter(123, workers=3) - yield wait(L) + future = await c.scatter(123, workers=3) + await wait(L) assert len(a.data) == 10 assert len(b.data) == 0 assert dict(w.data) == {future.key: 123} for i, alias in enumerate([3, [3], "alice"]): - result = yield c.submit(lambda x: x + 1, i, workers=alias) + result = await c.submit(lambda x: x + 1, i, workers=alias) assert result == i + 1 - yield c.close() - yield [a.close(), b.close(), w.close()] - yield s.close() + await c.close() + await asyncio.gather(a.close(), b.close(), w.close()) + await s.close() def test_persist_get_sync(c): @@ -2741,7 +2736,7 @@ def test_persist_get_sync(c): @gen_cluster(client=True) -def test_persist_get(c, s, a, b): +async def test_persist_get(c, s, a, b): dadd = delayed(add) x, y = delayed(1), delayed(2) xx = delayed(add)(x, x) @@ -2751,17 +2746,17 @@ def test_persist_get(c, s, a, b): xxyy2 = c.persist(xxyy) xxyy3 = delayed(add)(xxyy2, 10) - yield gen.sleep(0.5) - result = yield c.get(xxyy3.dask, xxyy3.__dask_keys__(), sync=False) + await asyncio.sleep(0.5) + result = await c.gather(c.get(xxyy3.dask, xxyy3.__dask_keys__(), sync=False)) assert result[0] == ((1 + 1) + (2 + 2)) + 10 - result = yield c.compute(xxyy3) + result = await c.compute(xxyy3) assert result == ((1 + 1) + (2 + 2)) + 10 - result = yield c.compute(xxyy3) + result = await c.compute(xxyy3) assert result == ((1 + 1) + (2 + 2)) + 10 - result = yield c.compute(xxyy3) + result = await c.compute(xxyy3) assert result == ((1 + 1) + (2 + 2)) + 10 @@ -2782,12 +2777,12 @@ def test_client_num_fds(loop): @gen_cluster() -def test_startup_close_startup(s, a, b): - c = yield Client(s.address, asynchronous=True) - yield c.close() +async def test_startup_close_startup(s, a, b): + c = await Client(s.address, asynchronous=True) + await c.close() - c = yield Client(s.address, asynchronous=True) - yield c.close() + c = await Client(s.address, asynchronous=True) + await c.close() def test_startup_close_startup_sync(loop): @@ -2804,7 +2799,7 @@ def test_startup_close_startup_sync(loop): @gen_cluster(client=True) -def test_badly_serialized_exceptions(c, s, a, b): +async def test_badly_serialized_exceptions(c, s, a, b): def f(): class BadlySerializedException(Exception): def __reduce__(self): @@ -2815,7 +2810,7 @@ def __reduce__(self): x = c.submit(f) try: - result = yield x + result = await x except Exception as e: assert "hello world" in str(e) else: @@ -2823,16 +2818,16 @@ def __reduce__(self): @gen_cluster(client=True) -def test_rebalance(c, s, a, b): +async def test_rebalance(c, s, a, b): aws = s.workers[a.address] bws = s.workers[b.address] - x, y = yield c.scatter([1, 2], workers=[a.address]) + x, y = await c.scatter([1, 2], workers=[a.address]) assert len(a.data) == 2 assert len(b.data) == 0 s.validate_state() - yield c.rebalance() + await c.rebalance() s.validate_state() assert len(b.data) == 1 @@ -2845,21 +2840,21 @@ def test_rebalance(c, s, a, b): @gen_cluster(nthreads=[("127.0.0.1", 1)] * 4, client=True) -def test_rebalance_workers(e, s, a, b, c, d): - w, x, y, z = yield e.scatter([1, 2, 3, 4], workers=[a.address]) +async def test_rebalance_workers(e, s, a, b, c, d): + w, x, y, z = await e.scatter([1, 2, 3, 4], workers=[a.address]) assert len(a.data) == 4 assert len(b.data) == 0 assert len(c.data) == 0 assert len(d.data) == 0 - yield e.rebalance([x, y], workers=[a.address, c.address]) + await e.rebalance([x, y], workers=[a.address, c.address]) assert len(a.data) == 3 assert len(b.data) == 0 assert len(c.data) == 1 assert len(d.data) == 0 assert c.data == {x.key: 2} or c.data == {y.key: 3} - yield e.rebalance() + await e.rebalance() assert len(a.data) == 1 assert len(b.data) == 1 assert len(c.data) == 1 @@ -2868,9 +2863,9 @@ def test_rebalance_workers(e, s, a, b, c, d): @gen_cluster(client=True) -def test_rebalance_execution(c, s, a, b): +async def test_rebalance_execution(c, s, a, b): futures = c.map(inc, range(10), workers=a.address) - yield c.rebalance(futures) + await c.rebalance(futures) assert len(a.data) == len(b.data) == 5 s.validate_state() @@ -2885,10 +2880,10 @@ def test_rebalance_sync(c, s, a, b): @gen_cluster(client=True) -def test_rebalance_unprepared(c, s, a, b): +async def test_rebalance_unprepared(c, s, a, b): futures = c.map(slowinc, range(10), delay=0.05, workers=a.address) - yield gen.sleep(0.1) - yield c.rebalance(futures) + await asyncio.sleep(0.1) + await c.rebalance(futures) s.validate_state() @@ -2902,66 +2897,63 @@ async def test_rebalance_raises_missing_data(c, s, a, b): @gen_cluster(client=True) -def test_receive_lost_key(c, s, a, b): +async def test_receive_lost_key(c, s, a, b): x = c.submit(inc, 1, workers=[a.address]) - result = yield x - yield a.close() + await x + await a.close() start = time() while x.status == "finished": assert time() < start + 5 - yield gen.sleep(0.01) + await asyncio.sleep(0.01) @pytest.mark.skipif( not sys.platform.startswith("linux"), reason="Need 127.0.0.2 to mean localhost" ) @gen_cluster([("127.0.0.1", 1), ("127.0.0.2", 2)], client=True) -def test_unrunnable_task_runs(c, s, a, b): +async def test_unrunnable_task_runs(c, s, a, b): x = c.submit(inc, 1, workers=[a.ip]) - result = yield x + await x - yield a.close() + await a.close() start = time() while x.status == "finished": assert time() < start + 5 - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert s.tasks[x.key] in s.unrunnable assert s.get_task_status(keys=[x.key]) == {x.key: "no-worker"} - w = yield Worker(s.address, loop=s.loop) + w = await Worker(s.address, loop=s.loop) start = time() while x.status != "finished": assert time() < start + 2 - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert s.tasks[x.key] not in s.unrunnable - result = yield x + result = await x assert result == 2 - yield w.close() + await w.close() @gen_cluster(client=True, nthreads=[]) -def test_add_worker_after_tasks(c, s): +async def test_add_worker_after_tasks(c, s): futures = c.map(inc, range(10)) - - n = yield Nanny(s.address, nthreads=2, loop=s.loop, port=0) - - result = yield c.gather(futures) - - yield n.close() + n = await Nanny(s.address, nthreads=2, loop=s.loop, port=0) + await c.gather(futures) + await n.close() @pytest.mark.skipif( not sys.platform.startswith("linux"), reason="Need 127.0.0.2 to mean localhost" ) @gen_cluster([("127.0.0.1", 1), ("127.0.0.2", 2)], client=True) -def test_workers_register_indirect_data(c, s, a, b): - [x] = yield c.scatter([1], workers=a.address) +async def test_workers_register_indirect_data(c, s, a, b): + [x] = await c.scatter([1], workers=a.address) y = c.submit(inc, x, workers=b.ip) - yield y + await y assert b.data[x.key] == 1 assert s.tasks[x.key].who_has == {s.workers[a.address], s.workers[b.address]} assert s.workers[b.address].has_what == {s.tasks[x.key], s.tasks[y.key]} @@ -2969,20 +2961,20 @@ def test_workers_register_indirect_data(c, s, a, b): @gen_cluster(client=True) -def test_submit_on_cancelled_future(c, s, a, b): +async def test_submit_on_cancelled_future(c, s, a, b): x = c.submit(inc, 1) - yield x + await x - yield c.cancel(x) + await c.cancel(x) with pytest.raises(CancelledError): - y = c.submit(inc, x) + c.submit(inc, x) @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 10) -def test_replicate(c, s, *workers): - [a, b] = yield c.scatter([1, 2]) - yield s.replicate(keys=[a.key, b.key], n=5) +async def test_replicate(c, s, *workers): + [a, b] = await c.scatter([1, 2]) + await s.replicate(keys=[a.key, b.key], n=5) s.validate_state() assert len(s.tasks[a.key].who_has) == 5 @@ -2993,22 +2985,22 @@ def test_replicate(c, s, *workers): @gen_cluster(client=True) -def test_replicate_tuple_keys(c, s, a, b): +async def test_replicate_tuple_keys(c, s, a, b): x = delayed(inc)(1, dask_key_name=("x", 1)) f = c.persist(x) - yield c.replicate(f, n=5) + await c.replicate(f, n=5) s.validate_state() assert a.data and b.data - yield c.rebalance(f) + await c.rebalance(f) s.validate_state() @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 10) -def test_replicate_workers(c, s, *workers): +async def test_replicate_workers(c, s, *workers): - [a, b] = yield c.scatter([1, 2], workers=[workers[0].address]) - yield s.replicate( + [a, b] = await c.scatter([1, 2], workers=[workers[0].address]) + await s.replicate( keys=[a.key, b.key], n=5, workers=[w.address for w in workers[:5]] ) @@ -3020,7 +3012,7 @@ def test_replicate_workers(c, s, *workers): assert sum(a.key in w.data for w in workers[5:]) == 0 assert sum(b.key in w.data for w in workers[5:]) == 0 - yield s.replicate(keys=[a.key, b.key], n=1) + await s.replicate(keys=[a.key, b.key], n=1) assert len(s.tasks[a.key].who_has) == 1 assert len(s.tasks[b.key].who_has) == 1 @@ -3029,12 +3021,12 @@ def test_replicate_workers(c, s, *workers): s.validate_state() - yield s.replicate(keys=[a.key, b.key], n=None) # all + await s.replicate(keys=[a.key, b.key], n=None) # all assert len(s.tasks[a.key].who_has) == 10 assert len(s.tasks[b.key].who_has) == 10 s.validate_state() - yield s.replicate( + await s.replicate( keys=[a.key, b.key], n=1, workers=[w.address for w in workers[:5]] ) assert sum(a.key in w.data for w in workers[:5]) == 1 @@ -3056,30 +3048,30 @@ def __getstate__(self): @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 10) -def test_replicate_tree_branching(c, s, *workers): +async def test_replicate_tree_branching(c, s, *workers): obj = CountSerialization() - [future] = yield c.scatter([obj]) - yield s.replicate(keys=[future.key], n=10) + [future] = await c.scatter([obj]) + await s.replicate(keys=[future.key], n=10) max_count = max(w.data[future.key].n for w in workers) assert max_count > 1 @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 10) -def test_client_replicate(c, s, *workers): +async def test_client_replicate(c, s, *workers): x = c.submit(inc, 1) y = c.submit(inc, 2) - yield c.replicate([x, y], n=5) + await c.replicate([x, y], n=5) assert len(s.tasks[x.key].who_has) == 5 assert len(s.tasks[y.key].who_has) == 5 - yield c.replicate([x, y], n=3) + await c.replicate([x, y], n=3) assert len(s.tasks[x.key].who_has) == 3 assert len(s.tasks[y.key].who_has) == 3 - yield c.replicate([x, y]) + await c.replicate([x, y]) s.validate_state() assert len(s.tasks[x.key].who_has) == 10 @@ -3094,19 +3086,19 @@ def test_client_replicate(c, s, *workers): nthreads=[("127.0.0.1", 1), ("127.0.0.2", 1), ("127.0.0.2", 1)], timeout=None, ) -def test_client_replicate_host(client, s, a, b, c): +async def test_client_replicate_host(client, s, a, b, c): aws = s.workers[a.address] bws = s.workers[b.address] cws = s.workers[c.address] x = client.submit(inc, 1, workers="127.0.0.2") - yield wait([x]) + await wait([x]) assert s.tasks[x.key].who_has == {bws} or s.tasks[x.key].who_has == {cws} - yield client.replicate([x], workers=["127.0.0.2"]) + await client.replicate([x], workers=["127.0.0.2"]) assert s.tasks[x.key].who_has == {bws, cws} - yield client.replicate([x], workers=["127.0.0.1"]) + await client.replicate([x], workers=["127.0.0.1"]) assert s.tasks[x.key].who_has == {aws, bws, cws} @@ -3126,25 +3118,25 @@ def test_client_replicate_sync(c): @pytest.mark.skipif(WINDOWS, reason="Windows timer too coarse-grained") @gen_cluster(client=True, nthreads=[("127.0.0.1", 4)] * 1) -def test_task_load_adapts_quickly(c, s, a): +async def test_task_load_adapts_quickly(c, s, a): future = c.submit(slowinc, 1, delay=0.2) # slow - yield wait(future) + await wait(future) assert 0.15 < s.task_prefixes["slowinc"].duration_average < 0.4 futures = c.map(slowinc, range(10), delay=0) # very fast - yield wait(futures) + await wait(futures) assert 0 < s.task_prefixes["slowinc"].duration_average < 0.1 @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2) -def test_even_load_after_fast_functions(c, s, a, b): +async def test_even_load_after_fast_functions(c, s, a, b): x = c.submit(inc, 1, workers=a.address) # very fast y = c.submit(inc, 2, workers=b.address) # very fast - yield wait([x, y]) + await wait([x, y]) futures = c.map(inc, range(2, 11)) - yield wait(futures) + await wait(futures) assert any(f.key in a.data for f in futures) assert any(f.key in b.data for f in futures) @@ -3152,17 +3144,17 @@ def test_even_load_after_fast_functions(c, s, a, b): @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2) -def test_even_load_on_startup(c, s, a, b): +async def test_even_load_on_startup(c, s, a, b): x, y = c.map(inc, [1, 2]) - yield wait([x, y]) + await wait([x, y]) assert len(a.data) == len(b.data) == 1 @pytest.mark.skip @gen_cluster(client=True, nthreads=[("127.0.0.1", 2)] * 2) -def test_contiguous_load(c, s, a, b): +async def test_contiguous_load(c, s, a, b): w, x, y, z = c.map(inc, [1, 2, 3, 4]) - yield wait([w, x, y, z]) + await wait([w, x, y, z]) groups = [set(a.data), set(b.data)] assert {w.key, x.key} in groups @@ -3170,24 +3162,24 @@ def test_contiguous_load(c, s, a, b): @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 4) -def test_balanced_with_submit(c, s, *workers): +async def test_balanced_with_submit(c, s, *workers): L = [c.submit(slowinc, i) for i in range(4)] - yield wait(L) + await wait(L) for w in workers: assert len(w.data) == 1 @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 4) -def test_balanced_with_submit_and_resident_data(c, s, *workers): - [x] = yield c.scatter([10], broadcast=True) +async def test_balanced_with_submit_and_resident_data(c, s, *workers): + [x] = await c.scatter([10], broadcast=True) L = [c.submit(slowinc, x, pure=False) for i in range(4)] - yield wait(L) + await wait(L) for w in workers: assert len(w.data) == 2 @gen_cluster(client=True, nthreads=[("127.0.0.1", 20)] * 2) -def test_scheduler_saturates_cores(c, s, a, b): +async def test_scheduler_saturates_cores(c, s, a, b): for delay in [0, 0.01, 0.1]: futures = c.map(slowinc, range(100), delay=delay) futures = c.map(slowinc, futures, delay=delay / 10) @@ -3198,11 +3190,11 @@ def test_scheduler_saturates_cores(c, s, a, b): for w in s.workers.values() for p in w.processing.values() ) - yield gen.sleep(0.01) + await asyncio.sleep(0.01) @gen_cluster(client=True, nthreads=[("127.0.0.1", 20)] * 2) -def test_scheduler_saturates_cores_random(c, s, a, b): +async def test_scheduler_saturates_cores_random(c, s, a, b): for delay in [0, 0.01, 0.1]: futures = c.map(randominc, range(100), scale=0.1) while not s.tasks: @@ -3212,22 +3204,22 @@ def test_scheduler_saturates_cores_random(c, s, a, b): for w in s.workers.values() for p in w.processing.values() ) - yield gen.sleep(0.01) + await asyncio.sleep(0.01) @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 4) -def test_cancel_clears_processing(c, s, *workers): +async def test_cancel_clears_processing(c, s, *workers): da = pytest.importorskip("dask.array") x = c.submit(slowinc, 1, delay=0.2) while not s.tasks: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) - yield c.cancel(x) + await c.cancel(x) start = time() while any(v for w in s.workers.values() for v in w.processing): assert time() < start + 0.2 - yield gen.sleep(0.01) + await asyncio.sleep(0.01) s.validate_state() @@ -3270,50 +3262,50 @@ def test_default_get(): @gen_cluster(client=True) -def test_get_processing(c, s, a, b): - processing = yield c.processing() +async def test_get_processing(c, s, a, b): + processing = await c.processing() assert processing == valmap(tuple, s.processing) futures = c.map( slowinc, range(10), delay=0.1, workers=[a.address], allow_other_workers=True ) - yield gen.sleep(0.2) + await asyncio.sleep(0.2) - x = yield c.processing() + x = await c.processing() assert set(x) == {a.address, b.address} - x = yield c.processing(workers=[a.address]) + x = await c.processing(workers=[a.address]) assert isinstance(x[a.address], (list, tuple)) @gen_cluster(client=True) -def test_get_foo(c, s, a, b): +async def test_get_foo(c, s, a, b): futures = c.map(inc, range(10)) - yield wait(futures) + await wait(futures) - x = yield c.scheduler.ncores() + x = await c.scheduler.ncores() assert x == s.nthreads - x = yield c.scheduler.ncores(workers=[a.address]) + x = await c.scheduler.ncores(workers=[a.address]) assert x == {a.address: s.nthreads[a.address]} - x = yield c.scheduler.has_what() + x = await c.scheduler.has_what() assert valmap(sorted, x) == valmap(sorted, s.has_what) - x = yield c.scheduler.has_what(workers=[a.address]) + x = await c.scheduler.has_what(workers=[a.address]) assert valmap(sorted, x) == {a.address: sorted(s.has_what[a.address])} - x = yield c.scheduler.nbytes(summary=False) + x = await c.scheduler.nbytes(summary=False) assert x == s.get_nbytes(summary=False) - x = yield c.scheduler.nbytes(keys=[futures[0].key], summary=False) + x = await c.scheduler.nbytes(keys=[futures[0].key], summary=False) assert x == {futures[0].key: s.tasks[futures[0].key].nbytes} - x = yield c.scheduler.who_has() + x = await c.scheduler.who_has() assert valmap(sorted, x) == valmap(sorted, s.who_has) - x = yield c.scheduler.who_has(keys=[futures[0].key]) + x = await c.scheduler.who_has(keys=[futures[0].key]) assert valmap(sorted, x) == {futures[0].key: sorted(s.who_has[futures[0].key])} @@ -3326,34 +3318,34 @@ def assert_dict_key_equal(expected, actual): @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3) -def test_get_foo_lost_keys(c, s, u, v, w): +async def test_get_foo_lost_keys(c, s, u, v, w): x = c.submit(inc, 1, workers=[u.address]) - y = yield c.scatter(3, workers=[v.address]) - yield wait([x, y]) + y = await c.scatter(3, workers=[v.address]) + await wait([x, y]) ua, va, wa = u.address, v.address, w.address - d = yield c.scheduler.has_what() + d = await c.scheduler.has_what() assert_dict_key_equal(d, {ua: [x.key], va: [y.key], wa: []}) - d = yield c.scheduler.has_what(workers=[ua, va]) + d = await c.scheduler.has_what(workers=[ua, va]) assert_dict_key_equal(d, {ua: [x.key], va: [y.key]}) - d = yield c.scheduler.who_has() + d = await c.scheduler.who_has() assert_dict_key_equal(d, {x.key: [ua], y.key: [va]}) - d = yield c.scheduler.who_has(keys=[x.key, y.key]) + d = await c.scheduler.who_has(keys=[x.key, y.key]) assert_dict_key_equal(d, {x.key: [ua], y.key: [va]}) - yield u.close() - yield v.close() + await u.close() + await v.close() - d = yield c.scheduler.has_what() + d = await c.scheduler.has_what() assert_dict_key_equal(d, {wa: []}) - d = yield c.scheduler.has_what(workers=[ua, va]) + d = await c.scheduler.has_what(workers=[ua, va]) assert_dict_key_equal(d, {ua: [], va: []}) # The scattered key cannot be recomputed so it is forgotten - d = yield c.scheduler.who_has() + d = await c.scheduler.who_has() assert_dict_key_equal(d, {x.key: []}) # ... but when passed explicitly, it is included in the result - d = yield c.scheduler.who_has(keys=[x.key, y.key]) + d = await c.scheduler.who_has(keys=[x.key, y.key]) assert_dict_key_equal(d, {x.key: [], y.key: []}) @@ -3361,13 +3353,13 @@ def test_get_foo_lost_keys(c, s, u, v, w): @gen_cluster( client=True, Worker=Nanny, clean_kwargs={"threads": False, "processes": False} ) -def test_bad_tasks_fail(c, s, a, b): +async def test_bad_tasks_fail(c, s, a, b): f = c.submit(sys.exit, 0) with pytest.raises(KilledWorker) as info: - yield f + await f assert info.value.last_worker.nanny in {a.address, b.address} - yield [a.close(), b.close()] + await asyncio.gather(a.close(), b.close()) def test_get_processing_sync(c, s, a, b): @@ -3417,11 +3409,11 @@ def test_get_returns_early(c): @pytest.mark.slow @gen_cluster(Worker=Nanny, client=True) -def test_Client_clears_references_after_restart(c, s, a, b): +async def test_Client_clears_references_after_restart(c, s, a, b): x = c.submit(inc, 1) assert x.key in c.refcount - yield c.restart() + await c.restart() assert x.key not in c.refcount key = x.key @@ -3429,7 +3421,7 @@ def test_Client_clears_references_after_restart(c, s, a, b): import gc gc.collect() - yield gen.moment + await asyncio.sleep(0) assert key not in c.refcount @@ -3487,21 +3479,21 @@ def test_as_completed_next_batch(c): @gen_test() -def test_status(): - s = yield Scheduler(port=0) +async def test_status(): + s = await Scheduler(port=0) - c = yield Client(s.address, asynchronous=True) + c = await Client(s.address, asynchronous=True) assert c.status == "running" x = c.submit(inc, 1) - yield c.close() + await c.close() assert c.status == "closed" - yield s.close() + await s.close() @gen_cluster(client=True) -def test_persist_optimize_graph(c, s, a, b): +async def test_persist_optimize_graph(c, s, a, b): i = 10 for method in [c.persist, c.compute]: b = db.range(i, npartitions=2) @@ -3510,7 +3502,7 @@ def test_persist_optimize_graph(c, s, a, b): b3 = b2.map(inc) b4 = method(b3, optimize_graph=False) - yield wait(b4) + await wait(b4) assert set(map(tokey, b3.__dask_keys__())).issubset(s.tasks) @@ -3520,15 +3512,15 @@ def test_persist_optimize_graph(c, s, a, b): b3 = b2.map(inc) b4 = method(b3, optimize_graph=True) - yield wait(b4) + await wait(b4) assert not any(tokey(k) in s.tasks for k in b2.__dask_keys__()) @gen_cluster(client=True, nthreads=[]) -def test_scatter_raises_if_no_workers(c, s): +async def test_scatter_raises_if_no_workers(c, s): with pytest.raises(TimeoutError): - yield c.scatter(1, timeout=0.5) + await c.scatter(1, timeout=0.5) @pytest.mark.slow @@ -3593,13 +3585,13 @@ def test_reconnect(loop): @gen_cluster(client=True, nthreads=[], client_kwargs={"timeout": 0.5}) -def test_reconnect_timeout(c, s): +async def test_reconnect_timeout(c, s): with captured_logger(logging.getLogger("distributed.client")) as logger: - yield s.close() + await s.close() start = time() while c.status != "closed": - yield c._update_scheduler_info() - yield gen.sleep(0.05) + await c._update_scheduler_info() + await asyncio.sleep(0.05) assert time() < start + 5, "Timeout waiting for reconnect to fail" text = logger.getvalue() assert "Failed to reconnect" in text @@ -3621,22 +3613,21 @@ def test_open_close_many_workers(loop, worker, count, repeat): workers = set() status = True - @gen.coroutine - def start_worker(sleep, duration, repeat=1): + async def start_worker(sleep, duration, repeat=1): for i in range(repeat): - yield gen.sleep(sleep) + await asyncio.sleep(sleep) if not status: return w = worker(s["address"], loop=loop) running[w] = None workers.add(w) - yield w + await w addr = w.worker_address running[w] = addr - yield gen.sleep(duration) - yield w.close() + await asyncio.sleep(duration) + await w.close() del w - yield gen.moment + await asyncio.sleep(0) done.release() for i in range(count): @@ -3672,34 +3663,34 @@ def start_worker(sleep, duration, repeat=1): @gen_cluster(client=False, timeout=None) -def test_idempotence(s, a, b): - c = yield Client(s.address, asynchronous=True) - f = yield Client(s.address, asynchronous=True) +async def test_idempotence(s, a, b): + c = await Client(s.address, asynchronous=True) + f = await Client(s.address, asynchronous=True) # Submit x = c.submit(inc, 1) - yield x + await x log = list(s.transition_log) len_single_submit = len(log) # see last assert y = f.submit(inc, 1) assert x.key == y.key - yield y - yield gen.sleep(0.1) + await y + await asyncio.sleep(0.1) log2 = list(s.transition_log) assert log == log2 # Error a = c.submit(div, 1, 0) - yield wait(a) + await wait(a) assert a.status == "error" log = list(s.transition_log) b = f.submit(div, 1, 0) assert a.key == b.key - yield wait(b) - yield gen.sleep(0.1) + await wait(b) + await asyncio.sleep(0.1) log2 = list(s.transition_log) assert log == log2 @@ -3707,12 +3698,12 @@ def test_idempotence(s, a, b): # Simultaneous Submit d = c.submit(inc, 2) e = c.submit(inc, 2) - yield wait([d, e]) + await wait([d, e]) assert len(s.transition_log) == len_single_submit - yield c.close() - yield f.close() + await c.close() + await f.close() def test_scheduler_info(c): @@ -3771,40 +3762,40 @@ def f(): @gen_cluster(client=True) -def test_lose_scattered_data(c, s, a, b): - [x] = yield c.scatter([1], workers=a.address) +async def test_lose_scattered_data(c, s, a, b): + [x] = await c.scatter([1], workers=a.address) - yield a.close() - yield gen.sleep(0.1) + await a.close() + await asyncio.sleep(0.1) assert x.status == "cancelled" assert x.key not in s.tasks @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3) -def test_partially_lose_scattered_data(e, s, a, b, c): - x = yield e.scatter(1, workers=a.address) - yield e.replicate(x, n=2) +async def test_partially_lose_scattered_data(e, s, a, b, c): + x = await e.scatter(1, workers=a.address) + await e.replicate(x, n=2) - yield a.close() - yield gen.sleep(0.1) + await a.close() + await asyncio.sleep(0.1) assert x.status == "finished" assert s.get_task_status(keys=[x.key]) == {x.key: "memory"} @gen_cluster(client=True) -def test_scatter_compute_lose(c, s, a, b): - [x] = yield c.scatter([[1, 2, 3, 4]], workers=a.address) +async def test_scatter_compute_lose(c, s, a, b): + [x] = await c.scatter([[1, 2, 3, 4]], workers=a.address) y = c.submit(inc, 1, workers=b.address) z = c.submit(slowadd, x, y, delay=0.2) - yield gen.sleep(0.1) + await asyncio.sleep(0.1) - yield a.close() + await a.close() with pytest.raises(CancelledError): - yield wait(z) + await wait(z) assert x.status == "cancelled" assert y.status == "finished" @@ -3812,7 +3803,7 @@ def test_scatter_compute_lose(c, s, a, b): @gen_cluster(client=True) -def test_scatter_compute_store_lose(c, s, a, b): +async def test_scatter_compute_store_lose(c, s, a, b): """ Create irreplaceable data on one machine, cause a dependent computation to occur on another and complete @@ -3820,18 +3811,18 @@ def test_scatter_compute_store_lose(c, s, a, b): Kill the machine with the irreplaceable data. What happens to the complete result? How about after it GCs and tries to come back? """ - x = yield c.scatter(1, workers=a.address) + x = await c.scatter(1, workers=a.address) xx = c.submit(inc, x, workers=a.address) y = c.submit(inc, 1) z = c.submit(slowadd, xx, y, delay=0.2, workers=b.address) - yield wait(z) + await wait(z) - yield a.close() + await a.close() start = time() while x.status == "finished": - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 2 # assert xx.status == 'finished' @@ -3839,14 +3830,14 @@ def test_scatter_compute_store_lose(c, s, a, b): assert z.status == "finished" zz = c.submit(inc, z) - yield wait(zz) + await wait(zz) zkey = z.key del z start = time() while s.get_task_status(keys=[zkey]) != {zkey: "released"}: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 2 xxkey = xx.key @@ -3854,12 +3845,12 @@ def test_scatter_compute_store_lose(c, s, a, b): start = time() while x.key in s.tasks and zkey not in s.tasks and xxkey not in s.tasks: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 2 @gen_cluster(client=True) -def test_scatter_compute_store_lose_processing(c, s, a, b): +async def test_scatter_compute_store_lose_processing(c, s, a, b): """ Create irreplaceable data on one machine, cause a dependent computation to occur on another and complete @@ -3867,16 +3858,16 @@ def test_scatter_compute_store_lose_processing(c, s, a, b): Kill the machine with the irreplaceable data. What happens to the complete result? How about after it GCs and tries to come back? """ - [x] = yield c.scatter([1], workers=a.address) + [x] = await c.scatter([1], workers=a.address) y = c.submit(slowinc, x, delay=0.2) z = c.submit(inc, y) - yield gen.sleep(0.1) - yield a.close() + await asyncio.sleep(0.1) + await a.close() start = time() while x.status == "finished": - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 2 assert y.status == "cancelled" @@ -3884,28 +3875,28 @@ def test_scatter_compute_store_lose_processing(c, s, a, b): @gen_cluster(client=False) -def test_serialize_future(s, a, b): - c = yield Client(s.address, asynchronous=True) - f = yield Client(s.address, asynchronous=True) +async def test_serialize_future(s, a, b): + c = await Client(s.address, asynchronous=True) + f = await Client(s.address, asynchronous=True) future = c.submit(lambda: 1) - result = yield future + result = await future with temp_default_client(f): future2 = pickle.loads(pickle.dumps(future)) assert future2.client is f assert tokey(future2.key) in f.futures - result2 = yield future2 + result2 = await future2 assert result == result2 - yield c.close() - yield f.close() + await c.close() + await f.close() @gen_cluster(client=False) -def test_temp_client(s, a, b): - c = yield Client(s.address, asynchronous=True) - f = yield Client(s.address, asynchronous=True) +async def test_temp_client(s, a, b): + c = await Client(s.address, asynchronous=True) + f = await Client(s.address, asynchronous=True) with temp_default_client(c): assert default_client() is c @@ -3915,13 +3906,13 @@ def test_temp_client(s, a, b): assert default_client() is f assert default_client(c) is c - yield c.close() - yield f.close() + await c.close() + await f.close() @nodebug # test timing is fragile @gen_cluster(nthreads=[("127.0.0.1", 1)] * 3, client=True) -def test_persist_workers(e, s, a, b, c): +async def test_persist_workers(e, s, a, b, c): L1 = [delayed(inc)(i) for i in range(4)] total = delayed(sum)(L1) L2 = [delayed(add)(i, total) for i in L1] @@ -3938,7 +3929,7 @@ def test_persist_workers(e, s, a, b, c): allow_other_workers=L2 + [total2], ) - yield wait(out) + await wait(out) assert all(v.key in a.data for v in L1) assert total.key in b.data @@ -3946,7 +3937,7 @@ def test_persist_workers(e, s, a, b, c): @gen_cluster(nthreads=[("127.0.0.1", 1)] * 3, client=True) -def test_compute_workers(e, s, a, b, c): +async def test_compute_workers(e, s, a, b, c): L1 = [delayed(inc)(i) for i in range(4)] total = delayed(sum)(L1) L2 = [delayed(add)(i, total) for i in L1] @@ -3957,7 +3948,7 @@ def test_compute_workers(e, s, a, b, c): allow_other_workers=L1 + [total], ) - yield wait(out) + await wait(out) for v in L1: assert s.worker_restrictions[v.key] == {a.address} for v in L2: @@ -3968,13 +3959,13 @@ def test_compute_workers(e, s, a, b, c): @gen_cluster(client=True) -def test_compute_nested_containers(c, s, a, b): +async def test_compute_nested_containers(c, s, a, b): da = pytest.importorskip("dask.array") np = pytest.importorskip("numpy") x = da.ones(10, chunks=(5,)) + 1 future = c.compute({"x": [x], "y": 123}) - result = yield future + result = await future assert isinstance(result, dict) assert (result["x"][0] == np.ones(10) + 1).all() @@ -4004,19 +3995,19 @@ def test_get_restrictions(): @gen_cluster(client=True) -def test_scatter_type(c, s, a, b): - [future] = yield c.scatter([1]) +async def test_scatter_type(c, s, a, b): + [future] = await c.scatter([1]) assert future.type == int - d = yield c.scatter({"x": 1.0}) + d = await c.scatter({"x": 1.0}) assert d["x"].type == float @gen_cluster(client=True) -def test_retire_workers_2(c, s, a, b): - [x] = yield c.scatter([1], workers=a.address) +async def test_retire_workers_2(c, s, a, b): + [x] = await c.scatter([1], workers=a.address) - yield s.retire_workers(workers=[a.address]) + await s.retire_workers(workers=[a.address]) assert b.data == {x.key: 1} assert s.who_has == {x.key: {b.address}} assert s.has_what == {b.address: {x.key}} @@ -4025,16 +4016,16 @@ def test_retire_workers_2(c, s, a, b): @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 10) -def test_retire_many_workers(c, s, *workers): - futures = yield c.scatter(list(range(100))) +async def test_retire_many_workers(c, s, *workers): + futures = await c.scatter(list(range(100))) - yield s.retire_workers(workers=[w.address for w in workers[:7]]) + await s.retire_workers(workers=[w.address for w in workers[:7]]) - results = yield c.gather(futures) + results = await c.gather(futures) assert results == list(range(100)) while len(s.workers) != 3: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert len(s.has_what) == len(s.nthreads) == 3 @@ -4049,19 +4040,19 @@ def test_retire_many_workers(c, s, *workers): nthreads=[("127.0.0.1", 3)] * 2, config={"distributed.scheduler.default-task-durations": {"f": "10ms"}}, ) -def test_weight_occupancy_against_data_movement(c, s, a, b): +async def test_weight_occupancy_against_data_movement(c, s, a, b): s.extensions["stealing"]._pc.callback_time = 1000000 def f(x, y=0, z=0): sleep(0.01) return x - y = yield c.scatter([[1, 2, 3, 4]], workers=[a.address]) - z = yield c.scatter([1], workers=[b.address]) + y = await c.scatter([[1, 2, 3, 4]], workers=[a.address]) + z = await c.scatter([1], workers=[b.address]) futures = c.map(f, [1, 2, 3, 4], y=y, z=z) - yield wait(futures) + await wait(futures) assert sum(f.key in a.data for f in futures) >= 2 assert sum(f.key in b.data for f in futures) >= 1 @@ -4072,24 +4063,24 @@ def f(x, y=0, z=0): nthreads=[("127.0.0.1", 1), ("127.0.0.1", 10)], config={"distributed.scheduler.default-task-durations": {"f": "10ms"}}, ) -def test_distribute_tasks_by_nthreads(c, s, a, b): +async def test_distribute_tasks_by_nthreads(c, s, a, b): s.extensions["stealing"]._pc.callback_time = 1000000 def f(x, y=0): sleep(0.01) return x - y = yield c.scatter([1], broadcast=True) + y = await c.scatter([1], broadcast=True) futures = c.map(f, range(20), y=y) - yield wait(futures) + await wait(futures) assert len(b.data) > 2 * len(a.data) @gen_cluster(client=True, clean_kwargs={"threads": False}) -def test_add_done_callback(c, s, a, b): +async def test_add_done_callback(c, s, a, b): S = set() def f(future): @@ -4106,19 +4097,19 @@ def g(future): v.add_done_callback(f) w.add_done_callback(f) - yield wait((u, v, w, x)) + await wait((u, v, w, x)) x.add_done_callback(f) t = time() while len(S) < 4 and time() - t < 2.0: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert S == {(f.key, f.status) for f in (u, v, w, x)} @gen_cluster(client=True) -def test_normalize_collection(c, s, a, b): +async def test_normalize_collection(c, s, a, b): x = delayed(inc)(1) y = delayed(inc)(x) z = delayed(inc)(y) @@ -4133,7 +4124,7 @@ def test_normalize_collection(c, s, a, b): @gen_cluster(client=True) -def test_normalize_collection_dask_array(c, s, a, b): +async def test_normalize_collection_dask_array(c, s, a, b): da = pytest.importorskip("dask.array") x = da.ones(10, chunks=(5,)) @@ -4151,8 +4142,8 @@ def test_normalize_collection_dask_array(c, s, a, b): for k, v in yy.dask.items(): assert zz.dask[k].key == v.key - result1 = yield c.compute(z) - result2 = yield c.compute(zz) + result1 = await c.compute(z) + result2 = await c.compute(zz) assert result1 == result2 @@ -4175,7 +4166,7 @@ def test_normalize_collection_with_released_futures(c): @gen_cluster(client=True) -def test_auto_normalize_collection(c, s, a, b): +async def test_auto_normalize_collection(c, s, a, b): da = pytest.importorskip("dask.array") x = da.ones(10, chunks=5) @@ -4185,17 +4176,17 @@ def test_auto_normalize_collection(c, s, a, b): y = x.map_blocks(slowinc, delay=1, dtype=x.dtype) yy = c.persist(y) - yield wait(yy) + await wait(yy) start = time() future = c.compute(y.sum()) - yield future + await future end = time() assert end - start < 1 start = time() z = c.persist(y + 1) - yield wait(z) + await wait(z) end = time() assert end - start < 1 @@ -4224,7 +4215,7 @@ def assert_no_data_loss(scheduler): @gen_cluster(client=True, timeout=None) -def test_interleave_computations(c, s, a, b): +async def test_interleave_computations(c, s, a, b): import distributed distributed.g = s @@ -4238,14 +4229,14 @@ def test_interleave_computations(c, s, a, b): done = ("memory", "released") - yield gen.sleep(0.1) + await asyncio.sleep(0.1) x_keys = [x.key for x in xs] y_keys = [y.key for y in ys] z_keys = [z.key for z in zs] while not s.tasks or any(w.processing for w in s.workers.values()): - yield gen.sleep(0.05) + await asyncio.sleep(0.05) x_done = sum(state in done for state in s.get_task_status(keys=x_keys).values()) y_done = sum(state in done for state in s.get_task_status(keys=y_keys).values()) z_done = sum(state in done for state in s.get_task_status(keys=z_keys).values()) @@ -4259,7 +4250,7 @@ def test_interleave_computations(c, s, a, b): @pytest.mark.skip(reason="Now prefer first-in-first-out") @gen_cluster(client=True, timeout=None) -def test_interleave_computations_map(c, s, a, b): +async def test_interleave_computations_map(c, s, a, b): xs = c.map(slowinc, range(30), delay=0.02) ys = c.map(slowdec, xs, delay=0.02) zs = c.map(slowadd, xs, ys, delay=0.02) @@ -4271,7 +4262,7 @@ def test_interleave_computations_map(c, s, a, b): z_keys = [z.key for z in zs] while not s.tasks or any(w.processing for w in s.workers.values()): - yield gen.sleep(0.05) + await asyncio.sleep(0.05) x_done = sum(state in done for state in s.get_task_status(keys=x_keys).values()) y_done = sum(state in done for state in s.get_task_status(keys=y_keys).values()) z_done = sum(state in done for state in s.get_task_status(keys=z_keys).values()) @@ -4282,78 +4273,78 @@ def test_interleave_computations_map(c, s, a, b): @gen_cluster(client=True) -def test_scatter_dict_workers(c, s, a, b): - yield c.scatter({"a": 10}, workers=[a.address, b.address]) +async def test_scatter_dict_workers(c, s, a, b): + await c.scatter({"a": 10}, workers=[a.address, b.address]) assert "a" in a.data or "a" in b.data @pytest.mark.slow @gen_test() -def test_client_timeout(): +async def test_client_timeout(): c = Client("127.0.0.1:57484", asynchronous=True) s = Scheduler(loop=c.loop, port=57484) - yield gen.sleep(4) + await asyncio.sleep(4) try: - yield s + await s except EnvironmentError: # port in use - yield c.close() + await c.close() return start = time() - yield c + await c try: assert time() < start + 2 finally: - yield c.close() - yield s.close() + await c.close() + await s.close() @gen_cluster(client=True) -def test_submit_list_kwargs(c, s, a, b): - futures = yield c.scatter([1, 2, 3]) +async def test_submit_list_kwargs(c, s, a, b): + futures = await c.scatter([1, 2, 3]) def f(L=None): return sum(L) future = c.submit(f, L=futures) - result = yield future + result = await future assert result == 1 + 2 + 3 @gen_cluster(client=True) -def test_map_list_kwargs(c, s, a, b): - futures = yield c.scatter([1, 2, 3]) +async def test_map_list_kwargs(c, s, a, b): + futures = await c.scatter([1, 2, 3]) def f(i, L=None): return i + sum(L) futures = c.map(f, range(10), L=futures) - results = yield c.gather(futures) + results = await c.gather(futures) assert results == [i + 6 for i in range(10)] @gen_cluster(client=True) -def test_dont_clear_waiting_data(c, s, a, b): +async def test_dont_clear_waiting_data(c, s, a, b): start = time() - x = yield c.scatter(1) + x = await c.scatter(1) y = c.submit(slowinc, x, delay=0.5) while y.key not in s.tasks: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) key = x.key del x for i in range(5): assert s.waiting_data[key] - yield gen.moment + await asyncio.sleep(0) @gen_cluster(client=True) -def test_get_future_error_simple(c, s, a, b): +async def test_get_future_error_simple(c, s, a, b): f = c.submit(div, 1, 0) - yield wait(f) + await wait(f) assert f.status == "error" - function, args, kwargs, deps = yield c._get_futures_error(f) + function, args, kwargs, deps = await c._get_futures_error(f) # args contains only solid values, not keys assert function.__name__ == "div" with pytest.raises(ZeroDivisionError): @@ -4361,7 +4352,7 @@ def test_get_future_error_simple(c, s, a, b): @gen_cluster(client=True) -def test_get_futures_error(c, s, a, b): +async def test_get_futures_error(c, s, a, b): x0 = delayed(dec)(2, dask_key_name="x0") y0 = delayed(dec)(1, dask_key_name="y0") x = delayed(div)(1, x0, dask_key_name="x") @@ -4369,16 +4360,16 @@ def test_get_futures_error(c, s, a, b): tot = delayed(sum)(x, y, dask_key_name="tot") f = c.compute(tot) - yield wait(f) + await wait(f) assert f.status == "error" - function, args, kwargs, deps = yield c._get_futures_error(f) + function, args, kwargs, deps = await c._get_futures_error(f) assert function.__name__ == "div" assert args == (1, y0.key) @gen_cluster(client=True) -def test_recreate_error_delayed(c, s, a, b): +async def test_recreate_error_delayed(c, s, a, b): x0 = delayed(dec)(2) y0 = delayed(dec)(1) x = delayed(div)(1, x0) @@ -4389,7 +4380,7 @@ def test_recreate_error_delayed(c, s, a, b): assert f.status == "pending" - function, args, kwargs = yield c._recreate_error_locally(f) + function, args, kwargs = await c._recreate_error_locally(f) assert f.status == "error" assert function.__name__ == "div" assert args == (1, 0) @@ -4398,7 +4389,7 @@ def test_recreate_error_delayed(c, s, a, b): @gen_cluster(client=True) -def test_recreate_error_futures(c, s, a, b): +async def test_recreate_error_futures(c, s, a, b): x0 = c.submit(dec, 2) y0 = c.submit(dec, 1) x = c.submit(div, 1, x0) @@ -4408,7 +4399,7 @@ def test_recreate_error_futures(c, s, a, b): assert f.status == "pending" - function, args, kwargs = yield c._recreate_error_locally(f) + function, args, kwargs = await c._recreate_error_locally(f) assert f.status == "error" assert function.__name__ == "div" assert args == (1, 0) @@ -4417,13 +4408,13 @@ def test_recreate_error_futures(c, s, a, b): @gen_cluster(client=True) -def test_recreate_error_collection(c, s, a, b): +async def test_recreate_error_collection(c, s, a, b): b = db.range(10, npartitions=4) b = b.map(lambda x: 1 / x) b = b.persist() f = c.compute(b) - function, args, kwargs = yield c._recreate_error_locally(f) + function, args, kwargs = await c._recreate_error_locally(f) with pytest.raises(ZeroDivisionError): function(*args, **kwargs) @@ -4440,24 +4431,24 @@ def make_err(x): df2 = df.a.map(make_err) f = c.compute(df2) - function, args, kwargs = yield c._recreate_error_locally(f) + function, args, kwargs = await c._recreate_error_locally(f) with pytest.raises(ValueError): function(*args, **kwargs) # with persist df3 = c.persist(df2) - function, args, kwargs = yield c._recreate_error_locally(df3) + function, args, kwargs = await c._recreate_error_locally(df3) with pytest.raises(ValueError): function(*args, **kwargs) @gen_cluster(client=True) -def test_recreate_error_array(c, s, a, b): +async def test_recreate_error_array(c, s, a, b): da = pytest.importorskip("dask.array") pytest.importorskip("scipy") z = (da.linalg.inv(da.zeros((10, 10), chunks=10)) + 1).sum() zz = z.persist() - func, args, kwargs = yield c._recreate_error_locally(zz) + func, args, kwargs = await c._recreate_error_locally(zz) assert "0.,0.,0." in str(args).replace(" ", "") # args contain actual arrays @@ -4481,14 +4472,14 @@ def test_recreate_error_not_error(c): @gen_cluster(client=True) -def test_retire_workers(c, s, a, b): +async def test_retire_workers(c, s, a, b): assert set(s.workers) == {a.address, b.address} - yield c.retire_workers(workers=[a.address], close_workers=True) + await c.retire_workers(workers=[a.address], close_workers=True) assert set(s.workers) == {b.address} start = time() while a.status != "closed": - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 5 @@ -4497,7 +4488,7 @@ class MyException(Exception): @gen_cluster(client=True) -def test_robust_unserializable(c, s, a, b): +async def test_robust_unserializable(c, s, a, b): class Foo: def __getstate__(self): raise MyException() @@ -4506,14 +4497,14 @@ def __getstate__(self): future = c.submit(identity, Foo()) futures = c.map(inc, range(10)) - results = yield c.gather(futures) + results = await c.gather(futures) assert results == list(map(inc, range(10))) assert a.data and b.data @gen_cluster(client=True) -def test_robust_undeserializable(c, s, a, b): +async def test_robust_undeserializable(c, s, a, b): class Foo: def __getstate__(self): return 1 @@ -4523,17 +4514,17 @@ def __setstate__(self, state): future = c.submit(identity, Foo()) with pytest.raises(MyException): - yield future + await future futures = c.map(inc, range(10)) - results = yield c.gather(futures) + results = await c.gather(futures) assert results == list(map(inc, range(10))) assert a.data and b.data @gen_cluster(client=True) -def test_robust_undeserializable_function(c, s, a, b): +async def test_robust_undeserializable_function(c, s, a, b): class Foo: def __getstate__(self): return 1 @@ -4546,17 +4537,17 @@ def __call__(self, *args): future = c.submit(Foo(), 1) with pytest.raises(MyException): - yield future + await future futures = c.map(inc, range(10)) - results = yield c.gather(futures) + results = await c.gather(futures) assert results == list(map(inc, range(10))) assert a.data and b.data @gen_cluster(client=True) -def test_fire_and_forget(c, s, a, b): +async def test_fire_and_forget(c, s, a, b): future = c.submit(slowinc, 1, delay=0.1) import distributed @@ -4568,7 +4559,7 @@ def f(x): start = time() while not hasattr(distributed, "foo"): - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 2 assert distributed.foo == 123 finally: @@ -4576,7 +4567,7 @@ def f(x): start = time() while len(s.tasks) > 1: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 2 assert set(s.who_wants) == {future.key} @@ -4584,14 +4575,14 @@ def f(x): @gen_cluster(client=True) -def test_fire_and_forget_err(c, s, a, b): +async def test_fire_and_forget_err(c, s, a, b): fire_and_forget(c.submit(div, 1, 0)) - yield gen.sleep(0.1) + await asyncio.sleep(0.1) # erred task should clear out quickly start = time() while s.tasks: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 1 @@ -4627,16 +4618,16 @@ def test_quiet_client_close_when_cluster_is_closed_before_client(loop): @gen_cluster() -def test_close(s, a, b): - c = yield Client(s.address, asynchronous=True) +async def test_close(s, a, b): + c = await Client(s.address, asynchronous=True) future = c.submit(inc, 1) - yield wait(future) + await wait(future) assert c.id in s.wants_what - yield c.close() + await c.close() start = time() while c.id in s.wants_what or s.tasks: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 5 @@ -4699,7 +4690,7 @@ def f(_): @gen_cluster(client=True) -def test_identity(c, s, a, b): +async def test_identity(c, s, a, b): assert c.id.lower().startswith("client") assert a.id.lower().startswith("worker") assert b.id.lower().startswith("worker") @@ -4707,7 +4698,7 @@ def test_identity(c, s, a, b): @gen_cluster(client=True, nthreads=[("127.0.0.1", 4)] * 2) -def test_get_client(c, s, a, b): +async def test_get_client(c, s, a, b): assert get_client() is c assert c.asynchronous @@ -4725,7 +4716,7 @@ def f(x): distributed.tmp_client = c try: futures = c.map(f, range(5)) - results = yield c.gather(futures) + results = await c.gather(futures) assert results == list(map(inc, range(5))) finally: del distributed.tmp_client @@ -4742,7 +4733,7 @@ def test_get_client_no_cluster(): @gen_cluster(client=True) -def test_serialize_collections(c, s, a, b): +async def test_serialize_collections(c, s, a, b): da = pytest.importorskip("dask.array") x = da.arange(10, chunks=(5,)).persist() @@ -4751,24 +4742,24 @@ def f(x): return x.sum().compute() future = c.submit(f, x) - result = yield future + result = await future assert result == sum(range(10)) @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 1, timeout=100) -def test_secede_simple(c, s, a): +async def test_secede_simple(c, s, a): def f(): client = get_client() secede() return client.submit(inc, 1).result() - result = yield c.submit(f) + result = await c.submit(f) assert result == 2 @pytest.mark.slow @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2, timeout=60) -def test_secede_balances(c, s, a, b): +async def test_secede_balances(c, s, a, b): count = threading.active_count() def f(x): @@ -4782,24 +4773,24 @@ def f(x): futures = c.map(f, range(100)) start = time() while not all(f.status == "finished" for f in futures): - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert threading.active_count() < count + 50 assert len(a.log) < 2 * len(b.log) assert len(b.log) < 2 * len(a.log) - results = yield c.gather(futures) + results = await c.gather(futures) assert results == [sum(map(inc, range(10)))] * 100 @gen_cluster(client=True) -def test_sub_submit_priority(c, s, a, b): +async def test_sub_submit_priority(c, s, a, b): def f(): client = get_client() client.submit(slowinc, 1, delay=0.2, key="slowinc") future = c.submit(f, key="f") - yield gen.sleep(0.1) + await asyncio.sleep(0.1) if len(s.tasks) == 2: assert ( s.priorities["f"] > s.priorities["slowinc"] @@ -4815,17 +4806,17 @@ def test_get_client_sync(c, s, a, b): @gen_cluster(client=True) -def test_serialize_collections_of_futures(c, s, a, b): +async def test_serialize_collections_of_futures(c, s, a, b): pd = pytest.importorskip("pandas") dd = pytest.importorskip("dask.dataframe") from dask.dataframe.utils import assert_eq df = pd.DataFrame({"x": [1, 2, 3]}) ddf = dd.from_pandas(df, npartitions=2).persist() - future = yield c.scatter(ddf) + future = await c.scatter(ddf) - ddf2 = yield future - df2 = yield c.compute(ddf2) + ddf2 = await future + df2 = await c.compute(ddf2) assert_eq(df, df2) @@ -4877,10 +4868,10 @@ def test_dynamic_workloads_sync_random(c): @gen_cluster(client=True) -def test_bytes_keys(c, s, a, b): +async def test_bytes_keys(c, s, a, b): key = b"inc-123" future = c.submit(inc, 1, key=key) - result = yield future + result = await future assert type(future.key) is bytes assert set(s.tasks) == {key} assert key in a.data or key in b.data @@ -4888,11 +4879,11 @@ def test_bytes_keys(c, s, a, b): @gen_cluster(client=True) -def test_unicode_ascii_keys(c, s, a, b): +async def test_unicode_ascii_keys(c, s, a, b): uni_type = type("") key = "inc-123" future = c.submit(inc, 1, key=key) - result = yield future + result = await future assert type(future.key) is uni_type assert set(s.tasks) == {key} assert key in a.data or key in b.data @@ -4900,32 +4891,31 @@ def test_unicode_ascii_keys(c, s, a, b): @gen_cluster(client=True) -def test_unicode_keys(c, s, a, b): +async def test_unicode_keys(c, s, a, b): uni_type = type("") key = "inc-123\u03bc" future = c.submit(inc, 1, key=key) - result = yield future + result = await future assert type(future.key) is uni_type assert set(s.tasks) == {key} assert key in a.data or key in b.data assert result == 2 future2 = c.submit(inc, future) - result2 = yield future2 + result2 = await future2 assert result2 == 3 - future3 = yield c.scatter({"data-123": 123}) - result3 = yield future3["data-123"] + future3 = await c.scatter({"data-123": 123}) + result3 = await future3["data-123"] assert result3 == 123 def test_use_synchronous_client_in_async_context(loop, c): - @gen.coroutine - def f(): - x = yield c.scatter(123) + async def f(): + x = await c.scatter(123) y = c.submit(inc, x) - z = yield c.gather(y) - raise gen.Return(z) + z = await c.gather(y) + return z z = sync(loop, f) assert z == 124 @@ -4956,11 +4946,13 @@ def test_warn_executor(loop, s, a, b): @gen_cluster([("127.0.0.1", 4)] * 2, client=True) -def test_call_stack_future(c, s, a, b): +async def test_call_stack_future(c, s, a, b): x = c.submit(slowdec, 1, delay=0.5) future = c.submit(slowinc, 1, delay=0.5) - yield gen.sleep(0.1) - results = yield [c.call_stack(future), c.call_stack(keys=[future.key])] + await asyncio.sleep(0.1) + results = await asyncio.gather( + c.call_stack(future), c.call_stack(keys=[future.key]) + ) assert all(list(first(result.values())) == [future.key] for result in results) assert results[0] == results[1] result = results[0] @@ -4972,11 +4964,11 @@ def test_call_stack_future(c, s, a, b): @gen_cluster([("127.0.0.1", 4)] * 2, client=True) -def test_call_stack_all(c, s, a, b): +async def test_call_stack_all(c, s, a, b): future = c.submit(slowinc, 1, delay=0.8) while not a.executing and not b.executing: - yield gen.sleep(0.01) - result = yield c.call_stack() + await asyncio.sleep(0.01) + result = await c.call_stack() w = a if a.executing else b assert list(result) == [w.address] assert list(result[w.address]) == [future.key] @@ -4984,100 +4976,100 @@ def test_call_stack_all(c, s, a, b): @gen_cluster([("127.0.0.1", 4)] * 2, client=True) -def test_call_stack_collections(c, s, a, b): +async def test_call_stack_collections(c, s, a, b): da = pytest.importorskip("dask.array") x = da.random.random(100, chunks=(10,)).map_blocks(slowinc, delay=0.5).persist() while not a.executing and not b.executing: - yield gen.sleep(0.001) - result = yield c.call_stack(x) + await asyncio.sleep(0.001) + result = await c.call_stack(x) assert result @gen_cluster([("127.0.0.1", 4)] * 2, client=True) -def test_call_stack_collections_all(c, s, a, b): +async def test_call_stack_collections_all(c, s, a, b): da = pytest.importorskip("dask.array") x = da.random.random(100, chunks=(10,)).map_blocks(slowinc, delay=0.5).persist() while not a.executing and not b.executing: - yield gen.sleep(0.001) - result = yield c.call_stack() + await asyncio.sleep(0.001) + result = await c.call_stack() assert result @gen_cluster(client=True, worker_kwargs={"profile_cycle_interval": 100}) -def test_profile(c, s, a, b): +async def test_profile(c, s, a, b): futures = c.map(slowinc, range(10), delay=0.05, workers=a.address) - yield wait(futures) + await wait(futures) - x = yield c.profile(start=time() + 10, stop=time() + 20) + x = await c.profile(start=time() + 10, stop=time() + 20) assert not x["count"] - x = yield c.profile(start=0, stop=time()) + x = await c.profile(start=0, stop=time()) assert ( x["count"] == sum(p["count"] for _, p in a.profile_history) + a.profile_recent["count"] ) - y = yield c.profile(start=time() - 0.300, stop=time()) + y = await c.profile(start=time() - 0.300, stop=time()) assert 0 < y["count"] < x["count"] assert not any(p["count"] for _, p in b.profile_history) - result = yield c.profile(workers=b.address) + result = await c.profile(workers=b.address) assert not result["count"] @gen_cluster(client=True, worker_kwargs={"profile_cycle_interval": 100}) -def test_profile_keys(c, s, a, b): +async def test_profile_keys(c, s, a, b): x = c.map(slowinc, range(10), delay=0.05, workers=a.address) y = c.map(slowdec, range(10), delay=0.05, workers=a.address) - yield wait(x + y) + await wait(x + y) - xp = yield c.profile("slowinc") - yp = yield c.profile("slowdec") - p = yield c.profile() + xp = await c.profile("slowinc") + yp = await c.profile("slowdec") + p = await c.profile() assert p["count"] == xp["count"] + yp["count"] with captured_logger(logging.getLogger("distributed")) as logger: - prof = yield c.profile("does-not-exist") + prof = await c.profile("does-not-exist") assert prof == profile.create() out = logger.getvalue() assert not out @gen_cluster() -def test_client_with_name(s, a, b): +async def test_client_with_name(s, a, b): with captured_logger("distributed.scheduler") as sio: - client = yield Client(s.address, asynchronous=True, name="foo") + client = await Client(s.address, asynchronous=True, name="foo") assert "foo" in client.id - yield client.close() + await client.close() text = sio.getvalue() assert "foo" in text @gen_cluster(client=True) -def test_future_defaults_to_default_client(c, s, a, b): +async def test_future_defaults_to_default_client(c, s, a, b): x = c.submit(inc, 1) - yield wait(x) + await wait(x) future = Future(x.key) assert future.client is c @gen_cluster(client=True) -def test_future_auto_inform(c, s, a, b): +async def test_future_auto_inform(c, s, a, b): x = c.submit(inc, 1) - yield wait(x) + await wait(x) - client = yield Client(s.address, asynchronous=True) + client = await Client(s.address, asynchronous=True) future = Future(x.key, client) start = time() while future.status != "finished": - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 1 - yield client.close() + await client.close() def test_client_async_before_loop_starts(): @@ -5089,7 +5081,7 @@ def test_client_async_before_loop_starts(): @pytest.mark.slow @gen_cluster(client=True, Worker=Nanny, timeout=60, nthreads=[("127.0.0.1", 3)] * 2) -def test_nested_compute(c, s, a, b): +async def test_nested_compute(c, s, a, b): def fib(x): assert get_worker().get_current_task() if x < 2: @@ -5100,71 +5092,71 @@ def fib(x): return c.compute() future = c.submit(fib, 8) - result = yield future + result = await future assert result == 21 assert len(s.transition_log) > 50 @gen_cluster(client=True) -def test_task_metadata(c, s, a, b): - yield c.set_metadata("x", 1) - result = yield c.get_metadata("x") +async def test_task_metadata(c, s, a, b): + await c.set_metadata("x", 1) + result = await c.get_metadata("x") assert result == 1 future = c.submit(inc, 1) key = future.key - yield wait(future) - yield c.set_metadata(key, 123) - result = yield c.get_metadata(key) + await wait(future) + await c.set_metadata(key, 123) + result = await c.get_metadata(key) assert result == 123 del future while key in s.tasks: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) with pytest.raises(KeyError): - yield c.get_metadata(key) + await c.get_metadata(key) - result = yield c.get_metadata(key, None) + result = await c.get_metadata(key, None) assert result is None - yield c.set_metadata(["x", "a"], 1) - result = yield c.get_metadata("x") + await c.set_metadata(["x", "a"], 1) + result = await c.get_metadata("x") assert result == {"a": 1} - yield c.set_metadata(["x", "b"], 2) - result = yield c.get_metadata("x") + await c.set_metadata(["x", "b"], 2) + result = await c.get_metadata("x") assert result == {"a": 1, "b": 2} - result = yield c.get_metadata(["x", "a"]) + result = await c.get_metadata(["x", "a"]) assert result == 1 - yield c.set_metadata(["x", "a", "c", "d"], 1) - result = yield c.get_metadata("x") + await c.set_metadata(["x", "a", "c", "d"], 1) + result = await c.get_metadata("x") assert result == {"a": {"c": {"d": 1}}, "b": 2} @gen_cluster(client=True, Worker=Nanny) -def test_logs(c, s, a, b): - yield wait(c.map(inc, range(5))) - logs = yield c.get_scheduler_logs(n=5) +async def test_logs(c, s, a, b): + await wait(c.map(inc, range(5))) + logs = await c.get_scheduler_logs(n=5) assert logs for _, msg in logs: assert "distributed.scheduler" in msg - w_logs = yield c.get_worker_logs(n=5) + w_logs = await c.get_worker_logs(n=5) assert set(w_logs.keys()) == {a.worker_address, b.worker_address} for log in w_logs.values(): for _, msg in log: assert "distributed.worker" in msg - n_logs = yield c.get_worker_logs(nanny=True) + n_logs = await c.get_worker_logs(nanny=True) assert set(n_logs.keys()) == {a.worker_address, b.worker_address} for log in n_logs.values(): for _, msg in log: assert "distributed.nanny" in msg - n_logs = yield c.get_worker_logs(nanny=True, workers=[a.worker_address]) + n_logs = await c.get_worker_logs(nanny=True, workers=[a.worker_address]) assert set(n_logs.keys()) == {a.worker_address} for log in n_logs.values(): for _, msg in log: @@ -5172,29 +5164,29 @@ def test_logs(c, s, a, b): @gen_cluster(client=True) -def test_avoid_delayed_finalize(c, s, a, b): +async def test_avoid_delayed_finalize(c, s, a, b): x = delayed(inc)(1) future = c.compute(x) - result = yield future + result = await future assert result == 2 assert list(s.tasks) == [future.key] == [x.key] @gen_cluster() -def test_config_scheduler_address(s, a, b): +async def test_config_scheduler_address(s, a, b): with dask.config.set({"scheduler-address": s.address}): with captured_logger("distributed.client") as sio: - c = yield Client(asynchronous=True) + c = await Client(asynchronous=True) assert c.scheduler.address == s.address text = sio.getvalue() assert s.address in text - yield c.close() + await c.close() @gen_cluster(client=True) -def test_warn_when_submitting_large_values(c, s, a, b): +async def test_warn_when_submitting_large_values(c, s, a, b): with warnings.catch_warnings(record=True) as record: future = c.submit(lambda x: x + 1, b"0" * 2000000) @@ -5215,34 +5207,33 @@ def test_warn_when_submitting_large_values(c, s, a, b): @gen_cluster() -def test_scatter_direct(s, a, b): - c = yield Client(s.address, asynchronous=True, heartbeat_interval=10) +async def test_scatter_direct(s, a, b): + c = await Client(s.address, asynchronous=True, heartbeat_interval=10) last = s.clients[c.id].last_seen start = time() while s.clients[c.id].last_seen == last: - yield gen.sleep(0.10) + await asyncio.sleep(0.10) assert time() < start + 5 - yield c.close() + await c.close() -@pytest.mark.skipif(sys.version_info[0] < 3, reason="cloudpickle Py27 issue") @gen_cluster(client=True) -def test_unhashable_function(c, s, a, b): +async def test_unhashable_function(c, s, a, b): d = {"a": 1} - result = yield c.submit(d.get, "a") + result = await c.submit(d.get, "a") assert result == 1 @gen_cluster() -def test_client_name(s, a, b): +async def test_client_name(s, a, b): with dask.config.set({"client-name": "hello-world"}): - c = yield Client(s.address, asynchronous=True) + c = await Client(s.address, asynchronous=True) assert any("hello-world" in name for name in list(s.clients)) - yield c.close() + await c.close() def test_client_doesnt_close_given_loop(loop, s, a, b): @@ -5253,11 +5244,11 @@ def test_client_doesnt_close_given_loop(loop, s, a, b): @gen_cluster(client=True, nthreads=[]) -def test_quiet_scheduler_loss(c, s): +async def test_quiet_scheduler_loss(c, s): c._periodic_callbacks["scheduler-info"].interval = 10 with captured_logger(logging.getLogger("distributed.client")) as logger: - yield s.close() - yield c._update_scheduler_info() + await s.close() + await c._update_scheduler_info() text = logger.getvalue() assert "BrokenPipeError" not in text @@ -5284,22 +5275,22 @@ async def test_dashboard_link_inproc(cleanup): @gen_test() -def test_client_timeout_2(): +async def test_client_timeout_2(): with dask.config.set({"distributed.comm.timeouts.connect": "10ms"}): start = time() c = Client("127.0.0.1:3755", asynchronous=True) with pytest.raises((TimeoutError, IOError)): - yield c + await c stop = time() assert c.status == "closed" - yield c.close() + await c.close() assert stop - start < 1 @gen_test() -def test_client_active_bad_port(): +async def test_client_active_bad_port(): import tornado.web import tornado.httpserver @@ -5309,8 +5300,8 @@ def test_client_active_bad_port(): with dask.config.set({"distributed.comm.timeouts.connect": "10ms"}): c = Client("127.0.0.1:8080", asynchronous=True) with pytest.raises((TimeoutError, IOError)): - yield c - yield c._close(fast=True) + await c + await c._close(fast=True) http_server.stop() @@ -5355,10 +5346,10 @@ async def test(s, a, b): @gen_cluster() -def test_de_serialization(s, a, b): +async def test_de_serialization(s, a, b): import numpy as np - c = yield Client( + c = await Client( s.address, asynchronous=True, serializers=["msgpack", "pickle"], @@ -5366,35 +5357,35 @@ def test_de_serialization(s, a, b): ) try: # Can send complex data - future = yield c.scatter(np.ones(5)) + future = await c.scatter(np.ones(5)) # But can not retrieve it with pytest.raises(TypeError): - result = yield future + result = await future finally: - yield c.close() + await c.close() @gen_cluster() -def test_de_serialization_none(s, a, b): +async def test_de_serialization_none(s, a, b): import numpy as np - c = yield Client(s.address, asynchronous=True, deserializers=["msgpack"]) + c = await Client(s.address, asynchronous=True, deserializers=["msgpack"]) try: # Can send complex data - future = yield c.scatter(np.ones(5)) + future = await c.scatter(np.ones(5)) # But can not retrieve it with pytest.raises(TypeError): - result = yield future + result = await future finally: - yield c.close() + await c.close() @gen_cluster() -def test_client_repr_closed(s, a, b): - c = yield Client(s.address, asynchronous=True) - yield c.close() +async def test_client_repr_closed(s, a, b): + c = await Client(s.address, asynchronous=True) + await c.close() c._repr_html_() @@ -5405,7 +5396,7 @@ def test_client_repr_closed_sync(loop): @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)]) -def test_nested_prioritization(c, s, w): +async def test_nested_prioritization(c, s, w): x = delayed(inc)(1, dask_key_name=("a", 2)) y = delayed(inc)(2, dask_key_name=("a", 10)) @@ -5413,7 +5404,7 @@ def test_nested_prioritization(c, s, w): fx, fy = c.compute([x, y]) - yield wait([fx, fy]) + await wait([fx, fy]) assert (o[x.key] < o[y.key]) == ( s.tasks[tokey(fx.key)].priority < s.tasks[tokey(fy.key)].priority @@ -5421,18 +5412,18 @@ def test_nested_prioritization(c, s, w): @gen_cluster(client=True) -def test_scatter_error_cancel(c, s, a, b): +async def test_scatter_error_cancel(c, s, a, b): # https://github.com/dask/distributed/issues/2038 def bad_fn(x): raise Exception("lol") - x = yield c.scatter(1) + x = await c.scatter(1) y = c.submit(bad_fn, x) del x - yield wait(y) + await wait(y) assert y.status == "error" - yield gen.sleep(0.1) + await asyncio.sleep(0.1) assert y.status == "error" # not cancelled @@ -5442,14 +5433,14 @@ def test_no_threads_lingering(): @gen_cluster() -def test_direct_async(s, a, b): - c = yield Client(s.address, asynchronous=True, direct_to_workers=True) +async def test_direct_async(s, a, b): + c = await Client(s.address, asynchronous=True, direct_to_workers=True) assert c.direct_to_workers - yield c.close() + await c.close() - c = yield Client(s.address, asynchronous=True, direct_to_workers=False) + c = await Client(s.address, asynchronous=True, direct_to_workers=False) assert not c.direct_to_workers - yield c.close() + await c.close() def test_direct_sync(c): @@ -5462,9 +5453,9 @@ def f(): @gen_cluster() -def test_mixing_clients(s, a, b): - c1 = yield Client(s.address, asynchronous=True) - c2 = yield Client(s.address, asynchronous=True) +async def test_mixing_clients(s, a, b): + c1 = await Client(s.address, asynchronous=True) + c2 = await Client(s.address, asynchronous=True) future = c1.submit(inc, 1) with pytest.raises(ValueError): @@ -5472,16 +5463,16 @@ def test_mixing_clients(s, a, b): assert not c2.futures # Don't create Futures on second Client - yield c1.close() - yield c2.close() + await c1.close() + await c2.close() @gen_cluster(client=True) -def test_tuple_keys(c, s, a, b): +async def test_tuple_keys(c, s, a, b): x = dask.delayed(inc)(1, dask_key_name=("x", 1)) y = dask.delayed(inc)(x, dask_key_name=("y", 1)) future = c.compute(y) - assert (yield future) == 3 + assert (await future) == 3 @gen_cluster(client=True) @@ -5493,34 +5484,34 @@ async def test_multiple_scatter(c, s, a, b): @gen_cluster(client=True) -def test_map_large_kwargs_in_graph(c, s, a, b): +async def test_map_large_kwargs_in_graph(c, s, a, b): np = pytest.importorskip("numpy") x = np.random.random(100000) futures = c.map(lambda a, b: a + b, range(100), b=x) while not s.tasks: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert len(s.tasks) == 101 assert any(k.startswith("ndarray") for k in s.tasks) @gen_cluster(client=True) -def test_retry(c, s, a, b): +async def test_retry(c, s, a, b): def f(): assert dask.config.get("foo") with dask.config.set(foo=False): future = c.submit(f) with pytest.raises(AssertionError): - yield future + await future with dask.config.set(foo=True): - yield future.retry() - yield future + await future.retry() + await future @gen_cluster(client=True) -def test_retry_dependencies(c, s, a, b): +async def test_retry_dependencies(c, s, a, b): def f(): return dask.config.get("foo") @@ -5528,21 +5519,21 @@ def f(): y = c.submit(inc, x) with pytest.raises(KeyError): - yield y + await y with dask.config.set(foo=100): - yield y.retry() - result = yield y + await y.retry() + result = await y assert result == 101 - yield y.retry() - yield x.retry() - result = yield y + await y.retry() + await x.retry() + result = await y assert result == 101 @gen_cluster(client=True) -def test_released_dependencies(c, s, a, b): +async def test_released_dependencies(c, s, a, b): def f(x): return dask.config.get("foo") + 1 @@ -5551,26 +5542,26 @@ def f(x): del x with pytest.raises(KeyError): - yield y + await y with dask.config.set(foo=100): - yield y.retry() - result = yield y + await y.retry() + result = await y assert result == 101 @gen_cluster(client=True, clean_kwargs={"threads": False}) -def test_profile_bokeh(c, s, a, b): +async def test_profile_bokeh(c, s, a, b): pytest.importorskip("bokeh.plotting") from bokeh.model import Model - yield c.map(slowinc, range(10), delay=0.2) - state, figure = yield c.profile(plot=True) + await c.gather(c.map(slowinc, range(10), delay=0.2)) + state, figure = await c.profile(plot=True) assert isinstance(figure, Model) with tmpfile("html") as fn: try: - yield c.profile(filename=fn) + await c.profile(filename=fn) except PermissionError: if WINDOWS: pytest.xfail() @@ -5578,7 +5569,7 @@ def test_profile_bokeh(c, s, a, b): @gen_cluster(client=True) -def test_get_mix_futures_and_SubgraphCallable(c, s, a, b): +async def test_get_mix_futures_and_SubgraphCallable(c, s, a, b): future = c.submit(add, 1, 2) subgraph = SubgraphCallable( @@ -5587,7 +5578,7 @@ def test_get_mix_futures_and_SubgraphCallable(c, s, a, b): dsk = {"a": 1, "b": 2, "c": (subgraph, "a", "b"), "d": (subgraph, "c", "b")} future2 = c.get(dsk, "d", sync=False) - result = yield future2 + result = await future2 assert result == 11 # Nested subgraphs @@ -5603,12 +5594,12 @@ def test_get_mix_futures_and_SubgraphCallable(c, s, a, b): dsk2 = {"e": 1, "f": 2, "g": (subgraph2, "e", "f")} - result = yield c.get(dsk2, "g", sync=False) + result = await c.get(dsk2, "g", sync=False) assert result == 22 @gen_cluster(client=True) -def test_get_mix_futures_and_SubgraphCallable_dask_dataframe(c, s, a, b): +async def test_get_mix_futures_and_SubgraphCallable_dask_dataframe(c, s, a, b): dd = pytest.importorskip("dask.dataframe") import pandas as pd @@ -5618,7 +5609,7 @@ def test_get_mix_futures_and_SubgraphCallable_dask_dataframe(c, s, a, b): ddf["x"] = ddf["x"].astype("f8") ddf = ddf.map_partitions(lambda x: x) ddf["x"] = ddf["x"].astype("f8") - result = yield c.compute(ddf) + result = await c.compute(ddf) assert result.equals(df.astype("f8")) @@ -5631,23 +5622,23 @@ def test_direct_to_workers(s, loop): @gen_cluster(client=True) -def test_instances(c, s, a, b): +async def test_instances(c, s, a, b): assert list(Client._instances) == [c] assert list(Scheduler._instances) == [s] assert set(Worker._instances) == {a, b} @gen_cluster(client=True) -def test_wait_for_workers(c, s, a, b): +async def test_wait_for_workers(c, s, a, b): future = asyncio.ensure_future(c.wait_for_workers(n_workers=3)) - yield gen.sleep(0.22) # 2 chances + await asyncio.sleep(0.22) # 2 chances assert not future.done() - w = yield Worker(s.address) + w = await Worker(s.address) start = time() - yield future + await future assert time() < start + 1 - yield w.close() + await w.close() @pytest.mark.skipif(WINDOWS, reason="num_fds not supported on windows") @@ -5750,14 +5741,14 @@ async def test_profile_server(c, s, a, b): @gen_cluster(client=True) -def test_await_future(c, s, a, b): +async def test_await_future(c, s, a, b): future = c.submit(inc, 1) async def f(): # flake8: noqa result = await future assert result == 2 - yield f() + await f() future = c.submit(div, 1, 0) @@ -5765,11 +5756,11 @@ async def f(): with pytest.raises(ZeroDivisionError): await future - yield f() + await f() @gen_cluster(client=True) -def test_as_completed_async_for(c, s, a, b): +async def test_as_completed_async_for(c, s, a, b): futures = c.map(inc, range(10)) ac = as_completed(futures) results = [] @@ -5779,13 +5770,13 @@ async def f(): result = await future results.append(result) - yield f() + await f() assert set(results) == set(range(1, 11)) @gen_cluster(client=True) -def test_as_completed_async_for_results(c, s, a, b): +async def test_as_completed_async_for_results(c, s, a, b): futures = c.map(inc, range(10)) ac = as_completed(futures, with_results=True) results = [] @@ -5794,20 +5785,20 @@ async def f(): async for future, result in ac: results.append(result) - yield f() + await f() assert set(results) == set(range(1, 11)) assert not s.counters["op"].components[0]["gather"] @gen_cluster(client=True) -def test_as_completed_async_for_cancel(c, s, a, b): +async def test_as_completed_async_for_cancel(c, s, a, b): x = c.submit(inc, 1) y = c.submit(sleep, 0.3) ac = as_completed([x, y]) async def _(): - await gen.sleep(0.1) + await asyncio.sleep(0.1) await y.cancel(asynchronous=True) c.loop.add_callback(_) @@ -5818,7 +5809,7 @@ async def f(): async for future in ac: L.append(future) - yield f() + await f() assert L == [x, y] @@ -5845,7 +5836,7 @@ async def f(): def test_client_sync_with_async_def(loop): async def ff(): - await gen.sleep(0.01) + await asyncio.sleep(0.01) return 1 with cluster() as (s, [a, b]): @@ -5888,13 +5879,13 @@ async def test_dont_hold_on_to_large_messages(c, s, a, b): ) pytest.fail("array should have been destroyed") - await gen.sleep(0.200) + await asyncio.sleep(0.200) @gen_cluster(client=True) async def test_run_scheduler_async_def(c, s, a, b): async def f(dask_scheduler): - await gen.sleep(0.01) + await asyncio.sleep(0.01) dask_scheduler.foo = "bar" await c.run_on_scheduler(f) @@ -5902,7 +5893,7 @@ async def f(dask_scheduler): assert s.foo == "bar" async def f(dask_worker): - await gen.sleep(0.01) + await asyncio.sleep(0.01) dask_worker.foo = "bar" await c.run(f) @@ -5913,23 +5904,23 @@ async def f(dask_worker): @gen_cluster(client=True) async def test_run_scheduler_async_def_wait(c, s, a, b): async def f(dask_scheduler): - await gen.sleep(0.01) + await asyncio.sleep(0.01) dask_scheduler.foo = "bar" await c.run_on_scheduler(f, wait=False) while not hasattr(s, "foo"): - await gen.sleep(0.01) + await asyncio.sleep(0.01) assert s.foo == "bar" async def f(dask_worker): - await gen.sleep(0.01) + await asyncio.sleep(0.01) dask_worker.foo = "bar" await c.run(f, wait=False) while not hasattr(a, "foo") or not hasattr(b, "foo"): - await gen.sleep(0.01) + await asyncio.sleep(0.01) assert a.foo == "bar" assert b.foo == "bar" @@ -5973,7 +5964,7 @@ async def test_client_gather_semaphor_loop(cleanup): @gen_cluster(client=True) -def test_as_completed_condition_loop(c, s, a, b): +async def test_as_completed_condition_loop(c, s, a, b): seq = c.map(inc, range(5)) ac = as_completed(seq) assert ac.condition._loop == c.loop.asyncio_loop diff --git a/distributed/tests/test_collections.py b/distributed/tests/test_collections.py index 61424c68f38..b9af1ef0222 100644 --- a/distributed/tests/test_collections.py +++ b/distributed/tests/test_collections.py @@ -34,7 +34,7 @@ def assert_equal(a, b): @gen_cluster(timeout=240, client=True) -def test_dataframes(c, s, a, b): +async def test_dataframes(c, s, a, b): df = pd.DataFrame( {"x": np.random.random(1000), "y": np.random.random(1000)}, index=np.arange(1000), @@ -46,7 +46,7 @@ def test_dataframes(c, s, a, b): assert rdf.divisions == ldf.divisions remote = c.compute(rdf) - result = yield remote + result = await remote tm.assert_frame_equal(result, ldf.compute(scheduler="sync")) @@ -63,19 +63,19 @@ def test_dataframes(c, s, a, b): for f in exprs: local = f(ldf).compute(scheduler="sync") remote = c.compute(f(rdf)) - remote = yield remote + remote = await remote assert_equal(local, remote) @gen_cluster(client=True) -def test__dask_array_collections(c, s, a, b): +async def test_dask_array_collections(c, s, a, b): import dask.array as da s.validate = False x_dsk = {("x", i, j): np.random.random((3, 3)) for i in range(3) for j in range(2)} y_dsk = {("y", i, j): np.random.random((3, 3)) for i in range(2) for j in range(3)} - x_futures = yield c.scatter(x_dsk) - y_futures = yield c.scatter(y_dsk) + x_futures = await c.scatter(x_dsk) + y_futures = await c.scatter(y_dsk) dt = np.random.random(0).dtype x_local = da.Array(x_dsk, "x", ((3, 3, 3), (3, 3)), dt) @@ -95,13 +95,13 @@ def test__dask_array_collections(c, s, a, b): local = expr(x_local, y_local).compute(scheduler="sync") remote = c.compute(expr(x_remote, y_remote)) - remote = yield remote + remote = await remote assert np.all(local == remote) @gen_cluster(client=True) -def test_bag_groupby_tasks_default(c, s, a, b): +async def test_bag_groupby_tasks_default(c, s, a, b): b = db.range(100, npartitions=10) b2 = b.groupby(lambda x: x % 13) assert not any("partd" in k[0] for k in b2.dask) @@ -147,11 +147,11 @@ def test_rolling_sync(client): @gen_cluster(client=True) -def test_loc(c, s, a, b): +async def test_loc(c, s, a, b): df = make_time_dataframe() ddf = dd.from_pandas(df, npartitions=10) future = c.compute(ddf.loc["2000-01-17":"2000-01-24"]) - yield future + await future def test_dataframe_groupby_tasks(client): @@ -182,7 +182,7 @@ def test_dataframe_groupby_tasks(client): @gen_cluster(client=True) -def test_sparse_arrays(c, s, a, b): +async def test_sparse_arrays(c, s, a, b): sparse = pytest.importorskip("sparse") da = pytest.importorskip("dask.array") @@ -191,13 +191,13 @@ def test_sparse_arrays(c, s, a, b): s = x.map_blocks(sparse.COO) future = c.compute(s.sum(axis=0)[:10]) - yield future + await future @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)]) -def test_delayed_none(c, s, w): +async def test_delayed_none(c, s, w): x = dask.delayed(None) y = dask.delayed(123) [xx, yy] = c.compute([x, y]) - assert (yield xx) is None - assert (yield yy) == 123 + assert await xx is None + assert await yy == 123 diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index 49033a6a11e..c75f9c48cf6 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -732,7 +732,7 @@ async def f(): @gen_cluster() -def test_thread_id(s, a, b): +async def test_thread_id(s, a, b): assert s.thread_id == a.thread_id == b.thread_id == threading.get_ident() diff --git a/distributed/tests/test_failed_workers.py b/distributed/tests/test_failed_workers.py index 99b1b4a42a7..e1556494fe2 100644 --- a/distributed/tests/test_failed_workers.py +++ b/distributed/tests/test_failed_workers.py @@ -1,10 +1,10 @@ +import asyncio import os import random from time import sleep import pytest from tlz import partition_all, first -from tornado import gen from dask import delayed from distributed import Client, Nanny, wait @@ -35,30 +35,30 @@ def test_submit_after_failed_worker_sync(loop): @gen_cluster(client=True, timeout=60, active_rpc_timeout=10) -def test_submit_after_failed_worker_async(c, s, a, b): - n = yield Nanny(s.address, nthreads=2, loop=s.loop) +async def test_submit_after_failed_worker_async(c, s, a, b): + n = await Nanny(s.address, nthreads=2, loop=s.loop) while len(s.workers) < 3: - yield gen.sleep(0.1) + await asyncio.sleep(0.1) L = c.map(inc, range(10)) - yield wait(L) + await wait(L) s.loop.add_callback(n.kill) total = c.submit(sum, L) - result = yield total + result = await total assert result == sum(map(inc, range(10))) - yield n.close() + await n.close() @gen_cluster(client=True, timeout=60) -def test_submit_after_failed_worker(c, s, a, b): +async def test_submit_after_failed_worker(c, s, a, b): L = c.map(inc, range(10)) - yield wait(L) - yield a.close() + await wait(L) + await a.close() total = c.submit(sum, L) - result = yield total + result = await total assert result == sum(map(inc, range(10))) @@ -78,73 +78,73 @@ def test_gather_after_failed_worker(loop): nthreads=[("127.0.0.1", 1)] * 4, config={"distributed.comm.timeouts.connect": "1s"}, ) -def test_gather_then_submit_after_failed_workers(c, s, w, x, y, z): +async def test_gather_then_submit_after_failed_workers(c, s, w, x, y, z): L = c.map(inc, range(20)) - yield wait(L) + await wait(L) w.process.process._process.terminate() total = c.submit(sum, L) for i in range(3): - yield wait(total) + await wait(total) addr = first(s.tasks[total.key].who_has).address for worker in [x, y, z]: if worker.worker_address == addr: worker.process.process._process.terminate() break - result = yield c.gather([total]) + result = await c.gather([total]) assert result == [sum(map(inc, range(20)))] @gen_cluster(Worker=Nanny, timeout=60, client=True) -def test_failed_worker_without_warning(c, s, a, b): +async def test_failed_worker_without_warning(c, s, a, b): L = c.map(inc, range(10)) - yield wait(L) + await wait(L) original_pid = a.pid with ignoring(CommClosedError): - yield c._run(os._exit, 1, workers=[a.worker_address]) + await c._run(os._exit, 1, workers=[a.worker_address]) start = time() while a.pid == original_pid: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() - start < 10 - yield gen.sleep(0.5) + await asyncio.sleep(0.5) start = time() while len(s.nthreads) < 2: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() - start < 10 - yield wait(L) + await wait(L) L2 = c.map(inc, range(10, 20)) - yield wait(L2) + await wait(L2) assert all(len(keys) > 0 for keys in s.has_what.values()) nthreads2 = dict(s.nthreads) - yield c.restart() + await c.restart() L = c.map(inc, range(10)) - yield wait(L) + await wait(L) assert all(len(keys) > 0 for keys in s.has_what.values()) assert not (set(nthreads2) & set(s.nthreads)) # no overlap @gen_cluster(Worker=Nanny, client=True, timeout=60) -def test_restart(c, s, a, b): +async def test_restart(c, s, a, b): assert s.nthreads == {a.worker_address: 1, b.worker_address: 2} x = c.submit(inc, 1) y = c.submit(inc, x) z = c.submit(div, 1, 0) - yield y + await y assert set(s.who_has) == {x.key, y.key} - f = yield c.restart() + f = await c.restart() assert f is c assert len(s.workers) == 2 @@ -162,12 +162,12 @@ def test_restart(c, s, a, b): @gen_cluster(Worker=Nanny, client=True, timeout=60) -def test_restart_cleared(c, s, a, b): +async def test_restart_cleared(c, s, a, b): x = 2 * delayed(1) + 1 f = c.compute(x) - yield wait([f]) + await wait([f]) - yield c.restart() + await c.restart() for coll in [s.tasks, s.unrunnable]: assert not coll @@ -204,18 +204,18 @@ def test_restart_sync(loop): @gen_cluster(Worker=Nanny, client=True, timeout=60) -def test_restart_fast(c, s, a, b): +async def test_restart_fast(c, s, a, b): L = c.map(sleep, range(10)) start = time() - yield c.restart() + await c.restart() assert time() - start < 10 assert len(s.nthreads) == 2 assert all(x.status == "cancelled" for x in L) x = c.submit(inc, 1) - result = yield x + result = await x assert result == 2 @@ -247,51 +247,51 @@ def test_restart_fast_sync(loop): @gen_cluster(Worker=Nanny, client=True, timeout=60) -def test_fast_kill(c, s, a, b): +async def test_fast_kill(c, s, a, b): L = c.map(sleep, range(10)) start = time() - yield c.restart() + await c.restart() assert time() - start < 10 assert all(x.status == "cancelled" for x in L) x = c.submit(inc, 1) - result = yield x + result = await x assert result == 2 @gen_cluster(Worker=Nanny, timeout=60) -def test_multiple_clients_restart(s, a, b): - c1 = yield Client(s.address, asynchronous=True) - c2 = yield Client(s.address, asynchronous=True) +async def test_multiple_clients_restart(s, a, b): + c1 = await Client(s.address, asynchronous=True) + c2 = await Client(s.address, asynchronous=True) x = c1.submit(inc, 1) y = c2.submit(inc, 2) - xx = yield x - yy = yield y + xx = await x + yy = await y assert xx == 2 assert yy == 3 - yield c1.restart() + await c1.restart() assert x.cancelled() start = time() while not y.cancelled(): - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 5 - yield c1.close() - yield c2.close() + await c1.close() + await c2.close() @gen_cluster(Worker=Nanny, timeout=60) -def test_restart_scheduler(s, a, b): +async def test_restart_scheduler(s, a, b): import gc gc.collect() addrs = (a.worker_address, b.worker_address) - yield s.restart() + await s.restart() assert len(s.nthreads) == 2 addrs2 = (a.worker_address, b.worker_address) @@ -299,26 +299,26 @@ def test_restart_scheduler(s, a, b): @gen_cluster(Worker=Nanny, client=True, timeout=60) -def test_forgotten_futures_dont_clean_up_new_futures(c, s, a, b): +async def test_forgotten_futures_dont_clean_up_new_futures(c, s, a, b): x = c.submit(inc, 1) - yield c.restart() + await c.restart() y = c.submit(inc, 1) del x import gc gc.collect() - yield gen.sleep(0.1) - yield y + await asyncio.sleep(0.1) + await y @gen_cluster(client=True, timeout=60, active_rpc_timeout=10) -def test_broken_worker_during_computation(c, s, a, b): +async def test_broken_worker_during_computation(c, s, a, b): s.allowed_failures = 100 - n = yield Nanny(s.address, nthreads=2, loop=s.loop) + n = await Nanny(s.address, nthreads=2, loop=s.loop) start = time() while len(s.nthreads) < 3: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 5 N = 256 @@ -333,37 +333,37 @@ def test_broken_worker_during_computation(c, s, a, b): key=["add-%d-%d" % (i, j) for j in range(len(L) // 2)] ) - yield gen.sleep(random.random() / 20) + await asyncio.sleep(random.random() / 20) with ignoring(CommClosedError): # comm will be closed abrupty - yield c._run(os._exit, 1, workers=[n.worker_address]) + await c._run(os._exit, 1, workers=[n.worker_address]) - yield gen.sleep(random.random() / 20) + await asyncio.sleep(random.random() / 20) while len(s.workers) < 3: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) with ignoring( CommClosedError, EnvironmentError ): # perhaps new worker can't be contacted yet - yield c._run(os._exit, 1, workers=[n.worker_address]) + await c._run(os._exit, 1, workers=[n.worker_address]) - [result] = yield c.gather(L) + [result] = await c.gather(L) assert isinstance(result, int) assert result == expected_result - yield n.close() + await n.close() @gen_cluster(client=True, Worker=Nanny, timeout=60) -def test_restart_during_computation(c, s, a, b): +async def test_restart_during_computation(c, s, a, b): xs = [delayed(slowinc)(i, delay=0.01) for i in range(50)] ys = [delayed(slowinc)(i, delay=0.01) for i in xs] zs = [delayed(slowadd)(x, y, delay=0.01) for x, y in zip(xs, ys)] total = delayed(sum)(zs) result = c.compute(total) - yield gen.sleep(0.5) + await asyncio.sleep(0.5) assert s.rprocessing - yield c.restart() + await c.restart() assert not s.rprocessing assert len(s.nthreads) == 2 @@ -371,59 +371,59 @@ def test_restart_during_computation(c, s, a, b): @gen_cluster(client=True, timeout=60) -def test_worker_who_has_clears_after_failed_connection(c, s, a, b): - n = yield Nanny(s.address, nthreads=2, loop=s.loop) +async def test_worker_who_has_clears_after_failed_connection(c, s, a, b): + n = await Nanny(s.address, nthreads=2, loop=s.loop) start = time() while len(s.nthreads) < 3: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 5 futures = c.map(slowinc, range(20), delay=0.01, key=["f%d" % i for i in range(20)]) - yield wait(futures) + await wait(futures) - result = yield c.submit(sum, futures, workers=a.address) + result = await c.submit(sum, futures, workers=a.address) for dep in set(a.dep_state) - set(a.task_state): a.release_dep(dep, report=True) n_worker_address = n.worker_address with ignoring(CommClosedError): - yield c._run(os._exit, 1, workers=[n_worker_address]) + await c._run(os._exit, 1, workers=[n_worker_address]) while len(s.workers) > 2: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) total = c.submit(sum, futures, workers=a.address) - yield total + await total assert not a.has_what.get(n_worker_address) assert not any(n_worker_address in s for s in a.who_has.values()) - yield n.close() + await n.close() @pytest.mark.slow @gen_cluster(client=True, timeout=60, Worker=Nanny, nthreads=[("127.0.0.1", 1)]) -def test_restart_timeout_on_long_running_task(c, s, a): +async def test_restart_timeout_on_long_running_task(c, s, a): with captured_logger("distributed.scheduler") as sio: future = c.submit(sleep, 3600) - yield gen.sleep(0.1) - yield c.restart(timeout=20) + await asyncio.sleep(0.1) + await c.restart(timeout=20) text = sio.getvalue() assert "timeout" not in text.lower() @gen_cluster(client=True, scheduler_kwargs={"worker_ttl": "500ms"}) -def test_worker_time_to_live(c, s, a, b): +async def test_worker_time_to_live(c, s, a, b): assert set(s.workers) == {a.address, b.address} a.periodic_callbacks["heartbeat"].stop() - yield gen.sleep(0.010) + await asyncio.sleep(0.010) assert set(s.workers) == {a.address, b.address} start = time() while set(s.workers) == {a.address, b.address}: - yield gen.sleep(0.050) + await asyncio.sleep(0.050) assert time() < start + 2 set(s.workers) == {b.address} diff --git a/distributed/tests/test_locks.py b/distributed/tests/test_locks.py index 4cf756ef178..0d22fc6cee9 100644 --- a/distributed/tests/test_locks.py +++ b/distributed/tests/test_locks.py @@ -10,8 +10,8 @@ @gen_cluster(client=True, nthreads=[("127.0.0.1", 8)] * 2) -def test_lock(c, s, a, b): - yield c.set_metadata("locked", False) +async def test_lock(c, s, a, b): + await c.set_metadata("locked", False) def f(x): client = get_client() @@ -23,16 +23,16 @@ def f(x): client.set_metadata("locked", False) futures = c.map(f, range(20)) - results = yield futures + await c.gather(futures) assert not s.extensions["locks"].events assert not s.extensions["locks"].ids @gen_cluster(client=True) -def test_timeout(c, s, a, b): +async def test_timeout(c, s, a, b): locks = s.extensions["locks"] lock = Lock("x") - result = yield lock.acquire() + result = await lock.acquire() assert result is True assert locks.ids["x"] == lock.id @@ -40,35 +40,35 @@ def test_timeout(c, s, a, b): assert lock.id != lock2.id start = time() - result = yield lock2.acquire(timeout=0.1) + result = await lock2.acquire(timeout=0.1) stop = time() assert stop - start < 0.3 assert result is False assert locks.ids["x"] == lock.id assert not locks.events["x"] - yield lock.release() + await lock.release() @gen_cluster(client=True) -def test_acquires_with_zero_timeout(c, s, a, b): +async def test_acquires_with_zero_timeout(c, s, a, b): lock = Lock("x") - yield lock.acquire(timeout=0) + await lock.acquire(timeout=0) assert lock.locked() - yield lock.release() + await lock.release() - yield lock.acquire(timeout=1) - yield lock.release() - yield lock.acquire(timeout=1) - yield lock.release() + await lock.acquire(timeout=1) + await lock.release() + await lock.acquire(timeout=1) + await lock.release() @gen_cluster(client=True) -def test_acquires_blocking(c, s, a, b): +async def test_acquires_blocking(c, s, a, b): lock = Lock("x") - yield lock.acquire(blocking=False) + await lock.acquire(blocking=False) assert lock.locked() - yield lock.release() + await lock.release() assert not lock.locked() with pytest.raises(ValueError): @@ -81,10 +81,10 @@ def test_timeout_sync(client): @gen_cluster(client=True) -def test_errors(c, s, a, b): +async def test_errors(c, s, a, b): lock = Lock("x") with pytest.raises(ValueError): - yield lock.release() + await lock.release() def test_lock_sync(client): @@ -103,19 +103,19 @@ def f(x): @gen_cluster(client=True) -def test_lock_types(c, s, a, b): +async def test_lock_types(c, s, a, b): for name in [1, ("a", 1), ["a", 1], b"123", "123"]: lock = Lock(name) assert lock.name == name - yield lock.acquire() - yield lock.release() + await lock.acquire() + await lock.release() assert not s.extensions["locks"].events @gen_cluster(client=True) -def test_serializable(c, s, a, b): +async def test_serializable(c, s, a, b): def f(x, lock=None): with lock: assert lock.name == "x" @@ -123,7 +123,7 @@ def f(x, lock=None): lock = Lock("x") futures = c.map(f, range(10), lock=lock) - yield c.gather(futures) + await c.gather(futures) lock2 = pickle.loads(pickle.dumps(lock)) assert lock2.name == lock.name diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index 2a19bdf8742..2c7a6f83671 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -10,7 +10,6 @@ import pytest from tlz import valmap, first -from tornado import gen from tornado.ioloop import IOLoop import dask @@ -29,7 +28,8 @@ ) -@gen_cluster(nthreads=[]) +# FIXME why does this leave behind unclosed Comm objects? +@gen_cluster(nthreads=[], allow_unclosed=True) async def test_nanny(s): async with Nanny(s.address, nthreads=2, loop=s.loop) as n: async with rpc(n.address) as nn: @@ -60,16 +60,16 @@ async def test_nanny(s): @gen_cluster(nthreads=[]) -def test_many_kills(s): - n = yield Nanny(s.address, nthreads=2, loop=s.loop) +async def test_many_kills(s): + n = await Nanny(s.address, nthreads=2, loop=s.loop) assert n.is_alive() - yield [n.kill() for i in range(5)] - yield [n.kill() for i in range(5)] - yield n.close() + await asyncio.gather(*(n.kill() for _ in range(5))) + await asyncio.gather(*(n.kill() for _ in range(5))) + await n.close() @gen_cluster(Worker=Nanny) -def test_str(s, a, b): +async def test_str(s, a, b): assert a.worker_address in str(a) assert a.worker_address in repr(a) assert str(a.nthreads) in str(a) @@ -77,59 +77,59 @@ def test_str(s, a, b): @gen_cluster(nthreads=[], timeout=20, client=True) -def test_nanny_process_failure(c, s): - n = yield Nanny(s.address, nthreads=2, loop=s.loop) +async def test_nanny_process_failure(c, s): + n = await Nanny(s.address, nthreads=2, loop=s.loop) first_dir = n.worker_dir assert os.path.exists(first_dir) original_address = n.worker_address ww = rpc(n.worker_address) - yield ww.update_data(data=valmap(dumps, {"x": 1, "y": 2})) + await ww.update_data(data=valmap(dumps, {"x": 1, "y": 2})) pid = n.pid assert pid is not None with ignoring(CommClosedError): - yield c.run(os._exit, 0, workers=[n.worker_address]) + await c.run(os._exit, 0, workers=[n.worker_address]) start = time() while n.pid == pid: # wait while process dies and comes back - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() - start < 5 start = time() - yield gen.sleep(1) + await asyncio.sleep(1) while not n.is_alive(): # wait while process comes back - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() - start < 5 # assert n.worker_address != original_address # most likely start = time() while n.worker_address not in s.nthreads or n.worker_dir is None: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() - start < 5 second_dir = n.worker_dir - yield n.close() + await n.close() assert not os.path.exists(second_dir) assert not os.path.exists(first_dir) assert first_dir != n.worker_dir - yield ww.close_rpc() + await ww.close_rpc() s.stop() @gen_cluster(nthreads=[]) -def test_run(s): +async def test_run(s): pytest.importorskip("psutil") - n = yield Nanny(s.address, nthreads=2, loop=s.loop) + n = await Nanny(s.address, nthreads=2, loop=s.loop) with rpc(n.address) as nn: - response = yield nn.run(function=dumps(lambda: 1)) + response = await nn.run(function=dumps(lambda: 1)) assert response["status"] == "OK" assert response["result"] == 1 - yield n.close() + await n.close() @pytest.mark.slow @@ -150,12 +150,12 @@ async def test_no_hang_when_scheduler_closes(s, a, b): @gen_cluster( Worker=Nanny, nthreads=[("127.0.0.1", 1)], worker_kwargs={"reconnect": False} ) -def test_close_on_disconnect(s, w): - yield s.close() +async def test_close_on_disconnect(s, w): + await s.close() start = time() while w.status != "closed": - yield gen.sleep(0.05) + await asyncio.sleep(0.05) assert time() < start + 9 @@ -165,70 +165,69 @@ class Something(Worker): @gen_cluster(client=True, Worker=Nanny) -def test_nanny_worker_class(c, s, w1, w2): - out = yield c._run(lambda dask_worker=None: str(dask_worker.__class__)) +async def test_nanny_worker_class(c, s, w1, w2): + out = await c._run(lambda dask_worker=None: str(dask_worker.__class__)) assert "Worker" in list(out.values())[0] assert w1.Worker is Worker @gen_cluster(client=True, Worker=Nanny, worker_kwargs={"worker_class": Something}) -def test_nanny_alt_worker_class(c, s, w1, w2): - out = yield c._run(lambda dask_worker=None: str(dask_worker.__class__)) +async def test_nanny_alt_worker_class(c, s, w1, w2): + out = await c._run(lambda dask_worker=None: str(dask_worker.__class__)) assert "Something" in list(out.values())[0] assert w1.Worker is Something @pytest.mark.slow @gen_cluster(client=False, nthreads=[]) -def test_nanny_death_timeout(s): - yield s.close() +async def test_nanny_death_timeout(s): + await s.close() w = Nanny(s.address, death_timeout=1) with pytest.raises(TimeoutError): - yield w + await w assert w.status == "closed" @gen_cluster(client=True, Worker=Nanny) -def test_random_seed(c, s, a, b): - @gen.coroutine - def check_func(func): +async def test_random_seed(c, s, a, b): + async def check_func(func): x = c.submit(func, 0, 2 ** 31, pure=False, workers=a.worker_address) y = c.submit(func, 0, 2 ** 31, pure=False, workers=b.worker_address) assert x.key != y.key - x = yield x - y = yield y + x = await x + y = await y assert x != y - yield check_func(lambda a, b: random.randint(a, b)) - yield check_func(lambda a, b: np.random.randint(a, b)) + await check_func(lambda a, b: random.randint(a, b)) + await check_func(lambda a, b: np.random.randint(a, b)) @pytest.mark.skipif( sys.platform.startswith("win"), reason="num_fds not supported on windows" ) @gen_cluster(client=False, nthreads=[]) -def test_num_fds(s): +async def test_num_fds(s): psutil = pytest.importorskip("psutil") proc = psutil.Process() # Warm up - w = yield Nanny(s.address) - yield w.close() + w = await Nanny(s.address) + await w.close() del w gc.collect() before = proc.num_fds() for i in range(3): - w = yield Nanny(s.address) - yield gen.sleep(0.1) - yield w.close() + w = await Nanny(s.address) + await asyncio.sleep(0.1) + await w.close() start = time() while proc.num_fds() > before: print("fds:", before, proc.num_fds()) - yield gen.sleep(0.1) + await asyncio.sleep(0.1) assert time() < start + 10 @@ -236,42 +235,42 @@ def test_num_fds(s): not sys.platform.startswith("linux"), reason="Need 127.0.0.2 to mean localhost" ) @gen_cluster(client=True, nthreads=[]) -def test_worker_uses_same_host_as_nanny(c, s): +async def test_worker_uses_same_host_as_nanny(c, s): for host in ["tcp://0.0.0.0", "tcp://127.0.0.2"]: - n = yield Nanny(s.address, host=host) + n = await Nanny(s.address, host=host) def func(dask_worker): return dask_worker.listener.listen_address - result = yield c.run(func) + result = await c.run(func) assert host in first(result.values()) - yield n.close() + await n.close() @gen_test() -def test_scheduler_file(): +async def test_scheduler_file(): with tmpfile() as fn: - s = yield Scheduler(scheduler_file=fn, port=8008) - w = yield Nanny(scheduler_file=fn) + s = await Scheduler(scheduler_file=fn, port=8008) + w = await Nanny(scheduler_file=fn) assert set(s.workers) == {w.worker_address} - yield w.close() + await w.close() s.stop() @gen_cluster(client=True, Worker=Nanny, nthreads=[("127.0.0.1", 2)]) -def test_nanny_timeout(c, s, a): - x = yield c.scatter(123) +async def test_nanny_timeout(c, s, a): + x = await c.scatter(123) with captured_logger( logging.getLogger("distributed.nanny"), level=logging.ERROR ) as logger: - response = yield a.restart(timeout=0.1) + response = await a.restart(timeout=0.1) out = logger.getvalue() assert "timed out" in out.lower() start = time() while x.status != "cancelled": - yield gen.sleep(0.1) + await asyncio.sleep(0.1) assert time() < start + 7 @@ -283,7 +282,7 @@ def test_nanny_timeout(c, s, a): timeout=20, clean_kwargs={"threads": False}, ) -def test_nanny_terminate(c, s, a): +async def test_nanny_terminate(c, s, a): from time import sleep def leak(): @@ -297,7 +296,7 @@ def leak(): future = c.submit(leak) start = time() while a.process.pid == proc: - yield gen.sleep(0.1) + await asyncio.sleep(0.1) assert time() < start + 10 out = logger.getvalue() assert "restart" in out.lower() @@ -339,45 +338,45 @@ def pause(dask_worker): @gen_cluster(nthreads=[], client=True) -def test_avoid_memory_monitor_if_zero_limit(c, s): - nanny = yield Nanny(s.address, loop=s.loop, memory_limit=0) - typ = yield c.run(lambda dask_worker: type(dask_worker.data)) +async def test_avoid_memory_monitor_if_zero_limit(c, s): + nanny = await Nanny(s.address, loop=s.loop, memory_limit=0) + typ = await c.run(lambda dask_worker: type(dask_worker.data)) assert typ == {nanny.worker_address: dict} - pcs = yield c.run(lambda dask_worker: list(dask_worker.periodic_callbacks)) + pcs = await c.run(lambda dask_worker: list(dask_worker.periodic_callbacks)) assert "memory" not in pcs assert "memory" not in nanny.periodic_callbacks future = c.submit(inc, 1) - assert (yield future) == 2 - yield gen.sleep(0.02) + assert await future == 2 + await asyncio.sleep(0.02) - yield c.submit(inc, 2) # worker doesn't pause + await c.submit(inc, 2) # worker doesn't pause - yield nanny.close() + await nanny.close() @gen_cluster(nthreads=[], client=True) -def test_scheduler_address_config(c, s): +async def test_scheduler_address_config(c, s): with dask.config.set({"scheduler-address": s.address}): - nanny = yield Nanny(loop=s.loop) + nanny = await Nanny(loop=s.loop) assert nanny.scheduler.address == s.address start = time() while not s.workers: - yield gen.sleep(0.1) + await asyncio.sleep(0.1) assert time() < start + 10 - yield nanny.close() + await nanny.close() @pytest.mark.slow @gen_test(timeout=20) -def test_wait_for_scheduler(): +async def test_wait_for_scheduler(): with captured_logger("distributed") as log: w = Nanny("127.0.0.1:44737") IOLoop.current().add_callback(w.start) - yield gen.sleep(6) - yield w.close() + await asyncio.sleep(6) + await w.close() log = log.getvalue() assert "error" not in log.lower(), log @@ -385,31 +384,31 @@ def test_wait_for_scheduler(): @gen_cluster(nthreads=[], client=True) -def test_environment_variable(c, s): +async def test_environment_variable(c, s): a = Nanny(s.address, loop=s.loop, memory_limit=0, env={"FOO": "123"}) b = Nanny(s.address, loop=s.loop, memory_limit=0, env={"FOO": "456"}) - yield [a, b] - results = yield c.run(lambda: os.environ["FOO"]) + await asyncio.gather(a, b) + results = await c.run(lambda: os.environ["FOO"]) assert results == {a.worker_address: "123", b.worker_address: "456"} - yield [a.close(), b.close()] + await asyncio.gather(a.close(), b.close()) @gen_cluster(nthreads=[], client=True) -def test_data_types(c, s): - w = yield Nanny(s.address, data=dict) - r = yield c.run(lambda dask_worker: type(dask_worker.data)) +async def test_data_types(c, s): + w = await Nanny(s.address, data=dict) + r = await c.run(lambda dask_worker: type(dask_worker.data)) assert r[w.worker_address] == dict - yield w.close() + await w.close() @gen_cluster(nthreads=[]) -def test_local_directory(s): +async def test_local_directory(s): with tmpfile() as fn: with dask.config.set(temporary_directory=fn): - w = yield Nanny(s.address) + w = await Nanny(s.address) assert w.local_directory.startswith(fn) assert "dask-worker-space" in w.local_directory - yield w.close() + await w.close() def _noop(x): @@ -423,13 +422,13 @@ def _noop(x): Worker=Nanny, config={"distributed.worker.daemon": False}, ) -def test_mp_process_worker_no_daemon(c, s, a): +async def test_mp_process_worker_no_daemon(c, s, a): def multiprocessing_worker(): p = mp.Process(target=_noop, args=(None,)) p.start() p.join() - yield c.submit(multiprocessing_worker) + await c.submit(multiprocessing_worker) @gen_cluster( @@ -438,12 +437,12 @@ def multiprocessing_worker(): Worker=Nanny, config={"distributed.worker.daemon": False}, ) -def test_mp_pool_worker_no_daemon(c, s, a): +async def test_mp_pool_worker_no_daemon(c, s, a): def pool_worker(world_size): with mp.Pool(processes=world_size) as p: p.map(_noop, range(world_size)) - yield c.submit(pool_worker, 4) + await c.submit(pool_worker, 4) @pytest.mark.asyncio @@ -490,7 +489,7 @@ async def test_nanny_closes_cleanly(cleanup): IOLoop.current().add_callback(w.terminate) start = time() while n.status != "closed": - await gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 5 assert n.status == "closed" diff --git a/distributed/tests/test_priorities.py b/distributed/tests/test_priorities.py index ae96517f1ac..cd4344da840 100644 --- a/distributed/tests/test_priorities.py +++ b/distributed/tests/test_priorities.py @@ -1,5 +1,6 @@ +import asyncio + import pytest -from tornado import gen from dask.core import flatten import dask @@ -66,29 +67,29 @@ async def test_persist(c, s): @gen_cluster(client=True) -def test_expand_compute(c, s, a, b): +async def test_expand_compute(c, s, a, b): low = delayed(inc)(1) many = [delayed(slowinc)(i, delay=0.1) for i in range(10)] high = delayed(inc)(2) low, many, high = c.compute([low, many, high], priority={low: -1, high: 1}) - yield wait(high) + await wait(high) assert s.tasks[low.key].state == "processing" @gen_cluster(client=True) -def test_expand_persist(c, s, a, b): +async def test_expand_persist(c, s, a, b): low = delayed(inc)(1, dask_key_name="low") many = [delayed(slowinc)(i, delay=0.1) for i in range(4)] high = delayed(inc)(2, dask_key_name="high") low, high, x, y, z, w = persist(low, high, *many, priority={low: -1, high: 1}) - yield wait(high) + await wait(high) assert s.tasks[low.key].state == "processing" @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)]) -def test_repeated_persists_same_priority(c, s, w): +async def test_repeated_persists_same_priority(c, s, w): xs = [delayed(slowinc)(i, delay=0.05, dask_key_name="x-%d" % i) for i in range(10)] ys = [ delayed(slowinc)(x, delay=0.05, dask_key_name="y-%d" % i) @@ -105,19 +106,19 @@ def test_repeated_persists_same_priority(c, s, w): while ( sum(t.state == "memory" for t in s.tasks.values()) < 5 ): # TODO: reduce this number - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert any(s.tasks[y.key].state == "memory" for y in ys) assert any(s.tasks[z.key].state == "memory" for z in zs) @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)]) -def test_last_in_first_out(c, s, w): +async def test_last_in_first_out(c, s, w): xs = [c.submit(slowinc, i, delay=0.05) for i in range(5)] ys = [c.submit(slowinc, x, delay=0.05) for x in xs] zs = [c.submit(slowinc, y, delay=0.05) for y in ys] while len(s.tasks) < 15 or not any(s.tasks[z.key].state == "memory" for z in zs): - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert not all(s.tasks[x.key].state == "memory" for x in xs) diff --git a/distributed/tests/test_publish.py b/distributed/tests/test_publish.py index dde10b11cf1..ab32d52a112 100644 --- a/distributed/tests/test_publish.py +++ b/distributed/tests/test_publish.py @@ -1,3 +1,4 @@ +import asyncio import pytest from dask import delayed @@ -7,101 +8,93 @@ from distributed.utils_test import gen_cluster, inc from distributed.utils_test import client, cluster_fixture, loop # noqa F401 from distributed.protocol import Serialized -from tornado import gen @gen_cluster(client=False) -def test_publish_simple(s, a, b): - c = yield Client(s.address, asynchronous=True) - f = yield Client(s.address, asynchronous=True) +async def test_publish_simple(s, a, b): + c = Client(s.address, asynchronous=True) + f = Client(s.address, asynchronous=True) + await asyncio.gather(c, f) - data = yield c.scatter(range(3)) - out = yield c.publish_dataset(data=data) + data = await c.scatter(range(3)) + await c.publish_dataset(data=data) assert "data" in s.extensions["publish"].datasets assert isinstance(s.extensions["publish"].datasets["data"]["data"], Serialized) with pytest.raises(KeyError) as exc_info: - out = yield c.publish_dataset(data=data) + await c.publish_dataset(data=data) assert "exists" in str(exc_info.value) assert "data" in str(exc_info.value) - result = yield c.scheduler.publish_list() + result = await c.scheduler.publish_list() assert result == ("data",) - result = yield f.scheduler.publish_list() + result = await f.scheduler.publish_list() assert result == ("data",) - yield c.close() - yield f.close() + await asyncio.gather(c.close(), f.close()) @gen_cluster(client=False) -def test_publish_non_string_key(s, a, b): - c = yield Client(s.address, asynchronous=True) - f = yield Client(s.address, asynchronous=True) - - try: +async def test_publish_non_string_key(s, a, b): + async with Client(s.address, asynchronous=True) as c: for name in [("a", "b"), 9.0, 8]: - data = yield c.scatter(range(3)) - out = yield c.publish_dataset(data, name=name) + data = await c.scatter(range(3)) + await c.publish_dataset(data, name=name) assert name in s.extensions["publish"].datasets assert isinstance( s.extensions["publish"].datasets[name]["data"], Serialized ) - datasets = yield c.scheduler.publish_list() + datasets = await c.scheduler.publish_list() assert name in datasets - finally: - yield c.close() - yield f.close() - @gen_cluster(client=False) -def test_publish_roundtrip(s, a, b): - c = yield Client(s.address, asynchronous=True) - f = yield Client(s.address, asynchronous=True) +async def test_publish_roundtrip(s, a, b): + c = await Client(s.address, asynchronous=True) + f = await Client(s.address, asynchronous=True) - data = yield c.scatter([0, 1, 2]) - yield c.publish_dataset(data=data) + data = await c.scatter([0, 1, 2]) + await c.publish_dataset(data=data) assert "published-data" in s.who_wants[data[0].key] - result = yield f.get_dataset(name="data") + result = await f.get_dataset(name="data") assert len(result) == len(data) - out = yield f.gather(result) + out = await f.gather(result) assert out == [0, 1, 2] with pytest.raises(KeyError) as exc_info: - result = yield f.get_dataset(name="nonexistent") + await f.get_dataset(name="nonexistent") assert "not found" in str(exc_info.value) assert "nonexistent" in str(exc_info.value) - yield c.close() - yield f.close() + await c.close() + await f.close() @gen_cluster(client=True) -def test_unpublish(c, s, a, b): - data = yield c.scatter([0, 1, 2]) - yield c.publish_dataset(data=data) +async def test_unpublish(c, s, a, b): + data = await c.scatter([0, 1, 2]) + await c.publish_dataset(data=data) key = data[0].key del data - yield c.scheduler.publish_delete(name="data") + await c.scheduler.publish_delete(name="data") assert "data" not in s.extensions["publish"].datasets start = time() while key in s.who_wants: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 5 with pytest.raises(KeyError) as exc_info: - result = yield c.get_dataset(name="data") + await c.get_dataset(name="data") assert "not found" in str(exc_info.value) assert "data" in str(exc_info.value) @@ -113,19 +106,19 @@ def test_unpublish_sync(client): client.unpublish_dataset(name="data") with pytest.raises(KeyError) as exc_info: - result = client.get_dataset(name="data") + client.get_dataset(name="data") assert "not found" in str(exc_info.value) assert "data" in str(exc_info.value) @gen_cluster(client=True) -def test_publish_multiple_datasets(c, s, a, b): +async def test_publish_multiple_datasets(c, s, a, b): x = delayed(inc)(1) y = delayed(inc)(2) - yield c.publish_dataset(x=x, y=y) - datasets = yield c.scheduler.publish_list() + await c.publish_dataset(x=x, y=y) + datasets = await c.scheduler.publish_list() assert set(datasets) == {"x", "y"} @@ -136,7 +129,7 @@ def test_unpublish_multiple_datasets_sync(client): client.unpublish_dataset(name="x") with pytest.raises(KeyError) as exc_info: - result = client.get_dataset(name="x") + client.get_dataset(name="x") datasets = client.list_datasets() assert set(datasets) == {"y"} @@ -147,17 +140,17 @@ def test_unpublish_multiple_datasets_sync(client): client.unpublish_dataset(name="y") with pytest.raises(KeyError) as exc_info: - result = client.get_dataset(name="y") + client.get_dataset(name="y") assert "not found" in str(exc_info.value) assert "y" in str(exc_info.value) @gen_cluster(client=False) -def test_publish_bag(s, a, b): +async def test_publish_bag(s, a, b): db = pytest.importorskip("dask.bag") - c = yield Client(s.address, asynchronous=True) - f = yield Client(s.address, asynchronous=True) + c = await Client(s.address, asynchronous=True) + f = await Client(s.address, asynchronous=True) bag = db.from_sequence([0, 1, 2]) bagp = c.persist(bag) @@ -166,19 +159,19 @@ def test_publish_bag(s, a, b): keys = {f.key for f in futures_of(bagp)} assert keys == set(bag.dask) - yield c.publish_dataset(data=bagp) + await c.publish_dataset(data=bagp) # check that serialization didn't affect original bag's dask assert len(futures_of(bagp)) == 3 - result = yield f.get_dataset("data") + result = await f.get_dataset("data") assert set(result.dask.keys()) == set(bagp.dask.keys()) assert {f.key for f in result.dask.values()} == {f.key for f in bagp.dask.values()} - out = yield f.compute(result) + out = await f.compute(result) assert out == [0, 1, 2] - yield c.close() - yield f.close() + await c.close() + await f.close() def test_datasets_setitem(client): @@ -223,19 +216,16 @@ def test_datasets_iter(client): @gen_cluster(client=True) -def test_pickle_safe(c, s, a, b): - c2 = yield Client(s.address, asynchronous=True, serializers=["msgpack"]) - try: - yield c2.publish_dataset(x=[1, 2, 3]) - result = yield c2.get_dataset("x") +async def test_pickle_safe(c, s, a, b): + async with Client(s.address, asynchronous=True, serializers=["msgpack"]) as c2: + await c2.publish_dataset(x=[1, 2, 3]) + result = await c2.get_dataset("x") assert result == [1, 2, 3] with pytest.raises(TypeError): - yield c2.publish_dataset(y=lambda x: x) + await c2.publish_dataset(y=lambda x: x) - yield c.publish_dataset(z=lambda x: x) # this can use pickle + await c.publish_dataset(z=lambda x: x) # this can use pickle with pytest.raises(TypeError): - yield c2.get_dataset("z") - finally: - yield c2.close() + await c2.get_dataset("z") diff --git a/distributed/tests/test_pubsub.py b/distributed/tests/test_pubsub.py index 639542df5ca..212d29d4802 100644 --- a/distributed/tests/test_pubsub.py +++ b/distributed/tests/test_pubsub.py @@ -2,7 +2,6 @@ from time import sleep import pytest -from tornado import gen import tlz as toolz from distributed import Pub, Sub, wait, get_worker, TimeoutError @@ -11,7 +10,7 @@ @gen_cluster(client=True, timeout=None) -def test_speed(c, s, a, b): +async def test_speed(c, s, a, b): """ This tests how quickly we can move messages back and forth @@ -45,13 +44,13 @@ def pingpong(a, b, start=False, n=1000, msg=1): y = c.submit(pingpong, "b", "a", n=100) start = time() - yield c.gather([x, y]) + await c.gather([x, y]) stop = time() # print('duration', stop - start) # I get around 3ms/roundtrip on my laptop @gen_cluster(client=True, nthreads=[]) -def test_client(c, s): +async def test_client(c, s): with pytest.raises(Exception): get_worker() sub = Sub("a") @@ -62,17 +61,17 @@ def test_client(c, s): start = time() while not set(sps.client_subscribers["a"]) == {c.id}: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 3 pub.put(123) - result = yield sub.__anext__() + result = await sub.__anext__() assert result == 123 @gen_cluster(client=True) -def test_client_worker(c, s, a, b): +async def test_client_worker(c, s, a, b): sub = Sub("a", client=c, worker=None) def f(x): @@ -80,11 +79,11 @@ def f(x): pub.put(x) futures = c.map(f, range(10)) - yield wait(futures) + await wait(futures) L = [] for i in range(10): - result = yield sub.get() + result = await sub.get() L.append(result) assert set(L) == set(range(10)) @@ -101,7 +100,7 @@ def f(x): or bps.publishers["a"] or len(sps.client_subscribers["a"]) != 1 ): - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 3 del sub @@ -112,20 +111,20 @@ def f(x): or any(aps.publish_to_scheduler.values()) or any(bps.publish_to_scheduler.values()) ): - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 3 @gen_cluster(client=True) -def test_timeouts(c, s, a, b): +async def test_timeouts(c, s, a, b): sub = Sub("a", client=c, worker=None) start = time() with pytest.raises(TimeoutError): - yield sub.get(timeout=0.1) + await sub.get(timeout=0.1) stop = time() assert stop - start < 1 with pytest.raises(TimeoutError): - yield sub.get(timeout=0.01) + await sub.get(timeout=0.01) @gen_cluster(client=True) @@ -140,13 +139,13 @@ async def test_repr(c, s, a, b): @pytest.mark.xfail(reason="out of order execution") @gen_cluster(client=True) -def test_basic(c, s, a, b): +async def test_basic(c, s, a, b): async def publish(): pub = Pub("a") i = 0 while True: - await gen.sleep(0.01) + await asyncio.sleep(0.01) pub._put(i) i += 1 @@ -157,7 +156,7 @@ def f(_): asyncio.ensure_future(c.run(publish, workers=[a.address])) tasks = [c.submit(f, i) for i in range(4)] - results = yield c.gather(tasks) + results = await c.gather(tasks) for r in results: x = r[0] diff --git a/distributed/tests/test_queues.py b/distributed/tests/test_queues.py index d797433d6b4..34009602a15 100644 --- a/distributed/tests/test_queues.py +++ b/distributed/tests/test_queues.py @@ -1,8 +1,7 @@ -from time import sleep import asyncio +from time import sleep import pytest -from tornado import gen from distributed import Client, Queue, Nanny, worker_client, wait, TimeoutError from distributed.metrics import time @@ -11,47 +10,47 @@ @gen_cluster(client=True) -def test_queue(c, s, a, b): - x = yield Queue("x") - y = yield Queue("y") - xx = yield Queue("x") +async def test_queue(c, s, a, b): + x = await Queue("x") + y = await Queue("y") + xx = await Queue("x") assert x.client is c future = c.submit(inc, 1) - yield x.put(future) - yield y.put(future) - future2 = yield xx.get() + await x.put(future) + await y.put(future) + future2 = await xx.get() assert future.key == future2.key with pytest.raises(TimeoutError): - yield x.get(timeout=0.1) + await x.get(timeout=0.1) del future, future2 - yield gen.sleep(0.1) + await asyncio.sleep(0.1) assert s.tasks # future still present in y's queue - yield y.get() # burn future + await y.get() # burn future start = time() while s.tasks: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 5 @gen_cluster(client=True) -def test_queue_with_data(c, s, a, b): - x = yield Queue("x") - xx = yield Queue("x") +async def test_queue_with_data(c, s, a, b): + x = await Queue("x") + xx = await Queue("x") assert x.client is c - yield x.put((1, "hello")) - data = yield xx.get() + await x.put((1, "hello")) + data = await xx.get() assert data == (1, "hello") with pytest.raises(TimeoutError): - yield x.get(timeout=0.1) + await x.get(timeout=0.1) def test_sync(client): @@ -67,35 +66,35 @@ def test_sync(client): @gen_cluster() -def test_hold_futures(s, a, b): - c1 = yield Client(s.address, asynchronous=True) +async def test_hold_futures(s, a, b): + c1 = await Client(s.address, asynchronous=True) future = c1.submit(lambda x: x + 1, 10) - q1 = yield Queue("q") - yield q1.put(future) + q1 = await Queue("q") + await q1.put(future) del q1 - yield c1.close() + await c1.close() - yield gen.sleep(0.1) + await asyncio.sleep(0.1) - c2 = yield Client(s.address, asynchronous=True) - q2 = yield Queue("q") - future2 = yield q2.get() - result = yield future2 + c2 = await Client(s.address, asynchronous=True) + q2 = await Queue("q") + future2 = await q2.get() + result = await future2 assert result == 11 - yield c2.close() + await c2.close() @pytest.mark.skip(reason="getting same client from main thread") @gen_cluster(client=True) -def test_picklability(c, s, a, b): +async def test_picklability(c, s, a, b): q = Queue() def f(x): q.put(x + 1) - yield c.submit(f, 10) - result = yield q.get() + await c.submit(f, 10) + result = await q.get() assert result == 11 @@ -112,7 +111,7 @@ def f(x): @pytest.mark.slow @gen_cluster(client=True, nthreads=[("127.0.0.1", 2)] * 5, Worker=Nanny, timeout=None) -def test_race(c, s, *workers): +async def test_race(c, s, *workers): def f(i): with worker_client() as c: q = Queue("x", client=c) @@ -126,144 +125,144 @@ def f(i): return result q = Queue("x", client=c) - L = yield c.scatter(range(5)) + L = await c.scatter(range(5)) for future in L: - yield q.put(future) + await q.put(future) futures = c.map(f, range(5)) - results = yield c.gather(futures) + results = await c.gather(futures) assert all(r > 50 for r in results) assert sum(results) == 510 - qsize = yield q.qsize() + qsize = await q.qsize() assert not qsize @gen_cluster(client=True) -def test_same_futures(c, s, a, b): +async def test_same_futures(c, s, a, b): q = Queue("x") - future = yield c.scatter(123) + future = await c.scatter(123) for i in range(5): - yield q.put(future) + await q.put(future) assert s.wants_what["queue-x"] == {future.key} for i in range(4): - future2 = yield q.get() + future2 = await q.get() assert s.wants_what["queue-x"] == {future.key} - yield gen.sleep(0.05) + await asyncio.sleep(0.05) assert s.wants_what["queue-x"] == {future.key} - yield q.get() + await q.get() start = time() while s.wants_what["queue-x"]: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() - start < 2 @gen_cluster(client=True) -def test_get_many(c, s, a, b): - x = yield Queue("x") - xx = yield Queue("x") +async def test_get_many(c, s, a, b): + x = await Queue("x") + xx = await Queue("x") - yield x.put(1) - yield x.put(2) - yield x.put(3) + await x.put(1) + await x.put(2) + await x.put(3) - data = yield xx.get(batch=True) + data = await xx.get(batch=True) assert data == [1, 2, 3] - yield x.put(1) - yield x.put(2) - yield x.put(3) + await x.put(1) + await x.put(2) + await x.put(3) - data = yield xx.get(batch=2) + data = await xx.get(batch=2) assert data == [1, 2] with pytest.raises(TimeoutError): - data = yield asyncio.wait_for(xx.get(batch=2), 0.1) + await asyncio.wait_for(xx.get(batch=2), 0.1) @gen_cluster(client=True) -def test_Future_knows_status_immediately(c, s, a, b): - x = yield c.scatter(123) - q = yield Queue("q") - yield q.put(x) - - c2 = yield Client(s.address, asynchronous=True) - q2 = yield Queue("q", client=c2) - future = yield q2.get() +async def test_Future_knows_status_immediately(c, s, a, b): + x = await c.scatter(123) + q = await Queue("q") + await q.put(x) + + c2 = await Client(s.address, asynchronous=True) + q2 = await Queue("q", client=c2) + future = await q2.get() assert future.status == "finished" x = c.submit(div, 1, 0) - yield wait(x) - yield q.put(x) + await wait(x) + await q.put(x) - future2 = yield q2.get() + future2 = await q2.get() assert future2.status == "error" with pytest.raises(Exception): - yield future2 + await future2 start = time() while True: # we learn about the true error eventually try: - yield future2 + await future2 except ZeroDivisionError: break except Exception: assert time() < start + 5 - yield gen.sleep(0.05) + await asyncio.sleep(0.05) - yield c2.close() + await c2.close() @gen_cluster(client=True) -def test_erred_future(c, s, a, b): +async def test_erred_future(c, s, a, b): future = c.submit(div, 1, 0) - q = yield Queue() - yield q.put(future) - yield gen.sleep(0.1) - future2 = yield q.get() + q = await Queue() + await q.put(future) + await asyncio.sleep(0.1) + future2 = await q.get() with pytest.raises(ZeroDivisionError): - yield future2.result() + await future2.result() - exc = yield future2.exception() + exc = await future2.exception() assert isinstance(exc, ZeroDivisionError) @gen_cluster(client=True) -def test_close(c, s, a, b): - q = yield Queue() +async def test_close(c, s, a, b): + q = await Queue() q.close() q.close() while q.name in s.extensions["queues"].queues: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) @gen_cluster(client=True) -def test_timeout(c, s, a, b): - q = yield Queue("v", maxsize=1) +async def test_timeout(c, s, a, b): + q = await Queue("v", maxsize=1) start = time() with pytest.raises(TimeoutError): - yield q.get(timeout=0.3) + await q.get(timeout=0.3) stop = time() assert 0.2 < stop - start < 2.0 - yield q.put(1) + await q.put(1) start = time() with pytest.raises(TimeoutError): - yield q.put(2, timeout=0.3) + await q.put(2, timeout=0.3) stop = time() assert 0.1 < stop - start < 2.0 @gen_cluster(client=True) -def test_2220(c, s, a, b): +async def test_2220(c, s, a, b): q = Queue() def put(): @@ -275,4 +274,4 @@ def get(): fut = c.submit(put) res = c.submit(get) - yield [res, fut] + await c.gather([res, fut]) diff --git a/distributed/tests/test_resources.py b/distributed/tests/test_resources.py index 648a191224e..870b930fdae 100644 --- a/distributed/tests/test_resources.py +++ b/distributed/tests/test_resources.py @@ -1,8 +1,8 @@ +import asyncio from time import time from dask import delayed import pytest -from tornado import gen from distributed import Worker from distributed.client import wait @@ -13,24 +13,23 @@ @gen_cluster(client=True, nthreads=[]) -def test_resources(c, s): +async def test_resources(c, s): assert not s.worker_resources assert not s.resources a = Worker(s.address, loop=s.loop, resources={"GPU": 2}) b = Worker(s.address, loop=s.loop, resources={"GPU": 1, "DB": 1}) - - yield [a, b] + await asyncio.gather(a, b) assert s.resources == {"GPU": {a.address: 2, b.address: 1}, "DB": {b.address: 1}} assert s.worker_resources == {a.address: {"GPU": 2}, b.address: {"GPU": 1, "DB": 1}} - yield b.close() + await b.close() assert s.resources == {"GPU": {a.address: 2}, "DB": {}} assert s.worker_resources == {a.address: {"GPU": 2}} - yield a.close() + await a.close() @gen_cluster( @@ -40,25 +39,25 @@ def test_resources(c, s): ("127.0.0.1", 1, {"resources": {"A": 1, "B": 1}}), ], ) -def test_resource_submit(c, s, a, b): +async def test_resource_submit(c, s, a, b): x = c.submit(inc, 1, resources={"A": 3}) y = c.submit(inc, 2, resources={"B": 1}) z = c.submit(inc, 3, resources={"C": 2}) - yield wait(x) + await wait(x) assert x.key in a.data - yield wait(y) + await wait(y) assert y.key in b.data assert s.get_task_status(keys=[z.key]) == {z.key: "no-worker"} - d = yield Worker(s.address, loop=s.loop, resources={"C": 10}) + d = await Worker(s.address, loop=s.loop, resources={"C": 10}) - yield wait(z) + await wait(z) assert z.key in d.data - yield d.close() + await d.close() @gen_cluster( @@ -68,9 +67,9 @@ def test_resource_submit(c, s, a, b): ("127.0.0.1", 1, {"resources": {"B": 1}}), ], ) -def test_submit_many_non_overlapping(c, s, a, b): +async def test_submit_many_non_overlapping(c, s, a, b): futures = [c.submit(inc, i, resources={"A": 1}) for i in range(5)] - yield wait(futures) + await wait(futures) assert len(a.data) == 5 assert len(b.data) == 0 @@ -83,12 +82,12 @@ def test_submit_many_non_overlapping(c, s, a, b): ("127.0.0.1", 1, {"resources": {"B": 1}}), ], ) -def test_move(c, s, a, b): - [x] = yield c._scatter([1], workers=b.address) +async def test_move(c, s, a, b): + [x] = await c._scatter([1], workers=b.address) future = c.submit(inc, x, resources={"A": 1}) - yield wait(future) + await wait(future) assert a.data[future.key] == 2 @@ -99,14 +98,14 @@ def test_move(c, s, a, b): ("127.0.0.1", 1, {"resources": {"B": 1}}), ], ) -def test_dont_work_steal(c, s, a, b): - [x] = yield c._scatter([1], workers=a.address) +async def test_dont_work_steal(c, s, a, b): + [x] = await c._scatter([1], workers=a.address) futures = [ c.submit(slowadd, x, i, resources={"A": 1}, delay=0.05) for i in range(10) ] - yield wait(futures) + await wait(futures) assert all(f.key in a.data for f in futures) @@ -117,9 +116,9 @@ def test_dont_work_steal(c, s, a, b): ("127.0.0.1", 1, {"resources": {"B": 1}}), ], ) -def test_map(c, s, a, b): +async def test_map(c, s, a, b): futures = c.map(inc, range(10), resources={"B": 1}) - yield wait(futures) + await wait(futures) assert set(b.data) == {f.key for f in futures} assert not a.data @@ -131,13 +130,13 @@ def test_map(c, s, a, b): ("127.0.0.1", 1, {"resources": {"B": 1}}), ], ) -def test_persist(c, s, a, b): +async def test_persist(c, s, a, b): x = delayed(inc)(1) y = delayed(inc)(x) xx, yy = c.persist([x, y], resources={x: {"A": 1}, y: {"B": 1}}) - yield wait([xx, yy]) + await wait([xx, yy]) assert x.key in a.data assert y.key in b.data @@ -150,18 +149,18 @@ def test_persist(c, s, a, b): ("127.0.0.1", 1, {"resources": {"B": 11}}), ], ) -def test_compute(c, s, a, b): +async def test_compute(c, s, a, b): x = delayed(inc)(1) y = delayed(inc)(x) yy = c.compute(y, resources={x: {"A": 1}, y: {"B": 1}}) - yield wait(yy) + await wait(yy) assert b.data xs = [delayed(inc)(i) for i in range(10, 20)] xxs = c.compute(xs, resources={"B": 1}) - yield wait(xxs) + await wait(xxs) assert len(b.data) > 10 @@ -173,10 +172,10 @@ def test_compute(c, s, a, b): ("127.0.0.1", 1, {"resources": {"B": 1}}), ], ) -def test_get(c, s, a, b): +async def test_get(c, s, a, b): dsk = {"x": (inc, 1), "y": (inc, "x")} - result = yield c.get(dsk, "y", resources={"y": {"A": 1}}, sync=False) + result = await c.get(dsk, "y", resources={"y": {"A": 1}}, sync=False) assert result == 3 @@ -187,13 +186,13 @@ def test_get(c, s, a, b): ("127.0.0.1", 1, {"resources": {"B": 1}}), ], ) -def test_persist_tuple(c, s, a, b): +async def test_persist_tuple(c, s, a, b): x = delayed(inc)(1) y = delayed(inc)(x) xx, yy = c.persist([x, y], resources={(x, y): {"A": 1}}) - yield wait([xx, yy]) + await wait([xx, yy]) assert x.key in a.data assert y.key in a.data @@ -201,16 +200,16 @@ def test_persist_tuple(c, s, a, b): @gen_cluster(client=True) -def test_resources_str(c, s, a, b): +async def test_resources_str(c, s, a, b): pd = pytest.importorskip("pandas") dd = pytest.importorskip("dask.dataframe") - yield a.set_resources(MyRes=1) + await a.set_resources(MyRes=1) x = dd.from_pandas(pd.DataFrame({"A": [1, 2], "B": [3, 4]}), npartitions=1) y = x.apply(lambda row: row.sum(), axis=1, meta=(None, "int64")) yy = y.persist(resources={"MyRes": 1}) - yield wait(yy) + await wait(yy) ts_first = s.tasks[tokey(y.__dask_keys__()[0])] assert ts_first.resource_restrictions == {"MyRes": 1} @@ -225,38 +224,38 @@ def test_resources_str(c, s, a, b): ("127.0.0.1", 4, {"resources": {"A": 1}}), ], ) -def test_submit_many_non_overlapping(c, s, a, b): +async def test_submit_many_non_overlapping(c, s, a, b): futures = c.map(slowinc, range(100), resources={"A": 1}, delay=0.02) while len(a.data) + len(b.data) < 100: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert len(a.executing) <= 2 assert len(b.executing) <= 1 - yield wait(futures) + await wait(futures) assert a.total_resources == a.available_resources assert b.total_resources == b.available_resources @gen_cluster(client=True, nthreads=[("127.0.0.1", 4, {"resources": {"A": 2, "B": 1}})]) -def test_minimum_resource(c, s, a): +async def test_minimum_resource(c, s, a): futures = c.map(slowinc, range(30), resources={"A": 1, "B": 1}, delay=0.02) while len(a.data) < 30: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert len(a.executing) <= 1 - yield wait(futures) + await wait(futures) assert a.total_resources == a.available_resources @gen_cluster(client=True, nthreads=[("127.0.0.1", 2, {"resources": {"A": 1}})]) -def test_prefer_constrained(c, s, a): +async def test_prefer_constrained(c, s, a): futures = c.map(slowinc, range(1000), delay=0.1) constrained = c.map(inc, range(10), resources={"A": 1}) start = time() - yield wait(constrained) + await wait(constrained) end = time() assert end - start < 4 has_what = dict(s.has_what) @@ -273,27 +272,27 @@ def test_prefer_constrained(c, s, a): ("127.0.0.1", 2, {"resources": {"A": 1}}), ], ) -def test_balance_resources(c, s, a, b): +async def test_balance_resources(c, s, a, b): futures = c.map(slowinc, range(100), delay=0.1, workers=a.address) constrained = c.map(inc, range(2), resources={"A": 1}) - yield wait(constrained) + await wait(constrained) assert any(f.key in a.data for f in constrained) # share assert any(f.key in b.data for f in constrained) @gen_cluster(client=True, nthreads=[("127.0.0.1", 2)]) -def test_set_resources(c, s, a): - yield a.set_resources(A=2) +async def test_set_resources(c, s, a): + await a.set_resources(A=2) assert a.total_resources["A"] == 2 assert a.available_resources["A"] == 2 assert s.worker_resources[a.address] == {"A": 2} future = c.submit(slowinc, 1, delay=1, resources={"A": 1}) while a.available_resources["A"] == 2: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) - yield a.set_resources(A=3) + await a.set_resources(A=3) assert a.total_resources["A"] == 3 assert a.available_resources["A"] == 2 assert s.worker_resources[a.address] == {"A": 3} @@ -306,7 +305,7 @@ def test_set_resources(c, s, a): ("127.0.0.1", 1, {"resources": {"B": 1}}), ], ) -def test_persist_collections(c, s, a, b): +async def test_persist_collections(c, s, a, b): da = pytest.importorskip("dask.array") x = da.arange(10, chunks=(5,)) y = x.map_blocks(lambda x: x + 1) @@ -315,7 +314,7 @@ def test_persist_collections(c, s, a, b): ww, yy = c.persist([w, y], resources={tuple(y.__dask_keys__()): {"A": 1}}) - yield wait([ww, yy]) + await wait([ww, yy]) assert all(tokey(key) in a.data for key in y.__dask_keys__()) @@ -328,14 +327,14 @@ def test_persist_collections(c, s, a, b): ("127.0.0.1", 1, {"resources": {"B": 1}}), ], ) -def test_dont_optimize_out(c, s, a, b): +async def test_dont_optimize_out(c, s, a, b): da = pytest.importorskip("dask.array") x = da.arange(10, chunks=(5,)) y = x.map_blocks(lambda x: x + 1) z = y.map_blocks(lambda x: 2 * x) w = z.sum() - yield c.compute(w, resources={tuple(y.__dask_keys__()): {"A": 1}}) + await c.compute(w, resources={tuple(y.__dask_keys__()): {"A": 1}}) for key in map(tokey, y.__dask_keys__()): assert "executing" in str(a.story(key)) @@ -349,14 +348,14 @@ def test_dont_optimize_out(c, s, a, b): ("127.0.0.1", 1, {"resources": {"B": 1}}), ], ) -def test_full_collections(c, s, a, b): +async def test_full_collections(c, s, a, b): dd = pytest.importorskip("dask.dataframe") df = dd.demo.make_timeseries( freq="60s", partition_freq="1d", start="2000-01-01", end="2000-01-31" ) z = df.x + df.y # some extra nodes in the graph - yield c.compute(z, resources={tuple(z.dask): {"A": 1}}) + await c.compute(z, resources={tuple(z.dask): {"A": 1}}) assert a.log assert not b.log diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index d64c88fceea..5ed8e4e542d 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -1,18 +1,17 @@ import asyncio -import cloudpickle -import pickle -from collections import defaultdict import json +import logging +import pickle import operator import re import sys +from collections import defaultdict from time import sleep -import logging +import cloudpickle import dask from dask import delayed from tlz import merge, concat, valmap, first, frequencies -from tornado import gen import pytest @@ -50,7 +49,7 @@ @gen_cluster() -def test_administration(s, a, b): +async def test_administration(s, a, b): assert isinstance(s.address, str) assert s.address in str(s) assert str(sum(s.nthreads.values())) in repr(s) @@ -58,11 +57,11 @@ def test_administration(s, a, b): @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)]) -def test_respect_data_in_memory(c, s, a): +async def test_respect_data_in_memory(c, s, a): x = delayed(inc)(1) y = delayed(inc)(x) f = c.persist(y) - yield wait([f]) + await wait([f]) assert s.tasks[y.key].who_has == {s.workers[a.address]} @@ -70,37 +69,37 @@ def test_respect_data_in_memory(c, s, a): f2 = c.persist(z) while f2.key not in s.tasks or not s.tasks[f2.key]: assert s.tasks[y.key].who_has - yield gen.sleep(0.0001) + await asyncio.sleep(0.0001) @gen_cluster(client=True) -def test_recompute_released_results(c, s, a, b): +async def test_recompute_released_results(c, s, a, b): x = delayed(inc)(1) y = delayed(inc)(x) yy = c.persist(y) - yield wait(yy) + await wait(yy) while s.tasks[x.key].who_has or x.key in a.data or x.key in b.data: # let x go away - yield gen.sleep(0.01) + await asyncio.sleep(0.01) z = delayed(dec)(x) zz = c.compute(z) - result = yield zz + result = await zz assert result == 1 @gen_cluster(client=True) -def test_decide_worker_with_many_independent_leaves(c, s, a, b): - xs = yield [ +async def test_decide_worker_with_many_independent_leaves(c, s, a, b): + xs = await asyncio.gather( c.scatter(list(range(0, 100, 2)), workers=a.address), c.scatter(list(range(1, 100, 2)), workers=b.address), - ] + ) xs = list(concat(zip(*xs))) ys = [delayed(inc)(x) for x in xs] y2s = c.persist(ys) - yield wait(y2s) + await wait(y2s) nhits = sum(y.key in a.data for y in y2s[::2]) + sum( y.key in b.data for y in y2s[1::2] @@ -110,71 +109,70 @@ def test_decide_worker_with_many_independent_leaves(c, s, a, b): @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3) -def test_decide_worker_with_restrictions(client, s, a, b, c): +async def test_decide_worker_with_restrictions(client, s, a, b, c): x = client.submit(inc, 1, workers=[a.address, b.address]) - yield x + await x assert x.key in a.data or x.key in b.data @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3) -def test_move_data_over_break_restrictions(client, s, a, b, c): - [x] = yield client.scatter([1], workers=b.address) +async def test_move_data_over_break_restrictions(client, s, a, b, c): + [x] = await client.scatter([1], workers=b.address) y = client.submit(inc, x, workers=[a.address, b.address]) - yield wait(y) + await wait(y) assert y.key in a.data or y.key in b.data @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3) -def test_balance_with_restrictions(client, s, a, b, c): - [x], [y] = yield [ +async def test_balance_with_restrictions(client, s, a, b, c): + [x], [y] = await asyncio.gather( client.scatter([[1, 2, 3]], workers=a.address), client.scatter([1], workers=c.address), - ] + ) z = client.submit(inc, 1, workers=[a.address, c.address]) - yield wait(z) + await wait(z) assert s.tasks[z.key].who_has == {s.workers[c.address]} @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3) -def test_no_valid_workers(client, s, a, b, c): +async def test_no_valid_workers(client, s, a, b, c): x = client.submit(inc, 1, workers="127.0.0.5:9999") while not s.tasks: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert s.tasks[x.key] in s.unrunnable with pytest.raises(TimeoutError): - yield asyncio.wait_for(x, 0.05) + await asyncio.wait_for(x, 0.05) @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3) -def test_no_valid_workers_loose_restrictions(client, s, a, b, c): +async def test_no_valid_workers_loose_restrictions(client, s, a, b, c): x = client.submit(inc, 1, workers="127.0.0.5:9999", allow_other_workers=True) - - result = yield x + result = await x assert result == 2 @gen_cluster(client=True, nthreads=[]) -def test_no_workers(client, s): +async def test_no_workers(client, s): x = client.submit(inc, 1) while not s.tasks: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert s.tasks[x.key] in s.unrunnable with pytest.raises(TimeoutError): - yield asyncio.wait_for(x, 0.05) + await asyncio.wait_for(x, 0.05) @gen_cluster(nthreads=[]) -def test_retire_workers_empty(s): - yield s.retire_workers(workers=[]) +async def test_retire_workers_empty(s): + await s.retire_workers(workers=[]) @gen_cluster() -def test_remove_client(s, a, b): +async def test_remove_client(s, a, b): s.update_graph( tasks={"x": dumps_task((inc, 1)), "y": dumps_task((inc, "x"))}, dependencies={"x": [], "y": ["x"]}, @@ -192,15 +190,15 @@ def test_remove_client(s, a, b): @gen_cluster() -def test_server_listens_to_other_ops(s, a, b): +async def test_server_listens_to_other_ops(s, a, b): with rpc(s.address) as r: - ident = yield r.identity() + ident = await r.identity() assert ident["type"] == "Scheduler" assert ident["id"].lower().startswith("scheduler") @gen_cluster() -def test_remove_worker_from_scheduler(s, a, b): +async def test_remove_worker_from_scheduler(s, a, b): dsk = {("x-%d" % i): (inc, i) for i in range(20)} s.update_graph( tasks=valmap(dumps_task, dsk), @@ -216,7 +214,7 @@ def test_remove_worker_from_scheduler(s, a, b): @gen_cluster() -def test_remove_worker_by_name_from_scheduler(s, a, b): +async def test_remove_worker_by_name_from_scheduler(s, a, b): assert a.address in s.stream_comms assert s.remove_worker(address=a.name) == "OK" assert a.address not in s.nthreads @@ -225,7 +223,7 @@ def test_remove_worker_by_name_from_scheduler(s, a, b): @gen_cluster(config={"distributed.scheduler.events-cleanup-delay": "10 ms"}) -def test_clear_events_worker_removal(s, a, b): +async def test_clear_events_worker_removal(s, a, b): assert a.address in s.events assert a.address in s.nthreads assert b.address in s.events @@ -239,7 +237,7 @@ def test_clear_events_worker_removal(s, a, b): start = time() while a.address in s.events: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 2 assert b.address in s.events @@ -247,7 +245,7 @@ def test_clear_events_worker_removal(s, a, b): @gen_cluster( config={"distributed.scheduler.events-cleanup-delay": "10 ms"}, client=True ) -def test_clear_events_client_removal(c, s, a, b): +async def test_clear_events_client_removal(c, s, a, b): assert c.id in s.events s.remove_client(c.id) @@ -259,12 +257,12 @@ def test_clear_events_client_removal(c, s, a, b): # If it doesn't reconnect after a given time, the events log should be cleared start = time() while c.id in s.events: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 2 @gen_cluster() -def test_add_worker(s, a, b): +async def test_add_worker(s, a, b): w = Worker(s.address, nthreads=3) w.data["x-5"] = 6 w.data["y"] = 1 @@ -277,23 +275,23 @@ def test_add_worker(s, a, b): dependencies={k: set() for k in dsk}, ) s.validate_state() - yield w + await w s.validate_state() assert w.ip in s.host_info assert s.host_info[w.ip]["addresses"] == {a.address, b.address, w.address} - yield w.close() + await w.close() @gen_cluster(scheduler_kwargs={"blocked_handlers": ["feed"]}) -def test_blocked_handlers_are_respected(s, a, b): +async def test_blocked_handlers_are_respected(s, a, b): def func(scheduler): return dumps(dict(scheduler.worker_info)) - comm = yield connect(s.address) - yield comm.write({"op": "feed", "function": dumps(func), "interval": 0.01}) + comm = await connect(s.address) + await comm.write({"op": "feed", "function": dumps(func), "interval": 0.01}) - response = yield comm.read() + response = await comm.read() assert "exception" in response assert isinstance(response["exception"], ValueError) @@ -301,7 +299,7 @@ def func(scheduler): response["exception"] ) - yield comm.close() + await comm.close() def test_scheduler_init_pulls_blocked_handlers_from_config(): @@ -311,23 +309,23 @@ def test_scheduler_init_pulls_blocked_handlers_from_config(): @gen_cluster() -def test_feed(s, a, b): +async def test_feed(s, a, b): def func(scheduler): return dumps(dict(scheduler.worker_info)) - comm = yield connect(s.address) - yield comm.write({"op": "feed", "function": dumps(func), "interval": 0.01}) + comm = await connect(s.address) + await comm.write({"op": "feed", "function": dumps(func), "interval": 0.01}) for i in range(5): - response = yield comm.read() + response = await comm.read() expected = dict(s.worker_info) assert cloudpickle.loads(response) == expected - yield comm.close() + await comm.close() @gen_cluster() -def test_feed_setup_teardown(s, a, b): +async def test_feed_setup_teardown(s, a, b): def setup(scheduler): return 1 @@ -338,8 +336,8 @@ def func(scheduler, state): def teardown(scheduler, state): scheduler.flag = "done" - comm = yield connect(s.address) - yield comm.write( + comm = await connect(s.address) + await comm.write( { "op": "feed", "function": dumps(func), @@ -350,18 +348,18 @@ def teardown(scheduler, state): ) for i in range(5): - response = yield comm.read() + response = await comm.read() assert response == "OK" - yield comm.close() + await comm.close() start = time() while not hasattr(s, "flag"): - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() - start < 5 @gen_cluster() -def test_feed_large_bytestring(s, a, b): +async def test_feed_large_bytestring(s, a, b): np = pytest.importorskip("numpy") x = np.ones(10000000) @@ -370,19 +368,19 @@ def func(scheduler): y = x return True - comm = yield connect(s.address) - yield comm.write({"op": "feed", "function": dumps(func), "interval": 0.05}) + comm = await connect(s.address) + await comm.write({"op": "feed", "function": dumps(func), "interval": 0.05}) for i in range(5): - response = yield comm.read() + response = await comm.read() assert response is True - yield comm.close() + await comm.close() @gen_cluster(client=True) -def test_delete_data(c, s, a, b): - d = yield c.scatter({"x": 1, "y": 2, "z": 3}) +async def test_delete_data(c, s, a, b): + d = await c.scatter({"x": 1, "y": 2, "z": 3}) assert {ts.key for ts in s.tasks.values() if ts.who_has} == {"x", "y", "z"} assert set(a.data) | set(b.data) == {"x", "y", "z"} @@ -393,36 +391,36 @@ def test_delete_data(c, s, a, b): start = time() while set(a.data) | set(b.data) != {"z"}: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 5 @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)]) -def test_delete(c, s, a): +async def test_delete(c, s, a): x = c.submit(inc, 1) - yield x + await x assert x.key in a.data - yield c._cancel(x) + await c._cancel(x) start = time() while x.key in a.data: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 5 @gen_cluster() -def test_filtered_communication(s, a, b): - c = yield connect(s.address) - f = yield connect(s.address) - yield c.write({"op": "register-client", "client": "c", "versions": {}}) - yield f.write({"op": "register-client", "client": "f", "versions": {}}) - yield c.read() - yield f.read() +async def test_filtered_communication(s, a, b): + c = await connect(s.address) + f = await connect(s.address) + await c.write({"op": "register-client", "client": "c", "versions": {}}) + await f.write({"op": "register-client", "client": "f", "versions": {}}) + await c.read() + await f.read() assert set(s.client_comms) == {"c", "f"} - yield c.write( + await c.write( { "op": "update-graph", "tasks": {"x": dumps_task((inc, 1)), "y": dumps_task((inc, "x"))}, @@ -432,7 +430,7 @@ def test_filtered_communication(s, a, b): } ) - yield f.write( + await f.write( { "op": "update-graph", "tasks": { @@ -444,10 +442,10 @@ def test_filtered_communication(s, a, b): "keys": ["z"], } ) - (msg,) = yield c.read() + (msg,) = await c.read() assert msg["op"] == "key-in-memory" assert msg["key"] == "y" - (msg,) = yield f.read() + (msg,) = await f.read() assert msg["op"] == "key-in-memory" assert msg["key"] == "z" @@ -480,7 +478,7 @@ def test_dumps_task(): @gen_cluster() -def test_ready_remove_worker(s, a, b): +async def test_ready_remove_worker(s, a, b): s.update_graph( tasks={"x-%d" % i: dumps_task((inc, i)) for i in range(20)}, keys=["x-%d" % i for i in range(20)], @@ -497,11 +495,11 @@ def test_ready_remove_worker(s, a, b): @gen_cluster(client=True, Worker=Nanny) -def test_restart(c, s, a, b): +async def test_restart(c, s, a, b): futures = c.map(inc, range(20)) - yield wait(futures) + await wait(futures) - yield s.restart() + await s.restart() assert len(s.workers) == 2 @@ -514,56 +512,56 @@ def test_restart(c, s, a, b): @gen_cluster() -def test_broadcast(s, a, b): - result = yield s.broadcast(msg={"op": "ping"}) +async def test_broadcast(s, a, b): + result = await s.broadcast(msg={"op": "ping"}) assert result == {a.address: b"pong", b.address: b"pong"} - result = yield s.broadcast(msg={"op": "ping"}, workers=[a.address]) + result = await s.broadcast(msg={"op": "ping"}, workers=[a.address]) assert result == {a.address: b"pong"} - result = yield s.broadcast(msg={"op": "ping"}, hosts=[a.ip]) + result = await s.broadcast(msg={"op": "ping"}, hosts=[a.ip]) assert result == {a.address: b"pong", b.address: b"pong"} @gen_cluster(Worker=Nanny) -def test_broadcast_nanny(s, a, b): - result1 = yield s.broadcast(msg={"op": "identity"}, nanny=True) +async def test_broadcast_nanny(s, a, b): + result1 = await s.broadcast(msg={"op": "identity"}, nanny=True) assert all(d["type"] == "Nanny" for d in result1.values()) - result2 = yield s.broadcast( + result2 = await s.broadcast( msg={"op": "identity"}, workers=[a.worker_address], nanny=True ) assert len(result2) == 1 assert first(result2.values())["id"] == a.id - result3 = yield s.broadcast(msg={"op": "identity"}, hosts=[a.ip], nanny=True) + result3 = await s.broadcast(msg={"op": "identity"}, hosts=[a.ip], nanny=True) assert result1 == result3 @gen_test() -def test_worker_name(): - s = yield Scheduler(validate=True, port=0) - w = yield Worker(s.address, name="alice") +async def test_worker_name(): + s = await Scheduler(validate=True, port=0) + w = await Worker(s.address, name="alice") assert s.workers[w.address].name == "alice" assert s.aliases["alice"] == w.address with pytest.raises(ValueError): - w2 = yield Worker(s.address, name="alice") - yield w2.close() + w2 = await Worker(s.address, name="alice") + await w2.close() - yield w.close() - yield s.close() + await w.close() + await s.close() @gen_test() -def test_coerce_address(): +async def test_coerce_address(): with dask.config.set({"distributed.comm.timeouts.connect": "100ms"}): - s = yield Scheduler(validate=True, port=0) + s = await Scheduler(validate=True, port=0) print("scheduler:", s.address, s.listen_address) a = Worker(s.address, name="alice") b = Worker(s.address, name=123) c = Worker("127.0.0.1", s.port, name="charlie") - yield [a, b, c] + await asyncio.gather(a, b, c) assert s.coerce_address("127.0.0.1:8000") == "tcp://127.0.0.1:8000" assert s.coerce_address("[::1]:8000") == "tcp://[::1]:8000" @@ -591,8 +589,8 @@ def test_coerce_address(): assert s.coerce_address("zzzt:8000", resolve=False) == "tcp://zzzt:8000" - yield s.close() - yield [w.close() for w in [a, b, c]] + await s.close() + await asyncio.gather(a.close(), b.close(), c.close()) @pytest.mark.asyncio @@ -612,24 +610,24 @@ async def test_config_stealing(cleanup): sys.platform.startswith("win"), reason="file descriptors not really a thing" ) @gen_cluster(nthreads=[]) -def test_file_descriptors_dont_leak(s): +async def test_file_descriptors_dont_leak(s): psutil = pytest.importorskip("psutil") proc = psutil.Process() before = proc.num_fds() - w = yield Worker(s.address) - yield w.close() + w = await Worker(s.address) + await w.close() during = proc.num_fds() start = time() while proc.num_fds() > before: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 5 @gen_cluster() -def test_update_graph_culls(s, a, b): +async def test_update_graph_culls(s, a, b): s.update_graph( tasks={ "x": dumps_task((inc, 1)), @@ -650,11 +648,11 @@ def test_io_loop(loop): @gen_cluster(client=True) -def test_story(c, s, a, b): +async def test_story(c, s, a, b): x = delayed(inc)(1) y = delayed(inc)(x) f = c.persist(y) - yield wait([f]) + await wait([f]) assert s.transition_log @@ -667,38 +665,38 @@ def test_story(c, s, a, b): @gen_cluster(nthreads=[], client=True) -def test_scatter_no_workers(c, s): +async def test_scatter_no_workers(c, s): with pytest.raises(TimeoutError): - yield s.scatter(data={"x": 1}, client="alice", timeout=0.1) + await s.scatter(data={"x": 1}, client="alice", timeout=0.1) start = time() with pytest.raises(TimeoutError): - yield c.scatter(123, timeout=0.1) + await c.scatter(123, timeout=0.1) assert time() < start + 1.5 w = Worker(s.address, nthreads=3) - yield [c.scatter(data={"y": 2}, timeout=5), w] + await asyncio.gather(c.scatter(data={"y": 2}, timeout=5), w) assert w.data["y"] == 2 - yield w.close() + await w.close() @gen_cluster(nthreads=[]) -def test_scheduler_sees_memory_limits(s): - w = yield Worker(s.address, nthreads=3, memory_limit=12345) +async def test_scheduler_sees_memory_limits(s): + w = await Worker(s.address, nthreads=3, memory_limit=12345) assert s.workers[w.address].memory_limit == 12345 - yield w.close() + await w.close() @gen_cluster(client=True, timeout=1000) -def test_retire_workers(c, s, a, b): - [x] = yield c.scatter([1], workers=a.address) - [y] = yield c.scatter([list(range(1000))], workers=b.address) +async def test_retire_workers(c, s, a, b): + [x] = await c.scatter([1], workers=a.address) + [y] = await c.scatter([list(range(1000))], workers=b.address) assert s.workers_to_close() == [a.address] - workers = yield s.retire_workers() + workers = await s.retire_workers() assert list(workers) == [a.address] assert workers[a.address]["nthreads"] == a.nthreads assert list(s.nthreads) == [b.address] @@ -707,26 +705,26 @@ def test_retire_workers(c, s, a, b): assert s.workers[b.address].has_what == {s.tasks[x.key], s.tasks[y.key]} - workers = yield s.retire_workers() + workers = await s.retire_workers() assert not workers @gen_cluster(client=True) -def test_retire_workers_n(c, s, a, b): - yield s.retire_workers(n=1, close_workers=True) +async def test_retire_workers_n(c, s, a, b): + await s.retire_workers(n=1, close_workers=True) assert len(s.workers) == 1 - yield s.retire_workers(n=0, close_workers=True) + await s.retire_workers(n=0, close_workers=True) assert len(s.workers) == 1 - yield s.retire_workers(n=1, close_workers=True) + await s.retire_workers(n=1, close_workers=True) assert len(s.workers) == 0 - yield s.retire_workers(n=0, close_workers=True) + await s.retire_workers(n=0, close_workers=True) assert len(s.workers) == 0 while not (a.status.startswith("clos") and b.status.startswith("clos")): - yield gen.sleep(0.01) + await asyncio.sleep(0.01) @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 4) @@ -736,7 +734,7 @@ async def test_workers_to_close(cl, s, *workers): ): futures = cl.map(slowinc, [1, 1, 1], key=["a-4", "b-4", "c-1"]) while sum(len(w.processing) for w in s.workers.values()) < 3: - await gen.sleep(0.001) + await asyncio.sleep(0.001) wtc = s.workers_to_close() assert all(not s.workers[w].processing for w in wtc) @@ -744,7 +742,7 @@ async def test_workers_to_close(cl, s, *workers): @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 4) -def test_workers_to_close_grouped(c, s, *workers): +async def test_workers_to_close_grouped(c, s, *workers): groups = { workers[0].address: "a", workers[1].address: "a", @@ -760,30 +758,30 @@ def key(ws): # Assert that job in one worker blocks closure of group future = c.submit(slowinc, 1, delay=0.2, workers=workers[0].address) while len(s.rprocessing) < 1: - yield gen.sleep(0.001) + await asyncio.sleep(0.001) assert set(s.workers_to_close(key=key)) == {workers[2].address, workers[3].address} del future while len(s.rprocessing) > 0: - yield gen.sleep(0.001) + await asyncio.sleep(0.001) # Assert that *total* byte count in group determines group priority - av = yield c.scatter("a" * 100, workers=workers[0].address) - bv = yield c.scatter("b" * 75, workers=workers[2].address) - bv2 = yield c.scatter("b" * 75, workers=workers[3].address) + av = await c.scatter("a" * 100, workers=workers[0].address) + bv = await c.scatter("b" * 75, workers=workers[2].address) + bv2 = await c.scatter("b" * 75, workers=workers[3].address) assert set(s.workers_to_close(key=key)) == {workers[0].address, workers[1].address} @gen_cluster(client=True) -def test_retire_workers_no_suspicious_tasks(c, s, a, b): +async def test_retire_workers_no_suspicious_tasks(c, s, a, b): future = c.submit( slowinc, 100, delay=0.5, workers=a.address, allow_other_workers=True ) - yield gen.sleep(0.2) - yield s.retire_workers(workers=[a.address]) + await asyncio.sleep(0.2) + await s.retire_workers(workers=[a.address]) assert all(ts.suspicious == 0 for ts in s.tasks.values()) assert all(tp.suspicious == 0 for tp in s.task_prefixes.values()) @@ -793,48 +791,47 @@ def test_retire_workers_no_suspicious_tasks(c, s, a, b): @pytest.mark.skipif( sys.platform.startswith("win"), reason="file descriptors not really a thing" ) -@pytest.mark.skipif(sys.version_info < (3, 6), reason="intermittent failure") @gen_cluster(client=True, nthreads=[], timeout=240) -def test_file_descriptors(c, s): - yield gen.sleep(0.1) +async def test_file_descriptors(c, s): + await asyncio.sleep(0.1) psutil = pytest.importorskip("psutil") da = pytest.importorskip("dask.array") proc = psutil.Process() num_fds_1 = proc.num_fds() N = 20 - nannies = yield [Nanny(s.address, loop=s.loop) for i in range(N)] + nannies = await asyncio.gather(*[Nanny(s.address, loop=s.loop) for _ in range(N)]) while len(s.nthreads) < N: - yield gen.sleep(0.1) + await asyncio.sleep(0.1) num_fds_2 = proc.num_fds() - yield gen.sleep(0.2) + await asyncio.sleep(0.2) num_fds_3 = proc.num_fds() assert num_fds_3 <= num_fds_2 + N # add some heartbeats x = da.random.random(size=(1000, 1000), chunks=(25, 25)) x = c.persist(x) - yield wait(x) + await wait(x) num_fds_4 = proc.num_fds() assert num_fds_4 <= num_fds_2 + 2 * N y = c.persist(x + x.T) - yield wait(y) + await wait(y) num_fds_5 = proc.num_fds() assert num_fds_5 < num_fds_4 + N - yield gen.sleep(1) + await asyncio.sleep(1) num_fds_6 = proc.num_fds() assert num_fds_6 < num_fds_5 + N - yield [n.close() for n in nannies] - yield c.close() + await asyncio.gather(*[n.close() for n in nannies]) + await c.close() assert not s.rpc.open for addr, occ in c.rpc.occupied.items(): @@ -844,17 +841,17 @@ def test_file_descriptors(c, s): start = time() while proc.num_fds() > num_fds_1 + N: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 3 @pytest.mark.slow @nodebug @gen_cluster(client=True) -def test_learn_occupancy(c, s, a, b): +async def test_learn_occupancy(c, s, a, b): futures = c.map(slowinc, range(1000), delay=0.2) while sum(len(ts.who_has) for ts in s.tasks.values()) < 10: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert 100 < s.total_occupancy < 1000 for w in [a, b]: @@ -864,23 +861,23 @@ def test_learn_occupancy(c, s, a, b): @pytest.mark.slow @nodebug @gen_cluster(client=True) -def test_learn_occupancy_2(c, s, a, b): +async def test_learn_occupancy_2(c, s, a, b): future = c.map(slowinc, range(1000), delay=0.2) while not any(ts.who_has for ts in s.tasks.values()): - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert 100 < s.total_occupancy < 1000 @gen_cluster(client=True) -def test_occupancy_cleardown(c, s, a, b): +async def test_occupancy_cleardown(c, s, a, b): s.validate = False # Inject excess values in s.occupancy s.workers[a.address].occupancy = 2 s.total_occupancy += 2 futures = c.map(slowinc, range(100), delay=0.01) - yield wait(futures) + await wait(futures) # Verify that occupancy values have been zeroed out assert abs(s.total_occupancy) < 0.01 @@ -889,28 +886,28 @@ def test_occupancy_cleardown(c, s, a, b): @nodebug @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 30) -def test_balance_many_workers(c, s, *workers): +async def test_balance_many_workers(c, s, *workers): futures = c.map(slowinc, range(20), delay=0.2) - yield wait(futures) + await wait(futures) assert {len(w.has_what) for w in s.workers.values()} == {0, 1} @nodebug @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 30) -def test_balance_many_workers_2(c, s, *workers): +async def test_balance_many_workers_2(c, s, *workers): s.extensions["stealing"]._pc.callback_time = 100000000 futures = c.map(slowinc, range(90), delay=0.2) - yield wait(futures) + await wait(futures) assert {len(w.has_what) for w in s.workers.values()} == {3} @gen_cluster(client=True) -def test_learn_occupancy_multiple_workers(c, s, a, b): +async def test_learn_occupancy_multiple_workers(c, s, a, b): x = c.submit(slowinc, 1, delay=0.2, workers=a.address) - yield gen.sleep(0.05) + await asyncio.sleep(0.05) futures = c.map(slowinc, range(100), delay=0.2) - yield wait(x) + await wait(x) assert not any(v == 0.5 for w in s.workers.values() for v in w.processing.values()) s.validate_state() @@ -934,7 +931,7 @@ async def test_include_communication_in_occupancy(c, s, a, b): @gen_cluster(client=True) -def test_worker_arrives_with_processing_data(c, s, a, b): +async def test_worker_arrives_with_processing_data(c, s, a, b): x = delayed(slowinc)(1, delay=0.4) y = delayed(slowinc)(x, delay=0.4) z = delayed(slowinc)(y, delay=0.4) @@ -942,17 +939,17 @@ def test_worker_arrives_with_processing_data(c, s, a, b): yy, zz = c.persist([y, z]) while not any(w.processing for w in s.workers.values()): - yield gen.sleep(0.01) + await asyncio.sleep(0.01) w = Worker(s.address, nthreads=1) w.put_key_in_memory(y.key, 3) - yield w + await w start = time() while len(s.workers) < 3: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert s.get_task_status(keys={x.key, y.key, z.key}) == { x.key: "released", @@ -960,23 +957,23 @@ def test_worker_arrives_with_processing_data(c, s, a, b): z.key: "processing", } - yield w.close() + await w.close() @pytest.mark.slow @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)]) -def test_worker_breaks_and_returns(c, s, a): +async def test_worker_breaks_and_returns(c, s, a): future = c.submit(slowinc, 1, delay=0.1) for i in range(20): future = c.submit(slowinc, future, delay=0.1) - yield wait(future) + await wait(future) - yield a.batched_stream.comm.close() + await a.batched_stream.comm.close() - yield gen.sleep(0.1) + await asyncio.sleep(0.1) start = time() - yield wait(future, timeout=10) + await wait(future, timeout=10) end = time() assert end - start < 2 @@ -986,7 +983,7 @@ def test_worker_breaks_and_returns(c, s, a): @gen_cluster(client=True, nthreads=[]) -def test_no_workers_to_memory(c, s): +async def test_no_workers_to_memory(c, s): x = delayed(slowinc)(1, delay=0.4) y = delayed(slowinc)(x, delay=0.4) z = delayed(slowinc)(y, delay=0.4) @@ -994,17 +991,17 @@ def test_no_workers_to_memory(c, s): yy, zz = c.persist([y, z]) while not s.tasks: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) w = Worker(s.address, nthreads=1) w.put_key_in_memory(y.key, 3) - yield w + await w start = time() while not s.workers: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert s.get_task_status(keys={x.key, y.key, z.key}) == { x.key: "released", @@ -1012,11 +1009,11 @@ def test_no_workers_to_memory(c, s): z.key: "processing", } - yield w.close() + await w.close() @gen_cluster(client=True) -def test_no_worker_to_memory_restrictions(c, s, a, b): +async def test_no_worker_to_memory_restrictions(c, s, a, b): x = delayed(slowinc)(1, delay=0.4) y = delayed(slowinc)(x, delay=0.4) z = delayed(slowinc)(y, delay=0.4) @@ -1024,16 +1021,16 @@ def test_no_worker_to_memory_restrictions(c, s, a, b): yy, zz = c.persist([y, z], workers={(x, y, z): "alice"}) while not s.tasks: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) w = Worker(s.address, nthreads=1, name="alice") w.put_key_in_memory(y.key, 3) - yield w + await w while len(s.workers) < 3: - yield gen.sleep(0.01) - yield gen.sleep(0.3) + await asyncio.sleep(0.01) + await asyncio.sleep(0.3) assert s.get_task_status(keys={x.key, y.key, z.key}) == { x.key: "released", @@ -1041,7 +1038,7 @@ def test_no_worker_to_memory_restrictions(c, s, a, b): z.key: "processing", } - yield w.close() + await w.close() def test_run_on_scheduler_sync(loop): @@ -1058,78 +1055,78 @@ def f(dask_scheduler=None): @gen_cluster(client=True) -def test_run_on_scheduler(c, s, a, b): +async def test_run_on_scheduler(c, s, a, b): def f(dask_scheduler=None): return dask_scheduler.address - response = yield c._run_on_scheduler(f) + response = await c._run_on_scheduler(f) assert response == s.address @gen_cluster(client=True) -def test_close_worker(c, s, a, b): +async def test_close_worker(c, s, a, b): assert len(s.workers) == 2 - yield s.close_worker(worker=a.address) + await s.close_worker(worker=a.address) assert len(s.workers) == 1 assert a.address not in s.workers - yield gen.sleep(0.5) + await asyncio.sleep(0.5) assert len(s.workers) == 1 @pytest.mark.slow @gen_cluster(client=True, Worker=Nanny, timeout=20) -def test_close_nanny(c, s, a, b): +async def test_close_nanny(c, s, a, b): assert len(s.workers) == 2 assert a.process.is_alive() a_worker_address = a.worker_address start = time() - yield s.close_worker(worker=a_worker_address) + await s.close_worker(worker=a_worker_address) assert len(s.workers) == 1 assert a_worker_address not in s.workers start = time() while a.is_alive(): - yield gen.sleep(0.1) + await asyncio.sleep(0.1) assert time() < start + 5 assert not a.is_alive() assert a.pid is None for i in range(10): - yield gen.sleep(0.1) + await asyncio.sleep(0.1) assert len(s.workers) == 1 assert not a.is_alive() assert a.pid is None while a.status != "closed": - yield gen.sleep(0.05) + await asyncio.sleep(0.05) assert time() < start + 10 @gen_cluster(client=True, timeout=20) -def test_retire_workers_close(c, s, a, b): - yield s.retire_workers(close_workers=True) +async def test_retire_workers_close(c, s, a, b): + await s.retire_workers(close_workers=True) assert not s.workers while a.status != "closed" and b.status != "closed": - yield gen.sleep(0.01) + await asyncio.sleep(0.01) @gen_cluster(client=True, timeout=20, Worker=Nanny) -def test_retire_nannies_close(c, s, a, b): +async def test_retire_nannies_close(c, s, a, b): nannies = [a, b] - yield s.retire_workers(close_workers=True, remove=True) + await s.retire_workers(close_workers=True, remove=True) assert not s.workers start = time() while any(n.status != "closed" for n in nannies): - yield gen.sleep(0.05) + await asyncio.sleep(0.05) assert time() < start + 10 assert not any(n.is_alive() for n in nannies) @@ -1137,27 +1134,27 @@ def test_retire_nannies_close(c, s, a, b): @gen_cluster(client=True, nthreads=[("127.0.0.1", 2)]) -def test_fifo_submission(c, s, w): +async def test_fifo_submission(c, s, w): futures = [] for i in range(20): future = c.submit(slowinc, i, delay=0.1, key="inc-%02d" % i, fifo_timeout=0.01) futures.append(future) - yield gen.sleep(0.02) - yield wait(futures[-1]) + await asyncio.sleep(0.02) + await wait(futures[-1]) assert futures[10].status == "finished" @gen_test() -def test_scheduler_file(): +async def test_scheduler_file(): with tmpfile() as fn: - s = yield Scheduler(scheduler_file=fn, port=0) + s = await Scheduler(scheduler_file=fn, port=0) with open(fn) as f: data = json.load(f) assert data["address"] == s.address - c = yield Client(scheduler_file=fn, loop=s.loop, asynchronous=True) - yield c.close() - yield s.close() + c = await Client(scheduler_file=fn, loop=s.loop, asynchronous=True) + await c.close() + await s.close() @pytest.mark.xfail(reason="") @@ -1168,21 +1165,21 @@ async def test_non_existent_worker(c, s): address="127.0.0.1:5738", nthreads=2, nbytes={}, host_info={} ) futures = c.map(inc, range(10)) - await gen.sleep(0.300) + await asyncio.sleep(0.300) assert not s.workers assert all(ts.state == "no-worker" for ts in s.tasks.values()) @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3) -def test_correct_bad_time_estimate(c, s, *workers): +async def test_correct_bad_time_estimate(c, s, *workers): future = c.submit(slowinc, 1, delay=0) - yield wait(future) + await wait(future) futures = [c.submit(slowinc, future, delay=0.1, pure=False) for i in range(20)] - yield gen.sleep(0.5) + await asyncio.sleep(0.5) - yield wait(futures) + await wait(futures) assert all(w.data for w in workers), [sorted(w.data) for w in workers] @@ -1210,13 +1207,13 @@ async def test_service_hosts(): @gen_cluster(client=True, worker_kwargs={"profile_cycle_interval": 100}) -def test_profile_metadata(c, s, a, b): +async def test_profile_metadata(c, s, a, b): start = time() - 1 futures = c.map(slowinc, range(10), delay=0.05, workers=a.address) - yield wait(futures) - yield gen.sleep(0.200) + await wait(futures) + await asyncio.sleep(0.200) - meta = yield s.get_profile_metadata(profile_cycle_interval=0.100) + meta = await s.get_profile_metadata(profile_cycle_interval=0.100) now = time() + 1 assert meta assert all(start < t < now for t, count in meta["counts"]) @@ -1225,12 +1222,12 @@ def test_profile_metadata(c, s, a, b): @gen_cluster(client=True, worker_kwargs={"profile_cycle_interval": 100}) -def test_profile_metadata_keys(c, s, a, b): +async def test_profile_metadata_keys(c, s, a, b): x = c.map(slowinc, range(10), delay=0.05) y = c.map(slowdec, range(10), delay=0.05) - yield wait(x + y) + await wait(x + y) - meta = yield s.get_profile_metadata(profile_cycle_interval=0.100) + meta = await s.get_profile_metadata(profile_cycle_interval=0.100) assert set(meta["keys"]) == {"slowinc", "slowdec"} assert ( len(meta["counts"]) - 3 <= len(meta["keys"]["slowinc"]) <= len(meta["counts"]) @@ -1238,7 +1235,7 @@ def test_profile_metadata_keys(c, s, a, b): @gen_cluster(client=True) -def test_cancel_fire_and_forget(c, s, a, b): +async def test_cancel_fire_and_forget(c, s, a, b): x = delayed(slowinc)(1, delay=0.05) y = delayed(slowinc)(x, delay=0.05) z = delayed(slowinc)(y, delay=0.05) @@ -1246,8 +1243,8 @@ def test_cancel_fire_and_forget(c, s, a, b): future = c.compute(w) fire_and_forget(future) - yield gen.sleep(0.05) - yield future.cancel(force=True) + await asyncio.sleep(0.05) + await future.cancel(force=True) assert future.status == "cancelled" assert not s.tasks @@ -1255,34 +1252,34 @@ def test_cancel_fire_and_forget(c, s, a, b): @gen_cluster( client=True, Worker=Nanny, clean_kwargs={"processes": False, "threads": False} ) -def test_log_tasks_during_restart(c, s, a, b): +async def test_log_tasks_during_restart(c, s, a, b): future = c.submit(sys.exit, 0) - yield wait(future) + await wait(future) assert "exit" in str(s.events) @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2) -def test_reschedule(c, s, a, b): - yield c.submit(slowinc, -1, delay=0.1) # learn cost +async def test_reschedule(c, s, a, b): + await c.submit(slowinc, -1, delay=0.1) # learn cost x = c.map(slowinc, range(4), delay=0.1) # add much more work onto worker a futures = c.map(slowinc, range(10, 20), delay=0.1, workers=a.address) while len(s.tasks) < len(x) + len(futures): - yield gen.sleep(0.001) + await asyncio.sleep(0.001) for future in x: s.reschedule(key=future.key) # Worker b gets more of the original tasks - yield wait(x) + await wait(x) assert sum(future.key in b.data for future in x) >= 3 assert sum(future.key in a.data for future in x) <= 1 @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2) -def test_reschedule_warns(c, s, a, b): +async def test_reschedule_warns(c, s, a, b): with captured_logger(logging.getLogger("distributed.scheduler")) as sched: s.reschedule(key="__this-key-does-not-exist__") @@ -1291,11 +1288,11 @@ def test_reschedule_warns(c, s, a, b): @gen_cluster(client=True) -def test_get_task_status(c, s, a, b): +async def test_get_task_status(c, s, a, b): future = c.submit(inc, 1) - yield wait(future) + await wait(future) - result = yield a.scheduler.get_task_status(keys=[future.key]) + result = await a.scheduler.get_task_status(keys=[future.key]) assert result == {future.key: "memory"} @@ -1312,29 +1309,29 @@ def test_deque_handler(): @gen_cluster(client=True) -def test_retries(c, s, a, b): +async def test_retries(c, s, a, b): args = [ZeroDivisionError("one"), ZeroDivisionError("two"), 42] future = c.submit(varying(args), retries=3) - result = yield future + result = await future assert result == 42 assert s.tasks[future.key].retries == 1 assert future.key not in s.exceptions future = c.submit(varying(args), retries=2, pure=False) - result = yield future + result = await future assert result == 42 assert s.tasks[future.key].retries == 0 assert future.key not in s.exceptions future = c.submit(varying(args), retries=1, pure=False) with pytest.raises(ZeroDivisionError) as exc_info: - res = yield future + await future exc_info.match("two") future = c.submit(varying(args), retries=0, pure=False) with pytest.raises(ZeroDivisionError) as exc_info: - res = yield future + await future exc_info.match("one") @@ -1350,149 +1347,149 @@ async def test_mising_data_errant_worker(c, s, w1, w2, w3): y = c.submit(len, x, workers=w3.address) while not w3.tasks: - await gen.sleep(0.001) + await asyncio.sleep(0.001) await w1.close() await wait(y) @gen_cluster(client=True) -def test_dont_recompute_if_persisted(c, s, a, b): +async def test_dont_recompute_if_persisted(c, s, a, b): x = delayed(inc)(1, dask_key_name="x") y = delayed(inc)(x, dask_key_name="y") yy = y.persist() - yield wait(yy) + await wait(yy) old = list(s.transition_log) yyy = y.persist() - yield wait(yyy) + await wait(yyy) - yield gen.sleep(0.100) + await asyncio.sleep(0.100) assert list(s.transition_log) == old @gen_cluster(client=True) -def test_dont_recompute_if_persisted_2(c, s, a, b): +async def test_dont_recompute_if_persisted_2(c, s, a, b): x = delayed(inc)(1, dask_key_name="x") y = delayed(inc)(x, dask_key_name="y") z = delayed(inc)(y, dask_key_name="z") yy = y.persist() - yield wait(yy) + await wait(yy) old = s.story("x", "y") zz = z.persist() - yield wait(zz) + await wait(zz) - yield gen.sleep(0.100) + await asyncio.sleep(0.100) assert s.story("x", "y") == old @gen_cluster(client=True) -def test_dont_recompute_if_persisted_3(c, s, a, b): +async def test_dont_recompute_if_persisted_3(c, s, a, b): x = delayed(inc)(1, dask_key_name="x") y = delayed(inc)(2, dask_key_name="y") z = delayed(inc)(y, dask_key_name="z") w = delayed(operator.add)(x, z, dask_key_name="w") ww = w.persist() - yield wait(ww) + await wait(ww) old = list(s.transition_log) www = w.persist() - yield wait(www) - yield gen.sleep(0.100) + await wait(www) + await asyncio.sleep(0.100) assert list(s.transition_log) == old @gen_cluster(client=True) -def test_dont_recompute_if_persisted_4(c, s, a, b): +async def test_dont_recompute_if_persisted_4(c, s, a, b): x = delayed(inc)(1, dask_key_name="x") y = delayed(inc)(x, dask_key_name="y") z = delayed(inc)(x, dask_key_name="z") yy = y.persist() - yield wait(yy) + await wait(yy) old = s.story("x") while s.tasks["x"].state == "memory": - yield gen.sleep(0.01) + await asyncio.sleep(0.01) yyy, zzz = dask.persist(y, z) - yield wait([yyy, zzz]) + await wait([yyy, zzz]) new = s.story("x") assert len(new) > len(old) @gen_cluster(client=True) -def test_dont_forget_released_keys(c, s, a, b): +async def test_dont_forget_released_keys(c, s, a, b): x = c.submit(inc, 1, key="x") y = c.submit(inc, x, key="y") z = c.submit(dec, x, key="z") del x - yield wait([y, z]) + await wait([y, z]) del z while "z" in s.tasks: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert "x" in s.tasks @gen_cluster(client=True) -def test_dont_recompute_if_erred(c, s, a, b): +async def test_dont_recompute_if_erred(c, s, a, b): x = delayed(inc)(1, dask_key_name="x") y = delayed(div)(x, 0, dask_key_name="y") yy = y.persist() - yield wait(yy) + await wait(yy) old = list(s.transition_log) yyy = y.persist() - yield wait(yyy) + await wait(yyy) - yield gen.sleep(0.100) + await asyncio.sleep(0.100) assert list(s.transition_log) == old @gen_cluster() -def test_closing_scheduler_closes_workers(s, a, b): - yield s.close() +async def test_closing_scheduler_closes_workers(s, a, b): + await s.close() start = time() while a.status != "closed" or b.status != "closed": - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 2 @gen_cluster( client=True, nthreads=[("127.0.0.1", 1)], worker_kwargs={"resources": {"A": 1}} ) -def test_resources_reset_after_cancelled_task(c, s, w): +async def test_resources_reset_after_cancelled_task(c, s, w): future = c.submit(sleep, 0.2, resources={"A": 1}) while not w.executing: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) - yield future.cancel() + await future.cancel() while w.executing: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert not s.workers[w.address].used_resources["A"] assert w.available_resources == {"A": 1} - yield c.submit(inc, 1, resources={"A": 1}) + await c.submit(inc, 1, resources={"A": 1}) @gen_cluster(client=True) -def test_gh2187(c, s, a, b): +async def test_gh2187(c, s, a, b): def foo(): return "foo" @@ -1509,16 +1506,16 @@ def qux(x): w = c.submit(foo, key="w") x = c.submit(bar, w, key="x") y = c.submit(baz, x, key="y") - yield y + await y z = c.submit(qux, y, key="z") del y - yield gen.sleep(0.1) + await asyncio.sleep(0.1) f = c.submit(bar, x, key="y") - yield f + await f @gen_cluster(client=True) -def test_collect_versions(c, s, a, b): +async def test_collect_versions(c, s, a, b): cs = s.clients[c.id] (w1, w2) = s.workers.values() assert cs.versions @@ -1538,12 +1535,12 @@ async def test_idle_timeout(c, s, a, b): with captured_logger("distributed.scheduler") as logs: start = time() while s.status != "closed": - await gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 3 start = time() while not (a.status == "closed" and b.status == "closed"): - await gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 1 assert "idle" in logs.getvalue() @@ -1587,7 +1584,7 @@ async def f(dask_worker): @gen_cluster() -def test_workerstate_clean(s, a, b): +async def test_workerstate_clean(s, a, b): ws = s.workers[a.address].clean() assert ws.address == a.address b = pickle.dumps(ws) @@ -1595,16 +1592,16 @@ def test_workerstate_clean(s, a, b): @gen_cluster(client=True) -def test_result_type(c, s, a, b): +async def test_result_type(c, s, a, b): x = c.submit(lambda: 1) - yield x + await x assert "int" in s.tasks[x.key].type @gen_cluster() -def test_close_workers(s, a, b): - yield s.close(close_workers=True) +async def test_close_workers(s, a, b): + await s.close(close_workers=True) assert a.status == "closed" assert b.status == "closed" @@ -1613,22 +1610,22 @@ def test_close_workers(s, a, b): not sys.platform.startswith("linux"), reason="Need 127.0.0.2 to mean localhost" ) @gen_test() -def test_host_address(): - s = yield Scheduler(host="127.0.0.2", port=0) +async def test_host_address(): + s = await Scheduler(host="127.0.0.2", port=0) assert "127.0.0.2" in s.address - yield s.close() + await s.close() @gen_test() -def test_dashboard_address(): +async def test_dashboard_address(): pytest.importorskip("bokeh") - s = yield Scheduler(dashboard_address="127.0.0.1:8901", port=0) + s = await Scheduler(dashboard_address="127.0.0.1:8901", port=0) assert s.services["dashboard"].port == 8901 - yield s.close() + await s.close() - s = yield Scheduler(dashboard_address="127.0.0.1", port=0) + s = await Scheduler(dashboard_address="127.0.0.1", port=0) assert s.services["dashboard"].port - yield s.close() + await s.close() @gen_cluster(client=True) @@ -1644,16 +1641,16 @@ async def test_adaptive_target(c, s, a, b): # Long task x = c.submit(slowinc, 1, delay=0.5) while x.key not in s.tasks: - await gen.sleep(0.01) + await asyncio.sleep(0.01) assert s.adaptive_target(target_duration=".1s") == 1 # still one L = c.map(slowinc, range(100), delay=0.5) while len(s.tasks) < 100: - await gen.sleep(0.01) + await asyncio.sleep(0.01) assert 10 < s.adaptive_target(target_duration=".1s") <= 100 del x, L while s.tasks: - await gen.sleep(0.01) + await asyncio.sleep(0.01) assert s.adaptive_target(target_duration=".1s") == 0 diff --git a/distributed/tests/test_security.py b/distributed/tests/test_security.py index 8665ebead33..7bb2fd753c0 100644 --- a/distributed/tests/test_security.py +++ b/distributed/tests/test_security.py @@ -1,5 +1,4 @@ from contextlib import contextmanager -import sys try: import ssl @@ -149,15 +148,15 @@ def test_tls_config_for_role(): sec.get_tls_config_for_role("supervisor") +def assert_many_ciphers(ctx): + assert len(ctx.get_ciphers()) > 2 # Most likely + + def test_connection_args(): def basic_checks(ctx): assert ctx.verify_mode == ssl.CERT_REQUIRED assert ctx.check_hostname is False - def many_ciphers(ctx): - if sys.version_info >= (3, 6): - assert len(ctx.get_ciphers()) > 2 # Most likely - c = { "distributed.comm.tls.ca-file": ca_file, "distributed.comm.tls.scheduler.key": key1, @@ -171,12 +170,12 @@ def many_ciphers(ctx): assert not d["require_encryption"] ctx = d["ssl_context"] basic_checks(ctx) - many_ciphers(ctx) + assert_many_ciphers(ctx) d = sec.get_connection_args("worker") ctx = d["ssl_context"] basic_checks(ctx) - many_ciphers(ctx) + assert_many_ciphers(ctx) # No cert defined => no TLS d = sec.get_connection_args("client") @@ -193,13 +192,12 @@ def many_ciphers(ctx): assert d["require_encryption"] ctx = d["ssl_context"] basic_checks(ctx) - if sys.version_info >= (3, 6): - supported_ciphers = ctx.get_ciphers() - tls_12_ciphers = [c for c in supported_ciphers if "TLSv1.2" in c["description"]] - assert len(tls_12_ciphers) == 1 - tls_13_ciphers = [c for c in supported_ciphers if "TLSv1.3" in c["description"]] - if len(tls_13_ciphers): - assert len(tls_13_ciphers) == 3 + + supported_ciphers = ctx.get_ciphers() + tls_12_ciphers = [c for c in supported_ciphers if "TLSv1.2" in c["description"]] + assert len(tls_12_ciphers) == 1 + tls_13_ciphers = [c for c in supported_ciphers if "TLSv1.3" in c["description"]] + assert len(tls_13_ciphers) in (0, 3) def test_listen_args(): @@ -207,10 +205,6 @@ def basic_checks(ctx): assert ctx.verify_mode == ssl.CERT_REQUIRED assert ctx.check_hostname is False - def many_ciphers(ctx): - if sys.version_info >= (3, 6): - assert len(ctx.get_ciphers()) > 2 # Most likely - c = { "distributed.comm.tls.ca-file": ca_file, "distributed.comm.tls.scheduler.key": key1, @@ -224,12 +218,12 @@ def many_ciphers(ctx): assert not d["require_encryption"] ctx = d["ssl_context"] basic_checks(ctx) - many_ciphers(ctx) + assert_many_ciphers(ctx) d = sec.get_listen_args("worker") ctx = d["ssl_context"] basic_checks(ctx) - many_ciphers(ctx) + assert_many_ciphers(ctx) # No cert defined => no TLS d = sec.get_listen_args("client") @@ -246,13 +240,12 @@ def many_ciphers(ctx): assert d["require_encryption"] ctx = d["ssl_context"] basic_checks(ctx) - if sys.version_info >= (3, 6): - supported_ciphers = ctx.get_ciphers() - tls_12_ciphers = [c for c in supported_ciphers if "TLSv1.2" in c["description"]] - assert len(tls_12_ciphers) == 1 - tls_13_ciphers = [c for c in supported_ciphers if "TLSv1.3" in c["description"]] - if len(tls_13_ciphers): - assert len(tls_13_ciphers) == 3 + + supported_ciphers = ctx.get_ciphers() + tls_12_ciphers = [c for c in supported_ciphers if "TLSv1.2" in c["description"]] + assert len(tls_12_ciphers) == 1 + tls_13_ciphers = [c for c in supported_ciphers if "TLSv1.3" in c["description"]] + assert len(tls_13_ciphers) in (0, 3) @pytest.mark.asyncio diff --git a/distributed/tests/test_semaphore.py b/distributed/tests/test_semaphore.py index 7a36431042f..fc0a6172a85 100644 --- a/distributed/tests/test_semaphore.py +++ b/distributed/tests/test_semaphore.py @@ -327,17 +327,15 @@ async def test_oversubscribing_leases(c, s, a, b): accept new leases as long as the semaphore is oversubscribed. Oversubscription may occur if tasks hold the GIL for a longer time than the - lease-timeout is configured causing the lease refreshs to go stale and - timeout. + lease-timeout is configured causing the lease refresh to go stale and timeout. We cannot protect ourselves entirely from this but we can ensure that while a task with a timed out lease is still running, we block further acquisitions until we return to normal. An example would be a task which continuously locks the GIL for a longer - time than the lease timeout but this continous lock only makes up a + time than the lease timeout but this continuous lock only makes up a fraction of the tasks runtime. - """ # GH3705 diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 1c9fe22e2e8..fb5c96e14e6 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -1,3 +1,4 @@ +import asyncio import itertools import random import sys @@ -23,7 +24,6 @@ slowinc, ) from tlz import concat, sliding_window -from tornado import gen # Most tests here are timing-dependent setup_module = nodebug_setup_module @@ -34,70 +34,70 @@ not sys.platform.startswith("linux"), reason="Need 127.0.0.2 to mean localhost" ) @gen_cluster(client=True, nthreads=[("127.0.0.1", 2), ("127.0.0.2", 2)], timeout=20) -def test_work_stealing(c, s, a, b): - [x] = yield c._scatter([1], workers=a.address) +async def test_work_stealing(c, s, a, b): + [x] = await c._scatter([1], workers=a.address) futures = c.map(slowadd, range(50), [x] * 50) - yield wait(futures) + await wait(futures) assert len(a.data) > 10 assert len(b.data) > 10 @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2) -def test_dont_steal_expensive_data_fast_computation(c, s, a, b): +async def test_dont_steal_expensive_data_fast_computation(c, s, a, b): np = pytest.importorskip("numpy") x = c.submit(np.arange, 1000000, workers=a.address) - yield wait([x]) + await wait([x]) future = c.submit(np.sum, [1], workers=a.address) # learn that sum is fast - yield wait([future]) + await wait([future]) cheap = [ c.submit(np.sum, x, pure=False, workers=a.address, allow_other_workers=True) for i in range(10) ] - yield wait(cheap) + await wait(cheap) assert len(s.who_has[x.key]) == 1 assert len(b.data) == 0 assert len(a.data) == 12 @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2) -def test_steal_cheap_data_slow_computation(c, s, a, b): +async def test_steal_cheap_data_slow_computation(c, s, a, b): x = c.submit(slowinc, 100, delay=0.1) # learn that slowinc is slow - yield wait(x) + await wait(x) futures = c.map( slowinc, range(10), delay=0.1, workers=a.address, allow_other_workers=True ) - yield wait(futures) + await wait(futures) assert abs(len(a.data) - len(b.data)) <= 5 @pytest.mark.avoid_travis @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2) -def test_steal_expensive_data_slow_computation(c, s, a, b): +async def test_steal_expensive_data_slow_computation(c, s, a, b): np = pytest.importorskip("numpy") x = c.submit(slowinc, 100, delay=0.2, workers=a.address) - yield wait(x) # learn that slowinc is slow + await wait(x) # learn that slowinc is slow x = c.submit(np.arange, 1000000, workers=a.address) # put expensive data - yield wait(x) + await wait(x) slow = [c.submit(slowinc, x, delay=0.1, pure=False) for i in range(20)] - yield wait(slow) + await wait(slow) assert len(s.who_has[x.key]) > 1 assert b.data # not empty @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 10) -def test_worksteal_many_thieves(c, s, *workers): +async def test_worksteal_many_thieves(c, s, *workers): x = c.submit(slowinc, -1, delay=0.1) - yield x + await x xs = c.map(slowinc, [x] * 100, pure=False, delay=0.1) - yield wait(xs) + await wait(xs) for w, keys in s.has_what.items(): assert 2 < len(keys) < 30 @@ -107,31 +107,30 @@ def test_worksteal_many_thieves(c, s, *workers): @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2) -def test_dont_steal_unknown_functions(c, s, a, b): - futures = c.map(inc, [1, 2], workers=a.address, allow_other_workers=True) - yield wait(futures) - assert len(a.data) == 2, [len(a.data), len(b.data)] - assert len(b.data) == 0, [len(a.data), len(b.data)] +async def test_dont_steal_unknown_functions(c, s, a, b): + futures = c.map(inc, range(100), workers=a.address, allow_other_workers=True) + await wait(futures) + assert len(a.data) >= 95, [len(a.data), len(b.data)] @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2) -def test_eventually_steal_unknown_functions(c, s, a, b): +async def test_eventually_steal_unknown_functions(c, s, a, b): futures = c.map( slowinc, range(10), delay=0.1, workers=a.address, allow_other_workers=True ) - yield wait(futures) - assert len(a.data) >= 3 - assert len(b.data) >= 3 + await wait(futures) + assert len(a.data) >= 3, [len(a.data), len(b.data)] + assert len(b.data) >= 3, [len(a.data), len(b.data)] @pytest.mark.skip(reason="") @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3) -def test_steal_related_tasks(e, s, a, b, c): +async def test_steal_related_tasks(e, s, a, b, c): futures = e.map( slowinc, range(20), delay=0.05, workers=a.address, allow_other_workers=True ) - yield wait(futures) + await wait(futures) nearby = 0 for f1, f2 in sliding_window(2, futures): @@ -198,17 +197,17 @@ def fast_blacklisted(x, y=None): @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)], timeout=20) -def test_new_worker_steals(c, s, a): - yield wait(c.submit(slowinc, 1, delay=0.01)) +async def test_new_worker_steals(c, s, a): + await wait(c.submit(slowinc, 1, delay=0.01)) futures = c.map(slowinc, range(100), delay=0.05) total = c.submit(sum, futures) while len(a.task_state) < 10: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) - b = yield Worker(s.address, loop=s.loop, nthreads=1, memory_limit=MEMORY_LIMIT) + b = await Worker(s.address, loop=s.loop, nthreads=1, memory_limit=MEMORY_LIMIT) - result = yield total + result = await total assert result == sum(map(inc, range(100))) for w in [a, b]: @@ -216,44 +215,44 @@ def test_new_worker_steals(c, s, a): assert b.data - yield b.close() + await b.close() @gen_cluster(client=True, timeout=20) -def test_work_steal_no_kwargs(c, s, a, b): - yield wait(c.submit(slowinc, 1, delay=0.05)) +async def test_work_steal_no_kwargs(c, s, a, b): + await wait(c.submit(slowinc, 1, delay=0.05)) futures = c.map( slowinc, range(100), workers=a.address, allow_other_workers=True, delay=0.05 ) - yield wait(futures) + await wait(futures) assert 20 < len(a.data) < 80 assert 20 < len(b.data) < 80 total = c.submit(sum, futures) - result = yield total + result = await total assert result == sum(map(inc, range(100))) @gen_cluster(client=True, nthreads=[("127.0.0.1", 1), ("127.0.0.1", 2)]) -def test_dont_steal_worker_restrictions(c, s, a, b): +async def test_dont_steal_worker_restrictions(c, s, a, b): future = c.submit(slowinc, 1, delay=0.10, workers=a.address) - yield future + await future futures = c.map(slowinc, range(100), delay=0.1, workers=a.address) while len(a.task_state) + len(b.task_state) < 100: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert len(a.task_state) == 100 assert len(b.task_state) == 0 result = s.extensions["stealing"].balance() - yield gen.sleep(0.1) + await asyncio.sleep(0.1) assert len(a.task_state) == 100 assert len(b.task_state) == 0 @@ -262,15 +261,15 @@ def test_dont_steal_worker_restrictions(c, s, a, b): @gen_cluster( client=True, nthreads=[("127.0.0.1", 1), ("127.0.0.1", 2), ("127.0.0.1", 2)] ) -def test_steal_worker_restrictions(c, s, wa, wb, wc): +async def test_steal_worker_restrictions(c, s, wa, wb, wc): future = c.submit(slowinc, 1, delay=0.1, workers={wa.address, wb.address}) - yield future + await future ntasks = 100 futures = c.map(slowinc, range(ntasks), delay=0.1, workers={wa.address, wb.address}) while sum(len(w.task_state) for w in [wa, wb, wc]) < ntasks: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert 0 < len(wa.task_state) < ntasks assert 0 < len(wb.task_state) < ntasks @@ -278,7 +277,7 @@ def test_steal_worker_restrictions(c, s, wa, wb, wc): s.extensions["stealing"].balance() - yield gen.sleep(0.1) + await asyncio.sleep(0.1) assert 0 < len(wa.task_state) < ntasks assert 0 < len(wb.task_state) < ntasks @@ -289,19 +288,19 @@ def test_steal_worker_restrictions(c, s, wa, wb, wc): not sys.platform.startswith("linux"), reason="Need 127.0.0.2 to mean localhost" ) @gen_cluster(client=True, nthreads=[("127.0.0.1", 1), ("127.0.0.2", 1)]) -def test_dont_steal_host_restrictions(c, s, a, b): +async def test_dont_steal_host_restrictions(c, s, a, b): future = c.submit(slowinc, 1, delay=0.10, workers=a.address) - yield future + await future futures = c.map(slowinc, range(100), delay=0.1, workers="127.0.0.1") while len(a.task_state) + len(b.task_state) < 100: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert len(a.task_state) == 100 assert len(b.task_state) == 0 result = s.extensions["stealing"].balance() - yield gen.sleep(0.1) + await asyncio.sleep(0.1) assert len(a.task_state) == 100 assert len(b.task_state) == 0 @@ -310,25 +309,25 @@ def test_dont_steal_host_restrictions(c, s, a, b): not sys.platform.startswith("linux"), reason="Need 127.0.0.2 to mean localhost" ) @gen_cluster(client=True, nthreads=[("127.0.0.1", 1), ("127.0.0.2", 2)]) -def test_steal_host_restrictions(c, s, wa, wb): +async def test_steal_host_restrictions(c, s, wa, wb): future = c.submit(slowinc, 1, delay=0.10, workers=wa.address) - yield future + await future ntasks = 100 futures = c.map(slowinc, range(ntasks), delay=0.1, workers="127.0.0.1") while len(wa.task_state) < ntasks: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert len(wa.task_state) == ntasks assert len(wb.task_state) == 0 - wc = yield Worker(s.address, nthreads=1) + wc = await Worker(s.address, nthreads=1) start = time() while not wc.task_state or len(wa.task_state) == ntasks: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 3 - yield gen.sleep(0.1) + await asyncio.sleep(0.1) assert 0 < len(wa.task_state) < ntasks assert len(wb.task_state) == 0 assert 0 < len(wc.task_state) < ntasks @@ -337,19 +336,19 @@ def test_steal_host_restrictions(c, s, wa, wb): @gen_cluster( client=True, nthreads=[("127.0.0.1", 1, {"resources": {"A": 2}}), ("127.0.0.1", 1)] ) -def test_dont_steal_resource_restrictions(c, s, a, b): +async def test_dont_steal_resource_restrictions(c, s, a, b): future = c.submit(slowinc, 1, delay=0.10, workers=a.address) - yield future + await future futures = c.map(slowinc, range(100), delay=0.1, resources={"A": 1}) while len(a.task_state) + len(b.task_state) < 100: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert len(a.task_state) == 100 assert len(b.task_state) == 0 result = s.extensions["stealing"].balance() - yield gen.sleep(0.1) + await asyncio.sleep(0.1) assert len(a.task_state) == 100 assert len(b.task_state) == 0 @@ -357,30 +356,30 @@ def test_dont_steal_resource_restrictions(c, s, a, b): @gen_cluster( client=True, nthreads=[("127.0.0.1", 1, {"resources": {"A": 2}})], timeout=3 ) -def test_steal_resource_restrictions(c, s, a): +async def test_steal_resource_restrictions(c, s, a): future = c.submit(slowinc, 1, delay=0.10, workers=a.address) - yield future + await future futures = c.map(slowinc, range(100), delay=0.2, resources={"A": 1}) while len(a.task_state) < 101: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert len(a.task_state) == 101 - b = yield Worker(s.address, loop=s.loop, nthreads=1, resources={"A": 4}) + b = await Worker(s.address, loop=s.loop, nthreads=1, resources={"A": 4}) start = time() while not b.task_state or len(a.task_state) == 101: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 3 assert len(b.task_state) > 0 assert len(a.task_state) < 101 - yield b.close() + await b.close() @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 5, timeout=20) -def test_balance_without_dependencies(c, s, *workers): +async def test_balance_without_dependencies(c, s, *workers): s.extensions["stealing"]._pc.callback_time = 20 def slow(x): @@ -389,19 +388,19 @@ def slow(x): return y futures = c.map(slow, range(100)) - yield wait(futures) + await wait(futures) durations = [sum(w.data.values()) for w in workers] assert max(durations) / min(durations) < 3 @gen_cluster(client=True, nthreads=[("127.0.0.1", 4)] * 2) -def test_dont_steal_executing_tasks(c, s, a, b): +async def test_dont_steal_executing_tasks(c, s, a, b): futures = c.map( slowinc, range(4), delay=0.1, workers=a.address, allow_other_workers=True ) - yield wait(futures) + await wait(futures) assert len(a.data) == 4 assert len(b.data) == 0 @@ -411,14 +410,14 @@ def test_dont_steal_executing_tasks(c, s, a, b): nthreads=[("127.0.0.1", 1)] * 10, config={"distributed.scheduler.default-task-durations": {"slowidentity": 0.2}}, ) -def test_dont_steal_few_saturated_tasks_many_workers(c, s, a, *rest): +async def test_dont_steal_few_saturated_tasks_many_workers(c, s, a, *rest): s.extensions["stealing"]._pc.callback_time = 20 x = c.submit(mul, b"0", 100000000, workers=a.address) # 100 MB - yield wait(x) + await wait(x) futures = [c.submit(slowidentity, x, pure=False, delay=0.2) for i in range(2)] - yield wait(futures) + await wait(futures) assert len(a.data) == 3 assert not any(w.task_state for w in rest) @@ -430,16 +429,16 @@ def test_dont_steal_few_saturated_tasks_many_workers(c, s, a, *rest): worker_kwargs={"memory_limit": MEMORY_LIMIT}, config={"distributed.scheduler.default-task-durations": {"slowidentity": 0.2}}, ) -def test_steal_when_more_tasks(c, s, a, *rest): +async def test_steal_when_more_tasks(c, s, a, *rest): s.extensions["stealing"]._pc.callback_time = 20 x = c.submit(mul, b"0", 50000000, workers=a.address) # 50 MB - yield wait(x) + await wait(x) futures = [c.submit(slowidentity, x, pure=False, delay=0.2) for i in range(20)] start = time() while not any(w.task_state for w in rest): - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 1 @@ -453,20 +452,20 @@ def test_steal_when_more_tasks(c, s, a, *rest): } }, ) -def test_steal_more_attractive_tasks(c, s, a, *rest): +async def test_steal_more_attractive_tasks(c, s, a, *rest): def slow2(x): sleep(1) return x s.extensions["stealing"]._pc.callback_time = 20 x = c.submit(mul, b"0", 100000000, workers=a.address) # 100 MB - yield wait(x) + await wait(x) futures = [c.submit(slowidentity, x, pure=False, delay=0.2) for i in range(10)] future = c.submit(slow2, x, priority=-1) while not any(w.task_state for w in rest): - yield gen.sleep(0.01) + await asyncio.sleep(0.01) # good future moves first assert any(future.key in w.task_state for w in rest) @@ -476,7 +475,7 @@ def func(x): sleep(1) -def assert_balanced(inp, expected, c, s, *workers): +async def assert_balanced(inp, expected, c, s, *workers): steal = s.extensions["stealing"] steal._pc.stop() @@ -488,7 +487,7 @@ def assert_balanced(inp, expected, c, s, *workers): for w, ts in zip(workers, inp): for t in sorted(ts, reverse=True): if t: - [dat] = yield c.scatter([next(data_seq)], workers=w.address) + [dat] = await c.scatter([next(data_seq)], workers=w.address) ts = s.tasks[dat.key] # Ensure scheduler state stays consistent old_nbytes = ts.nbytes @@ -510,13 +509,13 @@ def assert_balanced(inp, expected, c, s, *workers): futures.append(f) while len(s.rprocessing) < len(futures): - yield gen.sleep(0.001) + await asyncio.sleep(0.001) for i in range(10): steal.balance() while steal.in_flight: - yield gen.sleep(0.001) + await asyncio.sleep(0.001) result = [ sorted([int(key_split(k)) for k in s.processing[w.address]], reverse=True) @@ -569,7 +568,9 @@ def assert_balanced(inp, expected, c, s, *workers): ], ) def test_balance(inp, expected): - test = lambda *args, **kwargs: assert_balanced(inp, expected, *args, **kwargs) + async def test(*args, **kwargs): + await assert_balanced(inp, expected, *args, **kwargs) + test = gen_cluster( client=True, nthreads=[("127.0.0.1", 1)] * len(inp), @@ -583,18 +584,18 @@ def test_balance(inp, expected): @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2, Worker=Nanny, timeout=20) -def test_restart(c, s, a, b): +async def test_restart(c, s, a, b): futures = c.map( slowinc, range(100), delay=0.1, workers=a.address, allow_other_workers=True ) while not s.processing[b.worker_address]: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) steal = s.extensions["stealing"] assert any(st for st in steal.stealable_all) assert any(x for L in steal.stealable.values() for x in L) - yield c.restart(timeout=10) + await c.restart(timeout=10) assert not any(x for x in steal.stealable_all) assert not any(x for L in steal.stealable.values() for x in L) @@ -604,7 +605,7 @@ def test_restart(c, s, a, b): client=True, config={"distributed.scheduler.default-task-durations": {"slowadd": 0.001}}, ) -def test_steal_communication_heavy_tasks(c, s, a, b): +async def test_steal_communication_heavy_tasks(c, s, a, b): steal = s.extensions["stealing"] x = c.submit(mul, b"0", int(s.bandwidth), workers=a.address) y = c.submit(mul, b"1", int(s.bandwidth), workers=b.address) @@ -623,29 +624,29 @@ def test_steal_communication_heavy_tasks(c, s, a, b): ] while not any(f.key in s.rprocessing for f in futures): - yield gen.sleep(0.01) + await asyncio.sleep(0.01) steal.balance() while steal.in_flight: - yield gen.sleep(0.001) + await asyncio.sleep(0.001) assert s.processing[b.address] @gen_cluster(client=True) -def test_steal_twice(c, s, a, b): +async def test_steal_twice(c, s, a, b): x = c.submit(inc, 1, workers=a.address) - yield wait(x) + await wait(x) futures = [c.submit(slowadd, x, i, delay=0.2) for i in range(100)] while len(s.tasks) < 100: # tasks are all allocated - yield gen.sleep(0.01) + await asyncio.sleep(0.01) # Army of new workers arrives to help - workers = yield [Worker(s.address, loop=s.loop) for _ in range(20)] + workers = await asyncio.gather(*[Worker(s.address, loop=s.loop) for _ in range(20)]) - yield wait(futures) + await wait(futures) has_what = dict(s.has_what) # take snapshot empty_workers = [w for w, keys in has_what.items() if not len(keys)] @@ -656,42 +657,42 @@ def test_steal_twice(c, s, a, b): ) assert max(map(len, has_what.values())) < 30 - yield c._close() - yield [w.close() for w in workers] + await c._close() + await asyncio.gather(*[w.close() for w in workers]) @gen_cluster(client=True) -def test_dont_steal_executing_tasks(c, s, a, b): +async def test_dont_steal_executing_tasks(c, s, a, b): steal = s.extensions["stealing"] future = c.submit(slowinc, 1, delay=0.5, workers=a.address) while not a.executing: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) steal.move_task_request( s.tasks[future.key], s.workers[a.address], s.workers[b.address] ) - yield gen.sleep(0.1) + await asyncio.sleep(0.1) assert future.key in a.executing assert not b.executing @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2) -def test_dont_steal_long_running_tasks(c, s, a, b): +async def test_dont_steal_long_running_tasks(c, s, a, b): def long(delay): with worker_client() as c: sleep(delay) - yield c.submit(long, 0.1) # learn duration - yield c.submit(inc, 1) # learn duration + await c.submit(long, 0.1) # learn duration + await c.submit(inc, 1) # learn duration long_tasks = c.map(long, [0.5, 0.6], workers=a.address, allow_other_workers=True) while sum(map(len, s.processing.values())) < 2: # let them start - yield gen.sleep(0.01) + await asyncio.sleep(0.01) start = time() while any(t.key in s.extensions["stealing"].key_stealable for t in long_tasks): - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 1 na = len(a.executing) @@ -699,9 +700,9 @@ def long(delay): incs = c.map(inc, range(100), workers=a.address, allow_other_workers=True) - yield gen.sleep(0.2) + await asyncio.sleep(0.2) - yield wait(long_tasks) + await wait(long_tasks) for t in long_tasks: assert ( @@ -716,19 +717,19 @@ def long(delay): strict=False, ) @gen_cluster(client=True, nthreads=[("127.0.0.1", 5)] * 2) -def test_cleanup_repeated_tasks(c, s, a, b): +async def test_cleanup_repeated_tasks(c, s, a, b): class Foo: pass s.extensions["stealing"]._pc.callback_time = 20 - yield c.submit(slowidentity, -1, delay=0.1) + await c.submit(slowidentity, -1, delay=0.1) objects = [c.submit(Foo, pure=False, workers=a.address) for _ in range(50)] x = c.map( slowidentity, objects, workers=a.address, allow_other_workers=True, delay=0.05 ) del objects - yield wait(x) + await wait(x) assert a.data and b.data assert len(a.data) + len(b.data) > 10 ws = weakref.WeakSet() @@ -738,7 +739,7 @@ class Foo: start = time() while a.data or b.data: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 1 assert not s.who_has @@ -748,7 +749,7 @@ class Foo: @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2) -def test_lose_task(c, s, a, b): +async def test_lose_task(c, s, a, b): with captured_logger("distributed.stealing") as log: s.periodic_callbacks["stealing"].interval = 1 for i in range(100): @@ -760,7 +761,7 @@ def test_lose_task(c, s, a, b): workers=a.address, allow_other_workers=True, ) - yield gen.sleep(0.01) + await asyncio.sleep(0.01) del futures out = log.getvalue() @@ -768,7 +769,7 @@ def test_lose_task(c, s, a, b): @gen_cluster(client=True) -def test_worker_stealing_interval(c, s, a, b): +async def test_worker_stealing_interval(c, s, a, b): from distributed.scheduler import WorkStealing ws = WorkStealing(s) diff --git a/distributed/tests/test_stress.py b/distributed/tests/test_stress.py index d5e1e62c574..707b93c03cf 100644 --- a/distributed/tests/test_stress.py +++ b/distributed/tests/test_stress.py @@ -1,8 +1,8 @@ -from operator import add +import asyncio import random import sys +from operator import add from time import sleep -import asyncio from dask import delayed import pytest @@ -27,7 +27,6 @@ nodebug_teardown_module, ) from distributed.client import wait -from tornado import gen # All tests here are slow in some way @@ -36,14 +35,14 @@ @gen_cluster(client=True) -def test_stress_1(c, s, a, b): +async def test_stress_1(c, s, a, b): n = 2 ** 6 seq = c.map(inc, range(n)) while len(seq) > 1: - yield gen.sleep(0.1) + await asyncio.sleep(0.1) seq = [c.submit(add, seq[i], seq[i + 1]) for i in range(0, len(seq), 2)] - result = yield seq[0] + result = await seq[0] assert result == sum(map(inc, range(n))) @@ -62,18 +61,18 @@ def test_stress_gc(loop, func, n): sys.platform.startswith("win"), reason="test can leave dangling RPC objects" ) @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 8, timeout=None) -def test_cancel_stress(c, s, *workers): +async def test_cancel_stress(c, s, *workers): da = pytest.importorskip("dask.array") x = da.random.random((50, 50), chunks=(2, 2)) x = c.persist(x) - yield wait([x]) + await wait([x]) y = (x.sum(axis=0) + x.sum(axis=1) + 1).std() n_todo = len(y.dask) - len(x.dask) for i in range(5): f = c.compute(y) while len(s.waiting) > (random.random() + 1) * 0.5 * n_todo: - yield gen.sleep(0.01) - yield c._cancel(f) + await asyncio.sleep(0.01) + await c._cancel(f) def test_cancel_stress_sync(loop): @@ -91,7 +90,7 @@ def test_cancel_stress_sync(loop): @gen_cluster(nthreads=[], client=True, timeout=None) -def test_stress_creation_and_deletion(c, s): +async def test_stress_creation_and_deletion(c, s): # Assertions are handled by the validate mechanism in the scheduler s.allowed_failures = 100000 da = pytest.importorskip("dask.array") @@ -101,28 +100,27 @@ def test_stress_creation_and_deletion(c, s): z = c.persist(y) - @gen.coroutine - def create_and_destroy_worker(delay): + async def create_and_destroy_worker(delay): start = time() while time() < start + 5: - n = yield Nanny(s.address, nthreads=2, loop=s.loop) - yield gen.sleep(delay) - yield n.close() + n = await Nanny(s.address, nthreads=2, loop=s.loop) + await asyncio.sleep(delay) + await n.close() print("Killed nanny") - yield asyncio.wait_for( + await asyncio.wait_for( All([create_and_destroy_worker(0.1 * i) for i in range(20)]), 60 ) @gen_cluster(nthreads=[("127.0.0.1", 1)] * 10, client=True, timeout=60) -def test_stress_scatter_death(c, s, *workers): +async def test_stress_scatter_death(c, s, *workers): import random s.allowed_failures = 1000 np = pytest.importorskip("numpy") - L = yield c.scatter([np.random.random(10000) for i in range(len(workers))]) - yield c.replicate(L, n=2) + L = await c.scatter([np.random.random(10000) for i in range(len(workers))]) + await c.replicate(L, n=2) adds = [ delayed(slowadd, pure=True)( @@ -147,7 +145,7 @@ def test_stress_scatter_death(c, s, *workers): from distributed.scheduler import logger for i in range(7): - yield gen.sleep(0.1) + await asyncio.sleep(0.1) try: s.validate_state() except Exception as c: @@ -159,11 +157,11 @@ def test_stress_scatter_death(c, s, *workers): else: raise w = random.choice(alive) - yield w.close() + await w.close() alive.remove(w) with ignoring(CancelledError): - yield c.gather(futures) + await c.gather(futures) futures = None @@ -175,7 +173,7 @@ def vsum(*args): @pytest.mark.avoid_travis @pytest.mark.slow @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 80, timeout=1000) -def test_stress_communication(c, s, *workers): +async def test_stress_communication(c, s, *workers): s.validate = False # very slow otherwise da = pytest.importorskip("dask.array") # Test consumes many file descriptors and can hang if the limit is too low @@ -189,13 +187,13 @@ def test_stress_communication(c, s, *workers): future = c.compute(z.sum()) - result = yield future + result = await future assert isinstance(result, float) @pytest.mark.skip @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 10, timeout=60) -def test_stress_steal(c, s, *workers): +async def test_stress_steal(c, s, *workers): s.validate = False for w in workers: w.validate = False @@ -209,7 +207,7 @@ def test_stress_steal(c, s, *workers): future = c.compute(total) while future.status != "finished": - yield gen.sleep(0.1) + await asyncio.sleep(0.1) for i in range(3): a = random.choice(workers) b = random.choice(workers) @@ -221,7 +219,7 @@ def test_stress_steal(c, s, *workers): @pytest.mark.slow @gen_cluster(nthreads=[("127.0.0.1", 1)] * 10, client=True, timeout=120) -def test_close_connections(c, s, *workers): +async def test_close_connections(c, s, *workers): da = pytest.importorskip("dask.array") x = da.random.random(size=(1000, 1000), chunks=(1000, 1)) for i in range(3): @@ -230,7 +228,7 @@ def test_close_connections(c, s, *workers): future = c.compute(x.sum()) while any(s.processing.values()): - yield gen.sleep(0.5) + await asyncio.sleep(0.5) worker = random.choice(list(workers)) for comm in worker._comms: comm.abort() @@ -238,7 +236,7 @@ def test_close_connections(c, s, *workers): # for w in workers: # print(w) - yield wait(future) + await wait(future) @pytest.mark.xfail( @@ -246,7 +244,7 @@ def test_close_connections(c, s, *workers): " https://github.com/tornadoweb/tornado/issues/2110" ) @gen_cluster(client=True, timeout=20, nthreads=[("127.0.0.1", 1)]) -def test_no_delay_during_large_transfer(c, s, w): +async def test_no_delay_during_large_transfer(c, s, w): pytest.importorskip("crick") np = pytest.importorskip("numpy") x = np.random.random(100000000) @@ -263,8 +261,8 @@ def test_no_delay_during_large_transfer(c, s, w): server._last_tick = time() with ResourceProfiler(dt=0.01) as rprof: - future = yield c.scatter(x, direct=True, hash=False) - yield gen.sleep(0.5) + future = await c.scatter(x, direct=True, hash=False) + await asyncio.sleep(0.5) rprof.close() x = None # lose ref diff --git a/distributed/tests/test_tls_functional.py b/distributed/tests/test_tls_functional.py index 3a2bebf790d..67594f42926 100644 --- a/distributed/tests/test_tls_functional.py +++ b/distributed/tests/test_tls_functional.py @@ -2,7 +2,7 @@ Various functional tests for TLS networking. Most are taken from other test files and adapted. """ -from tornado import gen +import asyncio from distributed import Nanny, worker_client, Queue from distributed.client import wait @@ -12,72 +12,72 @@ @gen_tls_cluster(client=True) -def test_basic(c, s, a, b): +async def test_basic(c, s, a, b): pass @gen_tls_cluster(client=True) -def test_Queue(c, s, a, b): +async def test_Queue(c, s, a, b): assert s.address.startswith("tls://") - x = yield Queue("x") - y = yield Queue("y") + x = await Queue("x") + y = await Queue("y") - size = yield x.qsize() + size = await x.qsize() assert size == 0 future = c.submit(inc, 1) - yield x.put(future) + await x.put(future) - future2 = yield x.get() + future2 = await x.get() assert future.key == future2.key @gen_tls_cluster(client=True, timeout=None) -def test_client_submit(c, s, a, b): +async def test_client_submit(c, s, a, b): assert s.address.startswith("tls://") x = c.submit(inc, 10) - result = yield x + result = await x assert result == 11 yy = [c.submit(slowinc, i) for i in range(10)] results = [] for y in yy: - results.append((yield y)) + results.append(await y) assert results == list(range(1, 11)) @gen_tls_cluster(client=True) -def test_gather(c, s, a, b): +async def test_gather(c, s, a, b): assert s.address.startswith("tls://") x = c.submit(inc, 10) y = c.submit(inc, x) - result = yield c._gather(x) + result = await c._gather(x) assert result == 11 - result = yield c._gather([x]) + result = await c._gather([x]) assert result == [11] - result = yield c._gather({"x": x, "y": [y]}) + result = await c._gather({"x": x, "y": [y]}) assert result == {"x": 11, "y": [12]} @gen_tls_cluster(client=True) -def test_scatter(c, s, a, b): +async def test_scatter(c, s, a, b): assert s.address.startswith("tls://") - d = yield c._scatter({"y": 20}) + d = await c._scatter({"y": 20}) ts = s.tasks["y"] assert ts.who_has assert ts.nbytes > 0 - yy = yield c._gather([d["y"]]) + yy = await c._gather([d["y"]]) assert yy == [20] @gen_tls_cluster(client=True, Worker=Nanny) -def test_nanny(c, s, a, b): +async def test_nanny(c, s, a, b): assert s.address.startswith("tls://") for n in [a, b]: assert isinstance(n, Nanny) @@ -86,34 +86,34 @@ def test_nanny(c, s, a, b): assert s.nthreads == {n.worker_address: n.nthreads for n in [a, b]} x = c.submit(inc, 10) - result = yield x + result = await x assert result == 11 @gen_tls_cluster(client=True) -def test_rebalance(c, s, a, b): - x, y = yield c._scatter([1, 2], workers=[a.address]) +async def test_rebalance(c, s, a, b): + x, y = await c._scatter([1, 2], workers=[a.address]) assert len(a.data) == 2 assert len(b.data) == 0 - yield c._rebalance() + await c._rebalance() assert len(a.data) == 1 assert len(b.data) == 1 @gen_tls_cluster(client=True, nthreads=[("tls://127.0.0.1", 2)] * 2) -def test_work_stealing(c, s, a, b): - [x] = yield c._scatter([1], workers=a.address) +async def test_work_stealing(c, s, a, b): + [x] = await c._scatter([1], workers=a.address) futures = c.map(slowadd, range(50), [x] * 50, delay=0.1) - yield gen.sleep(0.1) - yield wait(futures) + await asyncio.sleep(0.1) + await wait(futures) assert len(a.data) > 10 assert len(b.data) > 10 @gen_tls_cluster(client=True) -def test_worker_client(c, s, a, b): +async def test_worker_client(c, s, a, b): def func(x): with worker_client() as c: x = c.submit(inc, x) @@ -122,14 +122,14 @@ def func(x): return result x, y = c.map(func, [10, 20]) - xx, yy = yield c._gather([x, y]) + xx, yy = await c._gather([x, y]) assert xx == 10 + 1 + (10 + 1) * 2 assert yy == 20 + 1 + (20 + 1) * 2 @gen_tls_cluster(client=True, nthreads=[("tls://127.0.0.1", 1)] * 2) -def test_worker_client_gather(c, s, a, b): +async def test_worker_client_gather(c, s, a, b): a_address = a.address b_address = b.address assert a_address.startswith("tls://") @@ -145,30 +145,30 @@ def func(): return xx, yy future = c.submit(func) - result = yield future + result = await future assert result == (2, 3) @gen_tls_cluster(client=True) -def test_worker_client_executor(c, s, a, b): +async def test_worker_client_executor(c, s, a, b): def mysum(): with worker_client() as c: with c.get_executor() as e: return sum(e.map(double, range(30))) future = c.submit(mysum) - result = yield future + result = await future assert result == 30 * 29 @gen_tls_cluster(client=True, Worker=Nanny) -def test_retire_workers(c, s, a, b): +async def test_retire_workers(c, s, a, b): assert set(s.workers) == {a.worker_address, b.worker_address} - yield c.retire_workers(workers=[a.worker_address], close_workers=True) + await c.retire_workers(workers=[a.worker_address], close_workers=True) assert set(s.workers) == {b.worker_address} start = time() while a.status != "closed": - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 5 diff --git a/distributed/tests/test_utils.py b/distributed/tests/test_utils.py index e162b9fc2e1..86f1ca0c208 100644 --- a/distributed/tests/test_utils.py +++ b/distributed/tests/test_utils.py @@ -1,3 +1,4 @@ +import asyncio import array import datetime from functools import partial @@ -11,7 +12,6 @@ import numpy as np import pytest -from tornado import gen from tornado.ioloop import IOLoop import dask @@ -51,28 +51,23 @@ def test_All(loop): - @gen.coroutine - def throws(): + async def throws(): 1 / 0 - @gen.coroutine - def slow(): - yield gen.sleep(10) + async def slow(): + await asyncio.sleep(10) - @gen.coroutine - def inc(x): - raise gen.Return(x + 1) + async def inc(x): + return x + 1 - @gen.coroutine - def f(): - - results = yield All([inc(i) for i in range(10)]) + async def f(): + results = await All([inc(i) for i in range(10)]) assert results == list(range(1, 11)) start = time() for tasks in [[throws(), slow()], [slow(), throws()]]: try: - yield All(tasks) + await All(tasks) assert False except ZeroDivisionError: pass @@ -112,7 +107,7 @@ def function2(x): def test_sync_timeout(loop_in_thread): loop = loop_in_thread with pytest.raises(TimeoutError): - sync(loop_in_thread, gen.sleep, 0.5, callback_timeout=0.05) + sync(loop_in_thread, asyncio.sleep, 0.5, callback_timeout=0.05) def test_sync_closed_loop(): @@ -484,17 +479,17 @@ def test_two_loop_runners(loop_in_thread): @gen_test() -def test_loop_runner_gen(): +async def test_loop_runner_gen(): runner = LoopRunner(asynchronous=True) assert runner.loop is IOLoop.current() assert not runner.is_started() - yield gen.sleep(0.01) + await asyncio.sleep(0.01) runner.start() assert runner.is_started() - yield gen.sleep(0.01) + await asyncio.sleep(0.01) runner.stop() assert not runner.is_started() - yield gen.sleep(0.01) + await asyncio.sleep(0.01) def test_parse_bytes(): @@ -537,21 +532,20 @@ def test_parse_timedelta(): @gen_test() -def test_all_exceptions_logging(): - @gen.coroutine - def throws(): +async def test_all_exceptions_logging(): + async def throws(): raise Exception("foo1234") with captured_logger("") as sio: try: - yield All([throws() for _ in range(5)], quiet_exceptions=Exception) + await All([throws() for _ in range(5)], quiet_exceptions=Exception) except Exception: pass import gc gc.collect() - yield gen.sleep(0.1) + await asyncio.sleep(0.1) assert "foo1234" not in sio.getvalue() diff --git a/distributed/tests/test_utils_test.py b/distributed/tests/test_utils_test.py index 6a4a5ceaa5e..502b27b3013 100755 --- a/distributed/tests/test_utils_test.py +++ b/distributed/tests/test_utils_test.py @@ -1,10 +1,10 @@ +import asyncio from contextlib import contextmanager import socket import threading from time import sleep import pytest -from tornado import gen from distributed import Scheduler, Worker, Client, config, default_client from distributed.core import rpc @@ -43,7 +43,7 @@ def test_cluster(loop): @gen_cluster(client=True) -def test_gen_cluster(c, s, a, b): +async def test_gen_cluster(c, s, a, b): assert isinstance(c, Client) assert isinstance(s, Scheduler) for w in [a, b]: @@ -58,9 +58,9 @@ def test_gen_cluster_cleans_up_client(loop): assert not dask.config.get("get", None) @gen_cluster(client=True) - def f(c, s, a, b): + async def f(c, s, a, b): assert dask.config.get("get", None) - yield c.submit(inc, 1) + await c.submit(inc, 1) f() @@ -68,12 +68,17 @@ def f(c, s, a, b): @gen_cluster(client=False) -def test_gen_cluster_without_client(s, a, b): +async def test_gen_cluster_without_client(s, a, b): assert isinstance(s, Scheduler) for w in [a, b]: assert isinstance(w, Worker) assert s.nthreads == {w.address: w.nthreads for w in [a, b]} + async with Client(s.address, asynchronous=True) as c: + future = c.submit(lambda x: x + 1, 1) + result = await future + assert result == 2 + @gen_cluster( client=True, @@ -81,7 +86,7 @@ def test_gen_cluster_without_client(s, a, b): nthreads=[("tls://127.0.0.1", 1), ("tls://127.0.0.1", 2)], security=tls_only_security(), ) -def test_gen_cluster_tls(e, s, a, b): +async def test_gen_cluster_tls(e, s, a, b): assert isinstance(e, Client) assert isinstance(s, Scheduler) assert s.address.startswith("tls://") @@ -92,8 +97,8 @@ def test_gen_cluster_tls(e, s, a, b): @gen_test() -def test_gen_test(): - yield gen.sleep(0.01) +async def test_gen_test(): + await asyncio.sleep(0.01) @contextmanager @@ -154,8 +159,8 @@ def test_new_config(): def test_lingering_client(): @gen_cluster() - def f(s, a, b): - c = yield Client(s.address, asynchronous=True) + async def f(s, a, b): + await Client(s.address, asynchronous=True) f() @@ -177,16 +182,3 @@ def test_tls_cluster(tls_client): async def test_tls_scheduler(security, cleanup): async with Scheduler(security=security, host="localhost") as s: assert s.address.startswith("tls") - - -@gen_cluster() -async def test_gen_cluster_async(s, a, b): # flake8: noqa - async with Client(s.address, asynchronous=True) as c: - future = c.submit(lambda x: x + 1, 1) - result = await future - assert result == 2 - - -@gen_test() -async def test_gen_test_async(): # flake8: noqa - await gen.sleep(0.001) diff --git a/distributed/tests/test_variable.py b/distributed/tests/test_variable.py index 0e450aa7a02..a60345d0abb 100644 --- a/distributed/tests/test_variable.py +++ b/distributed/tests/test_variable.py @@ -1,11 +1,9 @@ import asyncio import random from time import sleep -import sys import logging import pytest -from tornado import gen from tornado.ioloop import IOLoop from distributed import Client, Variable, worker_client, Nanny, wait, TimeoutError @@ -17,27 +15,27 @@ @gen_cluster(client=True) -def test_variable(c, s, a, b): +async def test_variable(c, s, a, b): x = Variable("x") xx = Variable("x") assert x.client is c future = c.submit(inc, 1) - yield x.set(future) - future2 = yield xx.get() + await x.set(future) + future2 = await xx.get() assert future.key == future2.key del future, future2 - yield gen.sleep(0.1) + await asyncio.sleep(0.1) assert s.tasks # future still present x.delete() start = time() while s.tasks: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 5 @@ -53,13 +51,13 @@ async def test_delete_unset_variable(c, s, a, b): @gen_cluster(client=True) -def test_queue_with_data(c, s, a, b): +async def test_queue_with_data(c, s, a, b): x = Variable("x") xx = Variable("x") assert x.client is c - yield x.set((1, "hello")) - data = yield xx.get() + await x.set((1, "hello")) + data = await xx.get() assert data == (1, "hello") @@ -75,32 +73,32 @@ def test_sync(client): @gen_cluster() -def test_hold_futures(s, a, b): - c1 = yield Client(s.address, asynchronous=True) +async def test_hold_futures(s, a, b): + c1 = await Client(s.address, asynchronous=True) future = c1.submit(lambda x: x + 1, 10) x1 = Variable("x") - yield x1.set(future) + await x1.set(future) del x1 - yield c1.close() + await c1.close() - yield gen.sleep(0.1) + await asyncio.sleep(0.1) - c2 = yield Client(s.address, asynchronous=True) + c2 = await Client(s.address, asynchronous=True) x2 = Variable("x") - future2 = yield x2.get() - result = yield future2 + future2 = await x2.get() + result = await future2 assert result == 11 - yield c2.close() + await c2.close() @gen_cluster(client=True) -def test_timeout(c, s, a, b): +async def test_timeout(c, s, a, b): v = Variable("v") start = IOLoop.current().time() with pytest.raises(TimeoutError): - yield v.get(timeout=0.2) + await v.get(timeout=0.2) stop = IOLoop.current().time() if WINDOWS: # timing is weird with asyncio and Windows @@ -109,7 +107,7 @@ def test_timeout(c, s, a, b): assert 0.2 < stop - start < 2.0 with pytest.raises(TimeoutError): - yield v.get(timeout=0.01) + await v.get(timeout=0.01) def test_timeout_sync(client): @@ -139,10 +137,10 @@ async def test_cleanup(c, s, a, b): await v.set(x) del x - await gen.sleep(0.1) + await asyncio.sleep(0.1) t_future = xx = asyncio.ensure_future(vv._get()) - await gen.sleep(0) + await asyncio.sleep(0) asyncio.ensure_future(v.set(y)) future = await t_future @@ -162,22 +160,21 @@ def f(x): @gen_cluster(client=True) -def test_timeout_get(c, s, a, b): +async def test_timeout_get(c, s, a, b): v = Variable("v") tornado_future = v.get() vv = Variable("v") - yield vv.set(1) + await vv.set(1) - result = yield tornado_future + result = await tornado_future assert result == 1 -@pytest.mark.skipif(sys.version_info[0] == 2, reason="Multi-client issues") @pytest.mark.slow @gen_cluster(client=True, nthreads=[("127.0.0.1", 2)] * 5, Worker=Nanny, timeout=None) -def test_race(c, s, *workers): +async def test_race(c, s, *workers): NITERS = 50 def f(i): @@ -194,63 +191,63 @@ def f(i): return result v = Variable("x", client=c) - x = yield c.scatter(1) - yield v.set(x) + x = await c.scatter(1) + await v.set(x) futures = c.map(f, range(15)) - results = yield c.gather(futures) + results = await c.gather(futures) assert all(r > NITERS * 0.8 for r in results) start = time() while len(s.wants_what["variable-x"]) != 1: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() - start < 2 @gen_cluster(client=True) -def test_Future_knows_status_immediately(c, s, a, b): - x = yield c.scatter(123) +async def test_Future_knows_status_immediately(c, s, a, b): + x = await c.scatter(123) v = Variable("x") - yield v.set(x) + await v.set(x) - c2 = yield Client(s.address, asynchronous=True) + c2 = await Client(s.address, asynchronous=True) v2 = Variable("x", client=c2) - future = yield v2.get() + future = await v2.get() assert future.status == "finished" x = c.submit(div, 1, 0) - yield wait(x) - yield v.set(x) + await wait(x) + await v.set(x) - future2 = yield v2.get() + future2 = await v2.get() assert future2.status == "error" with pytest.raises(Exception): - yield future2 + await future2 start = time() while True: # we learn about the true error eventually try: - yield future2 + await future2 except ZeroDivisionError: break except Exception: assert time() < start + 5 - yield gen.sleep(0.05) + await asyncio.sleep(0.05) - yield c2.close() + await c2.close() @gen_cluster(client=True) -def test_erred_future(c, s, a, b): +async def test_erred_future(c, s, a, b): future = c.submit(div, 1, 0) var = Variable() - yield var.set(future) - yield gen.sleep(0.1) - future2 = yield var.get() + await var.set(future) + await asyncio.sleep(0.1) + future2 = await var.get() with pytest.raises(ZeroDivisionError): - yield future2.result() + await future2.result() - exc = yield future2.exception() + exc = await future2.exception() assert isinstance(exc, ZeroDivisionError) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index a5e364ec0cd..a1f2e46295c 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -16,14 +16,11 @@ from dask.system import CPU_COUNT import pytest from tlz import pluck, sliding_window, first -import tornado -from tornado import gen from distributed import ( Client, Nanny, get_client, - wait, default_client, get_worker, Reschedule, @@ -65,7 +62,7 @@ async def test_worker_nthreads(cleanup): @gen_cluster() -def test_str(s, a, b): +async def test_str(s, a, b): assert a.address in str(a) assert a.address in repr(a) assert str(a.nthreads) in str(a) @@ -85,7 +82,7 @@ async def test_identity(cleanup): @gen_cluster(client=True) -def test_worker_bad_args(c, s, a, b): +async def test_worker_bad_args(c, s, a, b): class NoReprObj: """ This object cannot be properly represented as a string. """ @@ -96,7 +93,7 @@ def __repr__(self): raise ValueError("I have no repr representation.") x = c.submit(NoReprObj, workers=a.address) - yield wait(x) + await wait(x) assert not a.executing assert a.data @@ -127,20 +124,17 @@ def reset(self): logger.setLevel(logging.DEBUG) logger.addHandler(hdlr) y = c.submit(bad_func, x, k=x, workers=b.address) - yield wait(y) + await wait(y) assert not b.executing assert y.status == "error" # Make sure job died because of bad func and not because of bad # argument. with pytest.raises(ZeroDivisionError): - yield y + await y - if sys.version_info[0] >= 3: - tb = yield y._traceback() - assert any( - "1 / 0" in line for line in pluck(3, traceback.extract_tb(tb)) if line - ) + tb = await y._traceback() + assert any("1 / 0" in line for line in pluck(3, traceback.extract_tb(tb)) if line) assert "Compute Failed" in hdlr.messages["warning"][0] logger.setLevel(old_level) @@ -149,14 +143,14 @@ def reset(self): xx = c.submit(add, 1, 2, workers=a.address) yy = c.submit(add, 3, 4, workers=b.address) - results = yield c._gather([xx, yy]) + results = await c._gather([xx, yy]) assert tuple(results) == (3, 7) @pytest.mark.slow @gen_cluster() -def dont_test_delete_data_with_missing_worker(c, a, b): +async def dont_test_delete_data_with_missing_worker(c, a, b): bad = "127.0.0.1:9001" # this worker doesn't exist c.who_has["z"].add(bad) c.who_has["z"].add(a.address) @@ -166,26 +160,26 @@ def dont_test_delete_data_with_missing_worker(c, a, b): cc = rpc(ip=c.ip, port=c.port) - yield cc.delete_data(keys=["z"]) # TODO: this hangs for a while + await cc.delete_data(keys=["z"]) # TODO: this hangs for a while assert "z" not in a.data assert not c.who_has["z"] assert not c.has_what[bad] assert not c.has_what[a.address] - yield cc.close_rpc() + await cc.close_rpc() @gen_cluster(client=True) -def test_upload_file(c, s, a, b): +async def test_upload_file(c, s, a, b): assert not os.path.exists(os.path.join(a.local_directory, "foobar.py")) assert not os.path.exists(os.path.join(b.local_directory, "foobar.py")) assert a.local_directory != b.local_directory with rpc(a.address) as aa, rpc(b.address) as bb: - yield [ + await asyncio.gather( aa.upload_file(filename="foobar.py", data=b"x = 123"), bb.upload_file(filename="foobar.py", data="x = 123"), - ] + ) assert os.path.exists(os.path.join(a.local_directory, "foobar.py")) assert os.path.exists(os.path.join(b.local_directory, "foobar.py")) @@ -196,17 +190,17 @@ def g(): return foobar.x future = c.submit(g, workers=a.address) - result = yield future + result = await future assert result == 123 - yield c.close() - yield s.close(close_workers=True) + await c.close() + await s.close(close_workers=True) assert not os.path.exists(os.path.join(a.local_directory, "foobar.py")) @pytest.mark.skip(reason="don't yet support uploading pyc files") @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)]) -def test_upload_file_pyc(c, s, w): +async def test_upload_file_pyc(c, s, w): with tmpfile() as dirname: os.mkdir(dirname) with open(os.path.join(dirname, "foo.py"), mode="w") as f: @@ -219,7 +213,7 @@ def test_upload_file_pyc(c, s, w): assert foo.f() == 123 pyc = importlib.util.cache_from_source(os.path.join(dirname, "foo.py")) assert os.path.exists(pyc) - yield c.upload_file(pyc) + await c.upload_file(pyc) def g(): import foo @@ -227,21 +221,21 @@ def g(): return foo.x future = c.submit(g) - result = yield future + result = await future assert result == 123 finally: sys.path.remove(dirname) @gen_cluster(client=True) -def test_upload_egg(c, s, a, b): +async def test_upload_egg(c, s, a, b): eggname = "testegg-1.0.0-py3.4.egg" local_file = __file__.replace("test_worker.py", eggname) assert not os.path.exists(os.path.join(a.local_directory, eggname)) assert not os.path.exists(os.path.join(b.local_directory, eggname)) assert a.local_directory != b.local_directory - yield c.upload_file(filename=local_file) + await c.upload_file(filename=local_file) assert os.path.exists(os.path.join(a.local_directory, eggname)) assert os.path.exists(os.path.join(b.local_directory, eggname)) @@ -252,25 +246,25 @@ def g(x): return testegg.inc(x) future = c.submit(g, 10, workers=a.address) - result = yield future + result = await future assert result == 10 + 1 - yield c.close() - yield s.close() - yield a.close() - yield b.close() + await c.close() + await s.close() + await a.close() + await b.close() assert not os.path.exists(os.path.join(a.local_directory, eggname)) @gen_cluster(client=True) -def test_upload_pyz(c, s, a, b): +async def test_upload_pyz(c, s, a, b): pyzname = "mytest.pyz" local_file = __file__.replace("test_worker.py", pyzname) assert not os.path.exists(os.path.join(a.local_directory, pyzname)) assert not os.path.exists(os.path.join(b.local_directory, pyzname)) assert a.local_directory != b.local_directory - yield c.upload_file(filename=local_file) + await c.upload_file(filename=local_file) assert os.path.exists(os.path.join(a.local_directory, pyzname)) assert os.path.exists(os.path.join(b.local_directory, pyzname)) @@ -281,42 +275,42 @@ def g(x): return mytest.inc(x) future = c.submit(g, 10, workers=a.address) - result = yield future + result = await future assert result == 10 + 1 - yield c.close() - yield s.close() - yield a.close() - yield b.close() + await c.close() + await s.close() + await a.close() + await b.close() assert not os.path.exists(os.path.join(a.local_directory, pyzname)) @pytest.mark.xfail(reason="Still lose time to network I/O") @gen_cluster(client=True) -def test_upload_large_file(c, s, a, b): +async def test_upload_large_file(c, s, a, b): pytest.importorskip("crick") - yield gen.sleep(0.05) + await asyncio.sleep(0.05) with rpc(a.address) as aa: - yield aa.upload_file(filename="myfile.dat", data=b"0" * 100000000) - yield gen.sleep(0.05) + await aa.upload_file(filename="myfile.dat", data=b"0" * 100000000) + await asyncio.sleep(0.05) assert a.digests["tick-duration"].components[0].max() < 0.050 @gen_cluster() -def test_broadcast(s, a, b): +async def test_broadcast(s, a, b): with rpc(s.address) as cc: - results = yield cc.broadcast(msg={"op": "ping"}) + results = await cc.broadcast(msg={"op": "ping"}) assert results == {a.address: b"pong", b.address: b"pong"} @gen_test() -def test_worker_with_port_zero(): - s = yield Scheduler(port=8007) - w = yield Worker(s.address) +async def test_worker_with_port_zero(): + s = await Scheduler(port=8007) + w = await Worker(s.address) assert isinstance(w.port, int) assert w.port > 1024 - yield w.close() + await w.close() @pytest.mark.slow @@ -334,10 +328,10 @@ async def test_worker_waits_for_scheduler(cleanup): @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)]) -def test_worker_task_data(c, s, w): +async def test_worker_task_data(c, s, w): x = delayed(2) xx = c.persist(x) - yield wait(xx) + await wait(xx) assert w.data[x.key] == 2 @@ -370,7 +364,7 @@ def __str__(self): @gen_cluster(client=True) -def test_chained_error_message(c, s, a, b): +async def test_chained_error_message(c, s, a, b): def chained_exception_fn(): class MyException(Exception): def __init__(self, msg): @@ -389,18 +383,18 @@ def __str__(self): f = c.submit(chained_exception_fn) try: - yield f + await f except Exception as e: assert e.__cause__ is not None assert "Bar" in str(e.__cause__) @gen_cluster() -def test_gather(s, a, b): +async def test_gather(s, a, b): b.data["x"] = 1 b.data["y"] = 2 with rpc(a.address) as aa: - resp = yield aa.gather(who_has={"x": [b.address], "y": [b.address]}) + resp = await aa.gather(who_has={"x": [b.address], "y": [b.address]}) assert resp["status"] == "OK" assert a.data["x"] == b.data["x"] @@ -415,9 +409,9 @@ async def test_io_loop(cleanup): @gen_cluster(client=True, nthreads=[]) -def test_spill_to_disk(c, s): +async def test_spill_to_disk(c, s): np = pytest.importorskip("numpy") - w = yield Worker( + w = await Worker( s.address, loop=s.loop, memory_limit=1200 / 0.6, @@ -426,79 +420,75 @@ def test_spill_to_disk(c, s): ) x = c.submit(np.random.randint, 0, 255, size=500, dtype="u1", key="x") - yield wait(x) + await wait(x) y = c.submit(np.random.randint, 0, 255, size=500, dtype="u1", key="y") - yield wait(y) + await wait(y) assert set(w.data) == {x.key, y.key} assert set(w.data.memory) == {x.key, y.key} assert set(w.data.fast) == set(w.data.memory) z = c.submit(np.random.randint, 0, 255, size=500, dtype="u1", key="z") - yield wait(z) + await wait(z) assert set(w.data) == {x.key, y.key, z.key} assert set(w.data.memory) == {y.key, z.key} assert set(w.data.disk) == {x.key} or set(w.data.slow) == {x.key, y.key} assert set(w.data.fast) == set(w.data.memory) assert set(w.data.slow) == set(w.data.disk) - yield x + await x assert set(w.data.memory) == {x.key, z.key} assert set(w.data.disk) == {y.key} or set(w.data.slow) == {x.key, y.key} assert set(w.data.fast) == set(w.data.memory) assert set(w.data.slow) == set(w.data.disk) - yield w.close() + await w.close() @gen_cluster(client=True) -def test_access_key(c, s, a, b): +async def test_access_key(c, s, a, b): def f(i): from distributed.worker import thread_state return thread_state.key futures = [c.submit(f, i, key="x-%d" % i) for i in range(20)] - results = yield c._gather(futures) + results = await c._gather(futures) assert list(results) == ["x-%d" % i for i in range(20)] @gen_cluster(client=True) -def test_run_dask_worker(c, s, a, b): +async def test_run_dask_worker(c, s, a, b): def f(dask_worker=None): return dask_worker.id - response = yield c._run(f) + response = await c._run(f) assert response == {a.address: a.id, b.address: b.id} @gen_cluster(client=True) -def test_run_coroutine_dask_worker(c, s, a, b): - if sys.version_info < (3,) and tornado.version_info < (4, 5): - pytest.skip("test needs Tornado 4.5+ on Python 2.7") - - @gen.coroutine - def f(dask_worker=None): - yield gen.sleep(0.001) - raise gen.Return(dask_worker.id) +async def test_run_coroutine_dask_worker(c, s, a, b): + async def f(dask_worker=None): + await asyncio.sleep(0.001) + return dask_worker.id - response = yield c.run(f) + response = await c.run(f) assert response == {a.address: a.id, b.address: b.id} @gen_cluster(client=True, nthreads=[]) -def test_Executor(c, s): +async def test_Executor(c, s): with ThreadPoolExecutor(2) as e: w = Worker(s.address, executor=e) assert w.executor is e - w = yield w + w = await w future = c.submit(inc, 1) - result = yield future + result = await future assert result == 2 assert e._threads # had to do some work - yield w.close() + await w.close() @pytest.mark.skip( @@ -510,22 +500,22 @@ def test_Executor(c, s): timeout=30, worker_kwargs={"memory_limit": 10e6}, ) -def test_spill_by_default(c, s, w): +async def test_spill_by_default(c, s, w): da = pytest.importorskip("dask.array") x = da.ones(int(10e6 * 0.7), chunks=1e6, dtype="u1") y = c.persist(x) - yield wait(y) + await wait(y) assert len(w.data.disk) # something is on disk del x, y @gen_cluster(nthreads=[("127.0.0.1", 1)], worker_kwargs={"reconnect": False}) -def test_close_on_disconnect(s, w): - yield s.close() +async def test_close_on_disconnect(s, w): + await s.close() start = time() while w.status != "closed": - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 5 @@ -547,20 +537,20 @@ async def test_memory_limit_auto(): @gen_cluster(client=True) -def test_inter_worker_communication(c, s, a, b): - [x, y] = yield c._scatter([1, 2], workers=a.address) +async def test_inter_worker_communication(c, s, a, b): + [x, y] = await c._scatter([1, 2], workers=a.address) future = c.submit(add, x, y, workers=b.address) - result = yield future + result = await future assert result == 3 @gen_cluster(client=True) -def test_clean(c, s, a, b): +async def test_clean(c, s, a, b): x = c.submit(inc, 1, workers=a.address) y = c.submit(inc, x, workers=b.address) - yield y + await y collections = [ a.tasks, @@ -580,21 +570,20 @@ def test_clean(c, s, a, b): y.release() while x.key in a.task_state: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) for c in collections: assert not c -@pytest.mark.skipif(sys.version_info[:2] == (3, 4), reason="mul bytes fails") @gen_cluster(client=True) -def test_message_breakup(c, s, a, b): +async def test_message_breakup(c, s, a, b): n = 100000 a.target_message_size = 10 * n b.target_message_size = 10 * n xs = [c.submit(mul, b"%d" % i, n, workers=a.address) for i in range(30)] y = c.submit(lambda *args: None, xs, workers=b.address) - yield y + await y assert 2 <= len(b.incoming_transfer_log) <= 20 assert 2 <= len(a.outgoing_transfer_log) <= 20 @@ -604,29 +593,29 @@ def test_message_breakup(c, s, a, b): @gen_cluster(client=True) -def test_types(c, s, a, b): +async def test_types(c, s, a, b): assert not a.types assert not b.types x = c.submit(inc, 1, workers=a.address) - yield wait(x) + await wait(x) assert a.types[x.key] == int y = c.submit(inc, x, workers=b.address) - yield wait(y) + await wait(y) assert b.types == {x.key: int, y.key: int} - yield c._cancel(y) + await c._cancel(y) start = time() while y.key in b.data: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 5 assert y.key not in b.types @gen_cluster() -def test_system_monitor(s, a, b): +async def test_system_monitor(s, a, b): assert b.monitor b.monitor.update() @@ -634,38 +623,38 @@ def test_system_monitor(s, a, b): @gen_cluster( client=True, nthreads=[("127.0.0.1", 2, {"resources": {"A": 1}}), ("127.0.0.1", 1)] ) -def test_restrictions(c, s, a, b): +async def test_restrictions(c, s, a, b): # Resource restrictions x = c.submit(inc, 1, resources={"A": 1}) - yield x + await x assert a.resource_restrictions == {x.key: {"A": 1}} - yield c._cancel(x) + await c._cancel(x) while x.key in a.task_state: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert a.resource_restrictions == {} @pytest.mark.xfail @gen_cluster(client=True) -def test_clean_nbytes(c, s, a, b): +async def test_clean_nbytes(c, s, a, b): L = [delayed(inc)(i) for i in range(10)] for i in range(5): L = [delayed(add)(x, y) for x, y in sliding_window(2, L)] total = delayed(sum)(L) future = c.compute(total) - yield wait(future) + await wait(future) - yield gen.sleep(1) + await asyncio.sleep(1) assert len(a.nbytes) + len(b.nbytes) == 1 @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 20) -def test_gather_many_small(c, s, a, *workers): +async def test_gather_many_small(c, s, a, *workers): a.total_out_connections = 2 - futures = yield c._scatter(list(range(100))) + futures = await c._scatter(list(range(100))) assert all(w.data for w in workers) @@ -673,7 +662,7 @@ def f(*args): return 10 future = c.submit(f, *futures, workers=a.address) - yield wait(future) + await wait(future) types = list(pluck(0, a.log)) req = [i for i, t in enumerate(types) if t == "request-dep"] @@ -684,12 +673,12 @@ def f(*args): @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3) -def test_multiple_transfers(c, s, w1, w2, w3): +async def test_multiple_transfers(c, s, w1, w2, w3): x = c.submit(inc, 1, workers=w1.address) y = c.submit(inc, 2, workers=w2.address) z = c.submit(add, x, y, workers=w3.address) - yield wait(z) + await wait(z) r = w3.startstops[z.key] transfers = [t for t in r if t["action"] == "transfer"] @@ -697,25 +686,25 @@ def test_multiple_transfers(c, s, w1, w2, w3): @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3) -def test_share_communication(c, s, w1, w2, w3): +async def test_share_communication(c, s, w1, w2, w3): x = c.submit(mul, b"1", int(w3.target_message_size + 1), workers=w1.address) y = c.submit(mul, b"2", int(w3.target_message_size + 1), workers=w2.address) - yield wait([x, y]) - yield c._replicate([x, y], workers=[w1.address, w2.address]) + await wait([x, y]) + await c._replicate([x, y], workers=[w1.address, w2.address]) z = c.submit(add, x, y, workers=w3.address) - yield wait(z) + await wait(z) assert len(w3.incoming_transfer_log) == 2 assert w1.outgoing_transfer_log assert w2.outgoing_transfer_log @gen_cluster(client=True) -def test_dont_overlap_communications_to_same_worker(c, s, a, b): +async def test_dont_overlap_communications_to_same_worker(c, s, a, b): x = c.submit(mul, b"1", int(b.target_message_size + 1), workers=a.address) y = c.submit(mul, b"2", int(b.target_message_size + 1), workers=a.address) - yield wait([x, y]) + await wait([x, y]) z = c.submit(add, x, y, workers=b.address) - yield wait(z) + await wait(z) assert len(b.incoming_transfer_log) == 2 l1, l2 = b.incoming_transfer_log @@ -724,7 +713,7 @@ def test_dont_overlap_communications_to_same_worker(c, s, a, b): @pytest.mark.avoid_travis @gen_cluster(client=True) -def test_log_exception_on_failed_task(c, s, a, b): +async def test_log_exception_on_failed_task(c, s, a, b): with tmpfile() as fn: fh = logging.FileHandler(fn) try: @@ -733,9 +722,9 @@ def test_log_exception_on_failed_task(c, s, a, b): logger.addHandler(fh) future = c.submit(div, 1, 0) - yield wait(future) + await wait(future) - yield gen.sleep(0.1) + await asyncio.sleep(0.1) fh.flush() with open(fn) as f: text = f.read() @@ -747,7 +736,7 @@ def test_log_exception_on_failed_task(c, s, a, b): @gen_cluster(client=True) -def test_clean_up_dependencies(c, s, a, b): +async def test_clean_up_dependencies(c, s, a, b): x = delayed(inc)(1) y = delayed(inc)(2) xx = delayed(inc)(x) @@ -755,26 +744,26 @@ def test_clean_up_dependencies(c, s, a, b): z = delayed(add)(xx, yy) zz = c.persist(z) - yield wait(zz) + await wait(zz) start = time() while len(a.data) + len(b.data) > 1: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 2 assert set(a.data) | set(b.data) == {zz.key} @gen_cluster(client=True) -def test_hold_onto_dependents(c, s, a, b): +async def test_hold_onto_dependents(c, s, a, b): x = c.submit(inc, 1, workers=a.address) y = c.submit(inc, x, workers=b.address) - yield wait(y) + await wait(y) assert x.key in b.data - yield c._cancel(y) - yield gen.sleep(0.1) + await c._cancel(y) + await asyncio.sleep(0.1) assert x.key in b.data @@ -796,20 +785,20 @@ async def test_worker_death_timeout(s): @gen_cluster(client=True) -def test_stop_doing_unnecessary_work(c, s, a, b): +async def test_stop_doing_unnecessary_work(c, s, a, b): futures = c.map(slowinc, range(1000), delay=0.01) - yield gen.sleep(0.1) + await asyncio.sleep(0.1) del futures start = time() while a.executing: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() - start < 0.5 @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)]) -def test_priorities(c, s, w): +async def test_priorities(c, s, w): values = [] for i in range(10): a = delayed(slowinc)(i, dask_key_name="a-%d" % i, delay=0.01) @@ -821,7 +810,7 @@ def test_priorities(c, s, w): values.append(b1) futures = c.compute(values) - yield wait(futures) + await wait(futures) log = [ t[0] @@ -833,12 +822,12 @@ def test_priorities(c, s, w): @gen_cluster(client=True) -def test_heartbeats(c, s, a, b): +async def test_heartbeats(c, s, a, b): x = s.workers[a.address].last_seen start = time() - yield gen.sleep(a.periodic_callbacks["heartbeat"].callback_time / 1000 + 0.1) + await asyncio.sleep(a.periodic_callbacks["heartbeat"].callback_time / 1000 + 0.1) while s.workers[a.address].last_seen == x: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 2 assert a.periodic_callbacks["heartbeat"].callback_time < 1000 @@ -848,7 +837,7 @@ def test_worker_dir(worker): with tmpfile() as fn: @gen_cluster(client=True, worker_kwargs={"local_directory": fn}) - def test_worker_dir(c, s, a, b): + async def test_worker_dir(c, s, a, b): directories = [w.local_directory for w in s.workers.values()] assert all(d.startswith(fn) for d in directories) assert len(set(directories)) == 2 # distinct @@ -857,7 +846,7 @@ def test_worker_dir(c, s, a, b): @gen_cluster(client=True) -def test_dataframe_attribute_error(c, s, a, b): +async def test_dataframe_attribute_error(c, s, a, b): class BadSize: def __init__(self, data): self.data = data @@ -866,12 +855,12 @@ def __sizeof__(self): raise TypeError("Hello") future = c.submit(BadSize, 123) - result = yield future + result = await future assert result.data == 123 @gen_cluster(client=True) -def test_fail_write_to_disk(c, s, a, b): +async def test_fail_write_to_disk(c, s, a, b): class Bad: def __getstate__(self): raise TypeError() @@ -880,15 +869,15 @@ def __sizeof__(self): return int(100e9) future = c.submit(Bad) - yield wait(future) + await wait(future) assert future.status == "error" with pytest.raises(TypeError): - yield future + await future futures = c.map(inc, range(10)) - results = yield c._gather(futures) + results = await c._gather(futures) assert results == list(map(inc, range(10))) @@ -896,9 +885,9 @@ def __sizeof__(self): @gen_cluster( nthreads=[("127.0.0.1", 2)], client=True, worker_kwargs={"memory_limit": 10e9} ) -def test_fail_write_many_to_disk(c, s, a): +async def test_fail_write_many_to_disk(c, s, a): a.validate = False - yield gen.sleep(0.1) + await asyncio.sleep(0.1) assert not a.paused class Bad: @@ -914,23 +903,23 @@ def __sizeof__(self): futures = c.map(Bad, range(11)) future = c.submit(lambda *args: 123, *futures) - yield wait(future) + await wait(future) with pytest.raises(Exception) as info: - yield future + await future # workers still operational - result = yield c.submit(inc, 1, workers=a.address) + result = await c.submit(inc, 1, workers=a.address) assert result == 2 @gen_cluster() -def test_pid(s, a, b): +async def test_pid(s, a, b): assert s.workers[a.address].pid == os.getpid() @gen_cluster(client=True) -def test_get_client(c, s, a, b): +async def test_get_client(c, s, a, b): def f(x): cc = get_client() future = cc.submit(inc, x) @@ -939,7 +928,7 @@ def f(x): assert default_client() is c future = c.submit(f, 10, workers=a.address) - result = yield future + result = await future assert result == 11 assert a._client @@ -951,7 +940,7 @@ def f(x): a_client = a._client for i in range(10): - yield wait(c.submit(f, i)) + await wait(c.submit(f, i)) assert a._client is a_client @@ -967,32 +956,30 @@ def f(x): @gen_cluster(client=True) -def test_get_client_coroutine(c, s, a, b): - @gen.coroutine - def f(): - client = yield get_client() +async def test_get_client_coroutine(c, s, a, b): + async def f(): + client = await get_client() future = client.submit(inc, 10) - result = yield future - raise gen.Return(result) + result = await future + return result - results = yield c.run(f) + results = await c.run(f) assert results == {a.address: 11, b.address: 11} def test_get_client_coroutine_sync(client, s, a, b): - @gen.coroutine - def f(): - client = yield get_client() + async def f(): + client = await get_client() future = client.submit(inc, 10) - result = yield future - raise gen.Return(result) + result = await future + return result results = client.run(f) assert results == {a["address"]: 11, b["address"]: 11} @gen_cluster() -def test_global_workers(s, a, b): +async def test_global_workers(s, a, b): n = len(Worker._instances) w = first(Worker._instances) assert w is a or w is b @@ -1000,24 +987,24 @@ def test_global_workers(s, a, b): @pytest.mark.skipif(WINDOWS, reason="file descriptors") @gen_cluster(nthreads=[]) -def test_worker_fds(s): +async def test_worker_fds(s): psutil = pytest.importorskip("psutil") - yield gen.sleep(0.05) + await asyncio.sleep(0.05) start = psutil.Process().num_fds() - worker = yield Worker(s.address, loop=s.loop) - yield gen.sleep(0.1) + worker = await Worker(s.address, loop=s.loop) + await asyncio.sleep(0.1) middle = psutil.Process().num_fds() start = time() while middle > start: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 1 - yield worker.close() + await worker.close() start = time() while psutil.Process().num_fds() > start: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 0.5 @@ -1045,28 +1032,28 @@ async def test_start_services(s): @gen_test() -def test_scheduler_file(): +async def test_scheduler_file(): with tmpfile() as fn: - s = yield Scheduler(scheduler_file=fn, port=8009) - w = yield Worker(scheduler_file=fn) + s = await Scheduler(scheduler_file=fn, port=8009) + w = await Worker(scheduler_file=fn) assert set(s.workers) == {w.address} - yield w.close() + await w.close() s.stop() @gen_cluster(client=True) -def test_scheduler_delay(c, s, a, b): +async def test_scheduler_delay(c, s, a, b): old = a.scheduler_delay assert abs(a.scheduler_delay) < 0.3 assert abs(b.scheduler_delay) < 0.3 - yield gen.sleep(a.periodic_callbacks["heartbeat"].callback_time / 1000 + 0.3) + await asyncio.sleep(a.periodic_callbacks["heartbeat"].callback_time / 1000 + 0.3) assert a.scheduler_delay != old @gen_cluster(client=True) -def test_statistical_profiling(c, s, a, b): +async def test_statistical_profiling(c, s, a, b): futures = c.map(slowinc, range(10), delay=0.1) - yield wait(futures) + await wait(futures) profile = a.profile_keys["slowinc"] assert profile["count"] @@ -1082,14 +1069,14 @@ def test_statistical_profiling(c, s, a, b): "distributed.worker.profile.cycle": "100ms", }, ) -def test_statistical_profiling_2(c, s, a, b): +async def test_statistical_profiling_2(c, s, a, b): da = pytest.importorskip("dask.array") while True: x = da.random.random(1000000, chunks=(10000,)) y = (x + x * 2) - x.sum().persist() - yield wait(y) + await wait(y) - profile = yield a.get_profile() + profile = await a.get_profile() text = str(profile) if profile["count"] and "sum" in text and "random" in text: break @@ -1100,7 +1087,7 @@ def test_statistical_profiling_2(c, s, a, b): client=True, worker_kwargs={"memory_monitor_interval": 10}, ) -def test_robust_to_bad_sizeof_estimates(c, s, a): +async def test_robust_to_bad_sizeof_estimates(c, s, a): np = pytest.importorskip("numpy") memory = psutil.Process().memory_info().rss a.memory_limit = memory / 0.7 + 400e6 @@ -1121,7 +1108,7 @@ def f(n): start = time() while not a.data.disk: - yield gen.sleep(0.1) + await asyncio.sleep(0.1) assert time() < start + 5 @@ -1142,7 +1129,7 @@ def f(n): }, timeout=20, ) -def test_pause_executor(c, s, a): +async def test_pause_executor(c, s, a): memory = psutil.Process().memory_info().rss a.memory_limit = memory / 0.5 + 200e6 np = pytest.importorskip("numpy") @@ -1157,7 +1144,7 @@ def f(): start = time() while not a.paused: - yield gen.sleep(0.01) + await asyncio.sleep(0.01) assert time() < start + 4, ( format_bytes(psutil.Process().memory_info().rss), format_bytes(a.memory_limit), @@ -1169,41 +1156,41 @@ def f(): assert sum(f.status == "finished" for f in futures) < 4 - yield wait(futures) + await wait(futures) @gen_cluster(client=True, worker_kwargs={"profile_cycle_interval": "50 ms"}) -def test_statistical_profiling_cycle(c, s, a, b): +async def test_statistical_profiling_cycle(c, s, a, b): futures = c.map(slowinc, range(20), delay=0.05) - yield wait(futures) - yield gen.sleep(0.01) + await wait(futures) + await asyncio.sleep(0.01) end = time() assert len(a.profile_history) > 3 - x = yield a.get_profile(start=time() + 10, stop=time() + 20) + x = await a.get_profile(start=time() + 10, stop=time() + 20) assert not x["count"] - x = yield a.get_profile(start=0, stop=time() + 10) + x = await a.get_profile(start=0, stop=time() + 10) recent = a.profile_recent["count"] actual = sum(p["count"] for _, p in a.profile_history) + a.profile_recent["count"] - x2 = yield a.get_profile(start=0, stop=time() + 10) + x2 = await a.get_profile(start=0, stop=time() + 10) assert x["count"] <= actual <= x2["count"] - y = yield a.get_profile(start=end - 0.300, stop=time()) + y = await a.get_profile(start=end - 0.300, stop=time()) assert 0 < y["count"] <= x["count"] @gen_cluster(client=True) -def test_get_current_task(c, s, a, b): +async def test_get_current_task(c, s, a, b): def some_name(): return get_worker().get_current_task() - result = yield c.submit(some_name) + result = await c.submit(some_name) assert result.startswith("some_name") @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2) -def test_reschedule(c, s, a, b): +async def test_reschedule(c, s, a, b): s.extensions["stealing"]._pc.stop() a_address = a.address @@ -1214,7 +1201,7 @@ def f(x): futures = c.map(f, range(4)) futures2 = c.map(slowinc, range(10), delay=0.1, workers=a.address) - yield wait(futures) + await wait(futures) assert all(f.key in b.data for f in futures) @@ -1234,20 +1221,20 @@ async def test_deque_handler(cleanup): @gen_cluster(nthreads=[], client=True) -def test_avoid_memory_monitor_if_zero_limit(c, s): - worker = yield Worker( +async def test_avoid_memory_monitor_if_zero_limit(c, s): + worker = await Worker( s.address, loop=s.loop, memory_limit=0, memory_monitor_interval=10 ) assert type(worker.data) is dict assert "memory" not in worker.periodic_callbacks future = c.submit(inc, 1) - assert (yield future) == 2 - yield gen.sleep(worker.memory_monitor_interval / 1000) + assert (await future) == 2 + await asyncio.sleep(worker.memory_monitor_interval / 1000) - yield c.submit(inc, 2) # worker doesn't pause + await c.submit(inc, 2) # worker doesn't pause - yield worker.close() + await worker.close() @gen_cluster( @@ -1257,7 +1244,7 @@ def test_avoid_memory_monitor_if_zero_limit(c, s): "distributed.worker.memory.target": False, }, ) -def test_dict_data_if_no_spill_to_disk(s, w): +async def test_dict_data_if_no_spill_to_disk(s, w): assert type(w.data) is dict @@ -1277,27 +1264,27 @@ def func(dask_scheduler): @gen_cluster(nthreads=[("127.0.0.1", 1)], worker_kwargs={"memory_limit": "2e3 MB"}) -def test_parse_memory_limit(s, w): +async def test_parse_memory_limit(s, w): assert w.memory_limit == 2e9 @gen_cluster(nthreads=[], client=True) -def test_scheduler_address_config(c, s): +async def test_scheduler_address_config(c, s): with dask.config.set({"scheduler-address": s.address}): - worker = yield Worker(loop=s.loop) + worker = await Worker(loop=s.loop) assert worker.scheduler.address == s.address - yield worker.close() + await worker.close() @pytest.mark.slow @gen_cluster(client=True) -def test_wait_for_outgoing(c, s, a, b): +async def test_wait_for_outgoing(c, s, a, b): np = pytest.importorskip("numpy") x = np.random.random(10000000) - future = yield c.scatter(x, workers=a.address) + future = await c.scatter(x, workers=a.address) y = c.submit(inc, future, workers=b.address) - yield wait(y) + await wait(y) assert len(b.incoming_transfer_log) == len(a.outgoing_transfer_log) == 1 bb = b.incoming_transfer_log[0]["duration"] @@ -1313,11 +1300,11 @@ def test_wait_for_outgoing(c, s, a, b): @gen_cluster( nthreads=[("127.0.0.1", 1), ("127.0.0.1", 1), ("127.0.0.2", 1)], client=True ) -def test_prefer_gather_from_local_address(c, s, w1, w2, w3): - x = yield c.scatter(123, workers=[w1.address, w3.address], broadcast=True) +async def test_prefer_gather_from_local_address(c, s, w1, w2, w3): + x = await c.scatter(123, workers=[w1.address, w3.address], broadcast=True) y = c.submit(inc, x, workers=[w2.address]) - yield wait(y) + await wait(y) assert any(d["who"] == w2.address for d in w1.outgoing_transfer_log) assert not any(d["who"] == w2.address for d in w3.outgoing_transfer_log) @@ -1329,14 +1316,14 @@ def test_prefer_gather_from_local_address(c, s, w1, w2, w3): timeout=30, config={"distributed.worker.connections.incoming": 1}, ) -def test_avoid_oversubscription(c, s, *workers): +async def test_avoid_oversubscription(c, s, *workers): np = pytest.importorskip("numpy") x = c.submit(np.random.random, 1000000, workers=[workers[0].address]) - yield wait(x) + await wait(x) futures = [c.submit(len, x, pure=False, workers=[w.address]) for w in workers[1:]] - yield wait(futures) + await wait(futures) # Original worker not responsible for all transfers assert len(workers[0].outgoing_transfer_log) < len(workers) - 2 @@ -1346,13 +1333,13 @@ def test_avoid_oversubscription(c, s, *workers): @gen_cluster(client=True, worker_kwargs={"metrics": {"my_port": lambda w: w.port}}) -def test_custom_metrics(c, s, a, b): +async def test_custom_metrics(c, s, a, b): assert s.workers[a.address].metrics["my_port"] == a.port assert s.workers[b.address].metrics["my_port"] == b.port @gen_cluster(client=True) -def test_register_worker_callbacks(c, s, a, b): +async def test_register_worker_callbacks(c, s, a, b): # preload function to run def mystartup(dask_worker): dask_worker.init_variable = 1 @@ -1374,81 +1361,81 @@ def test_startup2(): return os.getenv("MY_ENV_VALUE", None) == "WORKER_ENV_VALUE" # Nothing has been run yet - result = yield c.run(test_import) + result = await c.run(test_import) assert list(result.values()) == [False] * 2 - result = yield c.run(test_startup2) + result = await c.run(test_startup2) assert list(result.values()) == [False] * 2 # Start a worker and check that startup is not run - worker = yield Worker(s.address, loop=s.loop) - result = yield c.run(test_import, workers=[worker.address]) + worker = await Worker(s.address, loop=s.loop) + result = await c.run(test_import, workers=[worker.address]) assert list(result.values()) == [False] - yield worker.close() + await worker.close() # Add a preload function - response = yield c.register_worker_callbacks(setup=mystartup) + response = await c.register_worker_callbacks(setup=mystartup) assert len(response) == 2 # Check it has been ran on existing worker - result = yield c.run(test_import) + result = await c.run(test_import) assert list(result.values()) == [True] * 2 # Start a worker and check it is ran on it - worker = yield Worker(s.address, loop=s.loop) - result = yield c.run(test_import, workers=[worker.address]) + worker = await Worker(s.address, loop=s.loop) + result = await c.run(test_import, workers=[worker.address]) assert list(result.values()) == [True] - yield worker.close() + await worker.close() # Register another preload function - response = yield c.register_worker_callbacks(setup=mystartup2) + response = await c.register_worker_callbacks(setup=mystartup2) assert len(response) == 2 # Check it has been run - result = yield c.run(test_startup2) + result = await c.run(test_startup2) assert list(result.values()) == [True] * 2 # Start a worker and check it is ran on it - worker = yield Worker(s.address, loop=s.loop) - result = yield c.run(test_import, workers=[worker.address]) + worker = await Worker(s.address, loop=s.loop) + result = await c.run(test_import, workers=[worker.address]) assert list(result.values()) == [True] - result = yield c.run(test_startup2, workers=[worker.address]) + result = await c.run(test_startup2, workers=[worker.address]) assert list(result.values()) == [True] - yield worker.close() + await worker.close() @gen_cluster(client=True) -def test_register_worker_callbacks_err(c, s, a, b): +async def test_register_worker_callbacks_err(c, s, a, b): with pytest.raises(ZeroDivisionError): - yield c.register_worker_callbacks(setup=lambda: 1 / 0) + await c.register_worker_callbacks(setup=lambda: 1 / 0) @gen_cluster(nthreads=[]) -def test_data_types(s): - w = yield Worker(s.address, data=dict) +async def test_data_types(s): + w = await Worker(s.address, data=dict) assert isinstance(w.data, dict) - yield w.close() + await w.close() data = dict() - w = yield Worker(s.address, data=data) + w = await Worker(s.address, data=data) assert w.data is data - yield w.close() + await w.close() class Data(dict): def __init__(self, x, y): self.x = x self.y = y - w = yield Worker(s.address, data=(Data, {"x": 123, "y": 456})) + w = await Worker(s.address, data=(Data, {"x": 123, "y": 456})) assert w.data.x == 123 assert w.data.y == 456 - yield w.close() + await w.close() @gen_cluster(nthreads=[]) -def test_local_directory(s): +async def test_local_directory(s): with tmpfile() as fn: with dask.config.set(temporary_directory=fn): - w = yield Worker(s.address) + w = await Worker(s.address) assert w.local_directory.startswith(fn) assert "dask-worker-space" in w.local_directory @@ -1457,15 +1444,15 @@ def test_local_directory(s): not sys.platform.startswith("linux"), reason="Need 127.0.0.2 to mean localhost" ) @gen_cluster(nthreads=[], client=True) -def test_host_address(c, s): - w = yield Worker(s.address, host="127.0.0.2") +async def test_host_address(c, s): + w = await Worker(s.address, host="127.0.0.2") assert "127.0.0.2" in w.address - yield w.close() + await w.close() - n = yield Nanny(s.address, host="127.0.0.3") + n = await Nanny(s.address, host="127.0.0.3") assert "127.0.0.3" in n.address assert "127.0.0.3" in n.worker_address - yield n.close() + await n.close() def test_resource_limit(monkeypatch): @@ -1537,7 +1524,7 @@ async def test_worker_listens_on_same_interface_by_default(Worker): async def test_close_gracefully(c, s, a, b): futures = c.map(slowinc, range(200), delay=0.1) while not b.data: - await gen.sleep(0.1) + await asyncio.sleep(0.1) mem = set(b.data) proc = set(b.executing) @@ -1558,7 +1545,7 @@ async def test_lifetime(cleanup): async with Worker(s.address) as a, Worker(s.address, lifetime="1 seconds") as b: async with Client(s.address, asynchronous=True) as c: futures = c.map(slowinc, range(200), delay=0.1) - await gen.sleep(1.5) + await asyncio.sleep(1.5) assert b.status != "running" await b.finished() diff --git a/distributed/tests/test_worker_client.py b/distributed/tests/test_worker_client.py index 14a2d30f7d5..09ae20e8f20 100644 --- a/distributed/tests/test_worker_client.py +++ b/distributed/tests/test_worker_client.py @@ -1,3 +1,4 @@ +import asyncio import random import threading from time import sleep @@ -6,7 +7,6 @@ import dask from dask import delayed import pytest -from tornado import gen from distributed import ( worker_client, @@ -22,7 +22,7 @@ @gen_cluster(client=True) -def test_submit_from_worker(c, s, a, b): +async def test_submit_from_worker(c, s, a, b): def func(x): with worker_client() as c: x = c.submit(inc, x) @@ -31,7 +31,7 @@ def func(x): return result x, y = c.map(func, [10, 20]) - xx, yy = yield c._gather([x, y]) + xx, yy = await c._gather([x, y]) assert xx == 10 + 1 + (10 + 1) * 2 assert yy == 20 + 1 + (20 + 1) * 2 @@ -41,7 +41,7 @@ def func(x): @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2) -def test_scatter_from_worker(c, s, a, b): +async def test_scatter_from_worker(c, s, a, b): def func(): with worker_client() as c: futures = c.scatter([1, 2, 3, 4, 5]) @@ -56,7 +56,7 @@ def func(): return total.result() future = c.submit(func) - result = yield future + result = await future assert result == sum([1, 2, 3, 4, 5]) def func(): @@ -72,17 +72,17 @@ def func(): return correct future = c.submit(func) - result = yield future + result = await future assert result is True start = time() while not all(v == 1 for v in s.nthreads.values()): - yield gen.sleep(0.1) + await asyncio.sleep(0.1) assert time() < start + 5 @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2) -def test_scatter_singleton(c, s, a, b): +async def test_scatter_singleton(c, s, a, b): np = pytest.importorskip("numpy") def func(): @@ -91,11 +91,11 @@ def func(): future = c.scatter(x) assert future.type == np.ndarray - yield c.submit(func) + await c.submit(func) @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2) -def test_gather_multi_machine(c, s, a, b): +async def test_gather_multi_machine(c, s, a, b): a_address = a.address b_address = b.address assert a_address != b_address @@ -109,19 +109,19 @@ def func(): return xx, yy future = c.submit(func) - result = yield future + result = await future assert result == (2, 3) @gen_cluster(client=True) -def test_same_loop(c, s, a, b): +async def test_same_loop(c, s, a, b): def f(): with worker_client() as lc: return lc.loop is get_worker().loop future = c.submit(f) - result = yield future + result = await future assert result @@ -140,7 +140,7 @@ def mysum(): @gen_cluster(client=True) -def test_async(c, s, a, b): +async def test_async(c, s, a, b): def mysum(): result = 0 sub_tasks = [delayed(double)(i) for i in range(100)] @@ -152,16 +152,16 @@ def mysum(): return result future = c.compute(delayed(mysum)()) - yield future + await future start = time() while len(a.data) + len(b.data) > 1: - yield gen.sleep(0.1) + await asyncio.sleep(0.1) assert time() < start + 3 @gen_cluster(client=True, nthreads=[("127.0.0.1", 3)]) -def test_separate_thread_false(c, s, a): +async def test_separate_thread_false(c, s, a): a.count = 0 def f(i): @@ -174,19 +174,19 @@ def f(i): return i futures = c.map(f, range(20)) - results = yield c._gather(futures) + results = await c._gather(futures) assert list(results) == list(range(20)) @gen_cluster(client=True) -def test_client_executor(c, s, a, b): +async def test_client_executor(c, s, a, b): def mysum(): with worker_client() as c: with c.get_executor() as e: return sum(e.map(double, range(30))) future = c.submit(mysum) - result = yield future + result = await future assert result == 30 * 29 @@ -211,7 +211,7 @@ def f(x): @gen_cluster(client=True) -def test_local_client_warning(c, s, a, b): +async def test_local_client_warning(c, s, a, b): from distributed import local_client def func(x): @@ -223,18 +223,18 @@ def func(x): return result future = c.submit(func, 10) - result = yield future + result = await future assert result == 11 @gen_cluster(client=True) -def test_closing_worker_doesnt_close_client(c, s, a, b): +async def test_closing_worker_doesnt_close_client(c, s, a, b): def func(x): get_client() return - yield wait(c.map(func, range(10))) - yield a.close() + await wait(c.map(func, range(10))) + await a.close() assert c.status == "running" @@ -260,15 +260,15 @@ def test_secede_without_stealing_issue_1262(): # run the loop as an inner function so all workers are closed # and exceptions can be examined @gen_cluster(client=True, scheduler_kwargs={"extensions": extensions}) - def secede_test(c, s, a, b): + async def secede_test(c, s, a, b): def func(x): with worker_client() as wc: y = wc.submit(lambda: 1 + x) return wc.gather(y) - f = yield c.gather(c.submit(func, 1)) + f = await c.gather(c.submit(func, 1)) - raise gen.Return((c, s, a, b, f)) + return c, s, a, b, f c, s, a, b, f = secede_test() @@ -278,40 +278,40 @@ def func(x): @gen_cluster(client=True) -def test_compute_within_worker_client(c, s, a, b): +async def test_compute_within_worker_client(c, s, a, b): @dask.delayed def f(): with worker_client(): return dask.delayed(lambda x: x)(1).compute() - result = yield c.compute(f()) + result = await c.compute(f()) assert result == 1 @gen_cluster(client=True) -def test_worker_client_rejoins(c, s, a, b): +async def test_worker_client_rejoins(c, s, a, b): def f(): with worker_client(): pass return threading.current_thread() in get_worker().executor._threads - result = yield c.submit(f) + result = await c.submit(f) assert result @gen_cluster() -def test_submit_different_names(s, a, b): +async def test_submit_different_names(s, a, b): # https://github.com/dask/distributed/issues/2058 da = pytest.importorskip("dask.array") - c = yield Client( + c = await Client( "localhost:" + s.address.split(":")[-1], loop=s.loop, asynchronous=True ) try: X = c.persist(da.random.uniform(size=(100, 10), chunks=50)) - yield wait(X) + await wait(X) - fut = yield c.submit(lambda x: x.sum().compute(), X) + fut = await c.submit(lambda x: x.sum().compute(), X) assert fut > 0 finally: - yield c.close() + await c.close() diff --git a/distributed/utils.py b/distributed/utils.py index adc20d4f368..30af57a25fc 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -202,6 +202,7 @@ def ignoring(*exceptions): pass +# FIXME: this breaks if changed to async def... @gen.coroutine def ignore_exceptions(coroutines, *exceptions): """ Process list of coroutines, ignoring certain exceptions diff --git a/distributed/utils_test.py b/distributed/utils_test.py index e1db066b732..e466322eddf 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -1,5 +1,6 @@ import asyncio import collections +import gc from contextlib import contextmanager import copy import functools @@ -33,7 +34,6 @@ import dask from tlz import merge, memoize, assoc -from tornado import gen from tornado.ioloop import IOLoop from . import system @@ -761,18 +761,16 @@ def gen_test(timeout=10): """ Coroutine test @gen_test(timeout=5) - def test_foo(): - yield ... # use tornado coroutines + async def test_foo(): + await ... # use tornado coroutines """ def _(func): def test_func(): with clean() as loop: - if iscoroutinefunction(func): - cor = func - else: - cor = gen.coroutine(func) - loop.run_sync(cor, timeout=timeout) + if not iscoroutinefunction(func): + raise ValueError("@gen_test should wrap async def functions") + loop.run_sync(func, timeout=timeout) return test_func @@ -856,14 +854,15 @@ def gen_cluster( active_rpc_timeout=1, config={}, clean_kwargs={}, + allow_unclosed=False, ): from distributed import Client """ Coroutine test with small cluster @gen_cluster() - def test_foo(scheduler, worker1, worker2): - yield ... # use tornado coroutines + async def test_foo(scheduler, worker1, worker2): + await ... # use tornado coroutines See also: start @@ -878,10 +877,10 @@ def test_foo(scheduler, worker1, worker2): ) def _(func): - if not iscoroutinefunction(func): - func = gen.coroutine(func) - def test_func(): + if not iscoroutinefunction(func): + raise ValueError("@gen_cluster should wrap async def functions") + result = None workers = [] with clean(timeout=active_rpc_timeout, **clean_kwargs) as loop: @@ -905,6 +904,7 @@ async def coro(): "Failed to start gen_cluster, retrying", exc_info=True, ) + await asyncio.sleep(1) else: workers[:] = ws args = [s] + workers @@ -940,16 +940,28 @@ async def coro(): else: await c._close(fast=True) - for i in range(5): - if all(c.closed() for c in Comm._instances): - break - else: + def get_unclosed(): + return [c for c in Comm._instances if not c.closed()] + [ + c + for c in _global_clients.values() + if c.status != "closed" + ] + + try: + start = time() + while time() < start + 5: + gc.collect() + if not get_unclosed(): + break await asyncio.sleep(0.05) - else: - L = [c for c in Comm._instances if not c.closed()] + else: + if allow_unclosed: + print(f"Unclosed Comms: {get_unclosed()}") + else: + raise RuntimeError("Unclosed Comms", get_unclosed()) + finally: Comm._instances.clear() - # raise ValueError("Unclosed Comms", L) - print("Unclosed Comms", L) + _global_clients.clear() return result diff --git a/distributed/worker.py b/distributed/worker.py index a50103bacab..c6734bbce93 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -2310,6 +2310,7 @@ def rescind_key(self, key): # Execute Task # ################ + # FIXME: this breaks if changed to async def... @gen.coroutine def executor_submit(self, key, function, args=(), kwargs=None, executor=None): """ Safely run function in thread pool executor diff --git a/docs/source/asynchronous.rst b/docs/source/asynchronous.rst index a49788e1fbe..342981833db 100644 --- a/docs/source/asynchronous.rst +++ b/docs/source/asynchronous.rst @@ -64,18 +64,12 @@ function to run the asynchronous function: client.sync(f) -Python 2 Compatibility ----------------------- - -Everything here works with Python 2 if you replace ``await`` with ``yield``. -See more extensive comparison in the example below. - Example ------- This self-contained example starts an asynchronous client, submits a trivial -job, waits on the result, and then shuts down the client. You can see -implementations for Python 2 and 3 and for Asyncio and Tornado. +job, waits on the result, and then shuts down the client. You can see +implementations for Asyncio and Tornado. Python 3 with Tornado or Asyncio ++++++++++++++++++++++++++++++++ @@ -100,25 +94,6 @@ Python 3 with Tornado or Asyncio asyncio.get_event_loop().run_until_complete(f()) -Python 2/3 with Tornado -+++++++++++++++++++++++ - -.. code-block:: python - - from dask.distributed import Client - from tornado import gen - - @gen.coroutine - def f(): - client = yield Client(asynchronous=True) - future = client.submit(lambda x: x + 1, 10) - result = yield future - yield client.close() - raise gen.Return(result) - - from tornado.ioloop import IOLoop - IOLoop().run_sync(f) - Use Cases --------- From 79546ce43d24924a1b55fdde014b0de290465239 Mon Sep 17 00:00:00 2001 From: Julia Signell Date: Mon, 20 Apr 2020 11:04:44 -0400 Subject: [PATCH 0798/1550] Fix copy-paste in docs (#3728) --- distributed/deploy/spec.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index c6338d3b93f..eb9f0f0043e 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -551,7 +551,7 @@ def adapt( minimum_memory : str Minimum amount of memory to keep around in the cluster Expressed as a string like "100 GiB" - maximum_cores : int + maximum_memory : str Maximum amount of memory to keep around in the cluster Expressed as a string like "100 GiB" From 8534e84bba401f61f0339b72db100e637f7b729c Mon Sep 17 00:00:00 2001 From: Julia Signell Date: Mon, 20 Apr 2020 11:12:23 -0400 Subject: [PATCH 0799/1550] Configurable polling interval for cluster widget (#3723) Customize interval with which to do callbacks to get cluster status for the widget --- distributed/client.py | 8 +++++++- distributed/deploy/cluster.py | 7 ++++++- distributed/distributed.yaml | 4 +++- 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 5ba05a84a3b..a48c2367112 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -676,9 +676,15 @@ def __init__( heartbeat_interval = dask.config.get("distributed.client.heartbeat") heartbeat_interval = parse_timedelta(heartbeat_interval, default="ms") + scheduler_info_interval = parse_timedelta( + dask.config.get("distributed.client.scheduler-info-interval", default="ms") + ) + self._periodic_callbacks = dict() self._periodic_callbacks["scheduler-info"] = PeriodicCallback( - self._update_scheduler_info, 2000, io_loop=self.loop + self._update_scheduler_info, + scheduler_info_interval * 1000, + io_loop=self.loop, ) self._periodic_callbacks["heartbeat"] = PeriodicCallback( self._heartbeat, heartbeat_interval * 1000, io_loop=self.loop diff --git a/distributed/deploy/cluster.py b/distributed/deploy/cluster.py index 7164b17b076..592195443c1 100644 --- a/distributed/deploy/cluster.py +++ b/distributed/deploy/cluster.py @@ -3,6 +3,7 @@ import threading import warnings +import dask.config from dask.utils import format_bytes from .adaptive import Adaptive @@ -16,6 +17,7 @@ Logs, thread_state, format_dashboard_link, + parse_timedelta, ) @@ -319,7 +321,10 @@ def scale_cb(b): def update(): status.value = self._widget_status() - pc = PeriodicCallback(update, 500, io_loop=self.loop) + cluster_repr_interval = parse_timedelta( + dask.config.get("distributed.deploy.cluster-repr-interval", default="ms") + ) + pc = PeriodicCallback(update, cluster_repr_interval * 1000, io_loop=self.loop) self.periodic_callbacks["cluster-repr"] = pc pc.start() diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index 71ecd840a10..4f95a179bc3 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -96,10 +96,12 @@ distributed: preload-argv: [] # See https://docs.dask.org/en/latest/setup/custom-startup.html client: - heartbeat: 5s # time between client heartbeats + heartbeat: 5s # Interval between client heartbeats + scheduler-info-interval: 2s # Interval between scheduler-info updates deploy: lost-worker-timeout: 15s # Interval after which to hard-close a lost worker job + cluster-repr-interval: 500ms # Interval between calls to update cluster-repr for the widget adaptive: interval: 1s # Interval between scaling evaluations From 8376f227e288757c3fd9b1e8742ca350e6e57b25 Mon Sep 17 00:00:00 2001 From: Abdulelah Bin Mahfoodh Date: Mon, 20 Apr 2020 22:10:16 +0300 Subject: [PATCH 0800/1550] Add remote_python option in ssh cmd (#3709) * Add remote_python option in ssh cmd * Add remote_python option in ssh cmd --- distributed/deploy/ssh.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/distributed/deploy/ssh.py b/distributed/deploy/ssh.py index 4f0e713ffa9..22364ecbacf 100644 --- a/distributed/deploy/ssh.py +++ b/distributed/deploy/ssh.py @@ -110,7 +110,7 @@ async def start(self): cmd = " ".join( [ set_env, - sys.executable, + self.remote_python or sys.executable, "-m", self.worker_module, self.scheduler, @@ -186,7 +186,12 @@ async def start(self): ) cmd = " ".join( - [set_env, sys.executable, "-m", "distributed.cli.dask_scheduler",] + [ + set_env, + self.remote_python or sys.executable, + "-m", + "distributed.cli.dask_scheduler", + ] + cli_keywords(self.kwargs, cls=_Scheduler) ) self.proc = await self.connection.create_process(cmd) From 35dc9409f8cf99f82c354b97302846705dfbcc4a Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Mon, 20 Apr 2020 21:18:48 -0500 Subject: [PATCH 0801/1550] Use PeriodicCallback class from tornado (#3725) --- distributed/client.py | 9 +-- distributed/core.py | 15 ++--- distributed/counter.py | 4 +- distributed/deploy/adaptive_core.py | 4 +- distributed/deploy/cluster.py | 4 +- distributed/nanny.py | 5 +- distributed/scheduler.py | 7 +-- distributed/semaphore.py | 13 ++--- distributed/stealing.py | 14 ++--- distributed/utils.py | 88 ++++++++++++----------------- distributed/worker.py | 33 ++++------- 11 files changed, 77 insertions(+), 119 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index a48c2367112..52c0e2b420e 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -37,7 +37,7 @@ except ImportError: single_key = first from tornado import gen -from tornado.ioloop import IOLoop +from tornado.ioloop import IOLoop, PeriodicCallback from .batched import BatchedSend from .utils_comm import ( @@ -72,7 +72,6 @@ key_split, thread_state, no_default, - PeriodicCallback, LoopRunner, parse_timedelta, shutting_down, @@ -682,12 +681,10 @@ def __init__( self._periodic_callbacks = dict() self._periodic_callbacks["scheduler-info"] = PeriodicCallback( - self._update_scheduler_info, - scheduler_info_interval * 1000, - io_loop=self.loop, + self._update_scheduler_info, scheduler_info_interval * 1000, ) self._periodic_callbacks["heartbeat"] = PeriodicCallback( - self._heartbeat, heartbeat_interval * 1000, io_loop=self.loop + self._heartbeat, heartbeat_interval * 1000 ) self._start_arg = address diff --git a/distributed/core.py b/distributed/core.py index dd5e18d0007..df0a55780e7 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -12,7 +12,7 @@ import tblib from tlz import merge from tornado import gen -from tornado.ioloop import IOLoop +from tornado.ioloop import IOLoop, PeriodicCallback from .comm import ( connect, @@ -31,7 +31,6 @@ truncate_exception, ignoring, shutting_down, - PeriodicCallback, parse_timedelta, has_keyword, CancelledError, @@ -176,18 +175,14 @@ def stop(): self.periodic_callbacks = dict() - pc = PeriodicCallback(self.monitor.update, 500, io_loop=self.io_loop) + pc = PeriodicCallback(self.monitor.update, 500) self.periodic_callbacks["monitor"] = pc self._last_tick = time() - pc = PeriodicCallback( - self._measure_tick, - parse_timedelta( - dask.config.get("distributed.admin.tick.interval"), default="ms" - ) - * 1000, - io_loop=self.io_loop, + measure_tick_interval = parse_timedelta( + dask.config.get("distributed.admin.tick.interval"), default="ms" ) + pc = PeriodicCallback(self._measure_tick, measure_tick_interval * 1000) self.periodic_callbacks["tick"] = pc self.thread_id = 0 diff --git a/distributed/counter.py b/distributed/counter.py index ebc8cda6104..feffb69ce8c 100644 --- a/distributed/counter.py +++ b/distributed/counter.py @@ -1,8 +1,6 @@ from collections import defaultdict -from tornado.ioloop import IOLoop - -from .utils import PeriodicCallback +from tornado.ioloop import IOLoop, PeriodicCallback try: diff --git a/distributed/deploy/adaptive_core.py b/distributed/deploy/adaptive_core.py index 192e244bd08..7d15cb4c2c7 100644 --- a/distributed/deploy/adaptive_core.py +++ b/distributed/deploy/adaptive_core.py @@ -1,11 +1,11 @@ import collections import math -from tornado.ioloop import IOLoop +from tornado.ioloop import IOLoop, PeriodicCallback import tlz as toolz from ..metrics import time -from ..utils import parse_timedelta, PeriodicCallback +from ..utils import parse_timedelta class AdaptiveCore: diff --git a/distributed/deploy/cluster.py b/distributed/deploy/cluster.py index 592195443c1..35e0b97c613 100644 --- a/distributed/deploy/cluster.py +++ b/distributed/deploy/cluster.py @@ -2,6 +2,7 @@ import logging import threading import warnings +from tornado.ioloop import PeriodicCallback import dask.config from dask.utils import format_bytes @@ -9,7 +10,6 @@ from .adaptive import Adaptive from ..utils import ( - PeriodicCallback, log_errors, ignoring, sync, @@ -324,7 +324,7 @@ def update(): cluster_repr_interval = parse_timedelta( dask.config.get("distributed.deploy.cluster-repr-interval", default="ms") ) - pc = PeriodicCallback(update, cluster_repr_interval * 1000, io_loop=self.loop) + pc = PeriodicCallback(update, cluster_repr_interval * 1000) self.periodic_callbacks["cluster-repr"] = pc pc.start() diff --git a/distributed/nanny.py b/distributed/nanny.py index 3f7c20f98f9..f3a355dca89 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -11,7 +11,7 @@ import dask from dask.system import CPU_COUNT -from tornado.ioloop import IOLoop +from tornado.ioloop import IOLoop, PeriodicCallback from tornado import gen from .comm import get_address_host, unparse_host_port @@ -28,7 +28,6 @@ mp_context, silence_logging, json_load_robust, - PeriodicCallback, parse_timedelta, ignoring, TimeoutError, @@ -202,7 +201,7 @@ def __init__( self.scheduler = self.rpc(self.scheduler_addr) if self.memory_limit: - pc = PeriodicCallback(self.memory_monitor, 100, io_loop=self.loop) + pc = PeriodicCallback(self.memory_monitor, 100) self.periodic_callbacks["memory"] = pc if ( diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 86dd6b9203e..521415c7c25 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -32,7 +32,7 @@ groupby, concat, ) -from tornado.ioloop import IOLoop +from tornado.ioloop import IOLoop, PeriodicCallback import dask @@ -64,7 +64,6 @@ no_default, parse_timedelta, parse_bytes, - PeriodicCallback, shutting_down, key_split_group, empty_context, @@ -1357,11 +1356,11 @@ def __init__( ) if self.worker_ttl: - pc = PeriodicCallback(self.check_worker_ttl, self.worker_ttl, io_loop=loop) + pc = PeriodicCallback(self.check_worker_ttl, self.worker_ttl) self.periodic_callbacks["worker-ttl"] = pc if self.idle_timeout: - pc = PeriodicCallback(self.check_idle, self.idle_timeout / 4, io_loop=loop) + pc = PeriodicCallback(self.check_idle, self.idle_timeout / 4) self.periodic_callbacks["idle-timeout"] = pc if extensions is None: diff --git a/distributed/semaphore.py b/distributed/semaphore.py index 976f54704c4..263619c9073 100644 --- a/distributed/semaphore.py +++ b/distributed/semaphore.py @@ -3,7 +3,8 @@ import asyncio import dask from asyncio import TimeoutError -from .utils import PeriodicCallback, log_errors, parse_timedelta +from tornado.ioloop import PeriodicCallback +from .utils import log_errors, parse_timedelta from .worker import get_client from .metrics import time import warnings @@ -66,14 +67,12 @@ def __init__(self, scheduler): self.scheduler.extensions["semaphores"] = self - validation_callback_time = 1000 * parse_timedelta( + validation_callback_time = parse_timedelta( dask.config.get("distributed.scheduler.locks.lease-validation-interval"), default="s", ) self._pc_lease_timeout = PeriodicCallback( - self._check_lease_timeout, - validation_callback_time, - io_loop=self.scheduler.loop, + self._check_lease_timeout, validation_callback_time * 1000, ) self._pc_lease_timeout.start() self.lease_timeout = parse_timedelta( @@ -344,9 +343,7 @@ def __init__(self, max_leases=1, name=None, client=None): ) self._refreshing_leases = False pc = PeriodicCallback( - self._refresh_leases, - callback_time=1000 * refresh_leases_interval, - io_loop=self.client.io_loop, + self._refresh_leases, callback_time=refresh_leases_interval * 1000 ) self.refresh_callback = pc # Registering the pc to the client here is important for proper cleanup diff --git a/distributed/stealing.py b/distributed/stealing.py index 0d552d1689f..874ca98ce77 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -3,11 +3,13 @@ from math import log from time import time +from tornado.ioloop import PeriodicCallback + import dask from .comm.addressing import get_address_host from .core import CommClosedError from .diagnostics.plugin import SchedulerPlugin -from .utils import log_errors, parse_timedelta, PeriodicCallback +from .utils import log_errors, parse_timedelta from tlz import topk @@ -36,16 +38,12 @@ def __init__(self, scheduler): for worker in scheduler.workers: self.add_worker(worker=worker) - # `callback_time` is in milliseconds - callback_time = 1000 * parse_timedelta( + callback_time = parse_timedelta( dask.config.get("distributed.scheduler.work-stealing-interval"), default="ms", ) - pc = PeriodicCallback( - callback=self.balance, - callback_time=callback_time, - io_loop=self.scheduler.loop, - ) + # `callback_time` is in milliseconds + pc = PeriodicCallback(callback=self.balance, callback_time=callback_time * 1000) self._pc = pc self.scheduler.periodic_callbacks["stealing"] = pc self.scheduler.plugins.append(self) diff --git a/distributed/utils.py b/distributed/utils.py index 30af57a25fc..46bd4c245e8 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -48,7 +48,6 @@ ) import tlz as toolz -import tornado from tornado import gen from tornado.ioloop import IOLoop @@ -1118,17 +1117,6 @@ def nbytes(frame, _bytes_like=(bytes, bytearray)): return len(frame) -def PeriodicCallback(callback, callback_time, io_loop=None): - """ - Wrapper around tornado.IOLoop.PeriodicCallback, for compatibility - with removal of the `io_loop` parameter in Tornado 5.0. - """ - if tornado.version_info >= (5,): - return tornado.ioloop.PeriodicCallback(callback, callback_time) - else: - return tornado.ioloop.PeriodicCallback(callback, callback_time, io_loop) - - @contextmanager def time_warn(duration, text): start = time() @@ -1191,49 +1179,47 @@ def reset_logger_locks(): handler.createLock() -if tornado.version_info[0] >= 5: - - is_server_extension = False +is_server_extension = False - if "notebook" in sys.modules: - import traitlets - from notebook.notebookapp import NotebookApp - - is_server_extension = traitlets.config.Application.initialized() and isinstance( - traitlets.config.Application.instance(), NotebookApp - ) +if "notebook" in sys.modules: + import traitlets + from notebook.notebookapp import NotebookApp - if not is_server_extension: - is_kernel_and_no_running_loop = False + is_server_extension = traitlets.config.Application.initialized() and isinstance( + traitlets.config.Application.instance(), NotebookApp + ) - if is_kernel(): - try: - get_running_loop() - except RuntimeError: - is_kernel_and_no_running_loop = True - - if not is_kernel_and_no_running_loop: - - # TODO: Use tornado's AnyThreadEventLoopPolicy, instead of class below, - # once tornado > 6.0.3 is available. - if WINDOWS and hasattr(asyncio, "WindowsSelectorEventLoopPolicy"): - # WindowsProactorEventLoopPolicy is not compatible with tornado 6 - # fallback to the pre-3.8 default of Selector - # https://github.com/tornadoweb/tornado/issues/2608 - BaseEventLoopPolicy = asyncio.WindowsSelectorEventLoopPolicy - else: - BaseEventLoopPolicy = asyncio.DefaultEventLoopPolicy +if not is_server_extension: + is_kernel_and_no_running_loop = False - class AnyThreadEventLoopPolicy(BaseEventLoopPolicy): - def get_event_loop(self): - try: - return super().get_event_loop() - except (RuntimeError, AssertionError): - loop = self.new_event_loop() - self.set_event_loop(loop) - return loop - - asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy()) + if is_kernel(): + try: + get_running_loop() + except RuntimeError: + is_kernel_and_no_running_loop = True + + if not is_kernel_and_no_running_loop: + + # TODO: Use tornado's AnyThreadEventLoopPolicy, instead of class below, + # once tornado > 6.0.3 is available. + if WINDOWS and hasattr(asyncio, "WindowsSelectorEventLoopPolicy"): + # WindowsProactorEventLoopPolicy is not compatible with tornado 6 + # fallback to the pre-3.8 default of Selector + # https://github.com/tornadoweb/tornado/issues/2608 + BaseEventLoopPolicy = asyncio.WindowsSelectorEventLoopPolicy + else: + BaseEventLoopPolicy = asyncio.DefaultEventLoopPolicy + + class AnyThreadEventLoopPolicy(BaseEventLoopPolicy): + def get_event_loop(self): + try: + return super().get_event_loop() + except (RuntimeError, AssertionError): + loop = self.new_event_loop() + self.set_event_loop(loop) + return loop + + asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy()) @functools.lru_cache(1000) diff --git a/distributed/worker.py b/distributed/worker.py index c6734bbce93..ef95c1f4b7f 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -24,7 +24,7 @@ from tlz import pluck, merge, first, keymap from tornado import gen -from tornado.ioloop import IOLoop +from tornado.ioloop import IOLoop, PeriodicCallback from . import profile, comm, system from .batched import BatchedSend @@ -55,7 +55,6 @@ json_load_robust, key_split, offload, - PeriodicCallback, parse_bytes, parse_timedelta, iscoroutinefunction, @@ -474,9 +473,6 @@ def __init__( self.available_resources = (resources or {}).copy() self.death_timeout = parse_timedelta(death_timeout) - self.memory_monitor_interval = parse_timedelta( - memory_monitor_interval, default="ms" - ) self.extensions = dict() if silence_logs: silence_logging(level=silence_logs) @@ -659,23 +655,22 @@ def __init__( "worker": self, } - pc = PeriodicCallback(self.heartbeat, 1000, io_loop=self.io_loop) + pc = PeriodicCallback(self.heartbeat, 1000) self.periodic_callbacks["heartbeat"] = pc pc = PeriodicCallback( - lambda: self.batched_stream.send({"op": "keep-alive"}), - 60000, - io_loop=self.io_loop, + lambda: self.batched_stream.send({"op": "keep-alive"}), 60000, ) self.periodic_callbacks["keep-alive"] = pc self._address = contact_address + self.memory_monitor_interval = parse_timedelta( + memory_monitor_interval, default="ms" + ) if self.memory_limit: self._memory_monitoring = False pc = PeriodicCallback( - self.memory_monitor, - self.memory_monitor_interval * 1000, - io_loop=self.io_loop, + self.memory_monitor, self.memory_monitor_interval * 1000, ) self.periodic_callbacks["memory"] = pc @@ -688,19 +683,13 @@ def __init__( setproctitle("dask-worker [not started]") - pc = PeriodicCallback( - self.trigger_profile, - parse_timedelta( - dask.config.get("distributed.worker.profile.interval"), default="ms" - ) - * 1000, - io_loop=self.io_loop, + profile_trigger_interval = parse_timedelta( + dask.config.get("distributed.worker.profile.interval"), default="ms" ) + pc = PeriodicCallback(self.trigger_profile, profile_trigger_interval * 1000) self.periodic_callbacks["profile"] = pc - pc = PeriodicCallback( - self.cycle_profile, profile_cycle_interval * 1000, io_loop=self.io_loop - ) + pc = PeriodicCallback(self.cycle_profile, profile_cycle_interval * 1000) self.periodic_callbacks["profile-cycle"] = pc self.plugins = {} From 83c27fa8dae2cf389b4d8f142721255c9c6ff0e9 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Tue, 21 Apr 2020 15:05:44 +0100 Subject: [PATCH 0802/1550] Reuse CI scripts for local installation process (#3698) This commit streamlines the CI installation script, with the main purpose of making it easier to replicate on a Linux/MacOSX dev box the exact same environment that exists on CI. Apply safe defaults for environment variables if the script is not invoked by travis Do not install miniconda if the 'conda' command is already available Explicitly state conda channel every time to avoid relying on ~/.condarc Recommend using conda in developer docs --- .github/workflows/ci-windows.yaml | 4 +- .travis.yml | 8 +-- ...nvironment.yml => environment-windows.yml} | 13 ++--- continuous_integration/travis/install.sh | 51 +++++++++++-------- continuous_integration/travis/run_tests.sh | 6 +-- docs/source/develop.rst | 27 +++++----- 6 files changed, 60 insertions(+), 49 deletions(-) rename continuous_integration/{environment.yml => environment-windows.yml} (76%) diff --git a/.github/workflows/ci-windows.yaml b/.github/workflows/ci-windows.yaml index e0c95d0f234..2e536a79663 100644 --- a/.github/workflows/ci-windows.yaml +++ b/.github/workflows/ci-windows.yaml @@ -19,8 +19,8 @@ jobs: with: miniconda-version: "latest" python-version: ${{ matrix.python-version }} - environment-file: continuous_integration/environment.yml - activate-environment: testenv + environment-file: continuous_integration/environment-windows.yml + activate-environment: dask-distributed auto-activate-base: false - name: Install tornado diff --git a/.travis.yml b/.travis.yml index e8f2afc5057..b7995a2a034 100644 --- a/.travis.yml +++ b/.travis.yml @@ -6,9 +6,9 @@ dist: trusty env: matrix: - - PYTHON=3.6 TESTS=true COVERAGE=true PACKAGES="scikit-learn lz4" TORNADO=5 CRICK=true - - PYTHON=3.7 TESTS=true PACKAGES="scikit-learn python-snappy python-blosc" TORNADO=6 - - PYTHON=3.8 TESTS=true PACKAGES="scikit-learn python-snappy python-blosc" TORNADO=6 + - PYTHON=3.6 TESTS=true COVERAGE=true PACKAGES="lz4" TORNADO=5 CRICK=true + - PYTHON=3.7 TESTS=true PACKAGES="python-snappy python-blosc" TORNADO=6 + - PYTHON=3.8 TESTS=true PACKAGES="python-snappy python-blosc" TORNADO=6 matrix: fast_finish: true @@ -18,7 +18,7 @@ matrix: python: 3.6 env: LINT=true - os: osx - env: PYTHON=3.7 TESTS=true PACKAGES="scikit-learn python-snappy python-blosc" TORNADO=6 + env: PYTHON=3.7 TESTS=true PACKAGES="python-snappy python-blosc" TORNADO=6 if: type != pull_request OR commit_message =~ test-osx # Skip on PRs unless the commit message contains "test-osx" allow_failures: diff --git a/continuous_integration/environment.yml b/continuous_integration/environment-windows.yml similarity index 76% rename from continuous_integration/environment.yml rename to continuous_integration/environment-windows.yml index 5f09525caae..2cede561425 100644 --- a/continuous_integration/environment.yml +++ b/continuous_integration/environment-windows.yml @@ -1,6 +1,7 @@ -name: testenv +name: dask-distributed channels: - conda-forge + - defaults dependencies: - zstandard - bokeh!=2.0.0 @@ -16,18 +17,18 @@ dependencies: - prometheus_client - psutil - pytest + - pytest-asyncio + - pytest-repeat + - pytest-timeout + - pytest-faulthandler - requests + - sortedcollections - toolz - tblib - zict - fsspec - pip - pip: - - pytest-repeat - - pytest-timeout - - pytest-faulthandler - - sortedcollections - - pytest-asyncio - git+https://github.com/dask/dask - git+https://github.com/joblib/joblib.git - git+https://github.com/dask/zict diff --git a/continuous_integration/travis/install.sh b/continuous_integration/travis/install.sh index 4ee0790f6c5..e362ea7f079 100644 --- a/continuous_integration/travis/install.sh +++ b/continuous_integration/travis/install.sh @@ -5,6 +5,12 @@ # Note we disable progress bars to make Travis log loading much faster +# Set default variable values if unset +# (useful when this script is not invoked by Travis) +: ${PYTHON:=3.8} +: ${TORNADO:=6} +: ${PACKAGES:=python-snappy python-blosc} + # Install conda case "$(uname -s)" in 'Darwin') @@ -16,18 +22,17 @@ case "$(uname -s)" in *) ;; esac -wget https://repo.continuum.io/miniconda/$MINICONDA_FILENAME -O miniconda.sh -bash miniconda.sh -b -p $HOME/miniconda -export PATH="$HOME/miniconda/bin:$PATH" -conda config --set always_yes yes --set changeps1 no -conda update -q conda +if ! which conda; then + wget https://repo.continuum.io/miniconda/$MINICONDA_FILENAME -O miniconda.sh + bash miniconda.sh -b -p $HOME/miniconda + export PATH="$HOME/miniconda/bin:$PATH" +fi -# Create conda environment -conda create -q -n test-environment python=$PYTHON -source activate test-environment +conda config --set always_yes yes --set quiet yes --set changeps1 no +conda update conda -# Install dependencies -conda install -c conda-forge -q \ +# Create conda environment +conda create -n dask-distributed -c conda-forge -c defaults \ asyncssh \ bokeh \ click \ @@ -39,45 +44,49 @@ conda install -c conda-forge -q \ ipywidgets \ joblib \ jupyter_client \ - msgpack-python>=0.6.0 \ + 'msgpack-python>=0.6.0' \ netcdf4 \ paramiko \ prometheus_client \ psutil \ - pytest>=4 \ + 'pytest>=4' \ + pytest-asyncio \ + pytest-faulthandler \ + pytest-repeat \ pytest-timeout \ python=$PYTHON \ requests \ + scikit-learn \ scipy \ - tblib>=1.5.0 \ + sortedcollections \ + 'tblib>=1.5.0' \ toolz \ tornado=$TORNADO \ zstandard \ $PACKAGES +source activate dask-distributed + # stacktrace is not currently avaiable for Python 3.8. # Remove the version check block below when it is avaiable. if [[ $PYTHON != 3.8 ]]; then # For low-level profiler, install libunwind and stacktrace from conda-forge # For stacktrace we use --no-deps to avoid upgrade of python - conda install -c defaults -c conda-forge libunwind - conda install --no-deps -c defaults -c numba -c conda-forge stacktrace -fi; - -python -m pip install -q "pytest>=4" pytest-repeat pytest-faulthandler pytest-asyncio + conda install -c conda-forge -c defaults libunwind + conda install --no-deps -c conda-forge -c defaults -c numba stacktrace +fi python -m pip install -q git+https://github.com/dask/dask.git --upgrade --no-deps python -m pip install -q git+https://github.com/joblib/joblib.git --upgrade --no-deps python -m pip install -q git+https://github.com/intake/filesystem_spec.git --upgrade --no-deps python -m pip install -q git+https://github.com/dask/s3fs.git --upgrade --no-deps python -m pip install -q git+https://github.com/dask/zict.git --upgrade --no-deps -python -m pip install -q sortedcollections --no-deps python -m pip install -q keras --upgrade --no-deps if [[ $CRICK == true ]]; then - conda install -q cython + conda install -c conda-forge -c defaults cython python -m pip install -q git+https://github.com/jcrist/crick.git -fi; +fi # Install distributed python -m pip install --no-deps -e . diff --git a/continuous_integration/travis/run_tests.sh b/continuous_integration/travis/run_tests.sh index 14c3db7750a..1bf86545cef 100644 --- a/continuous_integration/travis/run_tests.sh +++ b/continuous_integration/travis/run_tests.sh @@ -19,7 +19,7 @@ echo "--" ulimit -a -H if [[ $COVERAGE == true ]]; then - coverage run $(which py.test) distributed -m "not avoid_travis" $PYTEST_OPTIONS; + coverage run $(which py.test) distributed -m "not avoid_travis" $PYTEST_OPTIONS else - py.test -m "not avoid_travis" distributed $PYTEST_OPTIONS; -fi; + py.test -m "not avoid_travis" distributed $PYTEST_OPTIONS +fi diff --git a/docs/source/develop.rst b/docs/source/develop.rst index 8d0a02fd73d..254eb914aaa 100644 --- a/docs/source/develop.rst +++ b/docs/source/develop.rst @@ -12,25 +12,26 @@ guidelines`_ in the main documentation. Install ------- -After setting up an environment as described in the `Dask developer -guidelines`_ you can clone this repository with git:: +Clone this repository with git:: git clone git@github.com:dask/distributed.git + cd distributed -and install it from source:: +Install all dependencies: - cd distributed - python setup.py install +On Linux / MacOSX:: -Using conda, for example:: + source continuous_integration/travis/install.sh - git clone git@github.com:{your-fork}/distributed.git - cd distributed - conda create -y -n distributed python=3.6 - conda activate distributed - python -m pip install -U -r requirements.txt - python -m pip install -U -r dev-requirements.txt - python -m pip install -e . +On Windows: + +1. Install anaconda or miniconda +2. :: + + conda create -n dask-distributed -c conda-forge -c defaults python=3.8 tornado=6 + conda activate dask-distributed + conda env update --file continuous_integration/environment-windows.yml + python -m pip install . To keep a fork in sync with the upstream source:: From 9cfd06685fa889e9b4705e9943cbc4c4a0fd6643 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 21 Apr 2020 10:34:02 -0700 Subject: [PATCH 0803/1550] Add Configuration Schema (#3696) * Add Configuration Schema This adds type and description information to the configuration using the jsonschema spec. So far this is only a proof of concept, and touches only a couple of entries. * Add schema test to CI * Try using sphinx-jsonschema * Add more configuration descriptions * add more descriptions * add a bunch more descriptions * use multi-line text blocks * add more descriptions * Test completeness of schema * add new config value * Revert changes adding a docpage * Add informative error when config and schema are out of sync --- distributed/distributed-schema.yaml | 796 ++++++++++++++++++++++++++++ distributed/distributed.yaml | 2 +- distributed/scheduler.py | 4 +- distributed/tests/test_config.py | 48 ++ 4 files changed, 848 insertions(+), 2 deletions(-) create mode 100644 distributed/distributed-schema.yaml diff --git a/distributed/distributed-schema.yaml b/distributed/distributed-schema.yaml new file mode 100644 index 00000000000..60dfbf54e05 --- /dev/null +++ b/distributed/distributed-schema.yaml @@ -0,0 +1,796 @@ +properties: + distributed: + type: object + properties: + + version: + type: integer + + scheduler: + type: object + properties: + + allowed-failures: + type: integer + minimum: 0 + description: | + The number of retries before a task is considered bad + + When a worker dies when a task is running that task is rerun elsewhere. + If many workers die while running this same task then we call the task bad, and raise a KilledWorker exception. + This is the number of workers that are allowed to die before this task is marked as bad. + + bandwidth: + type: + - integer + - string + description: | + The expected bandwidth between any pair of workers + + This is used when making scheduling decisions. + The scheduler will use this value as a baseline, but also learn it over time. + + blocked-handlers: + type: array + description: | + A list of handlers to exclude + + The scheduler operates by receiving messages from various workers and clients + and then performing operations based on those messages. + Each message has an operation like "close-worker" or "task-finished". + In some high security situations administrators may choose to block certain handlers + from running. Those handlers can be listed here. + + For a list of handlers see the `dask.distributed.Scheduler.handlers` attribute. + + default-data-size: + type: + - string + - integer + description: | + The default size of a piece of data if we don't know anything about it. + + This is used by the scheduler in some scheduling decisions + + events-cleanup-delay: + type: string + description: | + The amount of time to wait until workers or clients are removed from the event log + after they have been removed from the scheduler + + idle-timeout: + type: + - string + - "null" + description: | + Shut down the scheduler after this duration if no activity has occured + + This can be helpful to reduce costs and stop zombie processes from roaming the earth. + + transition-log-length: + type: integer + minimum: 0 + description: | + How long should we keep the transition log + + Every time a task transitions states (like "waiting", "processing", "memory", "released") + we record that transition in a log. + + To make sure that we don't run out of memory + we will clear out old entries after a certain length. + This is that length. + + work-stealing: + type: boolean + description: | + Whether or not to balance work between workers dynamically + + Some times one worker has more work than we expected. + The scheduler will move these tasks around as necessary by default. + Set this to false to disable this behavior + + work-stealing-interval: + type: string + description: | + How frequently to balance worker loads + + worker-ttl: + type: + - string + - "null" + description: | + Time to live for workers. + + If we don't receive a heartbeat faster than this then we assume that the worker has died. + + pickle: + type: boolean + description: | + Is the scheduler allowed to deserialize arbitrary bytestrings? + + The scheduler almost never deserializes user data. + However there are some cases where the user can submit functions to run directly on the scheduler. + This can be convenient for debugging, but also introduces some security risk. + By setting this to false we ensure that the user is unable to run arbitrary code on the scheduler. + + preload: + type: array + description: | + Run custom modules during the lifetime of the scheduler + + You can run custom modules when the scheduler starts up and closes down. + See https://docs.dask.org/en/latest/setup/custom-startup.html for more information + + preload-argv: + type: array + description: | + Arguments to pass into the preload scripts described above + + See https://docs.dask.org/en/latest/setup/custom-startup.html for more information + + unknown-task-duration: + type: string + description: | + Default duration for all tasks with unknown durations + + Over time the scheduler learns a duration for tasks. + However when it sees a new type of task for the first time it has to make a guess + as to how long it will take. This value is that guess. + + default-task-durations: + type: object + description: | + How long we expect function names to run + + Over time the scheduler will learn these values, but these give it a good starting point. + + validate: + type: boolean + description: | + Whether or not to run consistency checks during execution. + This is typically only used for debugging. + + dashboard: + type: object + description: | + Configuration options for Dask's real-time dashboard + + properties: + status: + type: object + description: The main status page of the dashboard + properties: + task-stream-length: + type: integer + minimum: 0 + description: | + The maximum number of tasks to include in the task stream plot + tasks: + type: object + description: | + The page which includes the full task stream history + properties: + task-stream-length: + type: integer + minimum: 0 + description: | + The maximum number of tasks to include in the task stream plot + tls: + type: object + description: | + Settings around securing the dashboard + properties: + ca-file: + type: + - string + - "null" + key: + type: + - string + - "null" + cert: + type: + - string + - "null" + bokeh-application: + type: object + description: | + Keywords to pass to the BokehTornado application + locks: + type: object + description: | + Settings for Dask's distributed Lock object + + See https://docs.dask.org/en/latest/futures.html#locks for more information + properties: + lease-validation-interval: + type: string + description: | + The time to wait until an acquired semaphore is released if the Client goes out of scope + lease-timeout: + type: string + description: | + The timeout after which a lease will be released if not refreshed + + http: + type: object + decription: Settings for Dask's embedded HTTP Server + properties: + routes: + type: array + description: | + A list of modules like "prometheus" and "health" that can be included or excluded as desired + + These modules will have a ``routes`` keyword that gets added to the main HTTP Server. + This is also a list that can be extended with user defined modules. + + + worker: + type: object + description: | + Configuration settings for Dask Workers + properties: + blocked-handlers: + type: array + description: | + A list of handlers to exclude + + The scheduler operates by receiving messages from various workers and clients + and then performing operations based on those messages. + Each message has an operation like "close-worker" or "task-finished". + In some high security situations administrators may choose to block certain handlers + from running. Those handlers can be listed here. + + For a list of handlers see the `dask.distributed.Scheduler.handlers` attribute. + + multiprocessing-method: + type: string + description: | + How we create new workers, one of "spawn", "forkserver", or "fork" + + This is passed to the ``multiprocessing.get_context`` function. + use-file-locking: + type: boolean + description: | + Whether or not to use lock files when creating workers + + Workers create a local directory in which to place temporary files. + When many workers are created on the same process at once + these workers can conflict with each other by trying to create this directory all at the same time. + + To avoid this, Dask usually used a file-based lock. + However, on some systems file-based locks don't work. + This is particularly common on HPC NFS systems, where users may want to set this to false. + connections: + type: object + description: | + The number of concurrent connections to allow to other workers + properties: + incoming: + type: integer + minimum: 0 + outgoing: + type: integer + minimum: 0 + + preload: + type: array + description: | + Run custom modules during the lifetime of the worker + + You can run custom modules when the worker starts up and closes down. + See https://docs.dask.org/en/latest/setup/custom-startup.html for more information + + preload-argv: + type: array + description: | + Arguments to pass into the preload scripts described above + + See https://docs.dask.org/en/latest/setup/custom-startup.html for more information + + daemon: + type: boolean + description: | + Whether or not to run our process as a daemon process + + validate: + type: boolean + description: | + Whether or not to run consistency checks during execution. + This is typically only used for debugging. + + lifetime: + type: object + description: | + The worker may choose to gracefully close itself down after some pre-determined time. + + This is particularly useful if you know that your worker job has a time limit on it. + This is particularly common in HPC job schedulers. + + For example if your worker has a walltime of one hour, + then you may want to set the lifetime.duration to "55 minutes" + properties: + duration: + type: + - string + - "null" + description: | + The time after creation to close the worker, like "1 hour" + stagger: + type: string + description: | + Random amount by which to stagger lifetimes + + If you create many workers at the same time, + you may want to avoid having them kill themselves all at the same time. + To avoid this you might want to set a stagger time, + so that they close themselves with some random variation, like "5 minutes" + + That way some workers can die, new ones can be brought up, + and data can be transferred over smoothly. + restart: + type: boolean + description: | + Do we try to resurrect the worker after the lifetime deadline? + + + profile: + type: object + description: | + The workers periodically poll every worker thread to see what they are working on. + This data gets collected into statistical profiling information, + which is then periodically bundled together and sent along to the scheduler. + properties: + interval: + type: string + description: | + The time between polling the worker threads, typically short like 10ms + cycle: + type: string + description: | + The time between bundling together this data and sending it to the scheduler + + This controls the granularity at which people can query the profile information + on the time axis. + low-level: + type: boolean + description: | + Whether or not to use the libunwind and stacktrace libraries + to gather profiling information at the lower level (beneath Python) + + To get this to work you will need to install the experimental stacktrace library at + + conda install -c numba stacktrace + + See https://github.com/numba/stacktrace + + memory: + type: object + description: | + When Dask workers have more data than memory they spill this data to disk. + They do this at a few conditions. + properties: + target: + type: number + minimum: 0 + maximum: 1 + description: | + Target fraction below which to try to keep memory + + spill: + type: number + minimum: 0 + maximum: 1 + description: | + When the process memory (as observed by the operating system) gets above this amount we spill data to disk. + + pause: + type: number + minimum: 0 + maximum: 1 + description: | + When the process memory (as observed by the operating system) gets above this amount + we no longer start new tasks on this worker. + + terminate: + type: number + minimum: 0 + maximum: 1 + description: | + When the process memory reaches this level the nanny process will kill the worker + (if a nanny is present) + + http: + type: object + decription: Settings for Dask's embedded HTTP Server + properties: + routes: + type: array + description: | + A list of modules like "prometheus" and "health" that can be included or excluded as desired + + These modules will have a ``routes`` keyword that gets added to the main HTTP Server. + This is also a list that can be extended with user defined modules. + http: + type: object + decription: Settings for Dask's embedded HTTP Server + properties: + routes: + type: array + description: | + A list of modules like "prometheus" and "health" that can be included or excluded as desired + + These modules will have a ``routes`` keyword that gets added to the main HTTP Server. + This is also a list that can be extended with user defined modules. + + nanny: + type: object + description: | + Configuration settings for Dask Nannies + properties: + + preload: + type: array + description: | + Run custom modules during the lifetime of the scheduler + + You can run custom modules when the scheduler starts up and closes down. + See https://docs.dask.org/en/latest/setup/custom-startup.html for more information + + preload-argv: + type: array + description: | + Arguments to pass into the preload scripts described above + + See https://docs.dask.org/en/latest/setup/custom-startup.html for more information + + client: + type: object + description: | + Configuration settings for Dask Clients + + properties: + heartbeat: + type: string + description: + This value is the time between heartbeats + + The client sends a periodic heartbeat message to the scheduler. + If it misses enough of these then the scheduler assumes that it has gone. + + scheduler-info-interval: + type: string + description: Interval between scheduler-info updates + + deploy: + type: object + description: Configuration settings for general Dask deployment + properties: + lost-worker-timeout: + type: string + description: | + Interval after which to hard-close a lost worker job + + Otherwise we wait for a while to see if a worker will reappear + + cluster-repr-interval: + type: string + description: Interval between calls to update cluster-repr for the widget + + adaptive: + type: object + description: Configuration settings for Dask's adaptive scheduling + properties: + interval: + type: string + description: | + The duration between checking in with adaptive scheduling load + + The adaptive system periodically checks scheduler load and determines + if it should scale the cluster up or down. + This is the timing between those checks. + + target-duration: + type: string + description: | + The desired time for the entire computation to run + + The adaptive system will try to start up enough workers to run + the computation in about this time. + + minimum: + type: integer + minimum: 0 + description: | + The minimum number of workers to keep around + + maximum: + type: number + minimum: 0 + description: | + The maximum number of workers to keep around + + wait-count: + type: integer + minimum: 1 + description: | + The number of times a worker should be suggested for removal before removing it + + This helps to smooth out the number of deployed workers + + comm: + type: object + description: Configuration settings for Dask communications + properties: + + retry: + type: object + description: | + Some operations (such as gathering data) are subject to re-tries with the below parameters + properties: + + count: + type: integer + minimum: 0 + description: | + The number of times to retry a connection + + delay: + type: object + properties: + min: + type: string + description: The first non-zero delay between retry attempts + max: + type: string + description: The maximum delay between retries + + compression: + type: string + description: | + The compression algorithm to use + + This could be one of lz4, snappy, zstd, or blosc + + offload: + type: + - boolean + - string + description: | + The size of message after which we choose to offload serialization to another thread + + In some cases, you may also choose to disable this altogether with the value false + This is useful if you want to include serialization in profiling data, + or if you have data types that are particularly sensitive to deserialization + + socket-backlog: + type: integer + description: | + When shuffling data between workers, there can + really be O(cluster size) connection requests + on a single worker socket, make sure the backlog + is large enough not to lose any. + + zstd: + type: object + description: Options for the Z Standard compression scheme + properties: + level: + type: integer + minimum: 1 + maximum: 22 + description: Compression level, between 1 and 22. + threads: + type: integer + minimum: -1 + description: | + Number of threads to use. + + 0 for single-threaded, -1 to infer from cpu count. + + timeouts: + type: object + properties: + connect: + type: string + tcp: + type: string + + require-encryption: + type: boolean + description: | + Whether to require encryption on non-local comms + + default-scheme: + type: string + description: The default protocol to use, like tcp or tls + + recent-messages-log-length: + type: integer + minimum: 0 + description: number of messages to keep for debugging + + tls: + type: object + properties: + ciphers: + type: + - string + - "null" + descsription: Allowed ciphers, specified as an OpenSSL cipher string. + + ca-file: + type: + - string + - "null" + description: Path to a CA file, in pem format + + scheduler: + type: object + description: TLS information for the scheduler + properties: + cert: + type: + - string + - "null" + description: Path to certificate file + key: + type: + - string + - "null" + description: | + Path to key file. + + Alternatively, the key can be appended to the cert file + above, and this field left blank + + worker: + type: object + description: TLS information for the worker + properties: + cert: + type: + - string + - "null" + description: Path to certificate file + key: + type: + - string + - "null" + description: | + Path to key file. + + Alternatively, the key can be appended to the cert file + above, and this field left blank + + client: + type: object + description: TLS information for the client + properties: + cert: + type: + - string + - "null" + description: Path to certificate file + key: + type: + - string + - "null" + description: | + Path to key file. + + Alternatively, the key can be appended to the cert file + above, and this field left blank + + dashboard: + type: object + properties: + link: + type: string + description: | + The form for the dashboard links + + This is used wherever we print out the link for the dashboard + It is filled in with relevant information like the schema, host, and port number + graph-max-items: + type: integer + minimum: 0 + description: maximum number of tasks to try to plot in "graph" view + + export-tool: + type: boolean + + admin: + type: object + description: | + Options for logs, event loops, and so on + properties: + tick: + type: object + description: | + Time between event loop health checks + + We set up a periodic callback to run on the event loop and check in fairly frequently. + (by default, this is every 20 milliseconds) + + If this periodic callback sees that the last time it checked in was several seconds ago + (by default, this is 3 seconds) + then it logs a warning saying that something has been stopping the event loop from smooth operation. + This is typically caused by GIL holding operations, + but could also be several other things. + + properties: + interval: + type: string + description: The time between ticks, default 20ms + limit : + type: string + description: The time allowed before triggering a warning + + max-error-length: + type: integer + minimum: 0 + description: | + Maximum length of traceback as text + + Some Python tracebacks can be very very long + (particularly in stack overflow errors) + + If the traceback is larger than this size (in bytes) then we truncate it. + + log-length: + type: integer + minimum: 0 + description: | + Default length of logs to keep in memory + + The scheduler and workers keep the last 10000 or so log entries in memory. + + log-format: + type: string + description: | + The log format to emit. + + See https://docs.python.org/3/library/logging.html#logrecord-attributes + + pdb-on-err: + type: boolean + description: Enter Python Debugger on scheduling error + + rmm: + type: object + description: | + Configuration options for the RAPIDS Memory Manager + properties: + pool-size: + type: + - integer + - "null" + description: + The size of the memory pool in bytes + ucx: + type: object + description: | + UCX provides access to other network interconnects like Infiniband and NVLINK + properties: + tcp: + type: + - boolean + - "null" + nvlink: + type: + - boolean + - "null" + infiniband: + type: + - boolean + - "null" + cuda_copy: + type: + - boolean + - "null" + net-devices: + type: + - string + - "null" + description: Define which Infiniband device to use diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index 4f95a179bc3..4103d592e2b 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -12,7 +12,7 @@ distributed: allowed-failures: 3 # number of retries before a task is considered bad bandwidth: 100000000 # 100 MB/s estimated worker-worker bandwidth blocked-handlers: [] - default-data-size: 1000 + default-data-size: 1kiB # Number of seconds to wait until workers or clients are removed from the events log # after they have been removed from the scheduler events-cleanup-delay: 1h diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 521415c7c25..82e5f812dd6 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -90,7 +90,9 @@ LOG_PDB = dask.config.get("distributed.admin.pdb-on-err") -DEFAULT_DATA_SIZE = dask.config.get("distributed.scheduler.default-data-size") +DEFAULT_DATA_SIZE = parse_bytes( + dask.config.get("distributed.scheduler.default-data-size") +) DEFAULT_EXTENSIONS = [ LockExtension, diff --git a/distributed/tests/test_config.py b/distributed/tests/test_config.py index 2017bb239f7..01cd6eec57b 100644 --- a/distributed/tests/test_config.py +++ b/distributed/tests/test_config.py @@ -3,6 +3,7 @@ import sys import tempfile import os +import yaml import pytest @@ -265,3 +266,50 @@ def test_logging_file_config(): """ subprocess.check_call([sys.executable, "-c", code]) os.remove(logging_config.name) + + +def test_schema(): + jsonschema = pytest.importorskip("jsonschema") + config_fn = os.path.join(os.path.dirname(__file__), "..", "distributed.yaml") + schema_fn = os.path.join(os.path.dirname(__file__), "..", "distributed-schema.yaml") + + with open(config_fn) as f: + config = yaml.safe_load(f) + + with open(schema_fn) as f: + schema = yaml.safe_load(f) + + jsonschema.validate(config, schema) + + +def test_schema_is_complete(): + config_fn = os.path.join(os.path.dirname(__file__), "..", "distributed.yaml") + schema_fn = os.path.join(os.path.dirname(__file__), "..", "distributed-schema.yaml") + + with open(config_fn) as f: + config = yaml.safe_load(f) + + with open(schema_fn) as f: + schema = yaml.safe_load(f) + + skip = {"default-task-durations", "bokeh-application"} + + def test_matches(c, s): + if set(c) != set(s["properties"]): + raise ValueError( + "\nThe distributed.yaml and distributed-schema.yaml files are not in sync.\n" + "This usually happens when we add a new configuration value,\n" + "but don't add the schema of that value to the distributed-schema.yaml file\n" + "Please modify these files to include the missing values: \n\n" + " distributed.yaml: {}\n" + " distributed-schema.yaml: {}\n\n" + "Examples in these files should be a good start, \n" + "even if you are not familiar with the jsonschema spec".format( + sorted(c), sorted(s["properties"]) + ) + ) + for k, v in c.items(): + if isinstance(v, dict) and k not in skip: + test_matches(c[k], s["properties"][k]) + + test_matches(config, schema) From 6db09f32d8ca58a56a50385a49150ef16a5d51b0 Mon Sep 17 00:00:00 2001 From: jakirkham Date: Tue, 21 Apr 2020 16:35:34 -0700 Subject: [PATCH 0804/1550] Relax NumPy requirement in UCX (#3731) * Make `device_array`'s shape a `tuple` While it works to have this be a single `int` (as it will be coerced to a `tuple`), go ahead and make it a `tuple` for clarity and to match more closely to the Numba case. * Use `"u1"` to specify `uint8` typed arrays This is equivalent to using NumPy's `uint8`, but has the added benefit of not requiring NumPy be imported to work. * Rename `is_cudas` to `cuda_frames` Matches the variable name in the `send` case to make things easier to follow. * Use `pack`/`unpack` for UCX frame metadata As `struct.pack` and `struct.unpack` are able to build `bytes` objects containing the frame metadata needed by UCX easily, just use these functions instead of creating NumPy arrays each time. Helps soften the NumPy requirement a bit. * Rename `cuda_array` to `device_array` Matches more closely to the name used by RMM and Numba. * Create function to allocate arrays on host To relax the NumPy requirement completely, add a function to allocate arrays on host. If NumPy is not present, this falls back to just allocating `bytearray` objects, which work just as well. * Fix formatting with black * Define `cuda_frames` with other frame definitions * Store `nframes` for simplicity Avoids multiple calls to `len(frames)`, is a bit easier to read, and matches the receive code path more closely. * Collect sizes along with other frame info * Use `sizes` to pick out non-trivial frames to send * Simply call `sum` on `sizes` for bytes sent * Use `host_array` to make buffers to receive into * Pack per frame metadata into one message To send fewer and larger messages, pack both which frames are on device and how large each frame is into one message. * Note what `struct` lines are packing/unpacking --- distributed/comm/ucx.py | 83 ++++++++++++++++++++++++++--------------- 1 file changed, 52 insertions(+), 31 deletions(-) diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index 4e6ca8116c8..7761afef7a1 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -6,10 +6,10 @@ .. _UCX: https://github.com/openucx/ucx """ import logging +import struct import weakref import dask -import numpy as np from .addressing import parse_host_port, unparse_host_port from .core import Comm, Connector, Listener, CommClosedError @@ -33,7 +33,8 @@ # required to ensure Dask configuration gets propagated to UCX, which needs # variables to be set before being imported. ucp = None -cuda_array = None +host_array = None +device_array = None def synchronize_stream(stream=0): @@ -46,7 +47,7 @@ def synchronize_stream(stream=0): def init_once(): - global ucp, cuda_array + global ucp, host_array, device_array if ucp is not None: return @@ -59,34 +60,42 @@ def init_once(): ucp.init(options=ucx_config, env_takes_precedence=True) + # Find the function, `host_array()`, to use when allocating new host arrays + try: + import numpy + + host_array = lambda n: numpy.empty((n,), dtype="u1") + except ImportError: + host_array = lambda n: bytearray(n) + # Find the function, `cuda_array()`, to use when allocating new CUDA arrays try: import rmm if hasattr(rmm, "DeviceBuffer"): - cuda_array = lambda n: rmm.DeviceBuffer(size=n) + device_array = lambda n: rmm.DeviceBuffer(size=n) else: # pre-0.11.0 import numba.cuda - def rmm_cuda_array(n): - a = rmm.device_array(n, dtype=np.uint8) + def rmm_device_array(n): + a = rmm.device_array(n, dtype="u1") weakref.finalize(a, numba.cuda.current_context) return a - cuda_array = rmm_cuda_array + device_array = rmm_device_array except ImportError: try: import numba.cuda - def numba_cuda_array(n): - a = numba.cuda.device_array((n,), dtype=np.uint8) + def numba_device_array(n): + a = numba.cuda.device_array((n,), dtype="u1") weakref.finalize(a, numba.cuda.current_context) return a - cuda_array = numba_cuda_array + device_array = numba_device_array except ImportError: - def cuda_array(n): + def device_array(n): raise RuntimeError( "In order to send/recv CUDA arrays, Numba or RMM is required" ) @@ -169,19 +178,25 @@ async def write( frames = await to_frames( msg, serializers=serializers, on_error=on_error ) + nframes = len(frames) + cuda_frames = tuple( + hasattr(f, "__cuda_array_interface__") for f in frames + ) + sizes = tuple(nbytes(f) for f in frames) send_frames = [ - each_frame for each_frame in frames if len(each_frame) > 0 + each_frame + for each_frame, each_size in zip(frames, sizes) + if each_size ] # Send meta data - cuda_frames = np.array( - [hasattr(f, "__cuda_array_interface__") for f in frames], - dtype=np.bool, - ) - await self.ep.send(np.array([len(frames)], dtype=np.uint64)) - await self.ep.send(cuda_frames) + + # Send # of frames (uint64) + await self.ep.send(struct.pack("Q", nframes)) + # Send which frames are CUDA (bool) and + # how large each frame is (uint64) await self.ep.send( - np.array([nbytes(f) for f in frames], dtype=np.uint64) + struct.pack(nframes * "?" + nframes * "Q", *cuda_frames, *sizes) ) # Send frames @@ -191,12 +206,12 @@ async def write( # syncing the default stream will wait for other non-blocking CUDA streams. # Note this is only sufficient if the memory being sent is not currently in use on # non-blocking CUDA streams. - if cuda_frames.any(): + if any(cuda_frames): synchronize_stream(0) for each_frame in send_frames: await self.ep.send(each_frame) - return sum(map(nbytes, send_frames)) + return sum(sizes) except (ucp.exceptions.UCXBaseException): self.abort() raise CommClosedError("While writing, the connection was closed") @@ -211,22 +226,28 @@ async def read(self, deserializers=("cuda", "dask", "pickle", "error")): try: # Recv meta data - nframes = np.empty(1, dtype=np.uint64) + + # Recv # of frames (uint64) + nframes_fmt = "Q" + nframes = host_array(struct.calcsize(nframes_fmt)) await self.ep.recv(nframes) - is_cudas = np.empty(nframes[0], dtype=np.bool) - await self.ep.recv(is_cudas) - sizes = np.empty(nframes[0], dtype=np.uint64) - await self.ep.recv(sizes) + (nframes,) = struct.unpack(nframes_fmt, nframes) + + # Recv which frames are CUDA (bool) and + # how large each frame is (uint64) + header_fmt = nframes * "?" + nframes * "Q" + header = host_array(struct.calcsize(header_fmt)) + await self.ep.recv(header) + header = struct.unpack(header_fmt, header) + cuda_frames, sizes = header[:nframes], header[nframes:] except (ucp.exceptions.UCXBaseException, CancelledError): self.abort() raise CommClosedError("While reading, the connection was closed") else: # Recv frames frames = [ - cuda_array(each_size) - if is_cuda - else np.empty(each_size, dtype=np.uint8) - for is_cuda, each_size in zip(is_cudas.tolist(), sizes.tolist()) + device_array(each_size) if is_cuda else host_array(each_size) + for is_cuda, each_size in zip(cuda_frames, sizes) ] recv_frames = [ each_frame for each_frame in frames if len(each_frame) > 0 @@ -234,7 +255,7 @@ async def read(self, deserializers=("cuda", "dask", "pickle", "error")): # It is necessary to first populate `frames` with CUDA arrays and synchronize # the default stream before starting receiving to ensure buffers have been allocated - if is_cudas.any(): + if any(cuda_frames): synchronize_stream(0) for each_frame in recv_frames: From ecdcb33720b2794c21b5bd322e4d2ccd8fec2787 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Fri, 24 Apr 2020 04:51:11 +0100 Subject: [PATCH 0805/1550] Reinstate support for legacy @gen_cluster functions (#3738) * gen_cluster and gen_test to accept legacy coroutines again * Legacy python cleanup --- distributed/_version.py | 4 +- distributed/comm/tests/test_comms.py | 7 +--- distributed/deploy/tests/test_local.py | 9 +---- distributed/metrics.py | 5 +-- distributed/protocol/pickle.py | 6 +-- distributed/protocol/tests/test_numpy.py | 3 -- distributed/protocol/tests/test_protocol.py | 3 -- distributed/tests/test_steal.py | 1 + distributed/tests/test_utils_test.py | 34 ++++++++++++++++ distributed/utils.py | 12 +----- distributed/utils_test.py | 45 ++++++--------------- 11 files changed, 57 insertions(+), 72 deletions(-) diff --git a/distributed/_version.py b/distributed/_version.py index f48634810c3..79f2770dd9c 100644 --- a/distributed/_version.py +++ b/distributed/_version.py @@ -96,9 +96,7 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env= if verbose: print("unable to find command, tried %s" % (commands,)) return None, None - stdout = p.communicate()[0].strip() - if sys.version_info[0] >= 3: - stdout = stdout.decode() + stdout = p.communicate()[0].strip().decode() if p.returncode != 0: if verbose: print("unable to run %s (error)" % dispcmd) diff --git a/distributed/comm/tests/test_comms.py b/distributed/comm/tests/test_comms.py index 035a95513fb..ac633c3aa86 100644 --- a/distributed/comm/tests/test_comms.py +++ b/distributed/comm/tests/test_comms.py @@ -646,8 +646,7 @@ async def handle_comm(comm): ) await comm.write({"x": "foo"}) # TODO: why is this necessary in Tornado 6 ? - # The wrong error is reported on Python 2, see https://github.com/tornadoweb/tornado/pull/2028 - if sys.version_info >= (3,) and os.name != "nt": + if os.name != "nt": try: # See https://serverfault.com/questions/793260/what-does-tlsv1-alert-unknown-ca-mean assert "unknown ca" in str(excinfo.value) @@ -670,9 +669,7 @@ async def handle_comm(comm): await connect( listener.contact_address, timeout=2, ssl_context=cli_ctx, ) - # The wrong error is reported on Python 2, see https://github.com/tornadoweb/tornado/pull/2028 - if sys.version_info >= (3,): - assert "certificate verify failed" in str(excinfo.value) + assert "certificate verify failed" in str(excinfo.value) # diff --git a/distributed/deploy/tests/test_local.py b/distributed/deploy/tests/test_local.py index 94a6016dd2a..0867968a894 100644 --- a/distributed/deploy/tests/test_local.py +++ b/distributed/deploy/tests/test_local.py @@ -67,7 +67,6 @@ def test_local_cluster_supports_blocked_handlers(loop): ) -@pytest.mark.skipif("sys.version_info[0] == 2", reason="fork issues") def test_close_twice(): with LocalCluster() as cluster: with Client(cluster.scheduler_address) as client: @@ -81,7 +80,6 @@ def test_close_twice(): assert not log -@pytest.mark.skipif("sys.version_info[0] == 2", reason="multi-loop") def test_procs(): with LocalCluster( 2, @@ -173,13 +171,11 @@ def test_transports_tcp_port(): assert e.submit(inc, 4).result() == 5 -@pytest.mark.skipif("sys.version_info[0] == 2", reason="") class LocalTest(ClusterTest, unittest.TestCase): Cluster = partial(LocalCluster, silence_logs=False, dashboard_address=None) kwargs = {"dashboard_address": None, "processes": False} -@pytest.mark.skipif("sys.version_info[0] == 2", reason="") def test_Client_with_local(loop): with LocalCluster( 1, scheduler_port=0, silence_logs=False, dashboard_address=None, loop=loop @@ -429,7 +425,6 @@ def test_bokeh(loop, processes): requests.get(url, timeout=0.2) -@pytest.mark.skipif(sys.version_info < (3, 6), reason="Unknown") def test_blocks_until_full(loop): with Client(loop=loop) as c: assert len(c.nthreads()) > 0 @@ -462,7 +457,8 @@ async def test_scale_up_and_down(): @pytest.mark.xfail( sys.version_info >= (3, 8) and LooseVersion(tornado.version) < "6.0.3", - reason="Known issue with Python 3.8 and Tornado < 6.0.3. See https://github.com/tornadoweb/tornado/pull/2683.", + reason="Known issue with Python 3.8 and Tornado < 6.0.3. " + "See https://github.com/tornadoweb/tornado/pull/2683.", strict=True, ) def test_silent_startup(): @@ -549,7 +545,6 @@ def test_death_timeout_raises(loop): LocalCluster._instances.clear() # ignore test hygiene checks -@pytest.mark.skipif(sys.version_info < (3, 6), reason="Unknown") @pytest.mark.asyncio async def test_bokeh_kwargs(cleanup): pytest.importorskip("bokeh") diff --git a/distributed/metrics.py b/distributed/metrics.py index f28e9f2ac7f..0f7d78a8129 100755 --- a/distributed/metrics.py +++ b/distributed/metrics.py @@ -49,10 +49,7 @@ def __init__(self): self.delta = None self.last_resync = float("-inf") - if sys.version_info >= (3,): - perf_counter = timemod.perf_counter - else: - perf_counter = timemod.clock + perf_counter = timemod.perf_counter def time(self): delta = self.delta diff --git a/distributed/protocol/pickle.py b/distributed/protocol/pickle.py index 629fb962fbf..9a1f135444f 100644 --- a/distributed/protocol/pickle.py +++ b/distributed/protocol/pickle.py @@ -1,12 +1,8 @@ import logging -import sys +import pickle import cloudpickle -if sys.version_info.major == 2: - import cPickle as pickle -else: - import pickle logger = logging.getLogger(__name__) diff --git a/distributed/protocol/tests/test_numpy.py b/distributed/protocol/tests/test_numpy.py index 08a7c2df244..70ee582fd70 100644 --- a/distributed/protocol/tests/test_numpy.py +++ b/distributed/protocol/tests/test_numpy.py @@ -1,4 +1,3 @@ -import sys from zlib import crc32 import numpy as np @@ -189,7 +188,6 @@ def test_itemsize(dt, size): assert itemsize(np.dtype(dt)) == size -@pytest.mark.skipif(sys.version_info[0] < 3, reason="numpy doesnt use memoryviews") def test_compress_numpy(): pytest.importorskip("lz4") x = np.ones(10000000, dtype="i4") @@ -238,7 +236,6 @@ async def test_dumps_large_blosc(c, s, a, b): await x -@pytest.mark.skipif(sys.version_info[0] < 3, reason="numpy doesnt use memoryviews") def test_compression_takes_advantage_of_itemsize(): pytest.importorskip("lz4") blosc = pytest.importorskip("blosc") diff --git a/distributed/protocol/tests/test_protocol.py b/distributed/protocol/tests/test_protocol.py index d3536933a96..6c8296edecb 100644 --- a/distributed/protocol/tests/test_protocol.py +++ b/distributed/protocol/tests/test_protocol.py @@ -1,5 +1,3 @@ -import sys - import dask import pytest @@ -209,7 +207,6 @@ def test_dumps_loads_Serialized(): assert result == result3 -@pytest.mark.skipif(sys.version_info[0] < 3, reason="NumPy doesnt use memoryviews") def test_maybe_compress_memoryviews(): np = pytest.importorskip("numpy") pytest.importorskip("lz4") diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index fb5c96e14e6..5ef3e5330ec 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -106,6 +106,7 @@ async def test_worksteal_many_thieves(c, s, *workers): assert sum(map(len, s.has_what.values())) < 150 +@pytest.mark.xfail(reason="GH#3574") @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2) async def test_dont_steal_unknown_functions(c, s, a, b): futures = c.map(inc, range(100), workers=a.address, allow_other_workers=True) diff --git a/distributed/tests/test_utils_test.py b/distributed/tests/test_utils_test.py index 502b27b3013..4e9e776b590 100755 --- a/distributed/tests/test_utils_test.py +++ b/distributed/tests/test_utils_test.py @@ -5,6 +5,7 @@ from time import sleep import pytest +from tornado import gen from distributed import Scheduler, Worker, Client, config, default_client from distributed.core import rpc @@ -49,6 +50,28 @@ async def test_gen_cluster(c, s, a, b): for w in [a, b]: assert isinstance(w, Worker) assert s.nthreads == {w.address: w.nthreads for w in [a, b]} + assert await c.submit(lambda: 123) == 123 + + +@gen_cluster(client=True) +def test_gen_cluster_legacy_implicit(c, s, a, b): + assert isinstance(c, Client) + assert isinstance(s, Scheduler) + for w in [a, b]: + assert isinstance(w, Worker) + assert s.nthreads == {w.address: w.nthreads for w in [a, b]} + assert (yield c.submit(lambda: 123)) == 123 + + +@gen_cluster(client=True) +@gen.coroutine +def test_gen_cluster_legacy_explicit(c, s, a, b): + assert isinstance(c, Client) + assert isinstance(s, Scheduler) + for w in [a, b]: + assert isinstance(w, Worker) + assert s.nthreads == {w.address: w.nthreads for w in [a, b]} + assert (yield c.submit(lambda: 123)) == 123 @pytest.mark.skip(reason="This hangs on travis") @@ -101,6 +124,17 @@ async def test_gen_test(): await asyncio.sleep(0.01) +@gen_test() +def test_gen_test_legacy_implicit(): + yield asyncio.sleep(0.01) + + +@gen_test() +@gen.coroutine +def test_gen_test_legacy_explicit(): + yield asyncio.sleep(0.01) + + @contextmanager def _listen(delay=0): serv = socket.socket() diff --git a/distributed/utils.py b/distributed/utils.py index 46bd4c245e8..ea333833b08 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -8,7 +8,6 @@ import functools from hashlib import md5 import html -import inspect import json import logging import multiprocessing @@ -1039,10 +1038,7 @@ def import_file(path): if ext in (".egg", ".zip", ".pyz"): if path not in sys.path: sys.path.insert(0, path) - if sys.version_info >= (3, 6): - names = (mod_info.name for mod_info in pkgutil.iter_modules([path])) - else: - names = (mod_info[1] for mod_info in pkgutil.iter_modules([path])) + names = (mod_info.name for mod_info in pkgutil.iter_modules([path])) names_to_import.extend(names) loaded = [] @@ -1285,11 +1281,7 @@ def color_of(x, palette=palette): def iscoroutinefunction(f): - if gen.is_coroutine_function(f): - return True - if sys.version_info >= (3, 5) and inspect.iscoroutinefunction(f): - return True - return False + return inspect.iscoroutinefunction(f) or gen.is_coroutine_function(f) @contextmanager diff --git a/distributed/utils_test.py b/distributed/utils_test.py index e466322eddf..05467bbeb49 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -18,7 +18,6 @@ import subprocess import sys import tempfile -import textwrap import threading from time import sleep import uuid @@ -34,6 +33,7 @@ import dask from tlz import merge, memoize, assoc +from tornado import gen from tornado.ioloop import IOLoop from . import system @@ -397,25 +397,9 @@ async def geninc(x, delay=0.02): return x + 1 -def compile_snippet(code, dedent=True): - if dedent: - code = textwrap.dedent(code) - code = compile(code, "", "exec") - ns = globals() - exec(code, ns, ns) - - -if sys.version_info >= (3, 5): - compile_snippet( - """ - async def asyncinc(x, delay=0.02): - await asyncio.sleep(delay) - return x + 1 - """ - ) - assert asyncinc # noqa: F821 -else: - asyncinc = None +async def asyncinc(x, delay=0.02): + await asyncio.sleep(delay) + return x + 1 _readone_queues = {} @@ -768,9 +752,11 @@ async def test_foo(): def _(func): def test_func(): with clean() as loop: - if not iscoroutinefunction(func): - raise ValueError("@gen_test should wrap async def functions") - loop.run_sync(func, timeout=timeout) + if iscoroutinefunction(func): + cor = func + else: + cor = gen.coroutine(func) + loop.run_sync(cor, timeout=timeout) return test_func @@ -877,10 +863,10 @@ async def test_foo(scheduler, worker1, worker2): ) def _(func): - def test_func(): - if not iscoroutinefunction(func): - raise ValueError("@gen_cluster should wrap async def functions") + if not iscoroutinefunction(func): + func = gen.coroutine(func) + def test_func(): result = None workers = [] with clean(timeout=active_rpc_timeout, **clean_kwargs) as loop: @@ -1001,12 +987,7 @@ def terminate_process(proc): else: proc.send_signal(signal.SIGINT) try: - if sys.version_info[0] == 3: - proc.wait(10) - else: - start = time() - while proc.poll() is None and time() < start + 10: - sleep(0.02) + proc.wait(10) finally: # Make sure we don't leave the process lingering around with ignoring(OSError): From 4199c546154a75afa5404a7cfbaa8e864286dcf3 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Fri, 24 Apr 2020 23:22:18 -0500 Subject: [PATCH 0806/1550] bump version to 2.15.0 --- docs/source/changelog.rst | 42 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index c1bcab71eb5..628a1d2147a 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -1,6 +1,46 @@ Changelog ========= +2.15.0 - 2020-04-24 +------------------- + +- Reinstate support for legacy ``@gen_cluster`` functions (:pr:`3738`) `crusaderky`_ +- Relax NumPy requirement in UCX (:pr:`3731`) `jakirkham`_ +- Add Configuration Schema (:pr:`3696`) `Matthew Rocklin`_ +- Reuse CI scripts for local installation process (:pr:`3698`) `crusaderky`_ +- Use ``PeriodicCallback`` class from tornado (:pr:`3725`) `James Bourbeau`_ +- Add ``remote_python`` option in ssh cmd (:pr:`3709`) `Abdulelah Bin Mahfoodh`_ +- Configurable polling interval for cluster widget (:pr:`3723`) `Julia Signell`_ +- Fix copy-paste in docs (:pr:`3728`) `Julia Signell`_ +- Replace ``gen.coroutine`` with async-await in tests (:pr:`3706`) `crusaderky`_ +- Fix flaky ``test_oversubscribing_leases`` (:pr:`3726`) `Florian Jetter`_ +- Add ``batch_size`` to ``Client.map`` (:pr:`3650`) `Tom Augspurger`_ +- Adjust semaphore test timeouts (:pr:`3720`) `Florian Jetter`_ +- Dask-serialize dicts longer than five elements (:pr:`3689`) `Richard J Zamora`_ +- Force ``threads_per_worker`` (:pr:`3715`) `crusaderky`_ +- Idempotent semaphore acquire with retries (:pr:`3690`) `Florian Jetter`_ +- Always use ``readinto`` in TCP (:pr:`3711`) `jakirkham`_ +- Avoid ``DeprecationWarning`` from pandas (:pr:`3712`) `Tom Augspurger`_ +- Allow modification of ``distributed.comm.retry`` at runtime (:pr:`3705`) `Florian Jetter`_ +- Do not log an error on unset variable delete (:pr:`3652`) `Jonathan J. Helmus`_ +- Add ``remote_python`` keyword to the new ``SSHCluster`` (:pr:`3701`) `Abdulelah Bin Mahfoodh`_ +- Replace Example with Examples in docstrings (:pr:`3697`) `Matthew Rocklin`_ +- Add ``Cluster`` ``__enter__`` and ``__exit__`` methods (:pr:`3699`) `Matthew Rocklin`_ +- Fix propagating inherit config in ``SSHCluster`` for non-bash shells (:pr:`3688`) `Abdulelah Bin Mahfoodh`_ +- Add ``Client.wait_to_workers`` to ``Client`` autosummary table (:pr:`3692`) `James Bourbeau`_ +- Replace Bokeh Server with Tornado HTTPServer (:pr:`3658`) `Matthew Rocklin`_ +- Fix ``dask-ssh`` after removing ``local-directory`` from ``dask_scheduler`` cli (:pr:`3684`) `Abdulelah Bin Mahfoodh`_ +- Support preload modules in ``Nanny`` (:pr:`3678`) `Matthew Rocklin`_ +- Refactor semaphore internals: make ``_get_lease`` synchronous (:pr:`3679`) `Lucas Rademaker`_ +- Don't make task graphs too big (:pr:`3671`) `Martin Durant`_ +- Pass through ``connection``/``listen_args`` as splatted keywords (:pr:`3674`) `Matthew Rocklin`_ +- Run preload at import, start, and teardown (:pr:`3673`) `Matthew Rocklin`_ +- Use relative URL in scheduler dashboard (:pr:`3676`) `Nicholas Smith`_ +- Expose ``Security`` object as public API (:pr:`3675`) `Matthew Rocklin`_ +- Add zoom tools to profile plots (:pr:`3672`) `James Bourbeau`_ +- Update ``Scheduler.rebalance`` return value when data is missing (:pr:`3670`) `James Bourbeau`_ + + 2.14.0 - 2020-04-03 ------------------- @@ -1672,3 +1712,5 @@ significantly without many new features. .. _`Prasun Anand`: https://github.com/prasunanand .. _`Jonathan J. Helmus`: https://github.com/jjhelmus .. _`Rami Chowdhury`: https://github.com/necaris +.. _`crusaderky`: https://github.com/crusaderky +.. _`Nicholas Smith`: https://github.com/nsmith- From 2d54ef947a7fe372b1fafc40d2f2b89ff88e9449 Mon Sep 17 00:00:00 2001 From: Dillon Niederhut Date: Mon, 27 Apr 2020 10:21:18 -0500 Subject: [PATCH 0807/1550] BUG: allows logging config under distributed key (#2952) * BUG: allows logging config under distributed key The logging documentation specifies a configuration path to logging info like config['distributed']['logging'], but the config module looks in config['logging']. This commit allows both by looking first in config['distributed'], then falling back to config. Closes https://github.com/dask/distributed/issues/2937 --- distributed/config.py | 28 ++++++++++++++++++----- distributed/tests/test_config.py | 39 ++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 6 deletions(-) diff --git a/distributed/config.py b/distributed/config.py index a313f18416b..43e545576e5 100644 --- a/distributed/config.py +++ b/distributed/config.py @@ -80,7 +80,8 @@ def _initialize_logging_old_style(config): "tornado": "critical", "tornado.application": "error", } - loggers.update(config.get("logging", {})) + base_config = _find_logging_config(config) + loggers.update(base_config.get("logging", {})) handler = logging.StreamHandler(sys.stderr) handler.setFormatter( @@ -103,7 +104,8 @@ def _initialize_logging_new_style(config): Initialize logging using logging's "Configuration dictionary schema". (ref.: https://docs.python.org/3/library/logging.config.html#configuration-dictionary-schema) """ - logging.config.dictConfig(config.get("logging")) + base_config = _find_logging_config(config) + logging.config.dictConfig(base_config.get("logging")) def _initialize_logging_file_config(config): @@ -111,20 +113,34 @@ def _initialize_logging_file_config(config): Initialize logging using logging's "Configuration file format". (ref.: https://docs.python.org/3/howto/logging.html#configuring-logging) """ + base_config = _find_logging_config(config) logging.config.fileConfig( - config.get("logging-file-config"), disable_existing_loggers=False + base_config.get("logging-file-config"), disable_existing_loggers=False ) +def _find_logging_config(config): + """ + Look for the dictionary containing logging-specific configurations, + starting in the 'distributed' dictionary and then trying the top-level + """ + logging_keys = {"logging-file-config", "logging"} + if logging_keys & config.get("distributed", {}).keys(): + return config["distributed"] + else: + return config + + def initialize_logging(config): - if "logging-file-config" in config: - if "logging" in config: + base_config = _find_logging_config(config) + if "logging-file-config" in base_config: + if "logging" in base_config: raise RuntimeError( "Config options 'logging-file-config' and 'logging' are mutually exclusive." ) _initialize_logging_file_config(config) else: - log_config = config.get("logging", {}) + log_config = base_config.get("logging", {}) if "version" in log_config: # logging module mandates version to be an int log_config["version"] = int(log_config["version"]) diff --git a/distributed/tests/test_config.py b/distributed/tests/test_config.py index 01cd6eec57b..74b57b1f011 100644 --- a/distributed/tests/test_config.py +++ b/distributed/tests/test_config.py @@ -109,6 +109,45 @@ def test_logging_empty_simple(): test_logging_default() +def test_logging_simple_under_distributed(): + """ + Test simple ("old-style") logging configuration under the distributed key. + """ + c = { + "distributed": { + "logging": {"distributed.foo": "info", "distributed.foo.bar": "error"} + } + } + # Must test using a subprocess to avoid wrecking pre-existing configuration + with new_config_file(c): + code = """if 1: + import logging + import dask + + from distributed.utils_test import captured_handler + + d = logging.getLogger('distributed') + assert len(d.handlers) == 1 + assert isinstance(d.handlers[0], logging.StreamHandler) + df = logging.getLogger('distributed.foo') + dfb = logging.getLogger('distributed.foo.bar') + + with captured_handler(d.handlers[0]) as distributed_log: + df.info("1: info") + dfb.warning("2: warning") + dfb.error("3: error") + + distributed_log = distributed_log.getvalue().splitlines() + + assert distributed_log == [ + "distributed.foo - INFO - 1: info", + "distributed.foo.bar - ERROR - 3: error", + ], (dask.config.config, distributed_log) + """ + + subprocess.check_call([sys.executable, "-c", code]) + + def test_logging_simple(): """ Test simple ("old-style") logging configuration. From 3de9973cafaf85809f712a719111b8db3839975d Mon Sep 17 00:00:00 2001 From: Martin Durant Date: Mon, 27 Apr 2020 13:00:29 -0400 Subject: [PATCH 0808/1550] Memoryview serialisation (#3743) * Serialise memview without copy * black * Special-case deser memview with one frame --- distributed/protocol/serialize.py | 11 ++++++++++- distributed/protocol/tests/test_serialize.py | 10 ++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index a1b35ec4463..4d02bc65207 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -556,7 +556,7 @@ def normalize_Serialized(o): # Teach serialize how to handle bytestrings -@dask_serialize.register((bytes, bytearray)) +@dask_serialize.register((bytes, bytearray, memoryview)) def _serialize_bytes(obj): header = {} # no special metadata frames = [obj] @@ -568,6 +568,15 @@ def _deserialize_bytes(header, frames): return b"".join(frames) +@dask_deserialize.register(memoryview) +def _serialize_memoryview(header, frames): + if len(frames) == 1: + out = frames[0] + else: + out = b"".join(frames) + return memoryview(out) + + ######################### # Descend into __dict__ # ######################### diff --git a/distributed/protocol/tests/test_serialize.py b/distributed/protocol/tests/test_serialize.py index 6a5af842ddd..4cad5a3653b 100644 --- a/distributed/protocol/tests/test_serialize.py +++ b/distributed/protocol/tests/test_serialize.py @@ -409,6 +409,7 @@ def _(x): ([MyObj([0, 1, 2]), 1], True), (tuple([MyObj(None)]), True), ({("x", i): MyObj(5) for i in range(100)}, True), + (memoryview(b"hello"), True), ], ) def test_check_dask_serializable(data, is_serializable): @@ -428,3 +429,12 @@ def test_serialize_lists(serializers): data_out = deserialize(header, frames) assert data_in == data_out + + +def test_deser_memoryview(): + data_in = memoryview(b"hello") + header, frames = serialize(data_in) + assert header["type"] == "builtins.memoryview" + assert frames[0] is data_in + data_out = deserialize(header, frames) + assert data_in == data_out From 26a9fd6256c24098c4020bbe9deab2e6fc914cca Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 27 Apr 2020 11:50:22 -0700 Subject: [PATCH 0809/1550] Warn if cluster closes before starting (#3735) Otherwise users get an odd message like the following: ```python-traceback File "/home/XXX/.local/lib/python3.6/site-packages/tornado/ioloop.py", line 743, in _run_callback ret = callback() File "/home/XXX/.local/lib/python3.6/site-packages/tornado/ioloop.py", line 767, in _discard_future_result future.result() File "/home/XXX/.local/lib/python3.6/site-packages/distributed/deploy/spec.py", line 386, in _close await self.scheduler_comm.close(close_workers=True) AttributeError: 'NoneType' object has no attribute 'close' ``` --- distributed/deploy/spec.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index eb9f0f0043e..8419659ca3a 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -383,7 +383,10 @@ async def _close(self): await future async with self._lock: with ignoring(CommClosedError): - await self.scheduler_comm.close(close_workers=True) + if self.scheduler_comm: + await self.scheduler_comm.close(close_workers=True) + else: + logger.warning("Cluster closed without starting up") await self.scheduler.close() for w in self._created: From 7c57f853bf271e83e80983f53cb874b95bddf7b9 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Mon, 27 Apr 2020 19:15:31 -0500 Subject: [PATCH 0810/1550] Ensure BokehTornado uses prefix (#3746) --- distributed/dashboard/core.py | 6 ++---- .../dashboard/tests/test_scheduler_bokeh.py | 19 +++++++++++++++++++ 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/distributed/dashboard/core.py b/distributed/dashboard/core.py index 6843b0659b3..6c3c7e919c2 100644 --- a/distributed/dashboard/core.py +++ b/distributed/dashboard/core.py @@ -27,10 +27,7 @@ def BokehApplication(applications, server, prefix="/", template_variables={}): extra = toolz.merge({"prefix": prefix}, template_variables) - apps = { - prefix + k.lstrip("/"): functools.partial(v, server, extra) - for k, v in applications.items() - } + apps = {k: functools.partial(v, server, extra) for k, v in applications.items()} apps = {k: Application(FunctionHandler(v)) for k, v in apps.items()} kwargs = dask.config.get("distributed.scheduler.dashboard.bokeh-application").copy() extra_websocket_origins = create_hosts_whitelist( @@ -39,6 +36,7 @@ def BokehApplication(applications, server, prefix="/", template_variables={}): application = BokehTornado( apps, + prefix=prefix, use_index=False, extra_websocket_origins=extra_websocket_origins, **kwargs, diff --git a/distributed/dashboard/tests/test_scheduler_bokeh.py b/distributed/dashboard/tests/test_scheduler_bokeh.py index 8ed1bb0f8a1..f943807d4df 100644 --- a/distributed/dashboard/tests/test_scheduler_bokeh.py +++ b/distributed/dashboard/tests/test_scheduler_bokeh.py @@ -8,6 +8,7 @@ import pytest pytest.importorskip("bokeh") +from bokeh.server.server import BokehTornado from tlz import first from tornado.httpclient import AsyncHTTPClient, HTTPRequest @@ -712,3 +713,21 @@ async def test_memory_by_key(c, s, a, b): mbk.update() assert mbk.source.data["name"] == ["add", "inc"] assert mbk.source.data["nbytes"] == [x.nbytes, sys.getsizeof(1)] + + +@gen_cluster(scheduler_kwargs={"http_prefix": "foo-bar", "dashboard": True}) +async def test_prefix_bokeh(s, a, b): + prefix = "foo-bar" + http_client = AsyncHTTPClient() + response = await http_client.fetch( + f"http://localhost:{s.http_server.port}/{prefix}/status" + ) + assert response.code == 200 + assert ( + f' - + + +
        - {% block content %} - {% endblock %} + {% block content %} + {% endblock %}
        - - + + + \ No newline at end of file From d419e41952c4da376e584c0874dbf91014c14ec8 Mon Sep 17 00:00:00 2001 From: Florian Jetter Date: Tue, 29 Jun 2021 15:57:47 +0200 Subject: [PATCH 1339/1550] Ensure shuffle split operations are blacklisted from work stealing (#4964) If shuffle split tasks are not blacklisted from work stealing, this can have catastrophic effects on performance. See also https://github.com/dask/distributed/issues/4962 --- distributed/stealing.py | 2 +- distributed/tests/test_steal.py | 35 +++++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/distributed/stealing.py b/distributed/stealing.py index 1929661abc5..e3398b4c9a1 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -453,4 +453,4 @@ def _can_steal(thief, ts, victim): return True -fast_tasks = {"shuffle-split"} +fast_tasks = {"split-shuffle"} diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index ee2695cea87..fbabd2a6086 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -827,3 +827,38 @@ async def test_balance_with_longer_task(c, s, a, b): ) # a task after y, suggesting a, but open to b await z assert z.key in b.data + + +@gen_cluster(client=True) +async def test_blacklist_shuffle_split(c, s, a, b): + + pd = pytest.importorskip("pandas") + dd = pytest.importorskip("dask.dataframe") + npart = 10 + df = dd.from_pandas(pd.DataFrame({"A": range(100), "B": 1}), npartitions=npart) + graph = df.shuffle( + "A", + shuffle="tasks", + # If we don't have enough partitions, we'll fall back to a simple shuffle + max_branch=npart - 1, + ).sum() + res = c.compute(graph) + + while not s.tasks: + await asyncio.sleep(0.005) + prefixes = set(s.task_prefixes.keys()) + from distributed.stealing import fast_tasks + + blacklisted = fast_tasks & prefixes + assert blacklisted + assert any(["split" in prefix for prefix in blacklisted]) + + stealable = s.extensions["stealing"].stealable + while not res.done(): + for tasks_per_level in stealable.values(): + for tasks in tasks_per_level: + for ts in tasks: + assert ts.prefix.name not in fast_tasks + assert "split" not in ts.prefix.name + await asyncio.sleep(0.001) + await res From b9d2e3bc151c21ed96975f5c71b8c1d5ece6fc8a Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 29 Jun 2021 06:58:54 -0700 Subject: [PATCH 1340/1550] Add maximum shard size to config (#4986) --- distributed/distributed-schema.yaml | 9 +++++++++ distributed/distributed.yaml | 1 + distributed/protocol/utils.py | 4 +++- 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/distributed/distributed-schema.yaml b/distributed/distributed-schema.yaml index 80f7adce25f..440a39fb2f0 100644 --- a/distributed/distributed-schema.yaml +++ b/distributed/distributed-schema.yaml @@ -629,6 +629,15 @@ properties: This is useful if you want to include serialization in profiling data, or if you have data types that are particularly sensitive to deserialization + shard: + type: string + description: | + The maximum size of a frame to send through a comm + + Some network infrastructure doesn't like sending through very large messages. + Dask comms will cut up these large messages into many small ones. + This attribute determines the maximum size of such a shard. + socket-backlog: type: integer description: | diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index f948c45fd99..9ea6360b49f 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -165,6 +165,7 @@ distributed: min: 1s # the first non-zero delay between re-tries max: 20s # the maximum delay between re-tries compression: auto + shard: 64MiB offload: 10MiB # Size after which we choose to offload serialization to another thread default-scheme: tcp socket-backlog: 2048 diff --git a/distributed/protocol/utils.py b/distributed/protocol/utils.py index 25ccce7c9f6..3f5a2f8f500 100644 --- a/distributed/protocol/utils.py +++ b/distributed/protocol/utils.py @@ -1,8 +1,10 @@ import struct +import dask + from ..utils import nbytes -BIG_BYTES_SHARD_SIZE = 2 ** 26 +BIG_BYTES_SHARD_SIZE = dask.utils.parse_bytes(dask.config.get("distributed.comm.shard")) msgpack_opts = { From dbb13ecf3c78f6ad301c8c40b18cebdef71789bf Mon Sep 17 00:00:00 2001 From: Florian Jetter Date: Tue, 29 Jun 2021 16:02:56 +0200 Subject: [PATCH 1341/1550] No longer hold dependencies of erred tasks in memory #4918 This is a follow up to #4784 and reduces complexity of Worker.release_key significantly. There is one non-trivial behavioural change regarding erred tasks. Current main branch holds on to dependencies of an erred task on a worker and implements a release mechanism once that erred task is released. I implemented this recently trying to capture status quo but I'm not convinced any longer that this is the correct behaviour. It treats the erred case specially which introduces a lot of complexity. The only place where this might be of interest is if an erred task wants to be recomputed locally. Not forgetting the data keys until the erred task was released would speed up this process. However, we'd still need to potentially compute some keys and I'm inclined to strike this feature in favour of reduced complexity. --- distributed/tests/test_worker.py | 51 ++++++++++++++---------- distributed/worker.py | 66 ++++++-------------------------- 2 files changed, 41 insertions(+), 76 deletions(-) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 2f3a7f58ede..903241f7225 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -48,6 +48,7 @@ mul, nodebug, slowinc, + slowsum, ) from distributed.worker import Worker, error_message, logger, parse_memory_limit @@ -2087,11 +2088,6 @@ def raise_exc(*args): await asyncio.sleep(0.01) expected_states = { - # We currently don't have a good way to actually release this memory as - # long as the tasks still have a dependent. We'll need to live with this - # memory for now - f.key: "memory", - g.key: "memory", res.key: "error", } @@ -2159,6 +2155,7 @@ def raise_exc(*args): expected_states = { f.key: "memory", + g.key: "memory", } assert_task_states_on_worker(expected_states, a) @@ -2166,7 +2163,6 @@ def raise_exc(*args): f.release() g.release() - # This is not happening for server in [s, a, b]: while server.tasks: await asyncio.sleep(0.01) @@ -2220,13 +2216,14 @@ def raise_exc(*args): res.release() # We no longer hold any refs to f or g and B didn't have any erros. It # releases everything as expected - while a.tasks: + while len(a.tasks) > 1: await asyncio.sleep(0.01) expected_states = { g.key: "memory", } + assert_task_states_on_worker(expected_states, a) assert_task_states_on_worker(expected_states, b) g.release() @@ -2283,7 +2280,6 @@ def raise_exc(*args): assert_task_states_on_worker(expected_states_A, a) expected_states_B = { - f.key: "memory", g.key: "memory", h.key: "memory", res.key: "error", @@ -2301,15 +2297,6 @@ def raise_exc(*args): # B must not forget a task since all have a still valid dependent expected_states_B = { - f.key: "memory", - # We actually cannot hold on to G even though the graph would suggest - # otherwise. This is because H was only introduced as a dependency and - # the scheduler never told the worker how H fits into the big picture. - # Therefore, it thinks that G does not have any dependents anymore and - # releases it. Too bad. Once we have speculative task assignments this - # should be more exact since we should always tell the worker what's - # going on - # g.key: released, h.key: "memory", res.key: "error", } @@ -2320,10 +2307,6 @@ def raise_exc(*args): expected_states_A = {} assert_task_states_on_worker(expected_states_A, a) expected_states_B = { - f.key: "memory", - # See above - # g.key: released, - h.key: "memory", res.key: "error", } @@ -2334,3 +2317,29 @@ def raise_exc(*args): for server in [s, a, b]: while server.tasks: await asyncio.sleep(0.01) + + +@gen_cluster(client=True, nthreads=[("127.0.0.1", x) for x in range(4)], timeout=None) +async def test_hold_on_to_replicas(c, s, *workers): + f1 = c.submit(inc, 1, workers=[workers[0].address], key="f1") + f2 = c.submit(inc, 2, workers=[workers[1].address], key="f2") + + sum_1 = c.submit( + slowsum, [f1, f2], delay=0.1, workers=[workers[2].address], key="sum" + ) + sum_2 = c.submit( + slowsum, [f1, sum_1], delay=0.2, workers=[workers[3].address], key="sum_2" + ) + f1.release() + f2.release() + + while sum_2.key not in workers[3].tasks: + await asyncio.sleep(0.01) + + while not workers[3].tasks[sum_2.key].state == "memory": + assert len(s.tasks[f1.key].who_has) >= 2 + assert s.tasks[f2.key].state == "released" + await asyncio.sleep(0.01) + + while len(workers[2].tasks) > 1: + await asyncio.sleep(0.01) diff --git a/distributed/worker.py b/distributed/worker.py index 3d5146f09e1..1f04b4cbde5 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1658,8 +1658,7 @@ def add_task( ts.dependencies.add(dep_ts) dep_ts.dependents.add(ts) - if dep_ts.state in ("fetch", "flight"): - # if we _need_ to grab data or are in the process + if dep_ts.state not in ("memory",): ts.waiting_for_data.add(dep_ts.key) self.update_who_has(who_has=who_has) @@ -1762,9 +1761,6 @@ def transition_fetch_waiting(self, ts, runspec): # clear `who_has` of stale info ts.who_has.clear() - # remove entry from dependents to avoid a spurious `gather_dep` call`` - for dependent in ts.dependents: - dependent.waiting_for_data.discard(ts.key) except Exception as e: logger.exception(e) if LOG_PDB: @@ -1794,9 +1790,6 @@ def transition_flight_waiting(self, ts, runspec): # clear `who_has` of stale info ts.who_has.clear() - # remove entry from dependents to avoid a spurious `gather_dep` call`` - for dependent in ts.dependents: - dependent.waiting_for_data.discard(ts.key) except Exception as e: logger.exception(e) if LOG_PDB: @@ -1991,6 +1984,8 @@ def transition_executing_done(self, ts, value=no_value, report=True): ts.traceback = msg["traceback"] ts.state = "error" out = "error" + for d in ts.dependents: + d.waiting_for_data.add(ts.key) # Don't release the dependency keys, but do remove them from `dependents` for dependency in ts.dependencies: @@ -2621,12 +2616,12 @@ def release_key( if self.validate: assert isinstance(key, str) - ts = self.tasks.get(key, TaskState(key=key)) + ts = self.tasks.get(key, None) # If the scheduler holds a reference which is usually the # case when it instructed the task to be computed here or if # data was scattered we must not release it unless the # scheduler allow us to. See also handle_delete_data and - if ts and ts.scheduler_holds_ref: + if ts is None or ts.scheduler_holds_ref: return logger.debug( "Release key %s", @@ -2640,28 +2635,14 @@ def release_key( self.log.append((key, "release-key", {"cause": cause}, reason)) else: self.log.append((key, "release-key", reason)) - if key in self.data and not ts.dependents: + if key in self.data: try: del self.data[key] except FileNotFoundError: logger.error("Tried to delete %s but no file found", exc_info=True) - if key in self.actors and not ts.dependents: + if key in self.actors: del self.actors[key] - # for any dependencies of key we are releasing remove task as dependent - for dependency in ts.dependencies: - dependency.dependents.discard(ts) - - if not dependency.dependents and dependency.state not in ( - # don't boot keys that are in flight - # we don't know if they're already queued up for transit - # in a gather_dep callback - "flight", - # The same is true for already executing keys. - "executing", - ): - self.release_key(dependency.key, reason=f"Dependent {ts} released") - for worker in ts.who_has: self.has_what[worker].discard(ts.key) ts.who_has.clear() @@ -2681,8 +2662,10 @@ def release_key( # Inform the scheduler of keys which will have gone missing # We are releasing them before they have completed if ts.state in PROCESSING: + # This path is only hit with work stealing msg = {"op": "release", "key": key, "cause": cause} else: + # This path is only hit when calling release_key manually msg = { "op": "release-worker-data", "keys": [key], @@ -2691,9 +2674,8 @@ def release_key( self.batched_stream.send(msg) self._notify_plugins("release_key", key, ts.state, cause, reason, report) - if key in self.tasks and not ts.dependents: - self.tasks.pop(key) - del ts + del self.tasks[key] + except CommClosedError: pass except Exception as e: @@ -2704,32 +2686,6 @@ def release_key( pdb.set_trace() raise - def rescind_key(self, key): - try: - if self.tasks[key].state not in PENDING: - return - - ts = self.tasks.pop(key) - - # Task has been rescinded - # For every task that it required - for dependency in ts.dependencies: - # Remove it as a dependent - dependency.dependents.remove(key) - # If the dependent is now without purpose (no dependencies), remove it - if not dependency.dependents: - self.release_key( - dependency.key, reason="All dependent keys rescinded" - ) - - except Exception as e: - logger.exception(e) - if LOG_PDB: - import pdb - - pdb.set_trace() - raise - ################ # Execute Task # ################ From 661728267bed7d8ecdc039a134b247ad404e5291 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Tue, 29 Jun 2021 15:04:46 +0100 Subject: [PATCH 1342/1550] Misc Sphinx tweaks (#4988) --- distributed/diagnostics/__init__.py | 9 --------- docs/source/api.rst | 12 ++++++------ docs/source/develop.rst | 20 ++++++++------------ docs/source/examples/word-count.rst | 3 ++- docs/source/killed.rst | 2 +- docs/source/worker.rst | 5 ++++- 6 files changed, 21 insertions(+), 30 deletions(-) diff --git a/distributed/diagnostics/__init__.py b/distributed/diagnostics/__init__.py index 390a7b94f39..b286654974c 100644 --- a/distributed/diagnostics/__init__.py +++ b/distributed/diagnostics/__init__.py @@ -1,11 +1,2 @@ -from contextlib import suppress - from .graph_layout import GraphLayout from .plugin import SchedulerPlugin - -with suppress(ImportError): - from .progressbar import progress -with suppress(ImportError): - from .resource_monitor import Occupancy -with suppress(ImportError): - from .scheduler_widgets import scheduler_status diff --git a/docs/source/api.rst b/docs/source/api.rst index dc49a2d477f..b5d66b759c4 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -23,11 +23,11 @@ API rejoin Reschedule -.. currentmodule:: distributed.recreate_exceptions +.. currentmodule:: distributed.recreate_tasks .. autosummary:: - ReplayExceptionClient.get_futures_error - ReplayExceptionClient.recreate_error_locally + ReplayTaskClient.recreate_task_locally + ReplayTaskClient.recreate_error_locally .. currentmodule:: distributed @@ -56,7 +56,7 @@ API .. autosummary:: as_completed - distributed.diagnostics.progress + distributed.diagnostics.progressbar.progress wait fire_and_forget futures_of @@ -112,7 +112,7 @@ Client .. autoclass:: Client :members: -.. autoclass:: distributed.recreate_exceptions.ReplayExceptionClient +.. autoclass:: distributed.recreate_tasks.ReplayTaskClient :members: @@ -151,7 +151,7 @@ Other .. autoclass:: as_completed :members: -.. autofunction:: distributed.diagnostics.progress +.. autofunction:: distributed.diagnostics.progressbar.progress .. autofunction:: wait .. autofunction:: fire_and_forget .. autofunction:: futures_of diff --git a/docs/source/develop.rst b/docs/source/develop.rst index 4d30f038b17..857008f1cbf 100644 --- a/docs/source/develop.rst +++ b/docs/source/develop.rst @@ -12,21 +12,17 @@ guidelines`_ in the main documentation. Install ------- -Clone this repository with git:: +1. Clone this repository with git:: - git clone git@github.com:dask/distributed.git - cd distributed - -Install all dependencies: - -All OS:: + git clone git@github.com:dask/distributed.git + cd distributed -1. Install anaconda or miniconda -2. :: +2. Install anaconda or miniconda (OS-dependent) +3. :: - conda env create --file continuous_integration/environment-3.8.yaml - conda activate dask-distributed - python -m pip install -e . + conda env create --file continuous_integration/environment-3.8.yaml + conda activate dask-distributed + python -m pip install -e . To keep a fork in sync with the upstream source:: diff --git a/docs/source/examples/word-count.rst b/docs/source/examples/word-count.rst index ad81a45028a..50535dbf2ac 100644 --- a/docs/source/examples/word-count.rst +++ b/docs/source/examples/word-count.rst @@ -237,7 +237,8 @@ The complete Python script for this example is shown below: import hdfs3 from collections import defaultdict, Counter - from distributed import Client, progress + from distributed import Client + from distributed.diagnostics.progressbar import progress hdfs = hdfs3.HDFileSystem('NAMENODE_HOSTNAME', port=NAMENODE_PORT) client = Client('SCHEDULER_IP:SCHEDULER:PORT') diff --git a/docs/source/killed.rst b/docs/source/killed.rst index 837ccd944b4..707adc1bec8 100644 --- a/docs/source/killed.rst +++ b/docs/source/killed.rst @@ -80,7 +80,7 @@ of distributed may do this automatically) For other errors, you might want to run the computation in your local client, if possible, or try grabbing just the task that errored and using -:func:`recreate_error_locally `, +:meth:`~distributed.recreate_tasks.ReplayTaskClient.recreate_error_locally`, as you would for ordinary exceptions happening during task execution. Specifically for connectivity problems (e.g., timeout exceptions in the worker diff --git a/docs/source/worker.rst b/docs/source/worker.rst index 20b4ab0067e..faa392d132a 100644 --- a/docs/source/worker.rst +++ b/docs/source/worker.rst @@ -353,7 +353,10 @@ To aggressively and automatically trim the memory in a production environment, y should instead set the environment variable ``MALLOC_TRIM_THRESHOLD_`` (note the final underscore) to 0 or a low number; see the `mallopt`_ man page for details. Reducing this value will increase the number of syscalls, and as a consequence may degrade -performance. **The variable must be set before starting the ``dask-worker`` process.** +performance. + +.. note:: + The variable must be set before starting the ``dask-worker`` process. jemalloc ~~~~~~~~ From 5dc591bbdd4427fe49fe90338a34fc85ee35f2c9 Mon Sep 17 00:00:00 2001 From: Florian Jetter Date: Tue, 29 Jun 2021 17:51:22 +0200 Subject: [PATCH 1343/1550] Pyupgrade (#4741) * Add pyupgrade to pre-commit * Add minimal py37 to black formatter * Apply pyupgrade * Remove pyupgrade from pre-commit config --- .pre-commit-config.yaml | 2 ++ distributed/actor.py | 2 +- distributed/cli/dask_scheduler.py | 4 ++-- distributed/cli/dask_ssh.py | 6 ++--- distributed/cli/dask_worker.py | 4 ++-- distributed/client.py | 21 +++++++--------- distributed/comm/addressing.py | 10 ++++---- distributed/comm/core.py | 8 +++---- distributed/comm/inproc.py | 6 ++--- distributed/comm/tcp.py | 20 +++++++--------- distributed/comm/tests/test_comms.py | 6 ++--- distributed/comm/tests/test_ucx.py | 2 +- distributed/comm/tests/test_ucx_config.py | 2 +- distributed/comm/ucx.py | 4 ++-- distributed/comm/utils.py | 2 +- distributed/core.py | 16 ++++++------- distributed/dashboard/components/nvml.py | 6 ++--- distributed/dashboard/components/scheduler.py | 12 +++++----- distributed/dashboard/components/shared.py | 14 +++++------ distributed/dashboard/components/worker.py | 10 ++++---- distributed/deploy/adaptive.py | 4 ++-- distributed/deploy/adaptive_core.py | 2 +- distributed/deploy/old_ssh.py | 16 +++++-------- distributed/deploy/spec.py | 2 +- distributed/deploy/ssh.py | 2 +- distributed/deploy/tests/test_local.py | 2 +- distributed/diagnostics/progress.py | 6 ++--- distributed/diskutils.py | 6 ++--- distributed/http/proxy.py | 10 ++++---- distributed/http/tests/test_core.py | 4 +--- distributed/locket.py | 4 ++-- distributed/nanny.py | 2 +- distributed/node.py | 2 +- distributed/process.py | 10 ++++---- distributed/profile.py | 4 ++-- distributed/protocol/core.py | 2 +- distributed/publish.py | 3 +-- distributed/pubsub.py | 4 ++-- distributed/pytest_resourceleaks.py | 9 ++++--- distributed/queues.py | 2 +- distributed/scheduler.py | 12 +++++----- distributed/security.py | 4 ++-- distributed/tests/make_tls_certs.py | 6 ++--- distributed/tests/test_asyncprocess.py | 6 ++--- distributed/tests/test_client.py | 18 +++++++------- distributed/tests/test_diskutils.py | 2 +- distributed/tests/test_failed_workers.py | 2 +- distributed/tests/test_scheduler.py | 6 ++--- distributed/tests/test_security.py | 2 +- distributed/tests/test_semaphore.py | 4 ++-- distributed/tests/test_steal.py | 2 +- distributed/tests/test_utils.py | 4 ++-- distributed/tests/test_utils_comm.py | 4 ++-- distributed/utils.py | 24 +++++++++---------- distributed/utils_comm.py | 4 ++-- distributed/utils_test.py | 14 +++++------ distributed/versions.py | 9 +++---- distributed/worker.py | 18 +++++++------- docs/source/conf.py | 7 ++---- 59 files changed, 189 insertions(+), 212 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5b9075a55c0..7b497037a1a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,6 +10,8 @@ repos: - id: black language_version: python3 exclude: versioneer.py + args: + - --target-version=py37 - repo: https://gitlab.com/pycqa/flake8 rev: 3.9.2 hooks: diff --git a/distributed/actor.py b/distributed/actor.py index 77b2cda67de..2ebbba53a1c 100644 --- a/distributed/actor.py +++ b/distributed/actor.py @@ -69,7 +69,7 @@ def __init__(self, cls, address, key, worker=None): self._client = None def __repr__(self): - return "" % (self._cls.__name__, self.key) + return f"" def __reduce__(self): return (Actor, (self._cls, self._address, self.key)) diff --git a/distributed/cli/dask_scheduler.py b/distributed/cli/dask_scheduler.py index c6297eda5a3..acb4d04198f 100755 --- a/distributed/cli/dask_scheduler.py +++ b/distributed/cli/dask_scheduler.py @@ -130,7 +130,7 @@ def main( tls_cert, tls_key, dashboard_address, - **kwargs + **kwargs, ): g0, g1, g2 = gc.get_threshold() # https://github.com/dask/distributed/issues/1653 gc.set_threshold(g0 * 3, g1 * 3, g2 * 3) @@ -194,7 +194,7 @@ def del_pid_file(): dashboard=dashboard, dashboard_address=dashboard_address, http_prefix=dashboard_prefix, - **kwargs + **kwargs, ) logger.info("-" * 47) diff --git a/distributed/cli/dask_ssh.py b/distributed/cli/dask_ssh.py index f81cd73d495..8619949b588 100755 --- a/distributed/cli/dask_ssh.py +++ b/distributed/cli/dask_ssh.py @@ -175,10 +175,10 @@ def main( version=distributed.__version__ ) ) - print("Worker nodes: {n}".format(n=len(hostnames))) + print(f"Worker nodes: {len(hostnames)}") for i, host in enumerate(hostnames): - print(" {num}: {host}".format(num=i, host=host)) - print("\nscheduler node: {addr}:{port}".format(addr=scheduler, port=scheduler_port)) + print(f" {i}: {host}") + print(f"\nscheduler node: {scheduler}:{scheduler_port}") print("---------------------------------------------------------------\n\n") # Monitor the output of remote processes. This blocks until the user issues a KeyboardInterrupt. diff --git a/distributed/cli/dask_worker.py b/distributed/cli/dask_worker.py index cc004baf631..d297ef57923 100755 --- a/distributed/cli/dask_worker.py +++ b/distributed/cli/dask_worker.py @@ -265,7 +265,7 @@ def main( dashboard_address, worker_class, preload_nanny, - **kwargs + **kwargs, ): g0, g1, g2 = gc.get_threshold() # https://github.com/dask/distributed/issues/1653 gc.set_threshold(g0 * 3, g1 * 3, g2 * 3) @@ -419,7 +419,7 @@ def del_pid_file(): name=name if nprocs == 1 or name is None or name == "" else str(name) + "-" + str(i), - **kwargs + **kwargs, ) for i in range(nprocs) ] diff --git a/distributed/client.py b/distributed/client.py index 51047dc34c3..6a81fd6da6d 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -7,7 +7,6 @@ import json import logging import os -import socket import sys import threading import uuid @@ -399,9 +398,9 @@ def __repr__(self): typ = self.type.__module__.split(".")[0] + "." + self.type.__name__ except AttributeError: typ = str(self.type) - return "" % (self.status, typ, self.key) + return f"" else: - return "" % (self.status, self.key) + return f"" def _repr_html_(self): text = "Future: %s " % html.escape(key_split(self.key)) @@ -494,7 +493,7 @@ async def wait(self, timeout=None): await asyncio.wait_for(self._get_event().wait(), timeout) def __repr__(self): - return "<%s: %s>" % (self.__class__.__name__, self.status) + return f"<{self.__class__.__name__}: {self.status}>" async def done_callback(future, callback): @@ -660,9 +659,7 @@ def __init__( logger.info("Config value `scheduler-address` found: %s", address) if address is not None and kwargs: - raise ValueError( - "Unexpected keyword arguments: {}".format(str(sorted(kwargs))) - ) + raise ValueError(f"Unexpected keyword arguments: {str(sorted(kwargs))}") if isinstance(address, (rpc, PooledRPCCall)): self.scheduler = address @@ -907,12 +904,12 @@ def __repr__(self): return text elif self.scheduler is not None: - return "<%s: scheduler=%r>" % ( + return "<{}: scheduler={!r}>".format( self.__class__.__name__, self.scheduler.address, ) else: - return "<%s: No scheduler connected>" % (self.__class__.__name__,) + return f"<{self.__class__.__name__}: No scheduler connected>" def _repr_html_(self): scheduler, info = self._get_scheduler_info() @@ -1073,7 +1070,7 @@ async def _start(self, timeout=no_default, **kwargs): asynchronous=self._asynchronous, **self._startup_kwargs, ) - except (OSError, socket.error) as e: + except OSError as e: if e.errno != errno.EADDRINUSE: raise # The default port was taken, use a random one @@ -1123,7 +1120,7 @@ async def _reconnect(self): try: await self._ensure_connected(timeout=timeout) break - except EnvironmentError: + except OSError: # Wait a bit before retrying await asyncio.sleep(0.1) timeout = deadline - self.loop.time() @@ -1203,7 +1200,7 @@ async def _update_scheduler_info(self): return try: self._scheduler_identity = SchedulerInfo(await self.scheduler.identity()) - except EnvironmentError: + except OSError: logger.debug("Not able to query scheduler for identity") async def _wait_for_workers(self, n_workers=0, timeout=None): diff --git a/distributed/comm/addressing.py b/distributed/comm/addressing.py index 3612b284d0d..e51f2dfd71e 100644 --- a/distributed/comm/addressing.py +++ b/distributed/comm/addressing.py @@ -37,7 +37,7 @@ def unparse_address(scheme, loc): >>> unparse_address('tcp', '127.0.0.1') 'tcp://127.0.0.1' """ - return "%s://%s" % (scheme, loc) + return f"{scheme}://{loc}" def normalize_address(addr): @@ -60,11 +60,11 @@ def parse_host_port(address, default_port=None): return address def _fail(): - raise ValueError("invalid address %r" % (address,)) + raise ValueError(f"invalid address {address!r}") def _default(): if default_port is None: - raise ValueError("missing port number in address %r" % (address,)) + raise ValueError(f"missing port number in address {address!r}") return default_port if "://" in address: @@ -99,7 +99,7 @@ def unparse_host_port(host, port=None): if ":" in host and not host.startswith("["): host = "[%s]" % host if port is not None: - return "%s:%s" % (host, port) + return f"{host}:{port}" else: return host @@ -120,7 +120,7 @@ def get_address_host_port(addr, strict=False): return backend.get_address_host_port(loc) except NotImplementedError: raise ValueError( - "don't know how to extract host and port for address %r" % (addr,) + f"don't know how to extract host and port for address {addr!r}" ) diff --git a/distributed/comm/core.py b/distributed/comm/core.py index ccdcbb99c20..a80863155e4 100644 --- a/distributed/comm/core.py +++ b/distributed/comm/core.py @@ -157,9 +157,9 @@ def handshake_configuration(local, remote): def __repr__(self): clsname = self.__class__.__name__ if self.closed(): - return "" % (clsname,) + return f"" else: - return "<%s %s local=%s remote=%s>" % ( + return "<{} {} local={} remote={}>".format( clsname, self.name or "", self.local_address, @@ -307,7 +307,7 @@ def time_left(): ) await asyncio.sleep(backoff) else: - raise IOError( + raise OSError( f"Timed out trying to connect to {addr} after {timeout} s" ) from active_exception @@ -323,7 +323,7 @@ def time_left(): except Exception as exc: with suppress(Exception): await comm.close() - raise IOError( + raise OSError( f"Timed out during handshake while connecting to {addr} after {timeout} s" ) from exc diff --git a/distributed/comm/inproc.py b/distributed/comm/inproc.py index 7374fba188d..bc812540a5e 100644 --- a/distributed/comm/inproc.py +++ b/distributed/comm/inproc.py @@ -43,7 +43,7 @@ def __init__(self): def add_listener(self, addr, listener): with self.lock: if addr in self.listeners: - raise RuntimeError("already listening on %r" % (addr,)) + raise RuntimeError(f"already listening on {addr!r}") self.listeners[addr] = listener def remove_listener(self, addr): @@ -170,7 +170,7 @@ def __init__( def _get_finalizer(self): def finalize(write_q=self._write_q, write_loop=self._write_loop, r=repr(self)): - logger.warning("Closing dangling queue in %s" % (r,)) + logger.warning(f"Closing dangling queue in {r}") write_loop.add_callback(write_q.put_nowait, _EOF) return finalize @@ -296,7 +296,7 @@ def __init__(self, manager): async def connect(self, address, deserialize=True, **connection_args): listener = self.manager.get_listener_for(address) if listener is None: - raise IOError("no endpoint for inproc address %r" % (address,)) + raise OSError(f"no endpoint for inproc address {address!r}") conn_req = ConnectionRequest( c2s_q=Queue(), diff --git a/distributed/comm/tcp.py b/distributed/comm/tcp.py index 91e6af308e4..b938bd3752a 100644 --- a/distributed/comm/tcp.py +++ b/distributed/comm/tcp.py @@ -92,7 +92,7 @@ def set_tcp_timeout(comm): logger.debug("Setting TCP user timeout: %d ms", timeout * 1000) TCP_USER_TIMEOUT = 18 # since Linux 2.6.37 sock.setsockopt(socket.SOL_TCP, TCP_USER_TIMEOUT, timeout * 1000) - except EnvironmentError as e: + except OSError as e: logger.warning("Could not set timeout on TCP stream: %s", e) @@ -105,7 +105,7 @@ def get_stream_address(comm): try: return unparse_host_port(*comm.socket.getsockname()[:2]) - except EnvironmentError: + except OSError: # Probably EBADF return "" @@ -119,14 +119,10 @@ def convert_stream_closed_error(obj, exc): exc = exc.real_error if ssl and isinstance(exc, ssl.SSLError): if "UNKNOWN_CA" in exc.reason: - raise FatalCommClosedError( - "in %s: %s: %s" % (obj, exc.__class__.__name__, exc) - ) - raise CommClosedError( - "in %s: %s: %s" % (obj, exc.__class__.__name__, exc) - ) from exc + raise FatalCommClosedError(f"in {obj}: {exc.__class__.__name__}: {exc}") + raise CommClosedError(f"in {obj}: {exc.__class__.__name__}: {exc}") from exc else: - raise CommClosedError("in %s: %s" % (obj, exc)) from exc + raise CommClosedError(f"in {obj}: {exc}") from exc def _close_comm(ref): @@ -304,7 +300,7 @@ def close(self): if stream.writing(): yield stream.write(b"") stream.socket.shutdown(socket.SHUT_RDWR) - except EnvironmentError: + except OSError: pass finally: self._finalizer.detach() @@ -452,7 +448,7 @@ async def start(self): sockets = netutil.bind_sockets( self.port, address=self.ip, backlog=backlog ) - except EnvironmentError as e: + except OSError as e: # EADDRINUSE can happen sporadically when trying to bind # to an ephemeral port if self.port != 0 or e.errno != errno.EADDRINUSE: @@ -545,7 +541,7 @@ def _get_server_args(self, **connection_args): async def _prepare_stream(self, stream, address): try: await stream.wait_for_handshake() - except EnvironmentError as e: + except OSError as e: # The handshake went wrong, log and ignore logger.warning( "Listener on %r: TLS handshake failed with remote %r: %s", diff --git a/distributed/comm/tests/test_comms.py b/distributed/comm/tests/test_comms.py index 277bb05916c..a9539e83621 100644 --- a/distributed/comm/tests/test_comms.py +++ b/distributed/comm/tests/test_comms.py @@ -862,7 +862,7 @@ async def connect(self, address, deserialize=True, **connection_args): return await super().connect(address, deserialize, **connection_args) else: self.failures += 1 - raise IOError() + raise OSError() class UnreliableBackend(TCPBackend): _connector_class = UnreliableConnector @@ -950,8 +950,8 @@ async def handle_comm(comm): listener = await listen(addr, handle_comm) listeners.append(listener) - assert len(set(l.listen_address for l in listeners)) == N - assert len(set(l.contact_address for l in listeners)) == N + assert len({l.listen_address for l in listeners}) == N + assert len({l.contact_address for l in listeners}) == N for listener in listeners: listener.stop() diff --git a/distributed/comm/tests/test_ucx.py b/distributed/comm/tests/test_ucx.py index 5daaf7e8693..ecb6b471114 100644 --- a/distributed/comm/tests/test_ucx.py +++ b/distributed/comm/tests/test_ucx.py @@ -79,7 +79,7 @@ def test_ucx_specific(): # 3. Test peer_address # 4. Test cleanup async def f(): - address = "ucx://{}:{}".format(HOST, 0) + address = f"ucx://{HOST}:{0}" async def handle_comm(comm): msg = await comm.read() diff --git a/distributed/comm/tests/test_ucx_config.py b/distributed/comm/tests/test_ucx_config.py index d9eeabbe3ed..09e71acbb06 100644 --- a/distributed/comm/tests/test_ucx_config.py +++ b/distributed/comm/tests/test_ucx_config.py @@ -76,7 +76,7 @@ def test_ucx_config_w_env_var(cleanup, loop, monkeypatch): dask.config.refresh() port = "13339" - sched_addr = "ucx://%s:%s" % (HOST, port) + sched_addr = f"ucx://{HOST}:{port}" with popen( ["dask-scheduler", "--no-dashboard", "--protocol", "ucx", "--port", port] diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index 647ed8313a3..457a5c96f5c 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -373,7 +373,7 @@ def __init__( comm_handler: None, deserialize=False, allow_offload=True, - **connection_args + **connection_args, ): if not address.startswith("ucx"): address = "ucx://" + address @@ -525,7 +525,7 @@ def _scrub_ucx_config(): for k, v in options.items(): if k not in valid_ucx_vars: logger.debug( - "Key: %s with value: %s not a valid UCX configuration option" % (k, v) + f"Key: {k} with value: {v} not a valid UCX configuration option" ) return options diff --git a/distributed/comm/utils.py b/distributed/comm/utils.py index 35f3c33ef3d..5301265caf5 100644 --- a/distributed/comm/utils.py +++ b/distributed/comm/utils.py @@ -87,7 +87,7 @@ def get_tcp_server_addresses(tcp_server): """ sockets = list(tcp_server._sockets.values()) if not sockets: - raise RuntimeError("TCP Server %r not started yet?" % (tcp_server,)) + raise RuntimeError(f"TCP Server {tcp_server!r} not started yet?") def _look_for_family(fam): socks = [] diff --git a/distributed/core.py b/distributed/core.py index 8369bab1f5f..b8c2ba46e20 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -428,7 +428,7 @@ async def handle_comm(self, comm): try: msg = await comm.read() logger.debug("Message from %r: %s", address, msg) - except EnvironmentError as e: + except OSError as e: if not sys.is_finalizing(): logger.debug( "Lost connection to %r while reading message: %s." @@ -517,7 +517,7 @@ async def handle_comm(self, comm): if reply and not is_dont_reply: try: await comm.write(result, serializers=serializers) - except (EnvironmentError, TypeError) as e: + except (OSError, TypeError) as e: logger.debug( "Lost connection to %r while sending result for op %r: %s", address, @@ -579,7 +579,7 @@ async def handle_stream(self, comm, extra=None, every_cycle=[]): else: func() - except (CommClosedError, EnvironmentError): + except (CommClosedError, OSError): # FIXME: This is silently ignored, is this intentional? pass except Exception as e: @@ -647,7 +647,7 @@ async def send_recv(comm, reply=True, serializers=None, deserializers=None, **kw response = await comm.read(deserializers=deserializers) else: response = None - except (EnvironmentError, CommClosedError): + except (OSError, CommClosedError): # On communication errors, we should simply close the communication force_close = True raise @@ -763,7 +763,7 @@ async def _close_comm(comm): if not comm.closed(): await comm.write({"op": "close", "reply": False}) await comm.close() - except EnvironmentError: + except OSError: comm.abort() tasks = [] @@ -792,9 +792,7 @@ async def send_recv_from_rpc(**kwargs): comm.name = "rpc." + key result = await send_recv(comm=comm, op=key, **kwargs) except (RPCClosed, CommClosedError) as e: - raise e.__class__( - "%s: while trying to call remote method %r" % (e, key) - ) + raise e.__class__(f"{e}: while trying to call remote method {key!r}") self.comms[comm] = True # mark as open return result @@ -881,7 +879,7 @@ def __exit__(self, *args): pass def __repr__(self): - return "" % (self.addr,) + return f"" class ConnectionPool: diff --git a/distributed/dashboard/components/nvml.py b/distributed/dashboard/components/nvml.py index b5c5547d9bd..cdb331016c6 100644 --- a/distributed/dashboard/components/nvml.py +++ b/distributed/dashboard/components/nvml.py @@ -49,7 +49,7 @@ def __init__(self, scheduler, width=600, **kwargs): id="bk-gpu-memory-worker-plot", width=int(width / 2), name="gpu_memory_histogram", - **kwargs + **kwargs, ) rect = memory.rect( source=self.source, @@ -67,7 +67,7 @@ def __init__(self, scheduler, width=600, **kwargs): id="bk-gpu-utilization-worker-plot", width=int(width / 2), name="gpu_utilization_histogram", - **kwargs + **kwargs, ) rect = utilization.rect( source=self.source, @@ -159,7 +159,7 @@ def update(self): "escaped_worker": [escape.url_escape(w) for w in worker], } - self.memory_figure.title.text = "GPU Memory: %s / %s" % ( + self.memory_figure.title.text = "GPU Memory: {} / {}".format( format_bytes(sum(memory)), format_bytes(memory_total), ) diff --git a/distributed/dashboard/components/scheduler.py b/distributed/dashboard/components/scheduler.py index f5631f78989..fe72c4e98e0 100644 --- a/distributed/dashboard/components/scheduler.py +++ b/distributed/dashboard/components/scheduler.py @@ -164,9 +164,11 @@ def update(self): color.append("blue") if total: - self.root.title.text = "Occupancy -- total time: %s wall time: %s" % ( - format_time(total), - format_time(total / self.scheduler.total_nthreads), + self.root.title.text = ( + "Occupancy -- total time: {} wall time: {}".format( + format_time(total), + format_time(total / self.scheduler.total_nthreads), + ) ) else: self.root.title.text = "Occupancy" @@ -2640,9 +2642,7 @@ def update(self): for name in self.names + self.extra_names: if name == "name": - data[name].insert( - 0, "Total ({nworkers})".format(nworkers=len(data[name])) - ) + data[name].insert(0, f"Total ({len(data[name])})") continue try: if len(self.scheduler.workers) == 0: diff --git a/distributed/dashboard/components/shared.py b/distributed/dashboard/components/shared.py index 360f4433aa7..6ea83576c2e 100644 --- a/distributed/dashboard/components/shared.py +++ b/distributed/dashboard/components/shared.py @@ -57,7 +57,7 @@ def __init__(self, **kwargs): tools="", x_range=x_range, id="bk-processing-stacks-plot", - **kwargs + **kwargs, ) fig.quad( source=self.source, @@ -297,7 +297,7 @@ def select_cb(attr, old, new): ), self.profile_plot, self.ts_plot, - **kwargs + **kwargs, ) @without_property_validation @@ -434,7 +434,7 @@ def ts_change(attr, old, new): row(self.reset_button, self.update_button, sizing_mode="scale_width"), self.profile_plot, self.ts_plot, - **kwargs + **kwargs, ) @without_property_validation @@ -486,7 +486,7 @@ def __init__(self, worker, height=150, last_count=None, **kwargs): height=height, tools=tools, x_range=x_range, - **kwargs + **kwargs, ) self.cpu.line(source=self.source, x="time", y="cpu") self.cpu.yaxis.axis_label = "Percentage" @@ -508,7 +508,7 @@ def __init__(self, worker, height=150, last_count=None, **kwargs): height=height, tools=tools, x_range=x_range, - **kwargs + **kwargs, ) self.mem.line(source=self.source, x="time", y="memory") self.mem.yaxis.axis_label = "Bytes" @@ -530,7 +530,7 @@ def __init__(self, worker, height=150, last_count=None, **kwargs): height=height, x_range=x_range, tools=tools, - **kwargs + **kwargs, ) self.bandwidth.line(source=self.source, x="time", y="read_bytes", color="red") self.bandwidth.line(source=self.source, x="time", y="write_bytes", color="blue") @@ -549,7 +549,7 @@ def __init__(self, worker, height=150, last_count=None, **kwargs): height=height, x_range=x_range, tools=tools, - **kwargs + **kwargs, ) self.num_fds.line(source=self.source, x="time", y="num_fds") diff --git a/distributed/dashboard/components/worker.py b/distributed/dashboard/components/worker.py index e664cd4cae5..5ede5529afc 100644 --- a/distributed/dashboard/components/worker.py +++ b/distributed/dashboard/components/worker.py @@ -110,7 +110,7 @@ def __init__(self, worker, height=300, **kwargs): y_range=y_range, height=height, tools="", - **kwargs + **kwargs, ) fig.rect( @@ -178,7 +178,7 @@ def update(self): self.who[msg["who"]] = len(self.who) msg["y"] = self.who[msg["who"]] - msg["hover"] = "%s / %s = %s/s" % ( + msg["hover"] = "{} / {} = {}/s".format( format_bytes(msg["total"]), format_time(msg["duration"]), format_bytes(msg["total"] / msg["duration"]), @@ -212,7 +212,7 @@ def __init__(self, worker, **kwargs): height=150, tools="", x_range=x_range, - **kwargs + **kwargs, ) fig.line(source=self.source, x="x", y="in", color="red") fig.line(source=self.source, x="x", y="out", color="blue") @@ -250,7 +250,7 @@ def __init__(self, worker, **kwargs): height=150, tools="", x_range=x_range, - **kwargs + **kwargs, ) fig.line(source=self.source, x="x", y="y") @@ -440,7 +440,7 @@ def __init__(self, server, sizing_mode="stretch_both", **kwargs): row(*pair, sizing_mode=sizing_mode) for pair in partition_all(2, figures) ], - sizing_mode=sizing_mode + sizing_mode=sizing_mode, ) def add_digest_figure(self, name): diff --git a/distributed/deploy/adaptive.py b/distributed/deploy/adaptive.py index 8ffb8555d83..2f966a3f578 100644 --- a/distributed/deploy/adaptive.py +++ b/distributed/deploy/adaptive.py @@ -86,7 +86,7 @@ def __init__( wait_count=None, target_duration=None, worker_key=None, - **kwargs + **kwargs, ): self.cluster = cluster self.worker_key = worker_key @@ -178,7 +178,7 @@ async def workers_to_close(self, target: int): target=target, key=pickle.dumps(self.worker_key) if self.worker_key else None, attribute="name", - **self._workers_to_close_kwargs + **self._workers_to_close_kwargs, ) async def scale_down(self, workers): diff --git a/distributed/deploy/adaptive_core.py b/distributed/deploy/adaptive_core.py index b077261dc1d..64c9dd03a61 100644 --- a/distributed/deploy/adaptive_core.py +++ b/distributed/deploy/adaptive_core.py @@ -171,7 +171,7 @@ async def recommendations(self, target: int) -> dict: not_yet_arrived = requested - observed to_close = set() if not_yet_arrived: - to_close.update((toolz.take(len(plan) - target, not_yet_arrived))) + to_close.update(toolz.take(len(plan) - target, not_yet_arrived)) if target < len(plan) - len(to_close): L = await self.workers_to_close(target=target) diff --git a/distributed/deploy/old_ssh.py b/distributed/deploy/old_ssh.py index 77b01e2388f..6d158e27f05 100644 --- a/distributed/deploy/old_ssh.py +++ b/distributed/deploy/old_ssh.py @@ -100,7 +100,7 @@ def async_ssh(cmd_dict): print( " " + bcolors.FAIL - + "Retrying... (attempt {n}/{total})".format(n=retries, total=3) + + f"Retrying... (attempt {retries}/{3})" + bcolors.ENDC ) @@ -152,7 +152,7 @@ def read_from_stderr(): cmd_dict["output_queue"].put( "[ {label} ] : ".format(label=cmd_dict["label"]) + bcolors.FAIL - + "{output}".format(output=line) + + f"{line}" + bcolors.ENDC ) line = stderr.readline() @@ -215,18 +215,14 @@ def start_scheduler( # Optionally re-direct stdout and stderr to a logfile if logdir is not None: - cmd = "mkdir -p {logdir} && ".format(logdir=logdir) + cmd + cmd = f"mkdir -p {logdir} && " + cmd cmd += "&> {logdir}/dask_scheduler_{addr}:{port}.log".format( addr=addr, port=port, logdir=logdir ) # Format output labels we can prepend to each line of output, and create # a 'status' key to keep track of jobs that terminate prematurely. - label = ( - bcolors.BOLD - + "scheduler {addr}:{port}".format(addr=addr, port=port) - + bcolors.ENDC - ) + label = bcolors.BOLD + f"scheduler {addr}:{port}" + bcolors.ENDC # Create a command dictionary, which contains everything we need to run and # interact with this command. @@ -309,12 +305,12 @@ def start_worker( # Optionally redirect stdout and stderr to a logfile if logdir is not None: - cmd = "mkdir -p {logdir} && ".format(logdir=logdir) + cmd + cmd = f"mkdir -p {logdir} && " + cmd cmd += "&> {logdir}/dask_scheduler_{addr}.log".format( addr=worker_addr, logdir=logdir ) - label = "worker {addr}".format(addr=worker_addr) + label = f"worker {worker_addr}" # Create a command dictionary, which contains everything we need to run and # interact with this command. diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index 9c2f1292649..aa3dc3f7b84 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -97,7 +97,7 @@ async def finished(self): await self._event_finished.wait() def __repr__(self): - return "<%s: status=%s>" % (type(self).__name__, self.status) + return f"<{type(self).__name__}: status={self.status}>" async def __aenter__(self): await self diff --git a/distributed/deploy/ssh.py b/distributed/deploy/ssh.py index 64452c31721..492c4ce05ef 100644 --- a/distributed/deploy/ssh.py +++ b/distributed/deploy/ssh.py @@ -43,7 +43,7 @@ async def close(self): await super().close() def __repr__(self): - return "" % (type(self).__name__, self.status) + return f"" class Worker(Process): diff --git a/distributed/deploy/tests/test_local.py b/distributed/deploy/tests/test_local.py index ee1b12913dc..13a3a16d928 100644 --- a/distributed/deploy/tests/test_local.py +++ b/distributed/deploy/tests/test_local.py @@ -558,7 +558,7 @@ async def test_bokeh_kwargs(cleanup): ) as c: client = AsyncHTTPClient() response = await client.fetch( - "http://localhost:{}/foo/status".format(c.scheduler.http_server.port) + f"http://localhost:{c.scheduler.http_server.port}/foo/status" ) assert "bokeh" in response.body.decode() diff --git a/distributed/diagnostics/progress.py b/distributed/diagnostics/progress.py index 28500fd8077..40e67471259 100644 --- a/distributed/diagnostics/progress.py +++ b/distributed/diagnostics/progress.py @@ -229,11 +229,11 @@ def format_time(t): m, s = divmod(t, 60) h, m = divmod(m, 60) if h: - return "{0:2.0f}hr {1:2.0f}min {2:4.1f}s".format(h, m, s) + return f"{h:2.0f}hr {m:2.0f}min {s:4.1f}s" elif m: - return "{0:2.0f}min {1:4.1f}s".format(m, s) + return f"{m:2.0f}min {s:4.1f}s" else: - return "{0:4.1f}s".format(s) + return f"{s:4.1f}s" class AllProgress(SchedulerPlugin): diff --git a/distributed/diskutils.py b/distributed/diskutils.py index fe393872531..8d35d6a1e73 100644 --- a/distributed/diskutils.py +++ b/distributed/diskutils.py @@ -23,7 +23,7 @@ def is_locking_enabled(): def safe_unlink(path): try: os.unlink(path) - except EnvironmentError as e: + except OSError as e: # Perhaps it was removed by someone else? if e.errno != errno.ENOENT: logger.error("Failed to remove %r", str(e)) @@ -121,7 +121,7 @@ def __init__(self, base_dir): def _init_workspace(self): try: os.mkdir(self.base_dir) - except EnvironmentError as e: + except OSError as e: if e.errno != errno.EEXIST: raise @@ -174,7 +174,7 @@ def _list_unknown_locks(self): for p in glob.glob(os.path.join(self.base_dir, "*" + DIR_LOCK_EXT)): try: st = os.stat(p) - except EnvironmentError: + except OSError: # May have been removed in the meantime pass else: diff --git a/distributed/http/proxy.py b/distributed/http/proxy.py index 6e39a999990..73e1f3d42a8 100644 --- a/distributed/http/proxy.py +++ b/distributed/http/proxy.py @@ -24,13 +24,13 @@ async def http_get(self, port, host, proxied_path): self.host = host # rewrite uri for jupyter-server-proxy handling - uri = "/proxy/%s/%s" % (str(port), proxied_path) + uri = f"/proxy/{str(port)}/{proxied_path}" self.request.uri = uri # slash is removed during regex in handler proxied_path = "/%s" % proxied_path - worker = "%s:%s" % (self.host, str(port)) + worker = f"{self.host}:{str(port)}" if not check_worker_dashboard_exits(self.scheduler, worker): msg = "Worker <%s> does not exist" % worker self.set_status(400) @@ -81,9 +81,9 @@ def initialize(self, dask_server=None, extra=None): self.extra = extra or {} def get(self, port, host, proxied_path): - worker_url = "%s:%s/%s" % (host, str(port), proxied_path) + worker_url = f"{host}:{str(port)}/{proxied_path}" msg = """ -

        Try navigating to %s for your worker dashboard

        +

        Try navigating to {} for your worker dashboard

        Dask tried to proxy you to that page through your @@ -101,7 +101,7 @@ def get(self, port, host, proxied_path): but less common in production clusters. Your IT administrators will know more

        - """ % ( + """.format( worker_url, worker_url, ) diff --git a/distributed/http/tests/test_core.py b/distributed/http/tests/test_core.py index 61cb713fcf2..ea3a313525b 100644 --- a/distributed/http/tests/test_core.py +++ b/distributed/http/tests/test_core.py @@ -6,7 +6,5 @@ @gen_cluster(client=True) async def test_scheduler(c, s, a, b): client = AsyncHTTPClient() - response = await client.fetch( - "http://localhost:{}/health".format(s.http_server.port) - ) + response = await client.fetch(f"http://localhost:{s.http_server.port}/health") assert response.code == 200 diff --git a/distributed/locket.py b/distributed/locket.py index 906938e6085..bb383345b9a 100644 --- a/distributed/locket.py +++ b/distributed/locket.py @@ -63,7 +63,7 @@ def _lock_file_non_blocking(file_): try: fcntl.flock(file_.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB) return True - except IOError as error: + except OSError as error: if error.errno in [errno.EACCES, errno.EAGAIN]: return False else: @@ -109,7 +109,7 @@ def _acquire_non_blocking(acquire, timeout, retry_period, path): if success: return elif timeout is not None and time.time() - start_time > timeout: - raise LockError("Couldn't lock {0}".format(path)) + raise LockError(f"Couldn't lock {path}") else: time.sleep(retry_period) diff --git a/distributed/nanny.py b/distributed/nanny.py index bcff9eee69c..9cc4c30c6e2 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -451,7 +451,7 @@ async def _on_exit(self, exitcode): ): try: await self._unregister() - except (EnvironmentError, CommClosedError): + except (OSError, CommClosedError): if not self.reconnect: await self.close() return diff --git a/distributed/node.py b/distributed/node.py index e21713dc85d..a7f9b8d31ac 100644 --- a/distributed/node.py +++ b/distributed/node.py @@ -63,7 +63,7 @@ def start_services(self, default_listen_ip): self.services[k] = service except Exception as e: warnings.warn( - "\nCould not launch service '%s' on port %s. " % (k, port) + f"\nCould not launch service '{k}' on port {port}. " + "Got the following message:\n\n" + str(e), stacklevel=3, diff --git a/distributed/process.py b/distributed/process.py index 1540bf3752c..9be72b07566 100644 --- a/distributed/process.py +++ b/distributed/process.py @@ -52,7 +52,7 @@ class AsyncProcess: def __init__(self, loop=None, target=None, name=None, args=(), kwargs={}): if not callable(target): - raise TypeError("`target` needs to be callable, not %r" % (type(target),)) + raise TypeError(f"`target` needs to be callable, not {type(target)!r}") self._state = _ProcessState() self._loop = loop or IOLoop.current(instance=False) @@ -91,7 +91,7 @@ def __init__(self, loop=None, target=None, name=None, args=(), kwargs={}): self._start_threads() def __repr__(self): - return "<%s %s>" % (self.__class__.__name__, self._name) + return f"<{self.__class__.__name__} {self._name}>" def _check_closed(self): if self._closed: @@ -211,11 +211,11 @@ def _start(): state.is_alive = True state.pid = process.pid - logger.debug("[%s] created process with pid %r" % (r, state.pid)) + logger.debug(f"[{r}] created process with pid {state.pid!r}") while True: msg = q.get() - logger.debug("[%s] got message %r" % (r, msg)) + logger.debug(f"[{r}] got message {msg!r}") op = msg["op"] if op == "start": _call_and_set_future(loop, msg["future"], _start) @@ -338,7 +338,7 @@ def daemon(self, value): def _asyncprocess_finalizer(proc): if proc.is_alive(): try: - logger.info("reaping stray process %s" % (proc,)) + logger.info(f"reaping stray process {proc}") proc.terminate() except OSError: pass diff --git a/distributed/profile.py b/distributed/profile.py index 160fe5a7b62..958e342754a 100644 --- a/distributed/profile.py +++ b/distributed/profile.py @@ -59,7 +59,7 @@ def identifier(frame): def repr_frame(frame): """Render a frame as a line for inclusion into a text traceback""" co = frame.f_code - text = ' File "%s", line %s, in %s' % (co.co_filename, frame.f_lineno, co.co_name) + text = f' File "{co.co_filename}", line {frame.f_lineno}, in {co.co_name}' line = linecache.getline(co.co_filename, frame.f_lineno, frame.f_globals).lstrip() return text + "\n\t" + line @@ -230,7 +230,7 @@ def traverse(state, start, stop, height): x += width traverse(state, 0, 1, 0) - percentages = ["{:.1f}%".format(100 * w) for w in widths] + percentages = [f"{100 * w:.1f}%" for w in widths] return { "left": starts, "right": stops, diff --git a/distributed/protocol/core.py b/distributed/protocol/core.py index 871e7e4df56..c4bd909ecb3 100644 --- a/distributed/protocol/core.py +++ b/distributed/protocol/core.py @@ -87,7 +87,7 @@ def _decode_default(obj): frames[offset], object_hook=msgpack_decode_default, use_list=False, - **msgpack_opts + **msgpack_opts, ) offset += 1 sub_frames = frames[offset : offset + sub_header["num-sub-frames"]] diff --git a/distributed/publish.py b/distributed/publish.py index 485b874d5f3..85150eecdb7 100644 --- a/distributed/publish.py +++ b/distributed/publish.py @@ -96,8 +96,7 @@ def __iter__(self): "Can't invoke iter() or 'for' on client.datasets when client is " "asynchronous; use 'async for' instead" ) - for key in self._client.list_datasets(): - yield key + yield from self._client.list_datasets() def __aiter__(self): if not self._client.asynchronous: diff --git a/distributed/pubsub.py b/distributed/pubsub.py index 91b006423ba..20822e145ca 100644 --- a/distributed/pubsub.py +++ b/distributed/pubsub.py @@ -350,7 +350,7 @@ def put(self, msg): self.loop.add_callback(self._put, msg) def __repr__(self): - return "".format(self.name) + return f"" __str__ = __repr__ @@ -462,6 +462,6 @@ async def _put(self, msg): self.condition.notify() def __repr__(self): - return "".format(self.name) + return f"" __str__ = __repr__ diff --git a/distributed/pytest_resourceleaks.py b/distributed/pytest_resourceleaks.py index 185f649761c..55d94762f4c 100644 --- a/distributed/pytest_resourceleaks.py +++ b/distributed/pytest_resourceleaks.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ A pytest plugin to trace resource leaks. """ @@ -67,7 +66,7 @@ def pytest_configure(config): leaks = leaks.split(",") unknown = sorted(set(leaks) - set(all_checkers)) if unknown: - raise ValueError("unknown resources: %r" % (unknown,)) + raise ValueError(f"unknown resources: {unknown!r}") checkers = [all_checkers[leak]() for leak in leaks] checker = LeakChecker( @@ -389,7 +388,7 @@ def pytest_runtest_protocol(self, item, nextitem): unknown = sorted(set(leaking.args) - set(all_checkers)) if unknown: raise ValueError( - "pytest.mark.leaking: unknown resources %r" % (unknown,) + f"pytest.mark.leaking: unknown resources {unknown!r}" ) classes = tuple(all_checkers[a] for a in leaking.args) self.skip_checkers[nodeid] = { @@ -428,7 +427,7 @@ def pytest_report_teststatus(self, report): report.outcome = "failed" report.longrepr = "\n".join( [ - "%s %s" % (nodeid, checker.format(before, after)) + f"{nodeid} {checker.format(before, after)}" for checker, before, after in leaks ] ) @@ -447,4 +446,4 @@ def pytest_terminal_summary(self, terminalreporter, exitstatus): for rep in leaked: nodeid = rep.nodeid for checker, before, after in self.leaks[nodeid]: - tr.line("%s %s" % (rep.nodeid, checker.format(before, after))) + tr.line(f"{rep.nodeid} {checker.format(before, after)}") diff --git a/distributed/queues.py b/distributed/queues.py index 481f497373c..5c81d25b848 100644 --- a/distributed/queues.py +++ b/distributed/queues.py @@ -46,7 +46,7 @@ def __init__(self, scheduler): self.scheduler.extensions["queues"] = self def create(self, comm=None, name=None, client=None, maxsize=0): - logger.debug("Queue name: {}".format(name)) + logger.debug(f"Queue name: {name}") if name not in self.queues: self.queues[name] = asyncio.Queue(maxsize=maxsize) self.client_refcount[name] = 1 diff --git a/distributed/scheduler.py b/distributed/scheduler.py index df7a43510b8..243f86d577d 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1546,7 +1546,7 @@ def set_nbytes(self, nbytes: Py_ssize_t): self._nbytes = nbytes def __repr__(self): - return "" % (self._key, self._state) + return f"" def _repr_html_(self): color = ( @@ -1606,7 +1606,7 @@ def __getitem__(self, key): return self._accessor(self._states[key]) def __repr__(self): - return "%s(%s)" % (self.__class__, dict(self)) + return f"{self.__class__}({dict(self)})" class _OptionalStateLegacyMapping(_StateLegacyMapping): @@ -1658,7 +1658,7 @@ def __contains__(self, k): return st is not None and bool(self._accessor(st)) def __repr__(self): - return "%s(%s)" % (self.__class__, set(self)) + return f"{self.__class__}({set(self)})" def _legacy_task_key_set(tasks): @@ -3746,7 +3746,7 @@ def del_scheduler_file(): self.start_periodic_callbacks() - setproctitle("dask-scheduler [%s]" % (self.address,)) + setproctitle(f"dask-scheduler [{self.address}]") return self async def close(self, comm=None, fast=False, close_workers=False): @@ -6420,7 +6420,7 @@ async def feed( response = function(self, state) await comm.write(response) await asyncio.sleep(interval) - except (EnvironmentError, CommClosedError): + except (OSError, CommClosedError): pass finally: if teardown: @@ -6759,7 +6759,7 @@ def coerce_address(self, addr, resolve=True): if isinstance(addr, tuple): addr = unparse_host_port(*addr) if not isinstance(addr, str): - raise TypeError("addresses should be strings or tuples, got %r" % (addr,)) + raise TypeError(f"addresses should be strings or tuples, got {addr!r}") if resolve: addr = resolve_address(addr) diff --git a/distributed/security.py b/distributed/security.py index d6a211571f1..4078590cf6b 100644 --- a/distributed/security.py +++ b/distributed/security.py @@ -164,14 +164,14 @@ def __repr__(self): items.append((k, "...")) else: items.append((k, repr(val))) - return "Security(" + ", ".join("%s=%s" % (k, v) for k, v in items) + ")" + return "Security(" + ", ".join(f"{k}={v}" for k, v in items) + ")" def get_tls_config_for_role(self, role): """ Return the TLS configuration for the given role, as a flat dict. """ if role not in {"client", "scheduler", "worker"}: - raise ValueError("unknown role %r" % (role,)) + raise ValueError(f"unknown role {role!r}") return { "ca_file": self.tls_ca_file, "ciphers": self.tls_ciphers, diff --git a/distributed/tests/make_tls_certs.py b/distributed/tests/make_tls_certs.py index 7286b780449..ac4616d7c8f 100644 --- a/distributed/tests/make_tls_certs.py +++ b/distributed/tests/make_tls_certs.py @@ -120,9 +120,9 @@ def make_cert_key(hostname, sign=False): ] subprocess.check_call(["openssl"] + args) - with open(cert_file, "r") as f: + with open(cert_file) as f: cert = f.read() - with open(key_file, "r") as f: + with open(key_file) as f: key = f.read() return cert, key finally: @@ -203,7 +203,7 @@ def make_ca(): # For certificate matching tests make_ca() - with open("tls-ca-cert.pem", "r") as f: + with open("tls-ca-cert.pem") as f: ca_cert = f.read() cert, key = make_cert_key("localhost", sign=True) diff --git a/distributed/tests/test_asyncprocess.py b/distributed/tests/test_asyncprocess.py index 26c695933bb..695f23cc4aa 100644 --- a/distributed/tests/test_asyncprocess.py +++ b/distributed/tests/test_asyncprocess.py @@ -406,7 +406,7 @@ def test_asyncprocess_child_teardown_on_parent_exit(): # test failure. try: readable = children_alive.poll(short_timeout) - except EnvironmentError: + except OSError: # Windows can raise BrokenPipeError. EnvironmentError is caught for # Python2/3 portability. assert sys.platform.startswith("win"), "should only raise on windows" @@ -423,7 +423,7 @@ def test_asyncprocess_child_teardown_on_parent_exit(): result = children_alive.recv() except EOFError: pass # Test passes. - except EnvironmentError: + except OSError: # Windows can raise BrokenPipeError. EnvironmentError is caught for # Python2/3 portability. assert sys.platform.startswith("win"), "should only raise on windows" @@ -432,7 +432,7 @@ def test_asyncprocess_child_teardown_on_parent_exit(): # Oops, children_alive read something. It should be closed. If # something was read, it's a message from the child telling us they # are still alive! - raise RuntimeError("unreachable: {}".format(result)) + raise RuntimeError(f"unreachable: {result}") finally: # Cleanup. diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 93314062a70..d40b2548ff1 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -1549,7 +1549,7 @@ def g(): with save_sys_modules(): for value in [123, 456]: - with tmp_text("myfile.py", "def f():\n return {}".format(value)) as fn: + with tmp_text("myfile.py", f"def f():\n return {value}") as fn: await c.upload_file(fn) x = c.submit(g, pure=False) @@ -1561,7 +1561,7 @@ def g(): async def test_upload_file_refresh_delayed(c, s, a, b): with save_sys_modules(): for value in [123, 456]: - with tmp_text("myfile.py", "def f():\n return {}".format(value)) as fn: + with tmp_text("myfile.py", f"def f():\n return {value}") as fn: await c.upload_file(fn) sys.path.append(os.path.dirname(fn)) @@ -1590,7 +1590,7 @@ def g(): try: for value in [123, 456]: with tmp_text( - "myfile.py", "def f():\n return {}".format(value) + "myfile.py", f"def f():\n return {value}" ) as fn_my_file: with zipfile.ZipFile("myfile.zip", "w") as z: z.write(fn_my_file, arcname=os.path.basename(fn_my_file)) @@ -1635,13 +1635,13 @@ def g(): package_1 = os.path.join(dirname, "package_1") os.mkdir(package_1) with open(os.path.join(package_1, "__init__.py"), "w") as f: - f.write("a = {}\n".format(value)) + f.write(f"a = {value}\n") # test multiple top-level packages package_2 = os.path.join(dirname, "package_2") os.mkdir(package_2) with open(os.path.join(package_2, "__init__.py"), "w") as f: - f.write("b = {}\n".format(value)) + f.write(f"b = {value}\n") # compile these into an egg subprocess.check_call( @@ -1887,12 +1887,12 @@ async def test_allow_restrictions(c, s, a, b): def test_bad_address(): try: Client("123.123.123.123:1234", timeout=0.1) - except (IOError, TimeoutError) as e: + except (OSError, TimeoutError) as e: assert "connect" in str(e).lower() try: Client("127.0.0.1:1234", timeout=0.1) - except (IOError, TimeoutError) as e: + except (OSError, TimeoutError) as e: assert "connect" in str(e).lower() @@ -4637,7 +4637,7 @@ async def test_client_timeout(): await asyncio.sleep(4) try: await s - except EnvironmentError: # port in use + except OSError: # port in use await c.close() return @@ -5155,7 +5155,7 @@ def test_get_client_no_cluster(): Worker._instances.clear() msg = "No global client found and no address provided" - with pytest.raises(ValueError, match=r"^{}$".format(msg)): + with pytest.raises(ValueError, match=fr"^{msg}$"): get_client() diff --git a/distributed/tests/test_diskutils.py b/distributed/tests/test_diskutils.py index 8a9d99e5844..a62fe1ca4e4 100644 --- a/distributed/tests/test_diskutils.py +++ b/distributed/tests/test_diskutils.py @@ -162,7 +162,7 @@ def test_workspace_rmtree_failure(tmpdir): # shutil.rmtree() may call its onerror callback several times assert lines for line in lines: - assert line.startswith("Failed to remove %r" % (a.dir_path,)) + assert line.startswith(f"Failed to remove {a.dir_path!r}") def test_locking_disabled(tmpdir): diff --git a/distributed/tests/test_failed_workers.py b/distributed/tests/test_failed_workers.py index fb3922a8b99..da23d0c48ed 100644 --- a/distributed/tests/test_failed_workers.py +++ b/distributed/tests/test_failed_workers.py @@ -344,7 +344,7 @@ async def test_broken_worker_during_computation(c, s, a, b): L = c.map( slowadd, *zip(*partition_all(2, L)), - key=["add-%d-%d" % (i, j) for j in range(len(L) // 2)] + key=["add-%d-%d" % (i, j) for j in range(len(L) // 2)], ) await asyncio.sleep(random.random() / 20) diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 418ab7e7f5e..88b3f2de4d8 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -786,7 +786,7 @@ async def test_workers_to_close_grouped(c, s, *workers): def key(ws): return groups[ws.address] - assert set(s.workers_to_close(key=key)) == set(w.address for w in workers) + assert set(s.workers_to_close(key=key)) == {w.address for w in workers} # Assert that job in one worker blocks closure of group future = c.submit(slowinc, 1, delay=0.2, workers=workers[0].address) @@ -2016,10 +2016,10 @@ def abort(self): pass def read(self, deserializers=None): - raise EnvironmentError + raise OSError def write(self, msg, serializers=None, on_error=None): - raise EnvironmentError + raise OSError class FlakyConnectionPool(ConnectionPool): diff --git a/distributed/tests/test_security.py b/distributed/tests/test_security.py index 28702650a49..305b8983752 100644 --- a/distributed/tests/test_security.py +++ b/distributed/tests/test_security.py @@ -395,7 +395,7 @@ def test_temporary_credentials(): sec_repr = repr(sec) fields = ["tls_ca_file"] fields.extend( - "tls_%s_%s" % (role, kind) + f"tls_{role}_{kind}" for role in ["client", "scheduler", "worker"] for kind in ["key", "cert"] ) diff --git a/distributed/tests/test_semaphore.py b/distributed/tests/test_semaphore.py index 5a5b9d02597..47678696651 100644 --- a/distributed/tests/test_semaphore.py +++ b/distributed/tests/test_semaphore.py @@ -283,10 +283,10 @@ def abort(self): pass def read(self, deserializers=None): - raise EnvironmentError + raise OSError def write(self, msg, serializers=None, on_error=None): - raise EnvironmentError + raise OSError class FlakyConnectionPool(ConnectionPool): diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index fbabd2a6086..03ee701fc50 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -553,7 +553,7 @@ async def assert_balanced(inp, expected, c, s, *workers): if result2 == expected2: return - raise Exception("Expected: {}; got: {}".format(str(expected2), str(result2))) + raise Exception(f"Expected: {str(expected2)}; got: {str(result2)}") @pytest.mark.parametrize( diff --git a/distributed/tests/test_utils.py b/distributed/tests/test_utils.py index 6f94ca3d506..be0e9847f78 100644 --- a/distributed/tests/test_utils.py +++ b/distributed/tests/test_utils.py @@ -146,10 +146,10 @@ def test_get_ip_interface(): elif sys.platform.startswith("linux"): assert get_ip_interface("lo") == "127.0.0.1" else: - pytest.skip("test needs to be enhanced for platform %r" % (sys.platform,)) + pytest.skip(f"test needs to be enhanced for platform {sys.platform!r}") non_existent_interface = "__non-existent-interface" - expected_error_message = "{!r}.+network interface.+".format(non_existent_interface) + expected_error_message = f"{non_existent_interface!r}.+network interface.+" if sys.platform == "darwin": expected_error_message += "'lo0'" diff --git a/distributed/tests/test_utils_comm.py b/distributed/tests/test_utils_comm.py index d17e892ebf0..ff0d6b09da6 100644 --- a/distributed/tests/test_utils_comm.py +++ b/distributed/tests/test_utils_comm.py @@ -56,10 +56,10 @@ def abort(self): pass def read(self, deserializers=None): - raise EnvironmentError + raise OSError def write(self, msg, serializers=None, on_error=None): - raise EnvironmentError + raise OSError class BrokenConnectionPool(ConnectionPool): diff --git a/distributed/utils.py b/distributed/utils.py index dcea26e628a..c691f3f8804 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -134,7 +134,7 @@ def _get_ip(host, port, family): sock.connect((host, port)) ip = sock.getsockname()[0] return ip - except EnvironmentError as e: + except OSError as e: warnings.warn( "Couldn't detect a suitable IP address for " "reaching %r, defaulting to hostname: %s" % (host, e), @@ -187,7 +187,7 @@ def get_ip_interface(ifname): for info in net_if_addrs[ifname]: if info.family == socket.AF_INET: return info.address - raise ValueError("interface %r doesn't have an IPv4 address" % (ifname,)) + raise ValueError(f"interface {ifname!r} doesn't have an IPv4 address") async def All(args, quiet_exceptions=()): @@ -317,7 +317,7 @@ def f(): loop.add_callback(f) if callback_timeout is not None: if not e.wait(callback_timeout): - raise TimeoutError("timed out after %s s." % (callback_timeout,)) + raise TimeoutError(f"timed out after {callback_timeout} s.") else: while not e.is_set(): e.wait(10) @@ -726,7 +726,7 @@ def validate_key(k): """Validate a key as received on a stream.""" typ = type(k) if typ is not str and typ is not bytes: - raise TypeError("Unexpected key type %s (value: %r)" % (typ, k)) + raise TypeError(f"Unexpected key type {typ} (value: {k!r})") def _maybe_complex(task): @@ -1085,13 +1085,11 @@ def command_has_keyword(cmd, k): if isinstance(getattr(cmd, "main"), click.core.Command): cmd = cmd.main if isinstance(cmd, click.core.Command): - cmd_params = set( - [ - p.human_readable_name - for p in cmd.params - if isinstance(p, click.core.Option) - ] - ) + cmd_params = { + p.human_readable_name + for p in cmd.params + if isinstance(p, click.core.Option) + } return k in cmd_params return False @@ -1308,11 +1306,11 @@ def cli_keywords(d: dict, cls=None, cmd=None): ) elif cls: raise ValueError( - "Class %s does not support keyword %s" % (typename(cls), k) + f"Class {typename(cls)} does not support keyword {k}" ) else: raise ValueError( - "Module %s does not support keyword %s" % (typename(cmd), k) + f"Module {typename(cmd)} does not support keyword {k}" ) def convert_value(v): diff --git a/distributed/utils_comm.py b/distributed/utils_comm.py index 80e6e0b8ae4..728b4b4c144 100644 --- a/distributed/utils_comm.py +++ b/distributed/utils_comm.py @@ -77,7 +77,7 @@ async def gather_from_workers(who_has, rpc, close=True, serializers=None, who=No for worker, c in coroutines.items(): try: r = await c - except EnvironmentError: + except OSError: missing_workers.add(worker) except ValueError as e: logger.info( @@ -112,7 +112,7 @@ def __init__(self, key): self.key = key def __repr__(self): - return "%s('%s')" % (type(self).__name__, self.key) + return f"{type(self).__name__}('{self.key}')" _round_robin_counter = [0] diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 0779d7de285..6f810354681 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -981,7 +981,7 @@ def get_unclosed(): if getattr(w, "data", None): try: w.data.clear() - except EnvironmentError: + except OSError: # zict backends can fail if their storage directory # was already removed pass @@ -1075,10 +1075,10 @@ def wait_for_port(address, timeout=5): while True: timeout = deadline - time() if timeout < 0: - raise RuntimeError("Failed to connect to %s" % (address,)) + raise RuntimeError(f"Failed to connect to {address}") try: sock = socket.create_connection(address, timeout=timeout) - except EnvironmentError: + except OSError: pass else: sock.close() @@ -1092,7 +1092,7 @@ def wait_for(predicate, timeout, fail_func=None, period=0.001): if time() > deadline: if fail_func is not None: fail_func() - pytest.fail("condition not reached until %s seconds" % (timeout,)) + pytest.fail(f"condition not reached until {timeout} seconds") async def async_wait_for(predicate, timeout, fail_func=None, period=0.001): @@ -1102,7 +1102,7 @@ async def async_wait_for(predicate, timeout, fail_func=None, period=0.001): if time() > deadline: if fail_func is not None: fail_func() - pytest.fail("condition not reached until %s seconds" % (timeout,)) + pytest.fail(f"condition not reached until {timeout} seconds") @memoize @@ -1120,7 +1120,7 @@ def has_ipv6(): serv.bind(("::", 0)) serv.listen(5) cli = socket.create_connection(serv.getsockname()[:2]) - except EnvironmentError: + except OSError: return False else: return True @@ -1422,7 +1422,7 @@ def bump_rlimit(limit, desired): if soft < desired: resource.setrlimit(limit, (desired, max(hard, desired))) except Exception as e: - pytest.skip("rlimit too low (%s) and can't be increased: %s" % (soft, e)) + pytest.skip(f"rlimit too low ({soft}) and can't be increased: {e}") def gen_tls_cluster(**kwargs): diff --git a/distributed/versions.py b/distributed/versions.py index b8d0a49a80f..13a282977bc 100644 --- a/distributed/versions.py +++ b/distributed/versions.py @@ -1,6 +1,5 @@ """ utilities for package version introspection """ -from __future__ import absolute_import, division, print_function import importlib import os @@ -25,9 +24,7 @@ # only these scheduler packages will be checked for version mismatch -scheduler_relevant_packages = set(pkg for pkg, _ in required_packages) | set( - ["lz4", "blosc"] -) +scheduler_relevant_packages = {pkg for pkg, _ in required_packages} | {"lz4", "blosc"} # notes to be displayed for mismatch packages @@ -135,12 +132,12 @@ def error_message(scheduler, workers, client, client_name="client"): ) versions.add(client_version) - worker_versions = set( + worker_versions = { workers[w].get(pkg, "MISSING") if isinstance(workers[w], dict) else workers[w] for w in workers - ) + } versions |= worker_versions if len(versions) <= 1: diff --git a/distributed/worker.py b/distributed/worker.py index 1f04b4cbde5..fd0d32c8885 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -186,7 +186,7 @@ def __init__(self, key, runspec=None): self.scheduler_holds_ref = False def __repr__(self): - return "" % (self.key, self.state) + return f"" def get_nbytes(self) -> int: nbytes = self.nbytes @@ -931,13 +931,13 @@ async def _register_with_scheduler(self): self.scheduler_delay = response["time"] - middle self.status = Status.running break - except EnvironmentError: + except OSError: logger.info("Waiting to connect to: %26s", self.scheduler.address) await asyncio.sleep(0.1) except TimeoutError: logger.info("Timed out when connecting to scheduler") if response["status"] != "OK": - raise ValueError("Unexpected response from register: %r" % (response,)) + raise ValueError(f"Unexpected response from register: {response!r}") else: await asyncio.gather( *[ @@ -1005,7 +1005,7 @@ async def heartbeat(self): logger.warning("Heartbeat to scheduler failed") if not self.reconnect: await self.close(report=False) - except IOError as e: + except OSError as e: # Scheduler is gone. Respect distributed.comm.timeouts.connect if "Timed out trying to connect" in str(e): await self.close(report=False) @@ -1187,12 +1187,12 @@ async def start(self): try: listening_address = "%s%s:%d" % (self.listener.prefix, self.ip, self.port) except Exception: - listening_address = "%s%s" % (self.listener.prefix, self.ip) + listening_address = f"{self.listener.prefix}{self.ip}" logger.info(" Start worker at: %26s", self.address) logger.info(" Listening to: %26s", listening_address) for k, v in self.service_ports.items(): - logger.info(" %16s at: %26s" % (k, self.ip + ":" + str(v))) + logger.info(" {:>16} at: {:>26}".format(k, self.ip + ":" + str(v))) logger.info("Waiting to connect to: %26s", self.scheduler.address) logger.info("-" * 49) logger.info(" Threads: %26d", self.nthreads) @@ -1427,7 +1427,7 @@ async def get_data( compressed = await comm.write(msg, serializers=serializers) response = await comm.read(deserializers=serializers) assert response == "OK", response - except EnvironmentError: + except OSError: logger.exception( "failed during get data with %s -> %s", self.address, who, exc_info=True ) @@ -1999,7 +1999,7 @@ def transition_executing_done(self, ts, value=no_value, report=True): return out - except EnvironmentError: + except OSError: logger.info("Comm closed") except Exception as e: logger.exception(e) @@ -2387,7 +2387,7 @@ async def gather_dep( self.incoming_count += 1 self.log.append(("receive-dep", worker, list(response["data"]))) - except EnvironmentError: + except OSError: logger.exception("Worker stream died during communication: %s", worker) has_what = self.has_what.pop(worker) self.pending_data_per_worker.pop(worker) diff --git a/docs/source/conf.py b/docs/source/conf.py index 201cd76b00b..ac7e07fdff9 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Dask.distributed documentation build configuration file, created by # sphinx-quickstart on Tue Oct 6 14:42:44 2015. @@ -471,16 +470,14 @@ def run(self): if "methods" in self.options: _, methods = self.get_members(app, c, ["method"], ["__init__"]) self.content = [ - "%s.%s" % (class_name, method) + f"{class_name}.{method}" for method in methods if not method.startswith("_") ] if "attributes" in self.options: _, attribs = self.get_members(app, c, ["attribute", "property"]) self.content = [ - "~%s.%s" % (clazz, attrib) - for attrib in attribs - if not attrib.startswith("_") + f"~{clazz}.{attrib}" for attrib in attribs if not attrib.startswith("_") ] return super().run() From 5688e503d9411e69d0a8278d2b9692d453814c73 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Tue, 29 Jun 2021 16:08:27 -0500 Subject: [PATCH 1344/1550] Rename plot dropdown (#4992) --- distributed/http/templates/base.html | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/http/templates/base.html b/distributed/http/templates/base.html index b0d428fac37..73e252f08e1 100644 --- a/distributed/http/templates/base.html +++ b/distributed/http/templates/base.html @@ -34,7 +34,7 @@ {% endfor %}