Skip to content

Reinstate: AMM ReduceReplicas to iterate only on replicated tasks #5341

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Sep 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions distributed/active_memory_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,13 +303,7 @@ class ReduceReplicas(ActiveMemoryManagerPolicy):
"""

def run(self):
# TODO this is O(n) to the total number of in-memory tasks on the cluster; it
# could be made faster by automatically attaching it to a TaskState when it
# goes above one replica and detaching it when it drops below two.
for ts in self.manager.scheduler.tasks.values():
if len(ts.who_has) < 2:
continue

for ts in self.manager.scheduler.replicated_tasks:
desired_replicas = 1 # TODO have a marker on TaskState

# If a dependent task has not been assigned to a worker yet, err on the side
Expand Down
118 changes: 69 additions & 49 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1858,6 +1858,7 @@ class SchedulerState:
_task_groups: dict
_task_prefixes: dict
_task_metadata: dict
_replicated_tasks: set
_total_nthreads: Py_ssize_t
_total_occupancy: double
_transitions_table: dict
Expand Down Expand Up @@ -1917,6 +1918,9 @@ def __init__(
self._tasks = tasks
else:
self._tasks = dict()
self._replicated_tasks = {
ts for ts in self._tasks.values() if len(ts._who_has) > 1
}
self._computations = deque(
maxlen=dask.config.get("distributed.diagnostics.computations.max-history")
)
Expand Down Expand Up @@ -2034,6 +2038,10 @@ def task_prefixes(self):
def task_metadata(self):
return self._task_metadata

@property
def replicated_tasks(self):
return self._replicated_tasks

@property
def total_nthreads(self):
return self._total_nthreads
Expand Down Expand Up @@ -2819,18 +2827,14 @@ def transition_memory_released(self, key, safe: bint = False):
dts._waiting_on.add(ts)

# XXX factor this out?
ts_nbytes: Py_ssize_t = ts.get_nbytes()
worker_msg = {
"op": "free-keys",
"keys": [key],
"reason": f"Memory->Released {key}",
}
for ws in ts._who_has:
del ws._has_what[ts]
ws._nbytes -= ts_nbytes
worker_msgs[ws._address] = [worker_msg]

ts._who_has.clear()
self.remove_all_replicas(ts)

ts.state = "released"

Expand Down Expand Up @@ -3428,6 +3432,40 @@ def worker_objective(self, ts: TaskState, ws: WorkerState) -> tuple:
else:
return (start_time, ws._nbytes)

@ccall
def add_replica(self, ts: TaskState, ws: WorkerState):
"""Note that a worker holds a replica of a task with state='memory'"""
if self._validate:
assert ws not in ts._who_has
assert ts not in ws._has_what

ws._nbytes += ts.get_nbytes()
ws._has_what[ts] = None
ts._who_has.add(ws)
if len(ts._who_has) == 2:
self._replicated_tasks.add(ts)

@ccall
def remove_replica(self, ts: TaskState, ws: WorkerState):
"""Note that a worker no longer holds a replica of a task"""
ws._nbytes -= ts.get_nbytes()
del ws._has_what[ts]
ts._who_has.remove(ws)
if len(ts._who_has) == 1:
self._replicated_tasks.remove(ts)

@ccall
def remove_all_replicas(self, ts: TaskState):
"""Remove all replicas of a task from all workers"""
ws: WorkerState
nbytes: Py_ssize_t = ts.get_nbytes()
for ws in ts._who_has:
ws._nbytes -= nbytes
del ws._has_what[ts]
if len(ts._who_has) > 1:
self._replicated_tasks.remove(ts)
ts._who_has.clear()


class Scheduler(SchedulerState, ServerNode):
"""Dynamic distributed task scheduler
Expand Down Expand Up @@ -4917,14 +4955,13 @@ async def remove_worker(self, comm=None, address=None, safe=False, close=True):
self.allowed_failures,
)

for ts in ws._has_what:
ts._who_has.remove(ws)
for ts in list(ws._has_what):
parent.remove_replica(ts, ws)
if not ts._who_has:
if ts._run_spec:
recommendations[ts._key] = "released"
else: # pure data
recommendations[ts._key] = "forgotten"
ws._has_what.clear()

self.transitions(recommendations)

Expand Down Expand Up @@ -5074,6 +5111,7 @@ def validate_memory(self, key):
ts: TaskState = parent._tasks[key]
dts: TaskState
assert ts._who_has
assert bool(ts in parent._replicated_tasks) == (len(ts._who_has) > 1)
assert not ts._processing_on
assert not ts._waiting_on
assert ts not in parent._unrunnable
Expand Down Expand Up @@ -5144,8 +5182,13 @@ def validate_state(self, allow_overlap=False):
for k, ts in parent._tasks.items():
assert isinstance(ts, TaskState), (type(ts), ts)
assert ts._key == k
assert bool(ts in parent._replicated_tasks) == (len(ts._who_has) > 1)
self.validate_key(k, ts)

for ts in parent._replicated_tasks:
assert ts._state == "memory"
assert ts._key in parent._tasks

