Skip to content

Commit 5e1c229

Browse files
committed
Code review
1 parent 74640b8 commit 5e1c229

File tree

4 files changed

+63
-33
lines changed

4 files changed

+63
-33
lines changed

distributed/collections.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
import heapq
44
import weakref
55
from collections import OrderedDict, UserDict
6-
from collections.abc import Callable, Iterator
6+
from collections.abc import Callable, Hashable, Iterator
77
from typing import MutableSet # TODO move to collections.abc (requires Python >=3.9)
88
from typing import Any, TypeVar, cast
99

10-
T = TypeVar("T")
10+
T = TypeVar("T", bound=Hashable)
1111

1212

1313
# TODO change to UserDict[K, V] (requires Python >=3.9)
@@ -44,6 +44,7 @@ class HeapSet(MutableSet[T]):
4444
_heap: list[tuple[Any, int, weakref.ref[T]]]
4545

4646
def __init__(self, *, key: Callable[[T], Any]):
47+
# FIXME https://github.com/python/mypy/issues/708
4748
self.key = key # type: ignore
4849
self._data = set()
4950
self._inc = 0
@@ -55,9 +56,6 @@ def __repr__(self) -> str:
5556
def __contains__(self, value: object) -> bool:
5657
return value in self._data
5758

58-
def __iter__(self) -> Iterator[T]:
59-
return iter(self._data)
60-
6159
def __len__(self) -> int:
6260
return len(self._data)
6361

@@ -72,6 +70,8 @@ def add(self, value: T) -> None:
7270

7371
def discard(self, value: T) -> None:
7472
self._data.discard(value)
73+
if not self._data:
74+
self._heap.clear()
7575

7676
def peek(self) -> T:
7777
"""Get the smallest element without removing it"""
@@ -93,13 +93,22 @@ def pop(self) -> T:
9393
self._data.remove(value)
9494
return value
9595

