Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions distributed/collections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from __future__ import annotations

import heapq
import weakref
from collections import OrderedDict, UserDict
from collections.abc import Callable, Hashable, Iterator
from typing import MutableSet # TODO move to collections.abc (requires Python >=3.9)
from typing import Any, TypeVar, cast

T = TypeVar("T", bound=Hashable)


# TODO change to UserDict[K, V] (requires Python >=3.9)
class LRU(UserDict):
"""Limited size mapping, evicting the least recently looked-up key when full"""

def __init__(self, maxsize: float):
super().__init__()
self.data = OrderedDict()
self.maxsize = maxsize

def __getitem__(self, key):
value = super().__getitem__(key)
cast(OrderedDict, self.data).move_to_end(key)
return value

def __setitem__(self, key, value):
if len(self) >= self.maxsize:
cast(OrderedDict, self.data).popitem(last=False)
super().__setitem__(key, value)


class HeapSet(MutableSet[T]):
"""A set-like where the `pop` method returns the smallest item, as sorted by an
arbitrary key function. Ties are broken by oldest first.

Values must be compatible with :mod:`weakref`.
"""

__slots__ = ("key", "_data", "_heap", "_inc")
key: Callable[[T], Any]
_data: set[T]
_inc: int
_heap: list[tuple[Any, int, weakref.ref[T]]]

def __init__(self, *, key: Callable[[T], Any]):
# FIXME https://github.com/python/mypy/issues/708
self.key = key # type: ignore
self._data = set()
self._inc = 0
self._heap = []

def __repr__(self) -> str:
return f"<{type(self).__name__}: {len(self)} items>"

def __contains__(self, value: object) -> bool:
return value in self._data

def __len__(self) -> int:
return len(self._data)

def add(self, value: T) -> None:
if value in self._data:
return
k = self.key(value) # type: ignore
vref = weakref.ref(value)
heapq.heappush(self._heap, (k, self._inc, vref))
self._data.add(value)
self._inc += 1

def discard(self, value: T) -> None:
self._data.discard(value)
if not self._data:
self._heap.clear()

def peek(self) -> T:
"""Get the smallest element without removing it"""
if not self._data:
raise KeyError("peek into empty set")
while True:
value = self._heap[0][2]()
if value in self._data:
return value
heapq.heappop(self._heap)

def pop(self) -> T:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be good to note that peek and pop are not O(1) operations. Those are the sorts of things you'd expect to be O(1) on a set.

I think this is fine, because if discard actually removed the element, it would be O(nlog(n)) (O(n) search + O(logn) remove). Instead, peek/pop are O(nlog(n)).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a sorted collection; I don't think anybody expects it to be magically O(1) :)

peek() is O(1) if you treat the bit that calls heappop as delayed housekeeping - e.g. you account for it in discard().

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While the Heap prefix of the class already implies some non-constant operation, I wouldn't mind documenting the actual complexity since this can differ for heaps. However, I think the stdlib doesn't document this properly either and I'm ok with skipping this.

peek() is O(1) if you treat the bit that calls heappop as delayed housekeeping - e.g. you account for it in discard().

I think the amortized time for both peek and discard are constant. I don't think we should dive into deep algorithm complexity analysis here, though :)

if not self._data:
raise KeyError("pop from an empty set")
while True:
_, _, vref = heapq.heappop(self._heap)
value = vref()
if value in self._data:
self._data.discard(value)
return value

def __iter__(self) -> Iterator[T]:
"""Iterate over all elements. This is a O(n) operation which returns the
elements in pseudo-random order.
"""
return iter(self._data)

def sorted(self) -> Iterator[T]:
"""Iterate over all elements. This is a O(n*logn) operation which returns the
elements in order, from smallest to largest according to the key and insertion
order.
"""
for _, _, vref in sorted(self._heap):
value = vref()
if value in self._data:
yield value

def clear(self) -> None:
self._data.clear()
self._heap.clear()
150 changes: 150 additions & 0 deletions distributed/tests/test_collections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import pytest

from distributed.collections import LRU, HeapSet


def test_lru():
l = LRU(maxsize=3)
l["a"] = 1
l["b"] = 2
l["c"] = 3
assert list(l.keys()) == ["a", "b", "c"]

# Use "a" and ensure it becomes the most recently used item
l["a"]
assert list(l.keys()) == ["b", "c", "a"]

# Ensure maxsize is respected
l["d"] = 4
assert len(l) == 3
assert list(l.keys()) == ["c", "a", "d"]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved from test_utils.py



def test_heapset():
class C:
def __init__(self, k, i):
self.k = k
self.i = i

def __hash__(self):
return hash(self.k)

def __eq__(self, other):
return isinstance(other, C) and other.k == self.k

heap = HeapSet(key=lambda c: c.i)

cx = C("x", 2)
cy = C("y", 1)
cz = C("z", 3)
cw = C("w", 4)
heap.add(cx)
heap.add(cy)
heap.add(cz)
heap.add(cw)
heap.add(C("x", 0)) # Ignored; x already in heap
assert len(heap) == 4
assert repr(heap) == "<HeapSet: 4 items>"

assert cx in heap
assert cy in heap
assert cz in heap
assert cw in heap

heap_sorted = heap.sorted()
# iteration does not empty heap
assert len(heap) == 4
assert next(heap_sorted) is cy
assert next(heap_sorted) is cx
assert next(heap_sorted) is cz
assert next(heap_sorted) is cw
with pytest.raises(StopIteration):
next(heap_sorted)

assert set(heap) == {cx, cy, cz, cw}

assert heap.peek() is cy
assert heap.pop() is cy
assert cx in heap
assert cy not in heap
assert cz in heap
assert cw in heap

