Skip to content

Commit 212e00a

Browse files
committed
Unique spans
1 parent 669429d commit 212e00a

File tree

5 files changed

+337
-192
lines changed

5 files changed

+337
-192
lines changed

distributed/scheduler.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1063,11 +1063,12 @@ class TaskGroup:
10631063
#: Cumulative duration of all completed actions, by action
10641064
all_durations: defaultdict[str, float]
10651065

1066-
#: Span ID (see distributed.spans).
1066+
#: Span ID (see ``distributed.spans``).
1067+
#: Matches ``distributed.worker_state_machine.TaskState.span_id``.
10671068
#: It is possible to end up in situation where different tasks of the same TaskGroup
10681069
#: belong to different spans; the purpose of this attribute is to arbitrarily force
10691070
#: everything onto the earliest encountered one.
1070-
span: tuple[str, ...]
1071+
span_id: str | None
10711072

10721073
__slots__ = tuple(__annotations__)
10731074

@@ -1084,7 +1085,7 @@ def __init__(self, name: str):
10841085
self.all_durations = defaultdict(float)
10851086
self.last_worker = None
10861087
self.last_worker_tasks_left = 0
1087-
self.span = ()
1088+
self.span_id = None
10881089

10891090
def add_duration(self, action: str, start: float, stop: float) -> None:
10901091
duration = stop - start
@@ -3338,6 +3339,7 @@ def _task_to_msg(self, ts: TaskState, duration: float = -1) -> dict[str, Any]:
33383339
"resource_restrictions": ts.resource_restrictions,
33393340
"actor": ts.actor,
33403341
"annotations": ts.annotations,
3342+
"span_id": ts.group.span_id,
33413343
}
33423344
if self.validate:
33433345
assert all(msg["who_has"].values())
@@ -4448,13 +4450,11 @@ def update_graph(
44484450

44494451
spans_ext: SpansExtension | None = self.extensions.get("spans")
44504452
if spans_ext:
4451-
span_annotations = spans_ext.new_tasks(new_tasks)
4452-
if span_annotations:
4453-
resolved_annotations["span"] = span_annotations
4454-
else:
4455-
# Edge case where some tasks define a span, while earlier tasks in the
4456-
# same TaskGroup don't define any
4457-
resolved_annotations.pop("span", None)
4453+
spans_ext.new_tasks(new_tasks)
4454+
# TaskGroup.span_id could be completely different from the one in the
4455+
# original annotations, so it has been dropped. Drop it here as well in
4456+
# order not to confuse SchedulerPlugin authors.
4457+
resolved_annotations.pop("span", None)
44584458

44594459
for plugin in list(self.plugins.values()):
44604460
try:

distributed/spans.py

Lines changed: 163 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from __future__ import annotations
22

3+
import uuid
4+
import weakref
35
from collections import defaultdict
46
from collections.abc import Iterable, Iterator
57
from contextlib import contextmanager
6-
from typing import TYPE_CHECKING
8+
from typing import TYPE_CHECKING, Any
79

810
import dask.config
911

@@ -15,18 +17,26 @@
1517

1618

1719
@contextmanager
18-
def span(*tags: str) -> Iterator[None]:
20+
def span(*tags: str) -> Iterator[str]:
1921
"""Tag group of tasks to be part of a certain group, called a span.
2022
21-
This context manager can be nested, thus creating sub-spans.
22-
Every cluster defines a global "default" span when no span has been defined by the client.
23+
This context manager can be nested, thus creating sub-spans. If you close and
24+
re-open a span context manager with the same tag, you'll end up with two separate
25+
spans.
26+
27+
Every cluster defines a global "default" span when no span has been defined by the
28+
client; the default span is automatically closed and reopened when all tasks
29+
associated to it have been completed; in other words the cluster is idle save for
30+
tasks that are explicitly annotated by a span. Note that, in some edge cases, you
31+
may end up with overlapping default spans, e.g. if a worker crashes and all unique
32+
tasks that were in memory on it need to be recomputed.
2333
2434
Examples
2535
--------
2636
>>> import dask.array as da
2737
>>> import distributed
2838
>>> client = distributed.Client()
29-
>>> with span("my_workflow"):
39+
>>> with span("my workflow"):
3040
... with span("phase 1"):
3141
... a = da.random.random(10)
3242
... b = a + 1
@@ -36,39 +46,59 @@ def span(*tags: str) -> Iterator[None]:
3646
>>> d.compute()
3747
3848
In the above example,
39-
- Tasks of collections a and b will be annotated on the scheduler and workers with
40-
``{'span': ('my_workflow', 'phase 1')}``
41-
- Tasks of collection c (that aren't already part of a or b) will be annotated with
42-
``{'span': ('my_workflow', 'phase 2')}``
43-
- Tasks of collection d (that aren't already part of a, b, or c) will *not* be
44-
annotated but will nonetheless be attached to span ``('default', )``
49+
- Tasks of collections a and b are annotated to belong to span
50+
``('my workflow', 'phase 1')``, 'ids': (<id0>, <id1>)}``;
51+
- Tasks of collection c (that aren't already part of a or b) are annotated to belong
52+
to span ``('my workflow', 'phase 2')``;
53+
- Tasks of collection d (that aren't already part of a, b, or c) are *not*
54+
annotated but will nonetheless be attached to span ``('default', )``.
4555
4656
You may also set more than one tag at once; e.g.
4757
>>> with span("workflow1", "version1"):
4858
... ...
4959
60+
Finally, you may capture the ID of a span on the client to match it with the
61+
:class:`Span` objects the scheduler:
62+
>>> cluster = distributed.LocalCluster()
63+
>>> client = distributed.Client(cluster)
64+
>>> with span("my workflow") as span_id:
65+
... client.submit(lambda: "Hello world!").result()
66+
>>> span = client.cluster.scheduler.extensions["spans"].spans[span_id]
5067
51-
Note
52-
----
68+
Notes
69+
-----
5370
Spans are based on annotations, and just like annotations they can be lost during
54-
optimization. Set config ``optimizatione.fuse.active: false`` to prevent this issue.
71+
optimization. Set config ``optimization.fuse.active: false`` to prevent this issue.
5572
"""
56-
prev_id = dask.config.get("annotations.span", ())
57-
with dask.config.set({"annotations.span": prev_id + tags}):
58-
yield
73+
if not tags:
74+
raise ValueError("Must specify at least one span tag")
75+
76+
prev_tags = dask.config.get("annotations.span.name", ())
77+
# You must specify the full history of IDs, not just the parent, because
78+
# otherwise you would not be able to uniquely identify grandparents when
79+
# they have no tasks of their own.
80+
prev_ids = dask.config.get("annotations.span.ids", ())
81+
ids = tuple(str(uuid.uuid4()) for _ in tags)
82+
with dask.annotate(span={"name": prev_tags + tags, "ids": prev_ids + ids}):
83+
yield ids[-1]
5984

6085

6186
class Span:
6287
#: (<tag>, <tag>, ...)
63-
#: Matches ``TaskState.annotations["span"]``, both on the scheduler and the worker,
64-
#: as well as ``TaskGroup.span``.
65-
#: Tasks with no 'span' annotation will be attached to Span ``("default", )``.
66-
id: tuple[str, ...]
88+
#: Matches ``TaskState.annotations["span"]["name"]``, both on the scheduler and the
89+
#: worker.
90+
name: tuple[str, ...]
91+
92+
#: <uuid>
93+
#: Taken from ``TaskState.annotations["span"]["id"][-1]``.
94+
#: Matches ``distributed.scheduler.TaskState.group.span_id``
95+
#: and ``distributed.worker_state_machine.TaskState.span_id``.
96+
id: str
97+
98+
_parent: weakref.ref[Span] | None
6799

68-
#: Direct children of this span tree
69-
#: Note: you can get the parent through
70-
#: ``distributed.extensions["spans"].spans[self.id[:-1]]``
71-
children: set[Span]
100+
#: Direct children of this span, sorted by creation time
101+
children: list[Span]
72102

73103
#: Task groups *directly* belonging to this span.
74104
#:
@@ -94,25 +124,47 @@ class Span:
94124
#: stop
95125
enqueued: float
96126

127+
# Support for weakrefs to a class with __slots__
128+
__weakref__: Any
129+
97130
__slots__ = tuple(__annotations__)
98131

99-
def __init__(self, span_id: tuple[str, ...], enqueued: float):
100-
self.id = span_id
101-
self.enqueued = enqueued
102-
self.children = set()
132+
def __init__(self, name: tuple[str, ...], id_: str, parent: Span | None):
133+
self.name = name
134+
self.id = id_
135+
self._parent = weakref.ref(parent) if parent is not None else None
136+
self.enqueued = time()
137+
self.children = []
103138
self.groups = set()
104139

105140
def __repr__(self) -> str:
106-
return f"Span{self.id}"
141+
return f"Span<name={self.name}, id={self.id}>"
142+
143+
def __getstate__(self) -> tuple[None, dict]:
144+
"""Break link to parent Span upon pickle"""
145+
return (
146+
None,
147+
{k: getattr(self, k) if k != "_parent" else None for k in self.__slots__},
148+
)
149+
150+
@property
151+
def parent(self) -> Span | None:
152+
if self._parent:
153+
out = self._parent()
154+
assert out
155+
return out
156+
return None
107157

108158
def traverse_spans(self) -> Iterator[Span]:
109-
"""Top-down recursion of all spans belonging to this span tree, including self"""
159+
"""Top-down recursion of all spans belonging to this branch off span tree,
160+
including self
161+
"""
110162
yield self
111163
for child in self.children:
112164
yield from child.traverse_spans()
113165

114166
def traverse_groups(self) -> Iterator[TaskGroup]:
115-
"""All TaskGroups belonging to this span tree"""
167+
"""All TaskGroups belonging to this branch of span tree"""
116168
for span in self.traverse_spans():
117169
yield from span.groups
118170

@@ -161,10 +213,26 @@ def states(self) -> defaultdict[TaskStateState, int]:
161213
"""
162214
out: defaultdict[TaskStateState, int] = defaultdict(int)
163215
for tg in self.traverse_groups():
164-
for state, cnt in tg.states.items():
165-
out[state] += cnt
216+
for state, count in tg.states.items():
217+
out[state] += count
166218
return out
167219

220+
@property
221+
def done(self) -> bool:
222+
"""Return True if all tasks in this span tree are completed; False otherwise.
223+
224+
Notes
225+
-----
226+
This property may transition from True to False, e.g. when a new sub-span is
227+
added or when a worker that contained the only replica of a task in memory
228+
crashes and the task need to be recomputed.
229+
230+
See also
231+
--------
232+
distributed.scheduler.TaskGroup.done
233+
"""
234+
return all(tg.done for tg in self.traverse_groups())
235+
168236
@property
169237
def all_durations(self) -> defaultdict[str, float]:
170238
"""Cumulative duration of all completed actions in this span tree, by action
@@ -205,72 +273,92 @@ def nbytes_total(self) -> int:
205273
class SpansExtension:
206274
"""Scheduler extension for spans support"""
207275

208-
#: All Span objects by span_id
209-
spans: dict[tuple[str, ...], Span]
276+
#: All Span objects by id
277+
spans: dict[str, Span]
278+
279+
#: Only the spans that don't have any parents, sorted by creation time.
280+
#: This is a convenience helper structure to speed up searches.
281+
root_spans: list[Span]
210282

211-
#: Only the spans that don't have any parents {client_id: Span}.
283+
#: All spans, keyed by their full name and sorted by creation time.
212284
#: This is a convenience helper structure to speed up searches.
213-
root_spans: dict[str, Span]
285+
spans_search_by_name: defaultdict[tuple[str, ...], list[Span]]
214286

215-
#: All spans, keyed by the individual tags that make up their span_id.
287+
#: All spans, keyed by the individual tags that make up their name and sorted by
288+
#: creation time.
216289
#: This is a convenience helper structure to speed up searches.
217-
spans_search_by_tag: defaultdict[str, set[Span]]
290+
spans_search_by_tag: defaultdict[str, list[Span]]
218291

219292
def __init__(self, scheduler: Scheduler):
220293
self.spans = {}
221-
self.root_spans = {}
222-
self.spans_search_by_tag = defaultdict(set)
294+
self.root_spans = []
295+
self.spans_search_by_name = defaultdict(list)
296+
self.spans_search_by_tag = defaultdict(list)
223297

224-
def new_tasks(self, tss: Iterable[TaskState]) -> dict[str, tuple[str, ...]]:
298+
def new_tasks(self, tss: Iterable[TaskState]) -> None:
225299
"""Acknowledge the creation of new tasks on the scheduler.
226300
Attach tasks to either the desired span or to ("default", ).
227-
Update TaskState.annotations["span"] and TaskGroup.span.
228-
229-
Returns
230-
-------
231-
{task key: span id}, only for tasks that explicitly define a span
301+
Update TaskGroup.span_id and wipe TaskState.annotations["span"].
232302
"""
233-
out = {}
303+
default_span = None
304+
234305
for ts in tss:
235306
# You may have different tasks belonging to the same TaskGroup but to
236307
# different spans. If that happens, arbitrarily force everything onto the
237308
# span of the earliest encountered TaskGroup.
238309
tg = ts.group
239-
if tg.span:
240-
span_id = tg.span
241-
else:
242-
span_id = ts.annotations.get("span", ("default",))
243-
assert isinstance(span_id, tuple)
244-
tg.span = span_id
245-
span = self._ensure_span(span_id)
310+
if not tg.span_id:
311+
ann = ts.annotations.get("span")
312+
if ann:
313+
span = self._ensure_span(ann["name"], ann["ids"])
314+
else:
315+
if not default_span:
316+
default_span = self._ensure_default_span()
317+
span = default_span
318+
319+
tg.span_id = span.id
246320
span.groups.add(tg)
247321

248-
# Override ts.annotations["span"] with span_id from task group
249-
if span_id == ("default",):
250-
ts.annotations.pop("span", None)
251-
else:
252-
ts.annotations["span"] = out[ts.key] = span_id
253-
254-
return out
322+
# The span may be completely different from the one referenced by the
323+
# annotation, due to the TaskGroup collision issue explained above.
324+
# Remove the annotation to avoid confusion, and instead rely on
325+
# distributed.scheduler.TaskState.group.span_id and
326+
# distributed.worker_state_machine.TaskState.span_id.
327+
ts.annotations.pop("span", None)
328+
329+
def _ensure_default_span(self) -> Span:
330+
"""Return the currently active default span, or create one if the previous one
331+
terminated. In other words, do not reuse the previous default span if all tasks
332+
that were not explicitly annotated with :func:`spans` on the client side are
333+
finished.
334+
"""
335+
defaults = self.spans_search_by_name["default",]
336+
if defaults and not defaults[-1].done:
337+
return defaults[-1]
338+
return self._ensure_span(("default",), (str(uuid.uuid4()),))
255339

256-
def _ensure_span(self, span_id: tuple[str, ...], enqueued: float = 0.0) -> Span:
340+
def _ensure_span(self, name: tuple[str, ...], ids: tuple[str, ...]) -> Span:
257341
"""Create Span if it doesn't exist and return it"""
258342
try:
259-
return self.spans[span_id]
343+
return self.spans[ids[-1]]
260344
except KeyError:
261345
pass
262346

263-
# When recursively creating parent spans, make sure that parents are not newer
264-
# than the children
265-
enqueued = enqueued or time()
347+
assert len(name) == len(ids)
348+
assert len(name) > 0
349+
350+
parent = None
351+
for i in range(1, len(name)):
352+
parent = self._ensure_span(name[:i], ids[:i])
266353

267-
span = self.spans[span_id] = Span(span_id, enqueued)
268-
for tag in span_id:
269-
self.spans_search_by_tag[tag].add(span)
270-
if len(span_id) > 1:
271-
parent = self._ensure_span(span_id[:-1], enqueued)
272-
parent.children.add(span)
354+
span = Span(name=name, id_=ids[-1], parent=parent)
355+
self.spans[span.id] = span
356+
self.spans_search_by_name[name].append(span)
357+
for tag in name:
358+
self.spans_search_by_tag[tag].append(span)
359+
if parent:
360+
parent.children.append(span)
273361
else:
274-
self.root_spans[span_id[0]] = span
362+
self.root_spans.append(span)
275363

276364
return span

0 commit comments

Comments
 (0)