From 92441f045669330f7be7f3ca718243be7a87803c Mon Sep 17 00:00:00 2001 From: Jim Crist Date: Fri, 15 Nov 2019 11:56:48 -0600 Subject: [PATCH] Remove `gen.coroutine` usage in scheduler Use `async`/`await` and `asyncio` idioms throughout. --- distributed/scheduler.py | 115 ++++++++++++++++++++------------------- 1 file changed, 58 insertions(+), 57 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 6724524344..9b0a84e8e8 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -24,7 +24,6 @@ from toolz import frequencies, merge, pluck, merge_sorted, first from toolz import valmap, second, compose, groupby from tornado import gen -from tornado.gen import Return from tornado.ioloop import IOLoop import dask @@ -2008,7 +2007,8 @@ def cancel_key(self, key, client, retries=5, force=False): if ts is None or not ts.who_wants: # no key yet, lets try again in a moment if retries: self.loop.add_future( - gen.sleep(0.2), lambda _: self.cancel_key(key, client, retries - 1) + asyncio.sleep(0.2), + lambda _: self.cancel_key(key, client, retries - 1), ) return if force or ts.who_wants == {cs}: # no one else wants this key @@ -2700,8 +2700,7 @@ async def proxy(self, comm=None, msg=None, worker=None, serializers=None): ) return d[worker] - @gen.coroutine - def rebalance(self, comm=None, keys=None, workers=None): + async def rebalance(self, comm=None, keys=None, workers=None): """ Rebalance keys so that each worker stores roughly equal bytes **Policy** @@ -2777,9 +2776,9 @@ def rebalance(self, comm=None, keys=None, workers=None): to_recipients[recipient.address][ts.key].append(sender.address) to_senders[sender.address].append(ts.key) - result = yield { - r: self.rpc(addr=r).gather(who_has=v) for r, v in to_recipients.items() - } + result = await asyncio.gather( + *(self.rpc(addr=r).gather(who_has=v) for r, v in to_recipients.items()) + ) for r, v in to_recipients.items(): self.log_event(r, {"action": "rebalance", "who_has": v}) @@ -2794,13 +2793,11 @@ def rebalance(self, comm=None, keys=None, workers=None): }, ) - if not all(r["status"] == "OK" for r in result.values()): - raise Return( - { - "status": "missing-data", - "keys": sum([r["keys"] for r in result if "keys" in r], []), - } - ) + if not all(r["status"] == "OK" for r in result): + return { + "status": "missing-data", + "keys": sum([r["keys"] for r in result if "keys" in r], []), + } for sender, recipient, ts in msgs: assert ts.state == "memory" @@ -2811,20 +2808,21 @@ def rebalance(self, comm=None, keys=None, workers=None): ("rebalance", ts.key, time(), sender.address, recipient.address) ) - result = yield { - r: self.rpc(addr=r).delete_data(keys=v, report=False) - for r, v in to_senders.items() - } + await asyncio.gather( + *( + self.rpc(addr=r).delete_data(keys=v, report=False) + for r, v in to_senders.items() + ) + ) for sender, recipient, ts in msgs: ts.who_has.remove(sender) sender.has_what.remove(ts) sender.nbytes -= ts.get_nbytes() - raise Return({"status": "OK"}) + return {"status": "OK"} - @gen.coroutine - def replicate( + async def replicate( self, comm=None, keys=None, @@ -2867,7 +2865,7 @@ def replicate( tasks = {self.tasks[k] for k in keys} missing_data = [ts.key for ts in tasks if not ts.who_has] if missing_data: - raise Return({"status": "missing-data", "keys": missing_data}) + return {"status": "missing-data", "keys": missing_data} # Delete extraneous data if delete: @@ -2878,12 +2876,14 @@ def replicate( for ws in random.sample(del_candidates, len(del_candidates) - n): del_worker_tasks[ws].add(ts) - yield [ - self.rpc(addr=ws.address).delete_data( - keys=[ts.key for ts in tasks], report=False + await asyncio.gather( + *( + self.rpc(addr=ws.address).delete_data( + keys=[ts.key for ts in tasks], report=False + ) + for ws, tasks in del_worker_tasks.items() ) - for ws, tasks in del_worker_tasks.items() - ] + ) for ws, tasks in del_worker_tasks.items(): ws.has_what -= tasks @@ -2911,11 +2911,13 @@ def replicate( for ws in random.sample(workers - ts.who_has, count): gathers[ws.address][ts.key] = [wws.address for wws in ts.who_has] - results = yield { - w: self.rpc(addr=w).gather(who_has=who_has) - for w, who_has in gathers.items() - } - for w, v in results.items(): + results = await asyncio.gather( + *( + self.rpc(addr=w).gather(who_has=who_has) + for w, who_has in gathers.items() + ) + ) + for w, v in zip(gathers, results): if v["status"] == "OK": self.add_keys(worker=w, keys=list(gathers[w])) else: @@ -3348,8 +3350,7 @@ def get_ncores(self, comm=None, workers=None): else: return {w: ws.nthreads for w, ws in self.workers.items()} - @gen.coroutine - def get_call_stack(self, comm=None, keys=None): + async def get_call_stack(self, comm=None, keys=None): if keys is not None: stack = list(keys) processing = set() @@ -3369,14 +3370,13 @@ def get_call_stack(self, comm=None, keys=None): workers = {w: None for w in self.workers} if not workers: - raise gen.Return({}) + return {} - else: - response = yield { - w: self.rpc(w).call_stack(keys=v) for w, v in workers.items() - } - response = {k: v for k, v in response.items() if v} - raise gen.Return(response) + results = await asyncio.gather( + *(self.rpc(w).call_stack(keys=v) for w, v in workers.items()) + ) + response = {w: r for w, r in zip(workers, results) if r} + return response def get_nbytes(self, comm=None, keys=None, summary=True): with log_errors(): @@ -4613,8 +4613,7 @@ def worker_objective(self, ts, ws): else: return (start_time, ws.nbytes) - @gen.coroutine - def get_profile( + async def get_profile( self, comm=None, workers=None, @@ -4627,15 +4626,17 @@ def get_profile( workers = self.workers else: workers = set(self.workers) & set(workers) - result = yield { - w: self.rpc(w).profile(start=start, stop=stop, key=key) for w in workers - } + results = await asyncio.gather( + *(self.rpc(w).profile(start=start, stop=stop, key=key) for w in workers) + ) + if merge_workers: - result = profile.merge(*result.values()) - raise gen.Return(result) + response = profile.merge(*results) + else: + response = dict(zip(workers, results)) + return response - @gen.coroutine - def get_profile_metadata( + async def get_profile_metadata( self, comm=None, workers=None, @@ -4653,22 +4654,22 @@ def get_profile_metadata( workers = self.workers else: workers = set(self.workers) & set(workers) - result = yield { - w: self.rpc(w).profile_metadata(start=start, stop=stop) for w in workers - } + results = await asyncio.gather( + *(self.rpc(w).profile_metadata(start=start, stop=stop) for w in workers) + ) - counts = [v["counts"] for v in result.values()] + counts = [v["counts"] for v in results] counts = itertools.groupby(merge_sorted(*counts), lambda t: t[0] // dt * dt) counts = [(time, sum(pluck(1, group))) for time, group in counts] keys = set() - for v in result.values(): + for v in results: for t, d in v["keys"]: for k in d: keys.add(k) keys = {k: [] for k in keys} - groups1 = [v["keys"] for v in result.values()] + groups1 = [v["keys"] for v in results] groups2 = list(merge_sorted(*groups1, key=first)) last = 0 @@ -4681,7 +4682,7 @@ def get_profile_metadata( for k, v in d.items(): keys[k][-1][1] += v - raise gen.Return({"counts": counts, "keys": keys}) + return {"counts": counts, "keys": keys} async def get_worker_logs(self, comm=None, n=None, workers=None, nanny=False): results = await self.broadcast(