Skip to content

Commit

Permalink
Fix handling of Dask DataFrame and other Awaitables passed to `run_co…
Browse files Browse the repository at this point in the history
…ro_as_sync` (#15687)
  • Loading branch information
kzvezdarov authored Oct 17, 2024
1 parent 82dc8ab commit 0eecd59
Show file tree
Hide file tree
Showing 19 changed files with 71 additions and 41 deletions.
3 changes: 1 addition & 2 deletions scripts/collections-manager
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#!/usr/bin/env python
import asyncio
import glob
import inspect
import os
import subprocess
import sys
Expand Down Expand Up @@ -92,7 +91,7 @@ def run_function(function_path: str):
function = getattr(module, function_name)

for collection_path in collection_paths():
if inspect.iscoroutinefunction(function):
if asyncio.iscoroutinefunction(function):
asyncio.run(function(collection_path))
else:
function(collection_path)
Expand Down
6 changes: 3 additions & 3 deletions src/integrations/prefect-dask/prefect_dask/task_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def count_to(highest_number):
```
"""

import inspect
import asyncio
from contextlib import ExitStack
from typing import (
Any,
Expand Down Expand Up @@ -145,7 +145,7 @@ def result(
)
# state.result is a `sync_compatible` function that may or may not return an awaitable
# depending on whether the parent frame is sync or not
if inspect.isawaitable(_result):
if asyncio.iscoroutine(_result):
_result = run_coro_as_sync(_result)
return _result

Expand Down Expand Up @@ -439,7 +439,7 @@ def __enter__(self):

if self.adapt_kwargs:
maybe_coro = self._cluster.adapt(**self.adapt_kwargs)
if inspect.isawaitable(maybe_coro):
if asyncio.iscoroutine(maybe_coro):
run_coro_as_sync(maybe_coro)

self._client = exit_stack.enter_context(
Expand Down
4 changes: 4 additions & 0 deletions src/integrations/prefect-dask/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ classifiers = [
[project.optional-dependencies]
dev = [
"coverage",
# Dask and its distributed scheduler follow the same release schedule and version,
# so here we apply the same restrictions as for the distributed package up above in
# project.dependencies
"dask[dataframe]>=2022.5.0,!=2023.3.2,!=2023.3.2.1,!=2023.4.*,!=2023.5.*",
"interrogate",
"mkdocs-gen-files",
"mkdocs-material",
Expand Down
29 changes: 29 additions & 0 deletions src/integrations/prefect-dask/tests/test_task_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import time
from typing import List

import dask.dataframe as dd
import distributed
import pandas as pd
import pytest
from distributed import LocalCluster
from prefect_dask import DaskTaskRunner
Expand Down Expand Up @@ -373,6 +375,33 @@ def test_flow():

assert "A future was garbage collected before it resolved" not in caplog.text

async def test_successful_dataframe_flow_run(self, task_runner):
@task
def task_a():
return dd.DataFrame.from_dict(
{"x": [1, 1, 1], "y": [2, 2, 2]}, npartitions=1
)

@task
def task_b(ddf):
return ddf.sum()

@task
def task_c(ddf):
return ddf.compute()

@flow(version="test", task_runner=task_runner)
def test_flow():
a = task_a.submit()
b = task_b.submit(a)
c = task_c.submit(b)

return c.result()

result = test_flow()

assert result.equals(pd.Series([3, 6], index=["x", "y"]))

class TestInputArguments:
async def test_dataclasses_can_be_passed_to_task_runners(self, task_runner):
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""A module to define flows interacting with Kubernetes resources."""

import inspect
import asyncio
from typing import Any, Callable, Dict, Optional

from prefect import flow, task
Expand Down Expand Up @@ -76,20 +76,20 @@ async def run_namespaced_job_async(
"""
kubernetes_job_run = (
await maybe_coro
if inspect.iscoroutine((maybe_coro := task(kubernetes_job.trigger)()))
if asyncio.iscoroutine((maybe_coro := task(kubernetes_job.trigger)()))
else maybe_coro
)

(
await maybe_coro
if inspect.iscoroutine(
if asyncio.iscoroutine(
maybe_coro := task(kubernetes_job_run.wait_for_completion)(print_func)
)
else maybe_coro
)

return (
await maybe_coro
if inspect.iscoroutine(maybe_coro := task(kubernetes_job_run.fetch_result)())
if asyncio.iscoroutine(maybe_coro := task(kubernetes_job_run.fetch_result)())
else maybe_coro
)
3 changes: 1 addition & 2 deletions src/integrations/prefect-ray/prefect_ray/task_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def count_to(highest_number):
"""

import asyncio # noqa: I001
import inspect
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -147,7 +146,7 @@ def result(
)
# state.result is a `sync_compatible` function that may or may not return an awaitable
# depending on whether the parent frame is sync or not
if inspect.isawaitable(_result):
if asyncio.iscoroutine(_result):
_result = run_coro_as_sync(_result)
return _result

Expand Down
8 changes: 4 additions & 4 deletions src/prefect/flow_engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import inspect
import asyncio
import logging
import os
import time
Expand Down Expand Up @@ -244,7 +244,7 @@ def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]":
else:
_result = self._return_value

if inspect.isawaitable(_result):
if asyncio.iscoroutine(_result):
# getting the value for a BaseResult may return an awaitable
# depending on whether the parent frame is sync or not
_result = run_coro_as_sync(_result)
Expand All @@ -263,7 +263,7 @@ def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]":
_result = self.state.result(raise_on_failure=raise_on_failure, fetch=True) # type: ignore
# state.result is a `sync_compatible` function that may or may not return an awaitable
# depending on whether the parent frame is sync or not
if inspect.isawaitable(_result):
if asyncio.iscoroutine(_result):
_result = run_coro_as_sync(_result)
return _result

Expand Down Expand Up @@ -477,7 +477,7 @@ def call_hooks(self, state: Optional[State] = None):
f" {state.name!r}"
)
result = hook(flow, flow_run, state)
if inspect.isawaitable(result):
if asyncio.iscoroutine(result):
run_coro_as_sync(result)
except Exception:
self.logger.error(
Expand Down
2 changes: 1 addition & 1 deletion src/prefect/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def __init__(

# the flow is considered async if its function is async or an async
# generator
self.isasync = inspect.iscoroutinefunction(
self.isasync = asyncio.iscoroutinefunction(
self.fn
) or inspect.isasyncgenfunction(self.fn)

Expand Down
4 changes: 2 additions & 2 deletions src/prefect/futures.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import abc
import asyncio
import collections
import concurrent.futures
import inspect
import threading
import uuid
from collections.abc import Generator, Iterator
Expand Down Expand Up @@ -166,7 +166,7 @@ def result(
)
# state.result is a `sync_compatible` function that may or may not return an awaitable
# depending on whether the parent frame is sync or not
if inspect.isawaitable(_result):
if asyncio.iscoroutine(_result):
_result = run_coro_as_sync(_result)
return _result

Expand Down
3 changes: 1 addition & 2 deletions src/prefect/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def fast_flow():

import asyncio
import datetime
import inspect
import logging
import os
import shutil
Expand Down Expand Up @@ -1213,7 +1212,7 @@ async def wrapper(task_status):
task_status.started()

result = fn(*args, **kwargs)
if inspect.iscoroutine(result):
if asyncio.iscoroutine(result):
await result

await self._runs_task_group.start(wrapper)
Expand Down
6 changes: 3 additions & 3 deletions src/prefect/server/database/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Injected database interface dependencies
"""

import inspect
import asyncio
from contextlib import ExitStack, contextmanager
from functools import wraps
from typing import Callable, Type, TypeVar
Expand Down Expand Up @@ -129,7 +129,7 @@ def sync_wrapper(*args, **kwargs):
inject(kwargs)
return fn(*args, **kwargs)

if inspect.iscoroutinefunction(fn):
if asyncio.iscoroutinefunction(fn):
return async_wrapper

return sync_wrapper
Expand Down Expand Up @@ -167,7 +167,7 @@ async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
db = provide_database_interface()
return await func(db, *args, **kwargs) # type: ignore

if inspect.iscoroutinefunction(func):
if asyncio.iscoroutinefunction(func):
return async_wrapper # type: ignore
else:
return sync_wrapper
Expand Down
11 changes: 6 additions & 5 deletions src/prefect/task_engine.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import inspect
import logging
import threading
Expand Down Expand Up @@ -291,7 +292,7 @@ def can_retry(self, exc: Exception) -> bool:
data=exc,
message=f"Task run encountered unexpected exception: {repr(exc)}",
)
if inspect.iscoroutinefunction(retry_condition):
if asyncio.iscoroutinefunction(retry_condition):
should_retry = run_coro_as_sync(
retry_condition(self.task, self.task_run, state)
)
Expand Down Expand Up @@ -335,7 +336,7 @@ def call_hooks(self, state: Optional[State] = None):
f" {state.name!r}"
)
result = hook(task, task_run, state)
if inspect.isawaitable(result):
if asyncio.iscoroutine(result):
run_coro_as_sync(result)
except Exception:
self.logger.error(
Expand Down Expand Up @@ -419,7 +420,7 @@ def set_state(self, state: State, force: bool = False) -> State:
# Avoid fetching the result unless it is cached, otherwise we defeat
# the purpose of disabling `cache_result_in_memory`
result = state.result(raise_on_failure=False, fetch=True)
if inspect.isawaitable(result):
if asyncio.iscoroutine(result):
result = run_coro_as_sync(result)
elif isinstance(state.data, ResultRecord):
result = state.data.result
Expand All @@ -443,7 +444,7 @@ def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]":
# if the return value is a BaseResult, we need to fetch it
if isinstance(self._return_value, BaseResult):
_result = self._return_value.get()
if inspect.isawaitable(_result):
if asyncio.iscoroutine(_result):
_result = run_coro_as_sync(_result)
return _result
elif isinstance(self._return_value, ResultRecord):
Expand Down Expand Up @@ -813,7 +814,7 @@ async def can_retry(self, exc: Exception) -> bool:
data=exc,
message=f"Task run encountered unexpected exception: {repr(exc)}",
)
if inspect.iscoroutinefunction(retry_condition):
if asyncio.iscoroutinefunction(retry_condition):
should_retry = await retry_condition(self.task, self.task_run, state)
elif inspect.isfunction(retry_condition):
should_retry = retry_condition(self.task, self.task_run, state)
Expand Down
3 changes: 2 additions & 1 deletion src/prefect/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This file requires type-checking with pyright because mypy does not yet support PEP612
# See https://github.com/python/mypy/issues/8645

import asyncio
import datetime
import inspect
from copy import copy
Expand Down Expand Up @@ -370,7 +371,7 @@ def __init__(

# the task is considered async if its function is async or an async
# generator
self.isasync = inspect.iscoroutinefunction(
self.isasync = asyncio.iscoroutinefunction(
self.fn
) or inspect.isasyncgenfunction(self.fn)

Expand Down
2 changes: 1 addition & 1 deletion src/prefect/utilities/asyncutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def is_async_fn(
while hasattr(func, "__wrapped__"):
func = func.__wrapped__

return inspect.iscoroutinefunction(func)
return asyncio.iscoroutinefunction(func)


def is_async_gen_fn(func):
Expand Down
5 changes: 2 additions & 3 deletions src/prefect/utilities/engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import asyncio
import contextlib
import inspect
import os
import signal
import time
Expand Down Expand Up @@ -503,7 +502,7 @@ def propose_state_sync(
# Avoid fetching the result unless it is cached, otherwise we defeat
# the purpose of disabling `cache_result_in_memory`
result = state.result(raise_on_failure=False, fetch=True)
if inspect.isawaitable(result):
if asyncio.iscoroutine(result):
result = run_coro_as_sync(result)
elif isinstance(state.data, ResultRecord):
result = state.data.result
Expand Down Expand Up @@ -870,7 +869,7 @@ def resolve_to_final_result(expr, context):
)

_result = state.result(raise_on_failure=False, fetch=True)
if inspect.isawaitable(_result):
if asyncio.iscoroutine(_result):
_result = run_coro_as_sync(_result)
return _result

Expand Down
4 changes: 2 additions & 2 deletions src/prefect/workers/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import abc
import inspect
import asyncio
import threading
from contextlib import AsyncExitStack
from functools import partial
Expand Down Expand Up @@ -1078,7 +1078,7 @@ async def wrapper(task_status):
task_status.started()

result = fn(*args, **kwargs)
if inspect.iscoroutine(result):
if asyncio.iscoroutine(result):
await result

await self._runs_task_group.start(wrapper)
Expand Down
5 changes: 2 additions & 3 deletions tests/public/flows/test_flow_calls.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import inspect

import anyio

Expand All @@ -18,7 +17,7 @@ async def aidentity_flow(x):

def test_async_flow_called_with_asyncio():
coro = aidentity_flow(1)
assert inspect.isawaitable(coro)
assert asyncio.iscoroutine(coro)
assert asyncio.run(coro) == 1


Expand All @@ -28,7 +27,7 @@ def test_async_flow_called_with_anyio():

async def test_async_flow_called_with_running_loop():
coro = aidentity_flow(1)
assert inspect.isawaitable(coro)
assert asyncio.iscoroutine(coro)
assert await coro == 1


Expand Down
Loading

0 comments on commit 0eecd59

Please sign in to comment.