Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove gen.coroutine usage in scheduler #3242

Merged
merged 2 commits into from
Nov 15, 2019
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 58 additions & 57 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from toolz import frequencies, merge, pluck, merge_sorted, first
from toolz import valmap, second, compose, groupby
from tornado import gen
from tornado.gen import Return
from tornado.ioloop import IOLoop

import dask
Expand Down Expand Up @@ -2008,7 +2007,8 @@ def cancel_key(self, key, client, retries=5, force=False):
if ts is None or not ts.who_wants: # no key yet, lets try again in a moment
if retries:
self.loop.add_future(
gen.sleep(0.2), lambda _: self.cancel_key(key, client, retries - 1)
asyncio.sleep(0.2),
lambda _: self.cancel_key(key, client, retries - 1),
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. There's no reason to manually call gen.sleep here, using IOLoop.call_later works just as well.

return
if force or ts.who_wants == {cs}: # no one else wants this key
Expand Down Expand Up @@ -2700,8 +2700,7 @@ async def proxy(self, comm=None, msg=None, worker=None, serializers=None):
)
return d[worker]

@gen.coroutine
def rebalance(self, comm=None, keys=None, workers=None):
async def rebalance(self, comm=None, keys=None, workers=None):
""" Rebalance keys so that each worker stores roughly equal bytes

**Policy**
Expand Down Expand Up @@ -2777,9 +2776,9 @@ def rebalance(self, comm=None, keys=None, workers=None):
to_recipients[recipient.address][ts.key].append(sender.address)
to_senders[sender.address].append(ts.key)

result = yield {
r: self.rpc(addr=r).gather(who_has=v) for r, v in to_recipients.items()
}
result = await asyncio.gather(
*(self.rpc(addr=r).gather(who_has=v) for r, v in to_recipients.items())
)
for r, v in to_recipients.items():
self.log_event(r, {"action": "rebalance", "who_has": v})

Expand All @@ -2794,13 +2793,11 @@ def rebalance(self, comm=None, keys=None, workers=None):
},
)

if not all(r["status"] == "OK" for r in result.values()):
raise Return(
{
"status": "missing-data",
"keys": sum([r["keys"] for r in result if "keys" in r], []),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this wrong before? I would have expected this to iterate over result.values

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this was wrong before.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's really good to have a second pair of eyes here

}
)
if not all(r["status"] == "OK" for r in result):
return {
"status": "missing-data",
"keys": sum([r["keys"] for r in result if "keys" in r], []),
}

for sender, recipient, ts in msgs:
assert ts.state == "memory"
Expand All @@ -2811,20 +2808,21 @@ def rebalance(self, comm=None, keys=None, workers=None):
("rebalance", ts.key, time(), sender.address, recipient.address)
)

result = yield {
r: self.rpc(addr=r).delete_data(keys=v, report=False)
for r, v in to_senders.items()
}
await asyncio.gather(
*(
self.rpc(addr=r).delete_data(keys=v, report=False)
for r, v in to_senders.items()
)
)

for sender, recipient, ts in msgs:
ts.who_has.remove(sender)
sender.has_what.remove(ts)
sender.nbytes -= ts.get_nbytes()

raise Return({"status": "OK"})
return {"status": "OK"}

@gen.coroutine
def replicate(
async def replicate(
self,
comm=None,
keys=None,
Expand Down Expand Up @@ -2867,7 +2865,7 @@ def replicate(
tasks = {self.tasks[k] for k in keys}
missing_data = [ts.key for ts in tasks if not ts.who_has]
if missing_data:
raise Return({"status": "missing-data", "keys": missing_data})
return {"status": "missing-data", "keys": missing_data}

# Delete extraneous data
if delete:
Expand All @@ -2878,12 +2876,14 @@ def replicate(
for ws in random.sample(del_candidates, len(del_candidates) - n):
del_worker_tasks[ws].add(ts)

yield [
self.rpc(addr=ws.address).delete_data(
keys=[ts.key for ts in tasks], report=False
await asyncio.gather(
*(
self.rpc(addr=ws.address).delete_data(
keys=[ts.key for ts in tasks], report=False
)
for ws, tasks in del_worker_tasks.items()
)
for ws, tasks in del_worker_tasks.items()
]
)

for ws, tasks in del_worker_tasks.items():
ws.has_what -= tasks
Expand Down Expand Up @@ -2911,11 +2911,13 @@ def replicate(
for ws in random.sample(workers - ts.who_has, count):
gathers[ws.address][ts.key] = [wws.address for wws in ts.who_has]

results = yield {
w: self.rpc(addr=w).gather(who_has=who_has)
for w, who_has in gathers.items()
}
for w, v in results.items():
results = await asyncio.gather(
*(
self.rpc(addr=w).gather(who_has=who_has)
for w, who_has in gathers.items()
)
)
for w, v in zip(gathers, results):
if v["status"] == "OK":
self.add_keys(worker=w, keys=list(gathers[w]))
else:
Expand Down Expand Up @@ -3348,8 +3350,7 @@ def get_ncores(self, comm=None, workers=None):
else:
return {w: ws.nthreads for w, ws in self.workers.items()}

@gen.coroutine
def get_call_stack(self, comm=None, keys=None):
async def get_call_stack(self, comm=None, keys=None):
if keys is not None:
stack = list(keys)
processing = set()
Expand All @@ -3369,14 +3370,13 @@ def get_call_stack(self, comm=None, keys=None):
workers = {w: None for w in self.workers}

if not workers:
raise gen.Return({})
return {}

else:
response = yield {
w: self.rpc(w).call_stack(keys=v) for w, v in workers.items()
}
response = {k: v for k, v in response.items() if v}
raise gen.Return(response)
results = await asyncio.gather(
*(self.rpc(w).call_stack(keys=v) for w, v in workers.items())
)
response = {w: r for w, r in zip(workers, results) if r}
return response

def get_nbytes(self, comm=None, keys=None, summary=True):
with log_errors():
Expand Down Expand Up @@ -4613,8 +4613,7 @@ def worker_objective(self, ts, ws):
else:
return (start_time, ws.nbytes)

@gen.coroutine
def get_profile(
async def get_profile(
self,
comm=None,
workers=None,
Expand All @@ -4627,15 +4626,17 @@ def get_profile(
workers = self.workers
else:
workers = set(self.workers) & set(workers)
result = yield {
w: self.rpc(w).profile(start=start, stop=stop, key=key) for w in workers
}
results = await asyncio.gather(
*(self.rpc(w).profile(start=start, stop=stop, key=key) for w in workers)
)

if merge_workers:
result = profile.merge(*result.values())
raise gen.Return(result)
response = profile.merge(*results)
else:
response = dict(zip(workers, results))
return response

@gen.coroutine
def get_profile_metadata(
async def get_profile_metadata(
self,
comm=None,
workers=None,
Expand All @@ -4653,22 +4654,22 @@ def get_profile_metadata(
workers = self.workers
else:
workers = set(self.workers) & set(workers)
result = yield {
w: self.rpc(w).profile_metadata(start=start, stop=stop) for w in workers
}
results = await asyncio.gather(
*(self.rpc(w).profile_metadata(start=start, stop=stop) for w in workers)
)

counts = [v["counts"] for v in result.values()]
counts = [v["counts"] for v in results]
counts = itertools.groupby(merge_sorted(*counts), lambda t: t[0] // dt * dt)
counts = [(time, sum(pluck(1, group))) for time, group in counts]

keys = set()
for v in result.values():
for v in results:
for t, d in v["keys"]:
for k in d:
keys.add(k)
keys = {k: [] for k in keys}

groups1 = [v["keys"] for v in result.values()]
groups1 = [v["keys"] for v in results]
groups2 = list(merge_sorted(*groups1, key=first))

last = 0
Expand All @@ -4681,7 +4682,7 @@ def get_profile_metadata(
for k, v in d.items():
keys[k][-1][1] += v

raise gen.Return({"counts": counts, "keys": keys})
return {"counts": counts, "keys": keys}

async def get_worker_logs(self, comm=None, n=None, workers=None, nanny=False):
results = await self.broadcast(
Expand Down