Skip to content

Commit

Permalink
Unique spans
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Jun 4, 2023
1 parent 669429d commit 8118ad8
Show file tree
Hide file tree
Showing 5 changed files with 342 additions and 195 deletions.
20 changes: 10 additions & 10 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1063,11 +1063,12 @@ class TaskGroup:
#: Cumulative duration of all completed actions, by action
all_durations: defaultdict[str, float]

#: Span ID (see distributed.spans).
#: Span ID (see ``distributed.spans``).
#: Matches ``distributed.worker_state_machine.TaskState.span_id``.
#: It is possible to end up in situation where different tasks of the same TaskGroup
#: belong to different spans; the purpose of this attribute is to arbitrarily force
#: everything onto the earliest encountered one.
span: tuple[str, ...]
span_id: str | None

__slots__ = tuple(__annotations__)

Expand All @@ -1084,7 +1085,7 @@ def __init__(self, name: str):
self.all_durations = defaultdict(float)
self.last_worker = None
self.last_worker_tasks_left = 0
self.span = ()
self.span_id = None

def add_duration(self, action: str, start: float, stop: float) -> None:
duration = stop - start
Expand Down Expand Up @@ -3338,6 +3339,7 @@ def _task_to_msg(self, ts: TaskState, duration: float = -1) -> dict[str, Any]:
"resource_restrictions": ts.resource_restrictions,
"actor": ts.actor,
"annotations": ts.annotations,
"span_id": ts.group.span_id,
}
if self.validate:
assert all(msg["who_has"].values())
Expand Down Expand Up @@ -4448,13 +4450,11 @@ def update_graph(

spans_ext: SpansExtension | None = self.extensions.get("spans")
if spans_ext:
span_annotations = spans_ext.new_tasks(new_tasks)
if span_annotations:
resolved_annotations["span"] = span_annotations
else:
# Edge case where some tasks define a span, while earlier tasks in the
# same TaskGroup don't define any
resolved_annotations.pop("span", None)
spans_ext.new_tasks(new_tasks)
# TaskGroup.span_id could be completely different from the one in the
# original annotations, so it has been dropped. Drop it here as well in
# order not to confuse SchedulerPlugin authors.
resolved_annotations.pop("span", None)

for plugin in list(self.plugins.values()):
try:
Expand Down
231 changes: 156 additions & 75 deletions distributed/spans.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations

import uuid
import weakref
from collections import defaultdict
from collections.abc import Iterable, Iterator
from contextlib import contextmanager
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

import dask.config

Expand All @@ -15,18 +17,26 @@


@contextmanager
def span(*tags: str) -> Iterator[None]:
def span(*tags: str) -> Iterator[str]:
"""Tag group of tasks to be part of a certain group, called a span.
This context manager can be nested, thus creating sub-spans.
Every cluster defines a global "default" span when no span has been defined by the client.
This context manager can be nested, thus creating sub-spans. If you close and
re-open a span context manager with the same tag, you'll end up with two separate
spans.
Every cluster defines a global "default" span when no span has been defined by the
client; the default span is automatically closed and reopened when all tasks
associated to it have been completed; in other words the cluster is idle save for
tasks that are explicitly annotated by a span. Note that, in some edge cases, you
may end up with overlapping default spans, e.g. if a worker crashes and all unique
tasks that were in memory on it need to be recomputed.
Examples
--------
>>> import dask.array as da
>>> import distributed
>>> client = distributed.Client()
>>> with span("my_workflow"):
>>> with span("my workflow"):
... with span("phase 1"):
... a = da.random.random(10)
... b = a + 1
Expand All @@ -36,39 +46,59 @@ def span(*tags: str) -> Iterator[None]:
>>> d.compute()
In the above example,
- Tasks of collections a and b will be annotated on the scheduler and workers with
``{'span': ('my_workflow', 'phase 1')}``
- Tasks of collection c (that aren't already part of a or b) will be annotated with
``{'span': ('my_workflow', 'phase 2')}``
- Tasks of collection d (that aren't already part of a, b, or c) will *not* be
annotated but will nonetheless be attached to span ``('default', )``
- Tasks of collections a and b are annotated to belong to span
``('my workflow', 'phase 1')``, 'ids': (<id0>, <id1>)}``;
- Tasks of collection c (that aren't already part of a or b) are annotated to belong
to span ``('my workflow', 'phase 2')``;
- Tasks of collection d (that aren't already part of a, b, or c) are *not*
annotated but will nonetheless be attached to span ``('default', )``.
You may also set more than one tag at once; e.g.
>>> with span("workflow1", "version1"):
... ...
Finally, you may capture the ID of a span on the client to match it with the
:class:`Span` objects the scheduler:
>>> cluster = distributed.LocalCluster()
>>> client = distributed.Client(cluster)
>>> with span("my workflow") as span_id:
... client.submit(lambda: "Hello world!").result()
>>> span = client.cluster.scheduler.extensions["spans"].spans[span_id]
Note
----
Notes
-----
Spans are based on annotations, and just like annotations they can be lost during
optimization. Set config ``optimizatione.fuse.active: false`` to prevent this issue.
optimization. Set config ``optimization.fuse.active: false`` to prevent this issue.
"""
prev_id = dask.config.get("annotations.span", ())
with dask.config.set({"annotations.span": prev_id + tags}):
yield
if not tags:
raise ValueError("Must specify at least one span tag")

prev_tags = dask.config.get("annotations.span.name", ())
# You must specify the full history of IDs, not just the parent, because
# otherwise you would not be able to uniquely identify grandparents when
# they have no tasks of their own.
prev_ids = dask.config.get("annotations.span.ids", ())
ids = tuple(str(uuid.uuid4()) for _ in tags)
with dask.annotate(span={"name": prev_tags + tags, "ids": prev_ids + ids}):
yield ids[-1]


class Span:
#: (<tag>, <tag>, ...)
#: Matches ``TaskState.annotations["span"]``, both on the scheduler and the worker,
#: as well as ``TaskGroup.span``.
#: Tasks with no 'span' annotation will be attached to Span ``("default", )``.
id: tuple[str, ...]
#: Matches ``TaskState.annotations["span"]["name"]``, both on the scheduler and the
#: worker.
name: tuple[str, ...]

#: <uuid>
#: Taken from ``TaskState.annotations["span"]["id"][-1]``.
#: Matches ``distributed.scheduler.TaskState.group.span_id``
#: and ``distributed.worker_state_machine.TaskState.span_id``.
id: str

#: Direct children of this span tree
#: Note: you can get the parent through
#: ``distributed.extensions["spans"].spans[self.id[:-1]]``
children: set[Span]
_parent: weakref.ref[Span] | None

#: Direct children of this span, sorted by creation time
children: list[Span]

#: Task groups *directly* belonging to this span.
#:
Expand All @@ -94,25 +124,40 @@ class Span:
#: stop
enqueued: float

# Support for weakrefs to a class with __slots__
__weakref__: Any

__slots__ = tuple(__annotations__)

def __init__(self, span_id: tuple[str, ...], enqueued: float):
self.id = span_id
self.enqueued = enqueued
self.children = set()
def __init__(self, name: tuple[str, ...], id_: str, parent: Span | None):
self.name = name
self.id = id_
self._parent = weakref.ref(parent) if parent is not None else None
self.enqueued = time()
self.children = []
self.groups = set()

def __repr__(self) -> str:
return f"Span{self.id}"
return f"Span<name={self.name}, id={self.id}>"

@property
def parent(self) -> Span | None:
if self._parent:
out = self._parent()
assert out
return out
return None

def traverse_spans(self) -> Iterator[Span]:
"""Top-down recursion of all spans belonging to this span tree, including self"""
"""Top-down recursion of all spans belonging to this branch off span tree,
including self
"""
yield self
for child in self.children:
yield from child.traverse_spans()

def traverse_groups(self) -> Iterator[TaskGroup]:
"""All TaskGroups belonging to this span tree"""
"""All TaskGroups belonging to this branch of span tree"""
for span in self.traverse_spans():
yield from span.groups

Expand Down Expand Up @@ -161,10 +206,26 @@ def states(self) -> defaultdict[TaskStateState, int]:
"""
out: defaultdict[TaskStateState, int] = defaultdict(int)
for tg in self.traverse_groups():
for state, cnt in tg.states.items():
out[state] += cnt
for state, count in tg.states.items():
out[state] += count
return out

@property
def done(self) -> bool:
"""Return True if all tasks in this span tree are completed; False otherwise.
Notes
-----
This property may transition from True to False, e.g. when a new sub-span is
added or when a worker that contained the only replica of a task in memory
crashes and the task need to be recomputed.
See also
--------
distributed.scheduler.TaskGroup.done
"""
return all(tg.done for tg in self.traverse_groups())

@property
def all_durations(self) -> defaultdict[str, float]:
"""Cumulative duration of all completed actions in this span tree, by action
Expand Down Expand Up @@ -205,72 +266,92 @@ def nbytes_total(self) -> int:
class SpansExtension:
"""Scheduler extension for spans support"""

#: All Span objects by span_id
spans: dict[tuple[str, ...], Span]
#: All Span objects by id
spans: dict[str, Span]

#: Only the spans that don't have any parents {client_id: Span}.
#: Only the spans that don't have any parents, sorted by creation time.
#: This is a convenience helper structure to speed up searches.
root_spans: dict[str, Span]
root_spans: list[Span]

#: All spans, keyed by the individual tags that make up their span_id.
#: All spans, keyed by their full name and sorted by creation time.
#: This is a convenience helper structure to speed up searches.
spans_search_by_tag: defaultdict[str, set[Span]]
spans_search_by_name: defaultdict[tuple[str, ...], list[Span]]

#: All spans, keyed by the individual tags that make up their name and sorted by
#: creation time.
#: This is a convenience helper structure to speed up searches.
spans_search_by_tag: defaultdict[str, list[Span]]

def __init__(self, scheduler: Scheduler):
self.spans = {}
self.root_spans = {}
self.spans_search_by_tag = defaultdict(set)
self.root_spans = []
self.spans_search_by_name = defaultdict(list)
self.spans_search_by_tag = defaultdict(list)

def new_tasks(self, tss: Iterable[TaskState]) -> dict[str, tuple[str, ...]]:
def new_tasks(self, tss: Iterable[TaskState]) -> None:
"""Acknowledge the creation of new tasks on the scheduler.
Attach tasks to either the desired span or to ("default", ).
Update TaskState.annotations["span"] and TaskGroup.span.
Returns
-------
{task key: span id}, only for tasks that explicitly define a span
Update TaskGroup.span_id and wipe TaskState.annotations["span"].
"""
out = {}
default_span = None

for ts in tss:
# You may have different tasks belonging to the same TaskGroup but to
# different spans. If that happens, arbitrarily force everything onto the
# span of the earliest encountered TaskGroup.
tg = ts.group
if tg.span:
span_id = tg.span
else:
span_id = ts.annotations.get("span", ("default",))
assert isinstance(span_id, tuple)
tg.span = span_id
span = self._ensure_span(span_id)
if not tg.span_id:
ann = ts.annotations.get("span")
if ann:
span = self._ensure_span(ann["name"], ann["ids"])
else:
if not default_span:
default_span = self._ensure_default_span()
span = default_span

tg.span_id = span.id
span.groups.add(tg)

# Override ts.annotations["span"] with span_id from task group
if span_id == ("default",):
ts.annotations.pop("span", None)
else:
ts.annotations["span"] = out[ts.key] = span_id

return out
# The span may be completely different from the one referenced by the
# annotation, due to the TaskGroup collision issue explained above.
# Remove the annotation to avoid confusion, and instead rely on
# distributed.scheduler.TaskState.group.span_id and
# distributed.worker_state_machine.TaskState.span_id.
ts.annotations.pop("span", None)

def _ensure_default_span(self) -> Span:
"""Return the currently active default span, or create one if the previous one
terminated. In other words, do not reuse the previous default span if all tasks
that were not explicitly annotated with :func:`spans` on the client side are
finished.
"""
defaults = self.spans_search_by_name["default",]
if defaults and not defaults[-1].done:
return defaults[-1]
return self._ensure_span(("default",), (str(uuid.uuid4()),))

def _ensure_span(self, span_id: tuple[str, ...], enqueued: float = 0.0) -> Span:
def _ensure_span(self, name: tuple[str, ...], ids: tuple[str, ...]) -> Span:
"""Create Span if it doesn't exist and return it"""
try:
return self.spans[span_id]
return self.spans[ids[-1]]
except KeyError:
pass

# When recursively creating parent spans, make sure that parents are not newer
# than the children
enqueued = enqueued or time()
assert len(name) == len(ids)
assert len(name) > 0

parent = None
for i in range(1, len(name)):
parent = self._ensure_span(name[:i], ids[:i])

span = self.spans[span_id] = Span(span_id, enqueued)
for tag in span_id:
self.spans_search_by_tag[tag].add(span)
if len(span_id) > 1:
parent = self._ensure_span(span_id[:-1], enqueued)
parent.children.add(span)
span = Span(name=name, id_=ids[-1], parent=parent)
self.spans[span.id] = span
self.spans_search_by_name[name].append(span)
for tag in name:
self.spans_search_by_tag[tag].append(span)
if parent:
parent.children.append(span)
else:
self.root_spans[span_id[0]] = span
self.root_spans.append(span)

return span
Loading

0 comments on commit 8118ad8

Please sign in to comment.