Skip to content

Commit ad45e4a

Browse files
committed
Do not reuse offload() for wrapping generic executor
1 parent bde5f9a commit ad45e4a

File tree

5 files changed

+71
-34
lines changed

5 files changed

+71
-34
lines changed

distributed/tests/test_utils.py

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
parse_ports,
5050
read_block,
5151
recursive_to_dict,
52+
run_in_executor_with_context,
5253
seek_delimiter,
5354
set_thread_state,
5455
sync,
@@ -663,6 +664,37 @@ def test_parse_ports():
663664
parse_ports("100.5")
664665

665666

667+
@gen_test()
668+
async def test_run_in_executor_with_context():
669+
class MyExecutor(Executor):
670+
call_count = 0
671+
672+
def submit(self, __fn, *args, **kwargs):
673+
self.call_count += 1
674+
f = Future()
675+
f.set_result(__fn(*args, **kwargs))
676+
return f
677+
678+
ex = MyExecutor()
679+
out = await run_in_executor_with_context(ex, inc, 1)
680+
assert out == 2
681+
assert ex.call_count == 1
682+
683+
684+
@gen_test()
685+
async def test_run_in_executor_with_context_preserves_contextvars():
686+
var = contextvars.ContextVar("var")
687+
688+
with ThreadPoolExecutor(2) as ex:
689+
690+
async def set_var(v: str) -> None:
691+
var.set(v)
692+
r = await run_in_executor_with_context(ex, var.get)
693+
assert r == v
694+
695+
await asyncio.gather(set_var("foo"), set_var("bar"))
696+
697+
666698
@gen_test()
667699
async def test_offload():
668700
assert (await offload(inc, 1)) == 2
@@ -681,23 +713,6 @@ async def set_var(v: str) -> None:
681713
await asyncio.gather(set_var("foo"), set_var("bar"))
682714

683715

684-
@gen_test()
685-
async def test_offload_custom_executor():
686-
class MyExecutor(Executor):
687-
call_count = 0
688-
689-
def submit(self, __fn, *args, **kwargs):
690-
self.call_count += 1
691-
f = Future()
692-
f.set_result(__fn(*args, **kwargs))
693-
return f
694-
695-
ex = MyExecutor()
696-
out = await offload(inc, 1, executor=ex)
697-
assert out == 2
698-
assert ex.call_count == 1
699-
700-
701716
def test_serialize_for_cli_deprecated():
702717
with pytest.warns(FutureWarning, match="serialize_for_cli is deprecated"):
703718
from distributed.utils import serialize_for_cli

distributed/utils.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1421,32 +1421,49 @@ def import_term(name: str) -> AnyType:
14211421
return getattr(module, attr_name)
14221422

14231423

