Skip to content

Commit

Permalink
Remove gen.coroutine usage in scheduler
Browse files Browse the repository at this point in the history
Use `async`/`await` and `asyncio` idioms throughout.
  • Loading branch information
jcrist committed Nov 15, 2019
1 parent 4d0d58a commit 92441f0
Showing 1 changed file with 58 additions and 57 deletions.
115 changes: 58 additions & 57 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from toolz import frequencies, merge, pluck, merge_sorted, first
from toolz import valmap, second, compose, groupby
from tornado import gen
from tornado.gen import Return
from tornado.ioloop import IOLoop

import dask
Expand Down Expand Up @@ -2008,7 +2007,8 @@ def cancel_key(self, key, client, retries=5, force=False):
if ts is None or not ts.who_wants: # no key yet, lets try again in a moment
if retries:
self.loop.add_future(
gen.sleep(0.2), lambda _: self.cancel_key(key, client, retries - 1)
asyncio.sleep(0.2),
lambda _: self.cancel_key(key, client, retries - 1),
)
return
if force or ts.who_wants == {cs}: # no one else wants this key
Expand Down Expand Up @@ -2700,8 +2700,7 @@ async def proxy(self, comm=None, msg=None, worker=None, serializers=None):
)
return d[worker]

@gen.coroutine
def rebalance(self, comm=None, keys=None, workers=None):
async def rebalance(self, comm=None, keys=None, workers=None):
""" Rebalance keys so that each worker stores roughly equal bytes
**Policy**
Expand Down Expand Up @@ -2777,9 +2776,9 @@ def rebalance(self, comm=None, keys=None, workers=None):
to_recipients[recipient.address][ts.key].append(sender.address)
to_senders[sender.address].append(ts.key)

result = yield {
r: self.rpc(addr=r).gather(who_has=v) for r, v in to_recipients.items()
}
result = await asyncio.gather(
*(self.rpc(addr=r).gather(who_has=v) for r, v in to_recipients.items())
)
for r, v in to_recipients.items():
self.log_event(r, {"action": "rebalance", "who_has": v})

Expand All @@ -2794,13 +2793,11 @@ def rebalance(self, comm=None, keys=None, workers=None):
},
)

if not all(r["status"] == "OK" for r in result.values()):
raise Return(
{
"status": "missing-data",
"keys": sum([r["keys"] for r in result if "keys" in r], []),
}
)
if not all(r["status"] == "OK" for r in result):
return {
"status": "missing-data",
"keys": sum([r["keys"] for r in result if "keys" in r], []),
}

for sender, recipient, ts in msgs:
assert ts.state == "memory"
Expand All @@ -2811,20 +2808,21 @@ def rebalance(self, comm=None, keys=None, workers=None):
("rebalance", ts.key, time(), sender.address, recipient.address)
)

result = yield {
r: self.rpc(addr=r).delete_data(keys=v, report=False)
for r, v in to_senders.items()
}
await asyncio.gather(
*(
self.rpc(addr=r).delete_data(keys=v, report=False)
for r, v in to_senders.items()
)
)

for sender, recipient, ts in msgs:
ts.who_has.remove(sender)
sender.has_what.remove(ts)
sender.nbytes -= ts.get_nbytes()

raise Return({"status": "OK"})
return {"status": "OK"}

@gen.coroutine
def replicate(
async def replicate(
self,
comm=None,
keys=None,
Expand Down Expand Up @@ -2867,7 +2865,7 @@ def replicate(
tasks = {self.tasks[k] for k in keys}
missing_data = [ts.key for ts in tasks if not ts.who_has]
if missing_data:
raise Return({"status": "missing-data", "keys": missing_data})
return {"status": "missing-data", "keys": missing_data}

# Delete extraneous data
if delete:
Expand All @@ -2878,12 +2876,14 @@ def replicate(
for ws in random.sample(del_candidates, len(del_candidates) - n):
del_worker_tasks[ws].add(ts)

yield [
self.rpc(addr=ws.address).delete_data(
keys=[ts.key for ts in tasks], report=False
await asyncio.gather(
*(
self.rpc(addr=ws.address).delete_data(
keys=[ts.key for ts in tasks], report=False
)
for ws, tasks in del_worker_tasks.items()
)
for ws, tasks in del_worker_tasks.items()
]
)

for ws, tasks in del_worker_tasks.items():
ws.has_what -= tasks
Expand Down Expand Up @@ -2911,11 +2911,13 @@ def replicate(
for ws in random.sample(workers - ts.who_has, count):
gathers[ws.address][ts.key] = [wws.address for wws in ts.who_has]

results = yield {
w: self.rpc(addr=w).gather(who_has=who_has)
for w, who_has in gathers.items()
}
for w, v in results.items():
results = await asyncio.gather(
*(
self.rpc(addr=w).gather(who_has=who_has)
for w, who_has in gathers.items()
)
)
for w, v in zip(gathers, results):
if v["status"] == "OK":
self.add_keys(worker=w, keys=list(gathers[w]))
else:
Expand Down Expand Up @@ -3348,8 +3350,7 @@ def get_ncores(self, comm=None, workers=None):
else:
return {w: ws.nthreads for w, ws in self.workers.items()}

@gen.coroutine
def get_call_stack(self, comm=None, keys=None):
async def get_call_stack(self, comm=None, keys=None):
if keys is not None:
stack = list(keys)
processing = set()
Expand All @@ -3369,14 +3370,13 @@ def get_call_stack(self, comm=None, keys=None):
workers = {w: None for w in self.workers}

if not workers:
raise gen.Return({})
return {}

else:
response = yield {
w: self.rpc(w).call_stack(keys=v) for w, v in workers.items()
}
response = {k: v for k, v in response.items() if v}
raise gen.Return(response)
results = await asyncio.gather(
*(self.rpc(w).call_stack(keys=v) for w, v in workers.items())
)
response = {w: r for w, r in zip(workers, results) if r}
return response

def get_nbytes(self, comm=None, keys=None, summary=True):
with log_errors():
Expand Down Expand Up @@ -4613,8 +4613,7 @@ def worker_objective(self, ts, ws):
else:
return (start_time, ws.nbytes)

@gen.coroutine
def get_profile(
async def get_profile(
self,
comm=None,
workers=None,
Expand All @@ -4627,15 +4626,17 @@ def get_profile(
workers = self.workers
else:
workers = set(self.workers) & set(workers)
result = yield {
w: self.rpc(w).profile(start=start, stop=stop, key=key) for w in workers
}
results = await asyncio.gather(
*(self.rpc(w).profile(start=start, stop=stop, key=key) for w in workers)
)

if merge_workers:
result = profile.merge(*result.values())
raise gen.Return(result)
response = profile.merge(*results)
else:
response = dict(zip(workers, results))
return response

@gen.coroutine
def get_profile_metadata(
async def get_profile_metadata(
self,
comm=None,
workers=None,
Expand All @@ -4653,22 +4654,22 @@ def get_profile_metadata(
workers = self.workers
else:
workers = set(self.workers) & set(workers)
result = yield {
w: self.rpc(w).profile_metadata(start=start, stop=stop) for w in workers
}
results = await asyncio.gather(
*(self.rpc(w).profile_metadata(start=start, stop=stop) for w in workers)
)

counts = [v["counts"] for v in result.values()]
counts = [v["counts"] for v in results]
counts = itertools.groupby(merge_sorted(*counts), lambda t: t[0] // dt * dt)
counts = [(time, sum(pluck(1, group))) for time, group in counts]

keys = set()
for v in result.values():
for v in results:
for t, d in v["keys"]:
for k in d:
keys.add(k)
keys = {k: [] for k in keys}

groups1 = [v["keys"] for v in result.values()]
groups1 = [v["keys"] for v in results]
groups2 = list(merge_sorted(*groups1, key=first))

last = 0
Expand All @@ -4681,7 +4682,7 @@ def get_profile_metadata(
for k, v in d.items():
keys[k][-1][1] += v

raise gen.Return({"counts": counts, "keys": keys})
return {"counts": counts, "keys": keys}

async def get_worker_logs(self, comm=None, n=None, workers=None, nanny=False):
results = await self.broadcast(
Expand Down

0 comments on commit 92441f0

Please sign in to comment.