assert heap.peek() is cx
assert heap.pop() is cx
assert heap.pop() is cz
assert heap.pop() is cw
assert not heap
with pytest.raises(KeyError):
heap.pop()
with pytest.raises(KeyError):
heap.peek()

# Test out-of-order discard
heap.add(cx)
heap.add(cy)
heap.add(cz)
heap.add(cw)
assert heap.peek() is cy

heap.remove(cy)
assert cy not in heap
with pytest.raises(KeyError):
heap.remove(cy)

heap.discard(cw)
assert cw not in heap
heap.discard(cw)

assert len(heap) == 2
assert list(heap.sorted()) == [cx, cz]
# cy is at the top of heap._heap, but is skipped
assert heap.peek() is cx
assert heap.pop() is cx
assert heap.peek() is cz
assert heap.pop() is cz
# heap._heap is not empty
assert not heap
with pytest.raises(KeyError):
heap.peek()
with pytest.raises(KeyError):
heap.pop()
assert list(heap.sorted()) == []

# Test clear()
heap.add(cx)
heap.clear()
assert not heap
heap.add(cx)
assert cx in heap
# Test discard last element
heap.discard(cx)
assert not heap
heap.add(cx)
assert cx in heap

# Test resilience to failure in key()
bad_key = C("bad_key", 0)
del bad_key.i
with pytest.raises(AttributeError):
heap.add(bad_key)
assert len(heap) == 1
assert set(heap) == {cx}

# Test resilience to failure in weakref.ref()
class D:
__slots__ = ("i",)

def __init__(self, i):
self.i = i

with pytest.raises(TypeError):
heap.add(D("bad_weakref", 2))
assert len(heap) == 1
assert set(heap) == {cx}

# Test resilience to key() returning non-sortable output
with pytest.raises(TypeError):
heap.add(C("unsortable_key", None))
assert len(heap) == 1
assert set(heap) == {cx}
19 changes: 0 additions & 19 deletions distributed/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from distributed.compatibility import MACOS, WINDOWS
from distributed.metrics import time
from distributed.utils import (
LRU,
All,
Log,
Logs,
Expand Down Expand Up @@ -594,24 +593,6 @@ def test_parse_ports():
parse_ports("100.5")


def test_lru():

l = LRU(maxsize=3)
l["a"] = 1
l["b"] = 2
l["c"] = 3
assert list(l.keys()) == ["a", "b", "c"]

# Use "a" and ensure it becomes the most recently used item
l["a"]
assert list(l.keys()) == ["b", "c", "a"]

# Ensure maxsize is respected
l["d"] = 4
assert len(l) == 3
assert list(l.keys()) == ["c", "a", "d"]


@gen_test()
async def test_offload():
assert (await offload(inc, 1)) == 2
Expand Down
75 changes: 0 additions & 75 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2589,81 +2589,6 @@ def __reduce__(self):
assert "return lambda: 1 / 0, ()" in logvalue


@gen_cluster(client=True)
async def test_gather_dep_exception_one_task(c, s, a, b):
"""Ensure an exception in a single task does not tear down an entire batch of gather_dep
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is misleading.
This test was testing resilience to an exception in the transitions of a single task after gather_dep - which should be dealt with through @fail_hard.

A legitimate exception in a single key of the bundle in gather_dep, namely a task that fails to unpickle, does make the whole gather_dep fail for all tasks. There's no code whatsoever to deal with this use case.



See also https://github.com/dask/distributed/issues/5152
See also test_gather_dep_exception_one_task_2
"""
fut = c.submit(inc, 1, workers=[a.address], key="f1")
fut2 = c.submit(inc, 2, workers=[a.address], key="f2")
fut3 = c.submit(inc, 3, workers=[a.address], key="f3")

import asyncio

event = asyncio.Event()
write_queue = asyncio.Queue()
b.rpc = _LockedCommPool(b.rpc, write_event=event, write_queue=write_queue)
b.rpc.remove(a.address)

def sink(a, b, *args):
return a + b

res1 = c.submit(sink, fut, fut2, fut3, workers=[b.address])
res2 = c.submit(sink, fut, fut2, workers=[b.address])

# Wait until we're sure the worker is attempting to fetch the data
while True:
peer_addr, msg = await write_queue.get()
if peer_addr == a.address and msg["op"] == "get_data":
break

# Provoke an "impossible transition exception"
# By choosing a state which doesn't exist we're not running into validation
# errors and the state machine should raise if we want to transition from
# fetch to memory

b.validate = False
b.tasks[fut3.key].state = "fetch"
event.set()

assert await res1 == 5
assert await res2 == 5

del res1, res2, fut, fut2
fut3.release()

while a.tasks and b.tasks:
await asyncio.sleep(0.1)


@gen_cluster(client=True)
async def test_gather_dep_exception_one_task_2(c, s, a, b):
"""Ensure an exception in a single task does not tear down an entire batch of gather_dep

The below triggers an fetch->memory transition

See also https://github.com/dask/distributed/issues/5152
See also test_gather_dep_exception_one_task
"""
# This test does not trigger the condition reliably but is a very easy case
# which should function correctly regardless

fut1 = c.submit(inc, 1, workers=[a.address], key="f1")
fut2 = c.submit(inc, fut1, workers=[b.address], key="f2")

while fut1.key not in b.tasks or b.tasks[fut1.key].state == "flight":
await asyncio.sleep(0)

s.handle_missing_data(
key="f1", worker=b.address, errant_worker=a.address, stimulus_id="test"
)

await fut2


Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests don't make sense anymore. Also, one of them directly tampers with the state which is a big no-no.

@gen_cluster(client=True)
async def test_acquire_replicas(c, s, a, b):
fut = c.submit(inc, 1, workers=[a.address])
Expand Down
Loading