Skip to content

Commit

Permalink
Merge pull request #2302 from PrefectHQ/task-flow-calls
Browse files Browse the repository at this point in the history
Update flow and task calls to return results
  • Loading branch information
zanieb authored Jul 11, 2022
2 parents 62df28d + a939482 commit 7499540
Show file tree
Hide file tree
Showing 21 changed files with 1,225 additions and 524 deletions.
14 changes: 10 additions & 4 deletions src/prefect/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,10 @@ class FlowRunContext(RunContext):
flow_run: The API metadata for the flow run
task_runner: The task runner instance being used for the flow run
result_storage: A block to used to persist run state data
task_run_futures: A list of futures for task runs created within this flow run
subflow_states: A list of states for flow runs created within this flow run
task_run_futures: A list of futures for task runs submitted within this flow run
task_run_states: A list of states for task runs created within this flow run
task_run_results: A mapping of result ids to task run states for this flow run
flow_run_states: A list of states for flow runs created within this flow run
sync_portal: A blocking portal for sync task/flow runs in an async flow
timeout_scope: The cancellation scope for flow level timeouts
"""
Expand All @@ -204,10 +206,14 @@ class FlowRunContext(RunContext):
task_runner: BaseTaskRunner
result_storage: StorageBlock

# Tracking created objects
# Counter for task calls allowing unique
task_run_dynamic_keys: Dict[str, int] = Field(default_factory=dict)

# Tracking for objects created by this flow run
task_run_futures: List[PrefectFuture] = Field(default_factory=list)
subflow_states: List[State] = Field(default_factory=list)
task_run_states: List[State] = Field(default_factory=list)
task_run_results: Dict[int, State] = Field(default_factory=dict)
flow_run_states: List[State] = Field(default_factory=list)

# The synchronous portal is only created for async flows for creating engine calls
# from synchronous task and subflow calls
Expand Down
164 changes: 136 additions & 28 deletions src/prefect/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
return_value_to_state,
safe_encode_exception,
)
from prefect.task_runners import BaseTaskRunner
from prefect.tasks import Task
from prefect.utilities.asyncutils import (
gather,
Expand All @@ -74,11 +75,15 @@
from prefect.utilities.pydantic import PartialModel

R = TypeVar("R")
EngineReturnType = Literal["future", "state", "result"]


UNTRACKABLE_TYPES = {bool, type(None), type(...), type(NotImplemented)}
engine_logger = get_logger("engine")


def enter_flow_run_engine_from_flow_call(
flow: Flow, parameters: Dict[str, Any]
flow: Flow, parameters: Dict[str, Any], return_type: EngineReturnType
) -> Union[State, Awaitable[State]]:
"""
Sync entrypoint for flow calls.
Expand All @@ -99,7 +104,7 @@ def enter_flow_run_engine_from_flow_call(

if TaskRunContext.get():
raise RuntimeError(
"Flows cannot be called from within tasks. Did you mean to call this "
"Flows cannot be run from within tasks. Did you mean to call this "
"flow in a flow?"
)

Expand All @@ -110,6 +115,7 @@ def enter_flow_run_engine_from_flow_call(
create_and_begin_subflow_run if is_subflow_run else create_then_begin_flow_run,
flow=flow,
parameters=parameters,
return_type=return_type,
)

# Async flow run
Expand Down Expand Up @@ -150,13 +156,19 @@ def enter_flow_run_engine_from_subprocess(flow_run_id: UUID) -> State:

@inject_client
async def create_then_begin_flow_run(
flow: Flow, parameters: Dict[str, Any], client: OrionClient
) -> State:
flow: Flow,
parameters: Dict[str, Any],
return_type: EngineReturnType,
client: OrionClient,
) -> Any:
"""
Async entrypoint for flow calls
Creates the flow run in the backend, then enters the main flow run engine.
"""
# TODO: Returns a `State` depending on `return_type` and we can add an overload to
# the function signature to clarify this eventually.

