Skip to content

Commit 98ed10e

Browse files
committed
initial integration
1 parent 7833601 commit 98ed10e

File tree

10 files changed

+282
-64
lines changed

10 files changed

+282
-64
lines changed

distributed/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from collections.abc import Callable, Container, Coroutine, Generator
1515
from enum import Enum
1616
from functools import partial
17-
from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, TypeVar, final
17+
from typing import TYPE_CHECKING, Any, ClassVar, Hashable, TypedDict, TypeVar, final
1818

1919
import tblib
2020
from tlz import merge
@@ -943,7 +943,7 @@ async def close(self, timeout=None):
943943
finally:
944944
self._event_finished.set()
945945

946-
def digest_metric(self, name: str, value: float) -> None:
946+
def digest_metric(self, name: Hashable, value: float) -> None:
947947
# Granular data (requires crick)
948948
if self.digests is not None:
949949
self.digests[name].add(value)

distributed/protocol/compression.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import dask
1919

20+
from distributed.span import meter
2021
from distributed.utils import ensure_memoryview, nbytes, no_default
2122

2223
compressions: dict[
@@ -150,6 +151,7 @@ def byte_sample(b, size, n):
150151
return memoryview(b"".join(parts))
151152

152153

154+
@meter("compress")
153155
def maybe_compress(
154156
payload,
155157
min_size=10_000,
@@ -195,6 +197,7 @@ def maybe_compress(
195197
return None, payload
196198

197199

200+
@meter("decompress")
198201
def decompress(header, frames):
199202
"""Decompress frames according to information in the header"""
200203
return [

distributed/protocol/serialize.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
pack_frames_prelude,
2525
unpack_frames,
2626
)
27+
from distributed.span import meter
2728
from distributed.utils import ensure_memoryview, has_keyword
2829

2930
dask_serialize = dask.utils.Dispatch("dask_serialize")
@@ -426,6 +427,7 @@ def deserialize(header, frames, deserializers=None):
426427
return loads(header, frames)
427428

428429

430+
@meter("serialize")
429431
def serialize_and_split(
430432
x, serializers=None, on_error="message", context=None, size=None
431433
):
@@ -470,6 +472,7 @@ def serialize_and_split(
470472
return header, out_frames
471473

472474

475+
@meter("deserialize")
473476
def merge_and_deserialize(header, frames, deserializers=None):
474477
"""Merge and deserialize frames
475478

distributed/span.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,14 @@ def own_time(self) -> float:
4747
assert self.stop_time is not None
4848
return self.total_time - self.other_time
4949

50+
@property
51+
def done(self) -> bool:
52+
return self.start_time is not None and self.stop_time is not None
53+
54+
@property
55+
def running(self) -> bool:
56+
return self.start_time is not None and self.stop_time is None
57+
5058
def start(self) -> None:
5159
assert self.start_time is None
5260
assert self.stop_time is None
@@ -83,14 +91,19 @@ def _unset_current(self) -> None:
8391
_current_span.reset(self._token)
8492
self._token = None
8593

86-
def _subspan(self, label: str | tuple[str, ...]) -> Span:
94+
def _subspan(
95+
self,
96+
label: str | tuple[str, ...],
97+
metric: Callable[[], float] = time.perf_counter,
98+
) -> Span:
8799
assert (
88100
self.start_time is not None
89101
), "Cannot create sub-span for a span that has not started"
90102
assert (
91103
self.stop_time is None
92104
), "Cannot create sub-span for a span that has already stopped"
93-
span = Span(label, self.metric)
105+
# TODO allow different metrics, or always use `self.metric`?
106+
span = Span(label, metric)
94107
self.subspans.append(span)
95108
return span
96109

@@ -130,7 +143,7 @@ def __repr__(self) -> str:
130143
return (
131144
f"{type(self).__name__}<"
132145
f"{self.label!r}, "
133-
f"total_time={self.total_time}, "
146+
f"total_time={self.total_time if self.done else '...'}, "
134147
f"start_time={self.start_time}, "
135148
f"stop_time={self.stop_time}, "
136149
">"
@@ -155,7 +168,7 @@ def get_span(
155168
except LookupError:
156169
span = Span(label, metric)
157170
else:
158-
assert metric is parent.metric, (metric, parent.metric)
171+
# assert metric is parent.metric, (metric, parent.metric)
159172
span = parent._subspan(label)
160173

161174
return span

distributed/spill.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from distributed.protocol import deserialize_bytes, serialize_bytelist
1616
from distributed.sizeof import safe_sizeof
17+
from distributed.span import meter
1718
from distributed.utils import RateLimiterFilter
1819

1920
logger = logging.getLogger(__name__)
@@ -239,8 +240,9 @@ def __getitem__(self, key: str) -> Any:
239240
if key in self.fast:
240241
# Note: don't log from self.fast.__getitem__, because that's called every
241242
# time a key is evicted, and we don't want to count those events here.
242-
nbytes = cast(int, self.fast.weights[key])
243-
self.fast_metrics.log_read(nbytes)
243+
with meter("memory-read"):
244+
nbytes = cast(int, self.fast.weights[key])
245+
self.fast_metrics.log_read(nbytes)
244246

245247
return super().__getitem__(key)
246248

@@ -362,7 +364,8 @@ def __init__(self, spill_directory: str, max_weight: int | Literal[False] = Fals
362364

363365
def __getitem__(self, key: str) -> Any:
364366
t0 = perf_counter()
365-
pickled = self.d[key]
367+
with meter("disk-read"):
368+
pickled = self.d[key]
366369
assert isinstance(pickled, bytearray if has_zict_230 else bytes)
367370
t1 = perf_counter()
368371
out = self.load(pickled)
@@ -408,7 +411,8 @@ def __setitem__(self, key: str, value: Any) -> None:
408411

409412
# Store to disk through File.
410413
# This may raise OSError, which is caught by SpillBuffer above.
411-
self.d[key] = pickled
414+
with meter("disk-write"):
415+
self.d[key] = pickled
412416

413417
t2 = perf_counter()
414418

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from __future__ import annotations
2+
3+
from collections.abc import Hashable
4+
5+
from distributed import Event, Worker
6+
from distributed.utils_test import async_wait_for, gen_cluster, inc, wait_for_state
7+
8+
9+
def get_digests(w: Worker, allow: str | None = None) -> dict[Hashable, float]:
10+
# import pprint; pprint.pprint(dict(w.digests_total))
11+
digests = {
12+
k: v
13+
for k, v in w.digests_total.items()
14+
if k
15+
not in {
16+
"latency",
17+
"tick-duration",
18+
"transfer-bandwidth",
19+
"transfer-duration",
20+
"compute-duration",
21+
}
22+
and (allow is None or allow in k)
23+
}
24+
assert all(v >= 0 for v in digests.values()), digests
25+
return digests
26+
27+
28+
@gen_cluster(client=True, nthreads=[("", 1)])
29+
async def test_basic_execute(c, s, a):
30+
await c.submit(inc, 1, key="x")
31+
assert list(get_digests(a)) == [
32+
("execute", "x", "thread", "thread-cpu"),
33+
("execute", "x", "thread"),
34+
("execute", "x"),
35+
]
36+
37+
38+
# @gen_cluster(client=True, nthreads=[("", 1)])
39+
# async def test_run_spec_deserialization(c, s, a):
40+
# """Test that deserialization of run_spec is metered"""
41+
# await c.submit(inc, 1, key="x")
42+
# assert 0 < a.digests_total["execute", "x", "deserialize", "seconds"] < 1
43+
44+
45+
@gen_cluster(client=True, nthreads=[("", 1)])
46+
async def test_cancelled_execute(c, s, a):
47+
"""cancelled(execute) tasks are metered as a separate lump total"""
48+
ev = await Event()
49+
x = c.submit(lambda ev: ev.wait(), ev, key="x")
50+
await wait_for_state("x", "executing", a)
51+
del x
52+
await wait_for_state("x", "cancelled", a)
53+
await ev.set()
54+
await async_wait_for(lambda: not a.state.tasks, timeout=5)
55+
56+
print(list(get_digests(a)))
57+
assert list(get_digests(a)) == [("execute", "x", "cancelled")]

distributed/utils.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,11 @@
2121
from asyncio import TimeoutError
2222
from collections import deque
2323
from collections.abc import Callable, Collection, Container, KeysView, ValuesView
24-
from concurrent.futures import CancelledError, ThreadPoolExecutor # noqa: F401
24+
from concurrent.futures import ( # noqa: F401
25+
CancelledError,
26+
Executor,
27+
ThreadPoolExecutor,
28+
)
2529
from contextlib import contextmanager, suppress
2630
from contextvars import ContextVar
2731
from datetime import timedelta
@@ -1410,12 +1414,28 @@ def import_term(name: str) -> AnyType:
14101414
return getattr(module, attr_name)
14111415

14121416

1413-
async def offload(fn, *args, **kwargs):
1417+
async def offload( # type: ignore[valid-type]
1418+
fn: Callable[P, T], # FIXME improper use of `ParamSpec`?
1419+
*args: P.args,
1420+
executor: Executor | None = None,
1421+
**kwargs: P.kwargs,
1422+
) -> T:
1423+
"""Wrapper around :meth:`~asyncio.AbstractEventLoop.run_in_executor`, which
1424+
propagates contextvars.
1425+
By default, it offloads to a thread pool with a single worker.
1426+
See also
1427+
--------
1428+
https://bugs.python.org/issue34014
1429+
"""
1430+
if executor is None:
1431+
# Not the same as defaulting to _offload_executor in the parameters, as this
1432+
# allows monkey-patching the _offload_executor during unit tests
1433+
executor = _offload_executor
1434+
14141435
loop = asyncio.get_running_loop()
1415-
# Retain context vars while deserializing; see https://bugs.python.org/issue34014
14161436
context = contextvars.copy_context()
14171437
return await loop.run_in_executor(
1418-
_offload_executor, lambda: context.run(fn, *args, **kwargs)
1438+
executor, lambda: context.run(fn, *args, **kwargs)
14191439
)
14201440

14211441

distributed/utils_test.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
from distributed.protocol import deserialize
6464
from distributed.scheduler import TaskState as SchedulerTaskState
6565
from distributed.security import Security
66+
from distributed.span import Span
6667
from distributed.utils import (
6768
DequeHandler,
6869
_offload_executor,
@@ -82,7 +83,7 @@
8283
StateMachineEvent,
8384
)
8485
from distributed.worker_state_machine import TaskState as WorkerTaskState
85-
from distributed.worker_state_machine import WorkerState
86+
from distributed.worker_state_machine import TracedEvent, WorkerState
8687

8788
try:
8889
import dask.array # register config
@@ -2267,11 +2268,11 @@ def __init__(self, *args, **kwargs):
22672268

22682269
super().__init__(*args, **kwargs)
22692270

2270-
async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent:
2271+
async def execute(self, key: str, *, stimulus_id: str, span: Span) -> TracedEvent:
22712272
self.in_execute.set()
22722273
await self.block_execute.wait()
22732274
try:
2274-
return await super().execute(key, stimulus_id=stimulus_id)
2275+
return await super().execute(key, stimulus_id=stimulus_id, span=span)
22752276
finally:
22762277
self.in_execute_exit.set()
22772278
await self.block_execute_exit.wait()

0 commit comments

Comments
 (0)