-
-
Notifications
You must be signed in to change notification settings - Fork 717
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove gen.coroutine
usage in scheduler
#3242
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,7 +24,6 @@ | |
from toolz import frequencies, merge, pluck, merge_sorted, first | ||
from toolz import valmap, second, compose, groupby | ||
from tornado import gen | ||
from tornado.gen import Return | ||
from tornado.ioloop import IOLoop | ||
|
||
import dask | ||
|
@@ -2008,7 +2007,8 @@ def cancel_key(self, key, client, retries=5, force=False): | |
if ts is None or not ts.who_wants: # no key yet, lets try again in a moment | ||
if retries: | ||
self.loop.add_future( | ||
gen.sleep(0.2), lambda _: self.cancel_key(key, client, retries - 1) | ||
asyncio.sleep(0.2), | ||
lambda _: self.cancel_key(key, client, retries - 1), | ||
) | ||
return | ||
if force or ts.who_wants == {cs}: # no one else wants this key | ||
|
@@ -2700,8 +2700,7 @@ async def proxy(self, comm=None, msg=None, worker=None, serializers=None): | |
) | ||
return d[worker] | ||
|
||
@gen.coroutine | ||
def rebalance(self, comm=None, keys=None, workers=None): | ||
async def rebalance(self, comm=None, keys=None, workers=None): | ||
""" Rebalance keys so that each worker stores roughly equal bytes | ||
|
||
**Policy** | ||
|
@@ -2777,9 +2776,9 @@ def rebalance(self, comm=None, keys=None, workers=None): | |
to_recipients[recipient.address][ts.key].append(sender.address) | ||
to_senders[sender.address].append(ts.key) | ||
|
||
result = yield { | ||
r: self.rpc(addr=r).gather(who_has=v) for r, v in to_recipients.items() | ||
} | ||
result = await asyncio.gather( | ||
*(self.rpc(addr=r).gather(who_has=v) for r, v in to_recipients.items()) | ||
) | ||
for r, v in to_recipients.items(): | ||
self.log_event(r, {"action": "rebalance", "who_has": v}) | ||
|
||
|
@@ -2794,13 +2793,11 @@ def rebalance(self, comm=None, keys=None, workers=None): | |
}, | ||
) | ||
|
||
if not all(r["status"] == "OK" for r in result.values()): | ||
raise Return( | ||
{ | ||
"status": "missing-data", | ||
"keys": sum([r["keys"] for r in result if "keys" in r], []), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Was this wrong before? I would have expected this to iterate over result.values There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, this was wrong before. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's really good to have a second pair of eyes here |
||
} | ||
) | ||
if not all(r["status"] == "OK" for r in result): | ||
return { | ||
"status": "missing-data", | ||
"keys": sum([r["keys"] for r in result if "keys" in r], []), | ||
} | ||
|
||
for sender, recipient, ts in msgs: | ||
assert ts.state == "memory" | ||
|
@@ -2811,20 +2808,21 @@ def rebalance(self, comm=None, keys=None, workers=None): | |
("rebalance", ts.key, time(), sender.address, recipient.address) | ||
) | ||
|
||
result = yield { | ||
r: self.rpc(addr=r).delete_data(keys=v, report=False) | ||
for r, v in to_senders.items() | ||
} | ||
await asyncio.gather( | ||
*( | ||
self.rpc(addr=r).delete_data(keys=v, report=False) | ||
for r, v in to_senders.items() | ||
) | ||
) | ||
|
||
for sender, recipient, ts in msgs: | ||
ts.who_has.remove(sender) | ||
sender.has_what.remove(ts) | ||
sender.nbytes -= ts.get_nbytes() | ||
|
||
raise Return({"status": "OK"}) | ||
return {"status": "OK"} | ||
|
||
@gen.coroutine | ||
def replicate( | ||
async def replicate( | ||
self, | ||
comm=None, | ||
keys=None, | ||
|
@@ -2867,7 +2865,7 @@ def replicate( | |
tasks = {self.tasks[k] for k in keys} | ||
missing_data = [ts.key for ts in tasks if not ts.who_has] | ||
if missing_data: | ||
raise Return({"status": "missing-data", "keys": missing_data}) | ||
return {"status": "missing-data", "keys": missing_data} | ||
|
||
# Delete extraneous data | ||
if delete: | ||
|
@@ -2878,12 +2876,14 @@ def replicate( | |
for ws in random.sample(del_candidates, len(del_candidates) - n): | ||
del_worker_tasks[ws].add(ts) | ||
|
||
yield [ | ||
self.rpc(addr=ws.address).delete_data( | ||
keys=[ts.key for ts in tasks], report=False | ||
await asyncio.gather( | ||
*( | ||
self.rpc(addr=ws.address).delete_data( | ||
keys=[ts.key for ts in tasks], report=False | ||
) | ||
for ws, tasks in del_worker_tasks.items() | ||
) | ||
for ws, tasks in del_worker_tasks.items() | ||
] | ||
) | ||
|
||
for ws, tasks in del_worker_tasks.items(): | ||
ws.has_what -= tasks | ||
|
@@ -2911,11 +2911,13 @@ def replicate( | |
for ws in random.sample(workers - ts.who_has, count): | ||
gathers[ws.address][ts.key] = [wws.address for wws in ts.who_has] | ||
|
||
results = yield { | ||
w: self.rpc(addr=w).gather(who_has=who_has) | ||
for w, who_has in gathers.items() | ||
} | ||
for w, v in results.items(): | ||
results = await asyncio.gather( | ||
*( | ||
self.rpc(addr=w).gather(who_has=who_has) | ||
for w, who_has in gathers.items() | ||
) | ||
) | ||
for w, v in zip(gathers, results): | ||
if v["status"] == "OK": | ||
self.add_keys(worker=w, keys=list(gathers[w])) | ||
else: | ||
|
@@ -3348,8 +3350,7 @@ def get_ncores(self, comm=None, workers=None): | |
else: | ||
return {w: ws.nthreads for w, ws in self.workers.items()} | ||
|
||
@gen.coroutine | ||
def get_call_stack(self, comm=None, keys=None): | ||
async def get_call_stack(self, comm=None, keys=None): | ||
if keys is not None: | ||
stack = list(keys) | ||
processing = set() | ||
|
@@ -3369,14 +3370,13 @@ def get_call_stack(self, comm=None, keys=None): | |
workers = {w: None for w in self.workers} | ||
|
||
if not workers: | ||
raise gen.Return({}) | ||
return {} | ||
|
||
else: | ||
response = yield { | ||
w: self.rpc(w).call_stack(keys=v) for w, v in workers.items() | ||
} | ||
response = {k: v for k, v in response.items() if v} | ||
raise gen.Return(response) | ||
results = await asyncio.gather( | ||
*(self.rpc(w).call_stack(keys=v) for w, v in workers.items()) | ||
) | ||
response = {w: r for w, r in zip(workers, results) if r} | ||
return response | ||
|
||
def get_nbytes(self, comm=None, keys=None, summary=True): | ||
with log_errors(): | ||
|
@@ -4613,8 +4613,7 @@ def worker_objective(self, ts, ws): | |
else: | ||
return (start_time, ws.nbytes) | ||
|
||
@gen.coroutine | ||
def get_profile( | ||
async def get_profile( | ||
self, | ||
comm=None, | ||
workers=None, | ||
|
@@ -4627,15 +4626,17 @@ def get_profile( | |
workers = self.workers | ||
else: | ||
workers = set(self.workers) & set(workers) | ||
result = yield { | ||
w: self.rpc(w).profile(start=start, stop=stop, key=key) for w in workers | ||
} | ||
results = await asyncio.gather( | ||
*(self.rpc(w).profile(start=start, stop=stop, key=key) for w in workers) | ||
) | ||
|
||
if merge_workers: | ||
result = profile.merge(*result.values()) | ||
raise gen.Return(result) | ||
response = profile.merge(*results) | ||
else: | ||
response = dict(zip(workers, results)) | ||
return response | ||
|
||
@gen.coroutine | ||
def get_profile_metadata( | ||
async def get_profile_metadata( | ||
self, | ||
comm=None, | ||
workers=None, | ||
|
@@ -4653,22 +4654,22 @@ def get_profile_metadata( | |
workers = self.workers | ||
else: | ||
workers = set(self.workers) & set(workers) | ||
result = yield { | ||
w: self.rpc(w).profile_metadata(start=start, stop=stop) for w in workers | ||
} | ||
results = await asyncio.gather( | ||
*(self.rpc(w).profile_metadata(start=start, stop=stop) for w in workers) | ||
) | ||
|
||
counts = [v["counts"] for v in result.values()] | ||
counts = [v["counts"] for v in results] | ||
counts = itertools.groupby(merge_sorted(*counts), lambda t: t[0] // dt * dt) | ||
counts = [(time, sum(pluck(1, group))) for time, group in counts] | ||
|
||
keys = set() | ||
for v in result.values(): | ||
for v in results: | ||
for t, d in v["keys"]: | ||
for k in d: | ||
keys.add(k) | ||
keys = {k: [] for k in keys} | ||
|
||
groups1 = [v["keys"] for v in result.values()] | ||
groups1 = [v["keys"] for v in results] | ||
groups2 = list(merge_sorted(*groups1, key=first)) | ||
|
||
last = 0 | ||
|
@@ -4681,7 +4682,7 @@ def get_profile_metadata( | |
for k, v in d.items(): | ||
keys[k][-1][1] += v | ||
|
||
raise gen.Return({"counts": counts, "keys": keys}) | ||
return {"counts": counts, "keys": keys} | ||
|
||
async def get_worker_logs(self, comm=None, n=None, workers=None, nanny=False): | ||
results = await self.broadcast( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed. There's no reason to manually call
gen.sleep
here, usingIOLoop.call_later
works just as well.