Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions distributed/http/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ async def test_prometheus_api_doc(c, s, a):
"""
pytest.importorskip("prometheus_client")

# Some metrics only appear after a task is executed
await c.submit(inc, 1)
# Some metrics only appear if there are tasks on the cluster
fut = c.submit(inc, 1)
await fut
# Semaphore metrics only appear after semaphores are used
sem = await Semaphore()
await sem.acquire()
Expand Down
12 changes: 7 additions & 5 deletions distributed/http/worker/prometheus/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,15 @@ def collect(self) -> Iterator[Metric]:
"Number of tasks at worker.",
labels=["state"],
)
for k, n in ws.task_counts.items():
if k == "memory" and hasattr(self.server.data, "slow"):
for state, n in ws.task_counter.current_count(by_prefix=False).items():
if state == "memory" and hasattr(self.server.data, "slow"):
n_spilled = len(self.server.data.slow)
tasks.add_metric(["memory"], n - n_spilled)
tasks.add_metric(["disk"], n_spilled)
if n - n_spilled > 0:
tasks.add_metric(["memory"], n - n_spilled)
if n_spilled > 0:
tasks.add_metric(["disk"], n_spilled)
else:
tasks.add_metric([k], n)
tasks.add_metric([state], n)
yield tasks

yield GaugeMetricFamily(
Expand Down
46 changes: 20 additions & 26 deletions distributed/http/worker/tests/test_worker_http.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import asyncio
import json
from unittest import mock

Expand All @@ -10,6 +9,7 @@
from distributed import Event, Worker, wait
from distributed.sizeof import sizeof
from distributed.utils_test import (
async_wait_for,
fetch_metrics,
fetch_metrics_body,
fetch_metrics_sample_names,
Expand All @@ -21,6 +21,10 @@
async def test_prometheus(c, s, a):
pytest.importorskip("prometheus_client")

# We need *some* tasks or dask_worker_tasks won't appear
fut = c.submit(lambda: 1)
await wait(fut)

active_metrics = await fetch_metrics_sample_names(
a.http_server.port, prefix="dask_worker_"
)
Expand Down Expand Up @@ -88,22 +92,7 @@ async def test_metrics_when_prometheus_client_not_installed(
async def test_prometheus_collect_task_states(c, s, a):
pytest.importorskip("prometheus_client")

async def assert_metrics(**kwargs):
expect = {
"constrained": 0,
"executing": 0,
"fetch": 0,
"flight": 0,
"long-running": 0,
"memory": 0,
"disk": 0,
"missing": 0,
"other": 0,
"ready": 0,
"waiting": 0,
}
expect.update(kwargs)

async def assert_metrics(**expect):
families = await fetch_metrics(a.http_server.port, prefix="dask_worker_")
actual = {
sample.labels["state"]: sample.value
Expand All @@ -117,24 +106,29 @@ async def assert_metrics(**kwargs):
ev = Event()

# submit a task which should show up in the prometheus scraping
future = c.submit(ev.wait)
while not a.state.executing:
await asyncio.sleep(0.001)
fut1 = c.submit(ev.wait)
await async_wait_for(lambda: a.state.executing, timeout=5)

await assert_metrics(executing=1)

await ev.set()
await c.gather(future)
await wait(fut1)

await assert_metrics(memory=1)
a.data.evict()
await assert_metrics(disk=1)

future.release()
fut2 = c.submit(lambda: 1)
await wait(fut2)
await assert_metrics(memory=2)

a.data.evict()
await assert_metrics(memory=1, disk=1)
a.data.evict()
await assert_metrics(disk=2)

while future.key in a.state.tasks:
await asyncio.sleep(0.001)
fut1.release()
fut2.release()

await async_wait_for(lambda: not a.state.tasks, timeout=5)
await assert_metrics()


Expand Down
9 changes: 2 additions & 7 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3817,14 +3817,9 @@ async def test_transition_counter_max_worker(c, s, a):
# This is set by @gen_cluster; it's False in production
assert s.transition_counter_max > 0
a.state.transition_counter_max = 1
fut = c.submit(inc, 2)
with captured_logger("distributed.worker") as logger:
fut = c.submit(inc, 2)
while True:
try:
a.validate_state()
except AssertionError:
break
await asyncio.sleep(0.01)
await async_wait_for(lambda: a.state.transition_counter > 0, timeout=5)

assert "TransitionCounterMaxExceeded" in logger.getvalue()
# Worker state is corrupted. Avoid test failure on gen_cluster teardown.
Expand Down
2 changes: 2 additions & 0 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3428,6 +3428,8 @@ async def test_Worker__to_dict(c, s, a):
"transition_counter",
"tasks",
"data_needed",
"task_counts",
"task_cumulative_elapsed",
}
assert d["tasks"]["x"]["key"] == "x"
assert d["data"] == {"x": None}
Expand Down
121 changes: 75 additions & 46 deletions distributed/tests/test_worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pickle
from collections import defaultdict
from collections.abc import Iterator
from time import sleep

import pytest
from tlz import first
Expand All @@ -13,13 +14,15 @@

import distributed.profile as profile
from distributed import Nanny, Worker, wait
from distributed.compatibility import MACOS, WINDOWS
from distributed.protocol.serialize import Serialize
from distributed.scheduler import TaskState as SchedulerTaskState
from distributed.utils import recursive_to_dict
from distributed.utils_test import (
NO_AMM,
_LockedCommPool,
assert_story,
async_wait_for,
freeze_data_fetching,
gen_cluster,
inc,
Expand Down Expand Up @@ -205,8 +208,18 @@ def test_WorkerState__to_dict(ws):
"state": "memory",
},
},
"task_counts": {"['x', 'flight']": 1, "['y', 'memory']": 1},
"task_cumulative_elapsed": {
"['x', 'flight']": "SNIP",
"['y', 'memory']": "SNIP",
},
"transition_counter": 3,
}

# timings data (a few microseconds each)
for k in actual["task_cumulative_elapsed"]:
actual["task_cumulative_elapsed"][k] = "SNIP"

assert actual == expect


Expand Down Expand Up @@ -649,16 +662,16 @@ async def test_lose_replica_during_fetch(c, s, w1, w2, w3, as_deps):

assert len(s.tasks["x"].who_has) == 2
await w2.close()
while len(s.tasks["x"].who_has) > 1:
await asyncio.sleep(0.01)
await async_wait_for(lambda: len(s.tasks["x"].who_has) == 1, timeout=5)

if as_deps:
y2 = c.submit(inc, x, key="y2", workers=[w1.address])
else:
s.request_acquire_replicas(w1.address, ["x"], stimulus_id="test")

while w1.state.tasks["x"].who_has != {w3.address}:
await asyncio.sleep(0.01)
await async_wait_for(
lambda: w1.state.tasks["x"].who_has == {w3.address}, timeout=5
)

await wait_for_state("x", "memory", w1)
assert_story(
Expand Down Expand Up @@ -1209,7 +1222,7 @@ def test_resumed_task_releases_resources(ws_with_running_task, done_ev_cls):
assert ws.available_resources == {"R": 0}
ws2 = "127.0.0.1:2"

ws.handle_stimulus(FreeKeysEvent("cancel", ["x"]))
ws.handle_stimulus(FreeKeysEvent(keys=["x"], stimulus_id="cancel"))
assert ws.tasks["x"].state == "cancelled"
assert ws.available_resources == {"R": 0}

Expand Down Expand Up @@ -1681,47 +1694,63 @@ def test_fetch_count(ws):
assert len(ws.missing_dep_flight) == 1


def test_task_counts(ws):
assert ws.task_counts == {
"constrained": 0,
"executing": 0,
"fetch": 0,
"flight": 0,
"long-running": 0,
"memory": 0,
"missing": 0,
"other": 0,
"ready": 0,
"waiting": 0,
}

def test_task_counter(ws):
ws2 = "127.0.0.1:2"
ws3 = "127.0.0.1:3"
for by_prefix in (False, True):
assert ws.task_counter.current_count(by_prefix=by_prefix) == {}
assert ws.task_counter.cumulative_elapsed(by_prefix=by_prefix) == {}

def test_task_counts_with_actors(ws):
ws.handle_stimulus(ComputeTaskEvent.dummy("x", actor=True, stimulus_id="s1"))
assert ws.actors == {"x": None}
assert ws.task_counts == {
"constrained": 0,
"executing": 1,
"fetch": 0,
"flight": 0,
"long-running": 0,
"memory": 0,
"missing": 0,
"other": 0,
"ready": 0,
"waiting": 0,
ws.handle_stimulus(
ComputeTaskEvent.dummy(
"('y-123', 7)", who_has={"('x-456', 8)": [ws2]}, stimulus_id="s1"
),
AcquireReplicasEvent(
who_has={"('x-789', 0)": [ws3], "z": [ws3]},
nbytes={"('x-789', 0)": 1, "z": 1},
stimulus_id="s2",
),
)
assert ws.task_counter.current_count() == {
("x", "flight"): 2,
("y", "waiting"): 1,
("z", "flight"): 1,
}
ws.handle_stimulus(ExecuteSuccessEvent.dummy("x", value=123, stimulus_id="s2"))
assert ws.actors == {"x": 123}
assert ws.task_counts == {
"constrained": 0,
"executing": 0,
"fetch": 0,
"flight": 0,
"long-running": 0,
"memory": 1,
"missing": 0,
"other": 0,
"ready": 0,
"waiting": 0,
assert ws.task_counter.current_count(by_prefix=False) == {"waiting": 1, "flight": 3}

def assert_time(actual, expect):
# timer accuracy in Windows can be very poor;
# see awful hack in distributed.metrics
margin_lo = 0.099 if WINDOWS else 0
# sleep() has been observed to have up to 450ms lag on MacOSX GitHub CI
margin_hi = 0.6 if MACOS else 0.1
assert expect - margin_lo <= actual < expect + margin_hi

sleep(0.1)
elapsed = ws.task_counter.cumulative_elapsed()
# Transitory states are not recorded
assert len(elapsed) == 3
assert_time(elapsed["x", "flight"], 0.2)
assert_time(elapsed["y", "waiting"], 0.1)
assert_time(elapsed["z", "flight"], 0.1)

elapsed = ws.task_counter.cumulative_elapsed(by_prefix=False)
assert len(elapsed) == 2
assert_time(elapsed["flight"], 0.3)
assert_time(elapsed["waiting"], 0.1)

# Forgotten keys disappear from current_count() and stop accruing time in
# cumulative_elapsed()
ws.handle_stimulus(FreeKeysEvent(keys=["('y-123', 7)"], stimulus_id="s3"))
assert ws.task_counter.current_count() == {
("x", "cancelled"): 1,
("x", "flight"): 1,
("z", "flight"): 1,
}
sleep(0.15)
elapsed = ws.task_counter.cumulative_elapsed()
assert len(elapsed) == 4
assert_time(elapsed["x", "flight"], 0.35)
assert_time(elapsed["x", "cancelled"], 0.15)
assert_time(elapsed["y", "waiting"], 0.1)
assert_time(elapsed["z", "flight"], 0.25)
2 changes: 1 addition & 1 deletion distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1012,7 +1012,7 @@ async def get_metrics(self) -> dict:
spilled_memory, spilled_disk = 0, 0

out = dict(
task_counts=self.state.task_counts,
task_counts=self.state.task_counter.current_count(by_prefix=False),
bandwidth={
"total": self.bandwidth,
"workers": dict(self.bandwidth_workers),
Expand Down
Loading