96-
def sorted(self) -> list[T]:
97-
"""Return a list containing all elements, from smallest to largest according to
98-
the key and insertion order.
96+
def __iter__(self) -> Iterator[T]:
97+
"""Iterate over all elements. This is a O(n) operation which returns the
98+
elements in pseudo-random order.
99+
"""
100+
return iter(self._data)
101+
102+
def sorted(self) -> Iterator[T]:
103+
"""Iterate ofer all elements. This is a O(n*logn) operation which returns the
104+
elements in order, from smallest to largest according to the key and insertion
105+
order.
99106
"""
100-
out = []
101107
for _, _, vref in sorted(self._heap):
102108
value = vref()
103109
if value in self._data:
104-
out.append(value)
105-
return out
110+
yield value
111+
112+
def clear(self) -> None:
113+
self._data.clear()
114+
self._heap.clear()

distributed/tests/test_collections.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,15 @@ def __eq__(self, other):
5151
assert cz in heap
5252
assert cw in heap
5353

54-
heap_list = heap.sorted()
54+
heap_sorted = heap.sorted()
5555
# iteration does not empty heap
5656
assert len(heap) == 4
57-
assert len(heap_list) == 4
58-
assert heap_list[0] is cy
59-
assert heap_list[1] is cx
60-
assert heap_list[2] is cz
61-
assert heap_list[3] is cw
57+
assert next(heap_sorted) is cy
58+
assert next(heap_sorted) is cx
59+
assert next(heap_sorted) is cz
60+
assert next(heap_sorted) is cw
61+
with pytest.raises(StopIteration):
62+
next(heap_sorted)
6263

6364
assert set(heap) == {cx, cy, cz, cw}
6465

@@ -96,7 +97,7 @@ def __eq__(self, other):
9697
heap.discard(cw)
9798

9899
assert len(heap) == 2
99-
assert heap.sorted() == [cx, cz]
100+
assert list(heap.sorted()) == [cx, cz]
100101
# cy is at the top of heap._heap, but is skipped
101102
assert heap.peek() is cx
102103
assert heap.pop() is cx
@@ -108,4 +109,16 @@ def __eq__(self, other):
108109
heap.peek()
109110
with pytest.raises(KeyError):
110111
heap.pop()
111-
assert heap.sorted() == []
112+
assert list(heap.sorted()) == []
113+
114+
# Test clear()
115+
heap.add(cx)
116+
heap.clear()
117+
assert not heap
118+
heap.add(cx)
119+
assert cx in heap
120+
# Test discard last element
121+
heap.discard(cx)
122+
assert not heap
123+
heap.add(cx)
124+
assert cx in heap

distributed/worker.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,9 +1078,9 @@ def _to_dict(self, *, exclude: Container[str] = ()) -> dict:
10781078
"status": self.status,
10791079
"ready": self.ready,
10801080
"constrained": self.constrained,
1081-
"data_needed": self.data_needed.sorted(),
1081+
"data_needed": list(self.data_needed.sorted()),
10821082
"data_needed_per_worker": {
1083-
w: v.sorted() for w, v in self.data_needed_per_worker.items()
1083+
w: list(v.sorted()) for w, v in self.data_needed_per_worker.items()
10841084
},
10851085
"long_running": self.long_running,
10861086
"executing_count": self.executing_count,
@@ -3032,16 +3032,16 @@ def _ensure_communicating(self, *, stimulus_id: str) -> RecsInstrs:
30323032
local = [w for w in workers if get_address_host(w) == host]
30333033
worker = random.choice(local or workers)
30343034

3035-
to_gather, total_nbytes = self._select_keys_for_gather(worker, ts)
3035+
to_gather_tasks, total_nbytes = self._select_keys_for_gather(worker, ts)
3036+
to_gather_keys = {ts.key for ts in to_gather_tasks}
30363037

30373038
self.log.append(
3038-
("gather-dependencies", worker, to_gather, stimulus_id, time())
3039+
("gather-dependencies", worker, to_gather_keys, stimulus_id, time())
30393040
)
30403041

30413042
self.comm_nbytes += total_nbytes
3042-
self.in_flight_workers[worker] = to_gather
3043-
for d_key in to_gather:
3044-
d_ts = self.tasks[d_key]
3043+
self.in_flight_workers[worker] = to_gather_keys
3044+
for d_ts in to_gather_tasks:
30453045
if self.validate:
30463046
assert d_ts.state == "fetch"
30473047
assert d_ts not in recommendations
@@ -3055,7 +3055,7 @@ def _ensure_communicating(self, *, stimulus_id: str) -> RecsInstrs:
30553055
instructions.append(
30563056
GatherDep(
30573057
worker=worker,
3058-
to_gather=to_gather,
3058+
to_gather=to_gather_keys,
30593059
total_nbytes=total_nbytes,
30603060
stimulus_id=stimulus_id,
30613061
)
@@ -3152,14 +3152,14 @@ def _put_key_in_memory(self, ts: TaskState, value, *, stimulus_id: str) -> Recs:
31523152

31533153
def _select_keys_for_gather(
31543154
self, worker: str, ts: TaskState
3155-
) -> tuple[set[str], int]:
3155+
) -> tuple[set[TaskState], int]:
31563156
"""``_ensure_communicating`` decided to fetch a single task from a worker,
31573157
following priority. In order to minimise overhead, request fetching other tasks
31583158
from the same worker within the message, following priority for the single
31593159
worker but ignoring higher priority tasks from other workers, up to
31603160
``target_message_size``.
31613161
"""
3162-
keys = {ts.key}
3162+
tss = {ts}
31633163
total_bytes = ts.get_nbytes()
31643164
tasks = self.data_needed_per_worker[worker]
31653165

@@ -3176,10 +3176,10 @@ def _select_keys_for_gather(
31763176
if other_worker != worker:
31773177
self.data_needed_per_worker[other_worker].remove(ts)
31783178

3179-
keys.add(ts.key)
3179+
tss.add(ts)
31803180
total_bytes += ts.get_nbytes()
31813181

3182-
return keys, total_bytes
3182+
return tss, total_bytes
31833183

31843184
@property
31853185
def total_comm_bytes(self):
@@ -4363,10 +4363,12 @@ def validate_state(self):
43634363

43644364
for ts in self.data_needed:
43654365
assert ts.state == "fetch"
4366+
assert self.tasks[ts.key] is ts
43664367
for worker, tss in self.data_needed_per_worker.items():
43674368
for ts in tss:
4368-
assert ts in self.data_needed
43694369
assert ts.state == "fetch"
4370+
assert self.tasks[ts.key] is ts
4371+
assert ts in self.data_needed
43704372
assert worker in ts.who_has
43714373

43724374
for ts in self.tasks.values():

distributed/worker_state_machine.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,13 @@ def __repr__(self) -> str:
174174
return f"<TaskState {self.key!r} {self.state}>"
175175

176176
def __eq__(self, other: object) -> bool:
177-
return isinstance(other, TaskState) and other.key == self.key
177+
if not isinstance(other, TaskState) or other.key != self.key:
178+
return False
179+
# When a task transitions to forgotten and exits Worker.tasks, it should be
180+
# immediately dereferenced. If the same task is recreated later on on the
181+
# worker, we should not have to deal with its previous incarnation lingering.
182+
assert other is self
183+
return True
178184

179185
def __hash__(self) -> int:
180186
return hash(self.key)

0 commit comments

Comments
 (0)