c: str
cs: ClientState
for c, cs in parent._clients.items():
Expand Down Expand Up @@ -5375,9 +5418,7 @@ def handle_missing_data(self, key=None, errant_worker=None, **kwargs):
return
ws: WorkerState = parent._workers_dv.get(errant_worker)
if ws is not None and ws in ts._who_has:
ts._who_has.remove(ws)
del ws._has_what[ts]
ws._nbytes -= ts.get_nbytes()
parent.remove_replica(ts, ws)
if not ts._who_has:
if ts._run_spec:
self.transitions({key: "released"})
Expand All @@ -5391,12 +5432,9 @@ def release_worker_data(self, comm=None, key=None, worker=None):
if not ws or not ts:
return
recommendations: dict = {}
if ts in ws._has_what:
del ws._has_what[ts]
ws._nbytes -= ts.get_nbytes()
wh: set = ts._who_has
wh.remove(ws)
if not wh:
if ws in ts._who_has:
parent.remove_replica(ts, ws)
if not ts._who_has:
recommendations[ts._key] = "released"
if recommendations:
self.transitions(recommendations)
Expand Down Expand Up @@ -5716,14 +5754,11 @@ async def gather(self, comm=None, keys=None, serializers=None):
)
if not workers or ts is None:
continue
ts_nbytes: Py_ssize_t = ts.get_nbytes()
recommendations: dict = {key: "released"}
for worker in workers:
ws = parent._workers_dv.get(worker)
if ws is not None and ts in ws._has_what:
del ws._has_what[ts]
ts._who_has.remove(ws)
ws._nbytes -= ts_nbytes
if ws is not None and ws in ts._who_has:
parent.remove_replica(ts, ws)
parent._transitions(
recommendations, client_msgs, worker_msgs
)
Expand Down Expand Up @@ -5922,10 +5957,8 @@ async def gather_on_worker(
if ts is None or ts._state != "memory":
logger.warning(f"Key lost during replication: {key}")
continue
if ts not in ws._has_what:
ws._nbytes += ts.get_nbytes()
ws._has_what[ts] = None
ts._who_has.add(ws)
if ws not in ts._who_has:
parent.add_replica(ts, ws)

return keys_failed

Expand Down Expand Up @@ -5962,11 +5995,9 @@ async def delete_worker_data(self, worker_address: str, keys: "list[str]") -> No

for key in keys:
ts: TaskState = parent._tasks.get(key)
if ts is not None and ts in ws._has_what:
if ts is not None and ws in ts._who_has:
assert ts._state == "memory"
del ws._has_what[ts]
ts._who_has.remove(ws)
ws._nbytes -= ts.get_nbytes()
parent.remove_replica(ts, ws)
if not ts._who_has:
# Last copy deleted
self.transitions({key: "released"})
Expand Down Expand Up @@ -6714,10 +6745,8 @@ def add_keys(self, comm=None, worker=None, keys=()):
for key in keys:
ts: TaskState = parent._tasks.get(key)
if ts is not None and ts._state == "memory":
if ts not in ws._has_what:
ws._nbytes += ts.get_nbytes()
ws._has_what[ts] = None
ts._who_has.add(ws)
if ws not in ts._who_has:
parent.add_replica(ts, ws)
else:
redundant_replicas.append(key)

Expand Down Expand Up @@ -6760,17 +6789,14 @@ def update_data(
if ts is None:
ts: TaskState = parent.new_task(key, None, "memory")
ts.state = "memory"
ts_nbytes: Py_ssize_t = nbytes.get(key, -1)
ts_nbytes = nbytes.get(key, -1)
if ts_nbytes >= 0:
ts.set_nbytes(ts_nbytes)
else:
ts_nbytes = ts.get_nbytes()

for w in workers:
ws: WorkerState = parent._workers_dv[w]
if ts not in ws._has_what:
ws._nbytes += ts_nbytes
ws._has_what[ts] = None
ts._who_has.add(ws)
if ws not in ts._who_has:
parent.add_replica(ts, ws)
self.report(
{"op": "key-in-memory", "key": key, "workers": list(workers)}
)
Expand Down Expand Up @@ -7737,9 +7763,7 @@ def _add_to_memory(
if state._validate:
assert ts not in ws._has_what

ts._who_has.add(ws)
ws._has_what[ts] = None
ws._nbytes += ts.get_nbytes()
state.add_replica(ts, ws)

deps: list = list(ts._dependents)
if len(deps) > 1:
Expand Down Expand Up @@ -7815,12 +7839,8 @@ def _propagate_forgotten(
ts._dependencies.clear()
ts._waiting_on.clear()

ts_nbytes: Py_ssize_t = ts.get_nbytes()

ws: WorkerState
for ws in ts._who_has:
del ws._has_what[ts]
ws._nbytes -= ts_nbytes
w: str = ws._address
if w in state._workers_dv: # in case worker has died
worker_msgs[w] = [
Expand All @@ -7830,7 +7850,7 @@ def _propagate_forgotten(
"reason": f"propagate-forgotten {ts.key}",
}
]
ts._who_has.clear()
state.remove_all_replicas(ts)


@cfunc
Expand Down