1424-
async def offload( # type: ignore[valid-type]
1425-
fn: Callable[P, T],
1424+
async def run_in_executor_with_context(
1425+
executor: Executor | None,
1426+
func: Callable[P, T],
1427+
/,
14261428
*args: P.args,
1427-
executor: Executor | None = None,
14281429
**kwargs: P.kwargs,
14291430
) -> T:
14301431
"""Variant of :meth:`~asyncio.AbstractEventLoop.run_in_executor`, which
14311432
propagates contextvars.
1432-
By default, it offloads to an ad-hoc thread pool with a single worker.
1433+
Note that this limits the type of Executor to those that do not pickle objects.
14331434
14341435
See also
14351436
--------
1437+
asyncio.AbstractEventLoop.run_in_executor
1438+
offload
14361439
https://bugs.python.org/issue34014
14371440
"""
1438-
if executor is None:
1439-
# Not the same as defaulting to _offload_executor in the parameters, as this
1440-
# allows monkey-patching the _offload_executor during unit tests
1441-
executor = _offload_executor
1442-
14431441
loop = asyncio.get_running_loop()
14441442
context = contextvars.copy_context()
14451443
return await loop.run_in_executor(
1446-
executor, lambda: context.run(fn, *args, **kwargs)
1444+
executor, lambda: context.run(func, *args, **kwargs)
14471445
)
14481446

14491447

1448+
def offload(
1449+
func: Callable[P, T],
1450+
/,
1451+
*args: P.args,
1452+
**kwargs: P.kwargs,
1453+
) -> Awaitable[T]:
1454+
"""Run a synchronous function in a separate thread.
1455+
Unlike :meth:`asyncio.to_thread`, this propagates contextvars and offloads to an
1456+
ad-hoc thread pool with a single worker.
1457+
1458+
See also
1459+
--------
1460+
asyncio.to_thread
1461+
run_in_executor_with_context
1462+
https://bugs.python.org/issue34014
1463+
"""
1464+
return run_in_executor_with_context(_offload_executor, func, *args, **kwargs)
1465+
1466+
14501467
class EmptyContext:
14511468
def __enter__(self):
14521469
pass

distributed/worker.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@
105105
offload,
106106
parse_ports,
107107
recursive_to_dict,
108+
run_in_executor_with_context,
108109
silence_logging,
109110
thread_state,
110111
wait_for,
@@ -2275,7 +2276,8 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent:
22752276
self.scheduler_delay,
22762277
)
22772278
elif "ThreadPoolExecutor" in str(type(e)):
2278-
result = await offload(
2279+
result = await run_in_executor_with_context(
2280+
e,
22792281
apply_function,
22802282
function,
22812283
args2,
@@ -2285,7 +2287,6 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent:
22852287
self.active_threads,
22862288
self.active_threads_lock,
22872289
self.scheduler_delay,
2288-
executor=e,
22892290
)
22902291
else:
22912292
# Can't capture contextvars across processes

distributed/worker_memory.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,15 +260,19 @@ def metrics_callback(label: Hashable, value: float, unit: str) -> None:
260260
label = (label,)
261261
worker.digest_metric(("memory-monitor", *label, unit), value)
262262

263-
# Work around bug with Tornado PeriodicCallback, which does not properly
264-
# insulate contextvars
263+
# Work around bug with Tornado 6.2 PeriodicCallback, which does not properly
264+
# insulate contextvars. Without this hack, you would see metrics that are
265+
# clearly emitted by Worker.execute labelled with 'memory-monitor'.So we're
266+
# wrapping our change in contextvars (inside add_callback) inside create_task(),
267+
# which copies and insulates the context.
265268
async def _() -> None:
266269
with context_meter.add_callback(metrics_callback):
267270
# Measure delta between the measures from the SpillBuffer and the total
268271
# end-to-end duration of _spill
269272
await self._spill(worker, memory)
270273

271274
await asyncio.create_task(_(), name="memory-monitor-spill")
275+
# End work around
272276

273277
async def _spill(self, worker: Worker, memory: int) -> None:
274278
"""Evict keys until the process memory goes below the ``target`` threshold"""

distributed/worker_state_machine.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3123,9 +3123,9 @@ def _execute_done_common(
31233123
def _handle_execute_success(self, ev: ExecuteSuccessEvent) -> RecsInstrs:
31243124
"""Task completed successfully"""
31253125
ts, recs, instr = self._execute_done_common(ev)
3126-
# This is used for scheduler-side heuristics such as work stealing; it's
3127-
# important that it does not contain overhead from the thread pool or the
3128-
# worker's event loop (which are not the task's fault and are unpredictable).
3126+
# This is used for scheduler-side occupancy heuristics; it's important that it
3127+
# does not contain overhead from the thread pool or the worker's event loop
3128+
# (which are not the task's fault and are unpredictable).
31293129
ts.startstops.append({"action": "compute", "start": ev.start, "stop": ev.stop})
31303130
ts.nbytes = ev.nbytes
31313131
ts.type = ev.type

0 commit comments

Comments
 (0)