connect_error = await client.api_healthcheck()
if connect_error:
raise RuntimeError(
Expand Down Expand Up @@ -187,11 +199,17 @@ async def create_then_begin_flow_run(
engine_logger.info(
f"Flow run {flow_run.name!r} received invalid parameters and is marked as failed."
)
return state
else:
state = await begin_flow_run(
flow=flow, flow_run=flow_run, parameters=parameters, client=client
)

return await begin_flow_run(
flow=flow, flow_run=flow_run, parameters=parameters, client=client
)
if return_type == "state":
return state
elif return_type == "result":
return state.result()
else:
raise ValueError(f"Invalid return type for flow engine {return_type!r}.")


@inject_client
Expand Down Expand Up @@ -339,8 +357,9 @@ async def begin_flow_run(
async def create_and_begin_subflow_run(
flow: Flow,
parameters: Dict[str, Any],
return_type: EngineReturnType,
client: OrionClient,
) -> State:
) -> Any:
"""
Async entrypoint for flows calls within a flow run
Expand Down Expand Up @@ -448,9 +467,14 @@ async def create_and_begin_subflow_run(
)

# Track the subflow state so the parent flow can use it to determine its final state
parent_flow_run_context.subflow_states.append(terminal_state)
parent_flow_run_context.flow_run_states.append(terminal_state)

return terminal_state
if return_type == "state":
return terminal_state
elif return_type == "result":
return terminal_state.result()
else:
raise ValueError(f"Invalid return type for flow engine {return_type!r}.")


async def orchestrate_flow_run(
Expand Down Expand Up @@ -545,7 +569,9 @@ async def orchestrate_flow_run(
# All tasks and subflows are reference tasks if there is no return value
# If there are no tasks, use `None` instead of an empty iterable
result = (
flow_run_context.task_run_futures + flow_run_context.subflow_states
flow_run_context.task_run_futures
+ flow_run_context.task_run_states
+ flow_run_context.flow_run_states
) or None

terminal_state = await return_value_to_state(
Expand Down Expand Up @@ -602,32 +628,35 @@ def enter_task_run_engine(
task: Task,
parameters: Dict[str, Any],
wait_for: Optional[Iterable[PrefectFuture]],
return_type: EngineReturnType,
task_runner: Optional[BaseTaskRunner],
) -> Union[PrefectFuture, Awaitable[PrefectFuture]]:
"""
Sync entrypoint for task calls
"""
flow_run_context = FlowRunContext.get()
if not flow_run_context:
raise RuntimeError(
"Tasks cannot be called outside of a flow. To call the underlying task function outside of a flow use `task.fn()`."
"Tasks cannot be run outside of a flow. To call the underlying task function outside of a flow use `task.fn()`."
)

if TaskRunContext.get():
raise RuntimeError(
"Tasks cannot be called from within tasks. Did you mean to call this "
"Tasks cannot be run from within tasks. Did you mean to call this "
"task in a flow?"
)

if flow_run_context.timeout_scope and flow_run_context.timeout_scope.cancel_called:
raise TimeoutError("Flow run timed out")

begin_run = partial(
create_and_submit_task_run,
create_task_run_then_submit,
task=task,
flow_run_context=flow_run_context,
parameters=parameters,
dynamic_key=_dynamic_key_for_task_run(flow_run_context, task),
wait_for=wait_for,
return_type=return_type,
task_runner=task_runner,
)

# Async task run in async flow run
Expand Down Expand Up @@ -664,10 +693,13 @@ async def collect_task_run_inputs(
async def add_futures_and_states_to_inputs(obj):
if isinstance(obj, PrefectFuture):
inputs.add(core.TaskRunResult(id=obj.task_run.id))

if isinstance(obj, State):
elif isinstance(obj, State):
if obj.state_details.task_run_id:
inputs.add(core.TaskRunResult(id=obj.state_details.task_run_id))
else:
state = get_state_for_result(obj)
if state and state.state_details.task_run_id:
inputs.add(core.TaskRunResult(id=state.state_details.task_run_id))

await visit_collection(
expr, visit_fn=add_futures_and_states_to_inputs, return_data=False
Expand All @@ -676,20 +708,48 @@ async def add_futures_and_states_to_inputs(obj):
return inputs


async def create_and_submit_task_run(
async def create_task_run_then_submit(
task: Task,
flow_run_context: FlowRunContext,
parameters: Dict[str, Any],
dynamic_key: str,
wait_for: Optional[Iterable[PrefectFuture]],
) -> PrefectFuture:
"""
Async entrypoint for task calls.
return_type: EngineReturnType,
task_runner: Optional[BaseTaskRunner],
) -> Union[PrefectFuture, State]:
task_run = await create_task_run(
task=task,
flow_run_context=flow_run_context,
parameters=parameters,
dynamic_key=_dynamic_key_for_task_run(flow_run_context, task),
wait_for=wait_for,
)

Tasks must be called within a flow. When tasks are called, they create a task run
and submit orchestration of the run to the flow run's task runner. The task runner
returns a future that is returned immediately.
"""
future = await submit_task_run(
task=task,
flow_run_context=flow_run_context,
parameters=parameters,
task_run=task_run,
wait_for=wait_for,
task_runner=task_runner or flow_run_context.task_runner,
)

if return_type == "future":
return future
elif return_type == "state":
return await future._wait()
elif return_type == "result":
return await future._result()
else:
raise ValueError(f"Invalid return type for task engine {return_type!r}.")


async def create_task_run(
task: Task,
flow_run_context: FlowRunContext,
parameters: Dict[str, Any],
dynamic_key: str,
wait_for: Optional[Iterable[PrefectFuture]],
) -> TaskRun:
task_inputs = {k: await collect_task_run_inputs(v) for k, v in parameters.items()}
if wait_for:
task_inputs["wait_for"] = await collect_task_run_inputs(wait_for)
Expand All @@ -707,7 +767,27 @@ async def create_and_submit_task_run(

logger.info(f"Created task run {task_run.name!r} for task {task.name!r}")

future = await flow_run_context.task_runner.submit(
return task_run


async def submit_task_run(
task: Task,
flow_run_context: FlowRunContext,
parameters: Dict[str, Any],
task_run: TaskRun,
wait_for: Optional[Iterable[PrefectFuture]],
task_runner: BaseTaskRunner,
) -> PrefectFuture:
"""
Async entrypoint for task calls.
Tasks must be called within a flow. When tasks are called, they create a task run
and submit orchestration of the run to the flow run's task runner. The task runner
returns a future that is returned immediately.
"""
logger = get_run_logger(flow_run_context)

future = await task_runner.submit(
task_run=task_run,
run_key=f"{task_run.name}-{task_run.id.hex}-{flow_run_context.flow_run.run_count}",
run_fn=begin_task_run,
Expand Down Expand Up @@ -1116,3 +1196,31 @@ def _dynamic_key_for_task_run(context: FlowRunContext, task: Task) -> int:
)
# Let the exit code be determined by the base exception type
raise


def get_state_for_result(obj: Any) -> Optional[State]:
"""
Get the state related to a result object.
`link_state_to_result` must have been called first.
"""
flow_run_context = FlowRunContext.get()
if flow_run_context:
return flow_run_context.task_run_results.get(id(obj))


def link_state_to_result(state: State, result: Any) -> None:
"""
Stores information about the state on the result or in the global context for
relationship tracking.
"""
if type(result) in UNTRACKABLE_TYPES:
return

# Cache the state onto the flow_run_context, associated by the id of the
# result. This allows a best-effort attempt to get the state from an object
# that wouldn't allow the __prefect_state__ attribute to be set. It also
# acts as a complete cache of states for reporting in a flow run state.
flow_run_context = FlowRunContext.get()
if flow_run_context:
flow_run_context.task_run_results[id(result)] = state
42 changes: 36 additions & 6 deletions src/prefect/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,21 +301,19 @@ def serialize_parameters(self, parameters: Dict[str, Any]) -> Dict[str, Any]:
return serialized_parameters

@overload
def __call__(
self: "Flow[P, NoReturn]", *args: P.args, **kwargs: P.kwargs
) -> State[T]:
def __call__(self: "Flow[P, NoReturn]", *args: P.args, **kwargs: P.kwargs) -> T:
# `NoReturn` matches if a type can't be inferred for the function which stops a
# sync function from matching the `Coroutine` overload
...

@overload
def __call__(
self: "Flow[P, Coroutine[Any, Any, T]]", *args: P.args, **kwargs: P.kwargs
) -> Awaitable[State[T]]:
) -> Awaitable[T]:
...

@overload
def __call__(self: "Flow[P, T]", *args: P.args, **kwargs: P.kwargs) -> State[T]:
def __call__(self: "Flow[P, T]", *args: P.args, **kwargs: P.kwargs) -> T:
...

def __call__(
Expand Down Expand Up @@ -367,7 +365,39 @@ def __call__(
# Convert the call args/kwargs to a parameter dict
parameters = get_call_parameters(self.fn, args, kwargs)

return enter_flow_run_engine_from_flow_call(self, parameters)
return enter_flow_run_engine_from_flow_call(
self, parameters, return_type="result"
)

@overload
def run(self: "Flow[P, NoReturn]", *args: P.args, **kwargs: P.kwargs) -> State[T]:
# `NoReturn` matches if a type can't be inferred for the function which stops a
# sync function from matching the `Coroutine` overload
...

@overload
def run(
self: "Flow[P, Coroutine[Any, Any, T]]", *args: P.args, **kwargs: P.kwargs
) -> Awaitable[T]:
...

@overload
def run(self: "Flow[P, T]", *args: P.args, **kwargs: P.kwargs) -> State[T]:
...

def run(
self,
*args: "P.args",
**kwargs: "P.kwargs",
):
from prefect.engine import enter_flow_run_engine_from_flow_call

# Convert the call args/kwargs to a parameter dict
parameters = get_call_parameters(self.fn, args, kwargs)

return enter_flow_run_engine_from_flow_call(
self, parameters, return_type="state"
)


@overload
Expand Down
Loading

0 comments on commit 7499540

Please sign in to comment.