Skip to content

Commit

Permalink
Fix scheduling of async tasks in async flows
Browse files Browse the repository at this point in the history
  • Loading branch information
zanieb committed Feb 25, 2024
1 parent e7139b4 commit 25f12bc
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 23 deletions.
23 changes: 16 additions & 7 deletions src/prefect/_internal/concurrency/calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,9 @@ def set_runner(self, portal: "Portal") -> None:

self.runner = portal

async def arun(self) -> Optional[T]:
return self.run()

def run(self) -> Optional[Awaitable[T]]:
"""
Execute the call and place the result on the future.
Expand All @@ -272,7 +275,7 @@ def run(self) -> Optional[Awaitable[T]]:
logger.debug("Skipping execution of cancelled call %r", self)
return None

logger.debug(
logger.info(
"Running call %r in thread %r%s",
self,
threading.current_thread().name,
Expand Down Expand Up @@ -371,14 +374,14 @@ def _run_sync(self):
else:
raise
except BaseException as exc:
logger.debug("Encountered exception in call %r", self, exc_info=True)
logger.info("Encountered exception in call %r", self, exc_info=True)
self.future.set_exception(exc)

# Prevent reference cycle in `exc`
del self
else:
self.future.set_result(result) # noqa: F821
logger.debug("Finished call %r", self) # noqa: F821
logger.info("Finished call %r", self) # noqa: F821

async def _run_async(self, coro):
from prefect._internal.concurrency.threads import in_global_loop
Expand Down Expand Up @@ -411,14 +414,14 @@ async def _run_async(self, coro):
else:
raise
except BaseException as exc:
logger.debug("Encountered exception in async call %r", self, exc_info=True)
logger.info("Encountered exception in async call %r", self, exc_info=True)

self.future.set_exception(exc)
# Prevent reference cycle in `exc`
del self
else:
self.future.set_result(result) # noqa: F821
logger.debug("Finished async call %r", self) # noqa: F821
logger.info("Finished async call %r", self) # noqa: F821

def __call__(self) -> T:
"""
Expand All @@ -442,13 +445,19 @@ async def run_and_return_result():
def __repr__(self) -> str:
name = getattr(self.fn, "__name__", str(self.fn))

def fmt(val):
if isinstance(val, float):
return f"{val:.2f}"
else:
return repr(val)

args, kwargs = self.args, self.kwargs
if args is None or kwargs is None:
call_args = "<dropped>"
call_args = "..."
else:
call_args = ", ".join(
[repr(arg) for arg in args]
+ [f"{key}={repr(val)}" for key, val in kwargs.items()]
+ [f"{key}={fmt(val)}" for key, val in kwargs.items()]
)

# Enforce a maximum length
Expand Down
2 changes: 2 additions & 0 deletions src/prefect/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
For more user-accessible information about the current run, see [`prefect.runtime`](../runtime/flow_run).
"""
import asyncio
import os
import sys
import warnings
Expand Down Expand Up @@ -256,6 +257,7 @@ class EngineContext(RunContext):
# The synchronous portal is only created for async flows for creating engine calls
# from synchronous task and subflow calls
sync_portal: Optional[anyio.abc.BlockingPortal] = None
loop: Optional[asyncio.AbstractEventLoop] = None
timeout_scope: Optional[anyio.abc.CancelScope] = None

# Task group that can be used for background tasks during the flow run
Expand Down
76 changes: 60 additions & 16 deletions src/prefect/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@
import prefect
import prefect.context
import prefect.plugins
from prefect._internal.concurrency.event_loop import (
get_running_loop,
run_coroutine_in_loop_from_async,
)
from prefect._internal.compatibility.deprecated import deprecated_parameter
from prefect._internal.compatibility.experimental import experimental_parameter
from prefect._internal.concurrency.api import create_call, from_async, from_sync
Expand Down Expand Up @@ -841,7 +845,9 @@ async def orchestrate_flow_run(
"Beginning execution...", extra={"state_message": True}
)

flow_call = create_call(flow.fn, *args, **kwargs)
flow_call = create_call(
wrap_flow_for_execution(flow_run_context, flow), *args, **kwargs
)

# This check for a parent call is needed for cases where the engine
# was entered directly during testing
Expand All @@ -855,13 +861,17 @@ async def orchestrate_flow_run(
# Unless the parent is async and the child is sync, run the
# child flow in the parent thread; running a sync child in
# an async parent could be bad for async performance.
not (parent_flow_run_context.flow.isasync and not flow.isasync)
not (parent_flow_run_context.flow.isasync != flow.isasync)
)
):
logger.debug(
"Executing flow in existing thread %s", user_thread.name
)
from_async.call_soon_in_waiting_thread(
flow_call, thread=user_thread, timeout=flow.timeout_seconds
)
else:
logger.debug("Executing flow in new thread")
from_async.call_soon_in_new_thread(
flow_call, timeout=flow.timeout_seconds
)
Expand Down Expand Up @@ -968,6 +978,39 @@ async def orchestrate_flow_run(
return state


def wrap_flow_for_execution(flow_run_context: FlowRunContext, flow: Flow):
"""
Wrap a user's flow function with engine state.
The wrapper will execute on the user's thread, allowing us to wait for futures
before the thread is shut down and capture information about the thread for
nested runs.
"""
if flow_run_context.flow.isasync:

async def execute_flow_and_wait_for_tasks(*args, **kwargs):
# Set the loop for the flow run so we can schedule tasks on it
object.__setattr__(flow_run_context, "loop", get_running_loop())
retval = await flow.fn(*args, **kwargs)

# Wait for task all of the task run futures, we must do this on the user's thread
# to ensure that all tasks finish before the event loop is closed
await gather(
*(future._wait for future in flow_run_context.task_run_futures)
)
return retval

else:

def execute_flow_and_wait_for_tasks(*args, **kwargs):
retval = flow.fn(*args, **kwargs)
for future in flow_run_context.task_run_futures:
future.wait()
return retval

return execute_flow_and_wait_for_tasks


@overload
async def pause_flow_run(
wait_for_input: None = None,
Expand Down Expand Up @@ -2132,22 +2175,23 @@ async def tick():

call = create_call(task.fn, *args, **kwargs)

if (
if flow_run_context and task.isasync and flow_run_context.flow.isasync:
# Async tasks can always be executed on asynchronous flow;
# schedule them to run directly on the flow's event loop
await run_coroutine_in_loop_from_async(
flow_run_context.loop, call.arun()
)
elif (
flow_run_context
and user_thread
and (
# Async and sync tasks can be executed on synchronous flows
# if the task runner is sequential; if the task is sync and a
# concurrent task runner is used, we must execute it in a worker
# thread instead.
(
concurrency_type == TaskConcurrencyType.SEQUENTIAL
and not flow_run_context.flow.isasync
)
# Async tasks can always be executed on asynchronous flow; if the
# flow is async we do not want to block the event loop with
# synchronous tasks
or (flow_run_context.flow.isasync and task.isasync)
and
# Async and sync tasks can be executed on synchronous flows
# if the task runner is sequential; if the task is sync and a
# concurrent task runner is used, we must execute it in a worker
# thread instead.
(
concurrency_type == TaskConcurrencyType.SEQUENTIAL
and not flow_run_context.flow.isasync
)
):
from_async.call_soon_in_waiting_thread(
Expand Down

0 comments on commit 25f12bc

Please sign in to comment.