diff --git a/scripts/collections-manager b/scripts/collections-manager index 3e421f1031c1..51e83a9bf834 100755 --- a/scripts/collections-manager +++ b/scripts/collections-manager @@ -1,7 +1,6 @@ #!/usr/bin/env python import asyncio import glob -import inspect import os import subprocess import sys @@ -92,7 +91,7 @@ def run_function(function_path: str): function = getattr(module, function_name) for collection_path in collection_paths(): - if inspect.iscoroutinefunction(function): + if asyncio.iscoroutinefunction(function): asyncio.run(function(collection_path)) else: function(collection_path) diff --git a/src/integrations/prefect-dask/prefect_dask/task_runners.py b/src/integrations/prefect-dask/prefect_dask/task_runners.py index 768de0b711a5..f4edbab25295 100644 --- a/src/integrations/prefect-dask/prefect_dask/task_runners.py +++ b/src/integrations/prefect-dask/prefect_dask/task_runners.py @@ -71,7 +71,7 @@ def count_to(highest_number): ``` """ -import inspect +import asyncio from contextlib import ExitStack from typing import ( Any, @@ -145,7 +145,7 @@ def result( ) # state.result is a `sync_compatible` function that may or may not return an awaitable # depending on whether the parent frame is sync or not - if inspect.isawaitable(_result): + if asyncio.iscoroutine(_result): _result = run_coro_as_sync(_result) return _result @@ -439,7 +439,7 @@ def __enter__(self): if self.adapt_kwargs: maybe_coro = self._cluster.adapt(**self.adapt_kwargs) - if inspect.isawaitable(maybe_coro): + if asyncio.iscoroutine(maybe_coro): run_coro_as_sync(maybe_coro) self._client = exit_stack.enter_context( diff --git a/src/integrations/prefect-dask/pyproject.toml b/src/integrations/prefect-dask/pyproject.toml index 6d1b033928d2..81525fa4c7f6 100644 --- a/src/integrations/prefect-dask/pyproject.toml +++ b/src/integrations/prefect-dask/pyproject.toml @@ -34,6 +34,10 @@ classifiers = [ [project.optional-dependencies] dev = [ "coverage", + # Dask and its distributed scheduler follow the same release schedule and version, + # so here we apply the same restrictions as for the distributed package up above in + # project.dependencies + "dask[dataframe]>=2022.5.0,!=2023.3.2,!=2023.3.2.1,!=2023.4.*,!=2023.5.*", "interrogate", "mkdocs-gen-files", "mkdocs-material", diff --git a/src/integrations/prefect-dask/tests/test_task_runners.py b/src/integrations/prefect-dask/tests/test_task_runners.py index b95285cd7724..748de299c65e 100644 --- a/src/integrations/prefect-dask/tests/test_task_runners.py +++ b/src/integrations/prefect-dask/tests/test_task_runners.py @@ -2,7 +2,9 @@ import time from typing import List +import dask.dataframe as dd import distributed +import pandas as pd import pytest from distributed import LocalCluster from prefect_dask import DaskTaskRunner @@ -373,6 +375,33 @@ def test_flow(): assert "A future was garbage collected before it resolved" not in caplog.text + async def test_successful_dataframe_flow_run(self, task_runner): + @task + def task_a(): + return dd.DataFrame.from_dict( + {"x": [1, 1, 1], "y": [2, 2, 2]}, npartitions=1 + ) + + @task + def task_b(ddf): + return ddf.sum() + + @task + def task_c(ddf): + return ddf.compute() + + @flow(version="test", task_runner=task_runner) + def test_flow(): + a = task_a.submit() + b = task_b.submit(a) + c = task_c.submit(b) + + return c.result() + + result = test_flow() + + assert result.equals(pd.Series([3, 6], index=["x", "y"])) + class TestInputArguments: async def test_dataclasses_can_be_passed_to_task_runners(self, task_runner): """ diff --git a/src/integrations/prefect-kubernetes/prefect_kubernetes/flows.py b/src/integrations/prefect-kubernetes/prefect_kubernetes/flows.py index 1c422301e4c2..60436b3064ef 100644 --- a/src/integrations/prefect-kubernetes/prefect_kubernetes/flows.py +++ b/src/integrations/prefect-kubernetes/prefect_kubernetes/flows.py @@ -1,6 +1,6 @@ """A module to define flows interacting with Kubernetes resources.""" -import inspect +import asyncio from typing import Any, Callable, Dict, Optional from prefect import flow, task @@ -76,13 +76,13 @@ async def run_namespaced_job_async( """ kubernetes_job_run = ( await maybe_coro - if inspect.iscoroutine((maybe_coro := task(kubernetes_job.trigger)())) + if asyncio.iscoroutine((maybe_coro := task(kubernetes_job.trigger)())) else maybe_coro ) ( await maybe_coro - if inspect.iscoroutine( + if asyncio.iscoroutine( maybe_coro := task(kubernetes_job_run.wait_for_completion)(print_func) ) else maybe_coro @@ -90,6 +90,6 @@ async def run_namespaced_job_async( return ( await maybe_coro - if inspect.iscoroutine(maybe_coro := task(kubernetes_job_run.fetch_result)()) + if asyncio.iscoroutine(maybe_coro := task(kubernetes_job_run.fetch_result)()) else maybe_coro ) diff --git a/src/integrations/prefect-ray/prefect_ray/task_runners.py b/src/integrations/prefect-ray/prefect_ray/task_runners.py index 0b09d78143ab..c8bcd9a62621 100644 --- a/src/integrations/prefect-ray/prefect_ray/task_runners.py +++ b/src/integrations/prefect-ray/prefect_ray/task_runners.py @@ -72,7 +72,6 @@ def count_to(highest_number): """ import asyncio # noqa: I001 -import inspect from typing import ( TYPE_CHECKING, Any, @@ -147,7 +146,7 @@ def result( ) # state.result is a `sync_compatible` function that may or may not return an awaitable # depending on whether the parent frame is sync or not - if inspect.isawaitable(_result): + if asyncio.iscoroutine(_result): _result = run_coro_as_sync(_result) return _result diff --git a/src/prefect/flow_engine.py b/src/prefect/flow_engine.py index 12e86d6f809c..7d982df61f25 100644 --- a/src/prefect/flow_engine.py +++ b/src/prefect/flow_engine.py @@ -1,4 +1,4 @@ -import inspect +import asyncio import logging import os import time @@ -244,7 +244,7 @@ def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]": else: _result = self._return_value - if inspect.isawaitable(_result): + if asyncio.iscoroutine(_result): # getting the value for a BaseResult may return an awaitable # depending on whether the parent frame is sync or not _result = run_coro_as_sync(_result) @@ -263,7 +263,7 @@ def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]": _result = self.state.result(raise_on_failure=raise_on_failure, fetch=True) # type: ignore # state.result is a `sync_compatible` function that may or may not return an awaitable # depending on whether the parent frame is sync or not - if inspect.isawaitable(_result): + if asyncio.iscoroutine(_result): _result = run_coro_as_sync(_result) return _result @@ -477,7 +477,7 @@ def call_hooks(self, state: Optional[State] = None): f" {state.name!r}" ) result = hook(flow, flow_run, state) - if inspect.isawaitable(result): + if asyncio.iscoroutine(result): run_coro_as_sync(result) except Exception: self.logger.error( diff --git a/src/prefect/flows.py b/src/prefect/flows.py index 7ffee3996e1c..d135d116ad51 100644 --- a/src/prefect/flows.py +++ b/src/prefect/flows.py @@ -286,7 +286,7 @@ def __init__( # the flow is considered async if its function is async or an async # generator - self.isasync = inspect.iscoroutinefunction( + self.isasync = asyncio.iscoroutinefunction( self.fn ) or inspect.isasyncgenfunction(self.fn) diff --git a/src/prefect/futures.py b/src/prefect/futures.py index 8b305cca45df..c90442135ec1 100644 --- a/src/prefect/futures.py +++ b/src/prefect/futures.py @@ -1,7 +1,7 @@ import abc +import asyncio import collections import concurrent.futures -import inspect import threading import uuid from collections.abc import Generator, Iterator @@ -166,7 +166,7 @@ def result( ) # state.result is a `sync_compatible` function that may or may not return an awaitable # depending on whether the parent frame is sync or not - if inspect.isawaitable(_result): + if asyncio.iscoroutine(_result): _result = run_coro_as_sync(_result) return _result diff --git a/src/prefect/runner/runner.py b/src/prefect/runner/runner.py index 5993cc2336c8..71622028bbd6 100644 --- a/src/prefect/runner/runner.py +++ b/src/prefect/runner/runner.py @@ -32,7 +32,6 @@ def fast_flow(): import asyncio import datetime -import inspect import logging import os import shutil @@ -1213,7 +1212,7 @@ async def wrapper(task_status): task_status.started() result = fn(*args, **kwargs) - if inspect.iscoroutine(result): + if asyncio.iscoroutine(result): await result await self._runs_task_group.start(wrapper) diff --git a/src/prefect/server/database/dependencies.py b/src/prefect/server/database/dependencies.py index 1ec0d0ca34d6..57a4d092f6ff 100644 --- a/src/prefect/server/database/dependencies.py +++ b/src/prefect/server/database/dependencies.py @@ -2,7 +2,7 @@ Injected database interface dependencies """ -import inspect +import asyncio from contextlib import ExitStack, contextmanager from functools import wraps from typing import Callable, Type, TypeVar @@ -129,7 +129,7 @@ def sync_wrapper(*args, **kwargs): inject(kwargs) return fn(*args, **kwargs) - if inspect.iscoroutinefunction(fn): + if asyncio.iscoroutinefunction(fn): return async_wrapper return sync_wrapper @@ -167,7 +167,7 @@ async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R: db = provide_database_interface() return await func(db, *args, **kwargs) # type: ignore - if inspect.iscoroutinefunction(func): + if asyncio.iscoroutinefunction(func): return async_wrapper # type: ignore else: return sync_wrapper diff --git a/src/prefect/task_engine.py b/src/prefect/task_engine.py index 36388fd05756..4e5a417cbf8b 100644 --- a/src/prefect/task_engine.py +++ b/src/prefect/task_engine.py @@ -1,3 +1,4 @@ +import asyncio import inspect import logging import threading @@ -291,7 +292,7 @@ def can_retry(self, exc: Exception) -> bool: data=exc, message=f"Task run encountered unexpected exception: {repr(exc)}", ) - if inspect.iscoroutinefunction(retry_condition): + if asyncio.iscoroutinefunction(retry_condition): should_retry = run_coro_as_sync( retry_condition(self.task, self.task_run, state) ) @@ -335,7 +336,7 @@ def call_hooks(self, state: Optional[State] = None): f" {state.name!r}" ) result = hook(task, task_run, state) - if inspect.isawaitable(result): + if asyncio.iscoroutine(result): run_coro_as_sync(result) except Exception: self.logger.error( @@ -419,7 +420,7 @@ def set_state(self, state: State, force: bool = False) -> State: # Avoid fetching the result unless it is cached, otherwise we defeat # the purpose of disabling `cache_result_in_memory` result = state.result(raise_on_failure=False, fetch=True) - if inspect.isawaitable(result): + if asyncio.iscoroutine(result): result = run_coro_as_sync(result) elif isinstance(state.data, ResultRecord): result = state.data.result @@ -443,7 +444,7 @@ def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]": # if the return value is a BaseResult, we need to fetch it if isinstance(self._return_value, BaseResult): _result = self._return_value.get() - if inspect.isawaitable(_result): + if asyncio.iscoroutine(_result): _result = run_coro_as_sync(_result) return _result elif isinstance(self._return_value, ResultRecord): @@ -813,7 +814,7 @@ async def can_retry(self, exc: Exception) -> bool: data=exc, message=f"Task run encountered unexpected exception: {repr(exc)}", ) - if inspect.iscoroutinefunction(retry_condition): + if asyncio.iscoroutinefunction(retry_condition): should_retry = await retry_condition(self.task, self.task_run, state) elif inspect.isfunction(retry_condition): should_retry = retry_condition(self.task, self.task_run, state) diff --git a/src/prefect/tasks.py b/src/prefect/tasks.py index 6fac57389e1a..cdc95ae55661 100644 --- a/src/prefect/tasks.py +++ b/src/prefect/tasks.py @@ -4,6 +4,7 @@ # This file requires type-checking with pyright because mypy does not yet support PEP612 # See https://github.com/python/mypy/issues/8645 +import asyncio import datetime import inspect from copy import copy @@ -370,7 +371,7 @@ def __init__( # the task is considered async if its function is async or an async # generator - self.isasync = inspect.iscoroutinefunction( + self.isasync = asyncio.iscoroutinefunction( self.fn ) or inspect.isasyncgenfunction(self.fn) diff --git a/src/prefect/utilities/asyncutils.py b/src/prefect/utilities/asyncutils.py index 9361b2dc003c..3939632e1641 100644 --- a/src/prefect/utilities/asyncutils.py +++ b/src/prefect/utilities/asyncutils.py @@ -83,7 +83,7 @@ def is_async_fn( while hasattr(func, "__wrapped__"): func = func.__wrapped__ - return inspect.iscoroutinefunction(func) + return asyncio.iscoroutinefunction(func) def is_async_gen_fn(func): diff --git a/src/prefect/utilities/engine.py b/src/prefect/utilities/engine.py index ca412e1b15d5..0f13a44a7a74 100644 --- a/src/prefect/utilities/engine.py +++ b/src/prefect/utilities/engine.py @@ -1,6 +1,5 @@ import asyncio import contextlib -import inspect import os import signal import time @@ -503,7 +502,7 @@ def propose_state_sync( # Avoid fetching the result unless it is cached, otherwise we defeat # the purpose of disabling `cache_result_in_memory` result = state.result(raise_on_failure=False, fetch=True) - if inspect.isawaitable(result): + if asyncio.iscoroutine(result): result = run_coro_as_sync(result) elif isinstance(state.data, ResultRecord): result = state.data.result @@ -870,7 +869,7 @@ def resolve_to_final_result(expr, context): ) _result = state.result(raise_on_failure=False, fetch=True) - if inspect.isawaitable(_result): + if asyncio.iscoroutine(_result): _result = run_coro_as_sync(_result) return _result diff --git a/src/prefect/workers/base.py b/src/prefect/workers/base.py index f93a13977b11..6f31925a1281 100644 --- a/src/prefect/workers/base.py +++ b/src/prefect/workers/base.py @@ -1,5 +1,5 @@ import abc -import inspect +import asyncio import threading from contextlib import AsyncExitStack from functools import partial @@ -1078,7 +1078,7 @@ async def wrapper(task_status): task_status.started() result = fn(*args, **kwargs) - if inspect.iscoroutine(result): + if asyncio.iscoroutine(result): await result await self._runs_task_group.start(wrapper) diff --git a/tests/public/flows/test_flow_calls.py b/tests/public/flows/test_flow_calls.py index 7f2d7389c47c..a4d72eac6ef9 100644 --- a/tests/public/flows/test_flow_calls.py +++ b/tests/public/flows/test_flow_calls.py @@ -1,5 +1,4 @@ import asyncio -import inspect import anyio @@ -18,7 +17,7 @@ async def aidentity_flow(x): def test_async_flow_called_with_asyncio(): coro = aidentity_flow(1) - assert inspect.isawaitable(coro) + assert asyncio.iscoroutine(coro) assert asyncio.run(coro) == 1 @@ -28,7 +27,7 @@ def test_async_flow_called_with_anyio(): async def test_async_flow_called_with_running_loop(): coro = aidentity_flow(1) - assert inspect.isawaitable(coro) + assert asyncio.iscoroutine(coro) assert await coro == 1 diff --git a/tests/server/database/test_dependencies.py b/tests/server/database/test_dependencies.py index b7532c0cfdf0..7a5c50f4b0d9 100644 --- a/tests/server/database/test_dependencies.py +++ b/tests/server/database/test_dependencies.py @@ -1,5 +1,5 @@ +import asyncio import datetime -import inspect from uuid import UUID import pytest @@ -170,7 +170,7 @@ class Returner: async def return_1(self, db): return 1 - assert inspect.iscoroutinefunction(Returner().return_1) + assert asyncio.iscoroutinefunction(Returner().return_1) async def test_inject_interface_class(): diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 4941ffee3e47..3cdd7569a965 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -2438,7 +2438,7 @@ def upstream_downstream_flow(result): # because it runs on the main thread with an active event loop. We need to update # result retrieval to be sync. result = upstream_state.result() - if inspect.isawaitable(result): + if asyncio.iscoroutine(result): result = run_coro_as_sync(result) downstream_state = downstream(result, return_state=True) return upstream_state, downstream_state