Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
from pathlib import Path

from airflow import DAG
from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.models import Variable
from airflow.providers.standard.operators.python import PythonOperator
from airflow.providers.standard.sensors.time_delta import TimeDeltaSensorAsync
import pytest

from system.openlineage.operator import OpenLineageTestOperator

Expand All @@ -45,21 +47,22 @@ def check_start_amount_func():
schedule=None,
catchup=False,
) as dag:
# Timedelta is compared to the DAGRun start timestamp, which can occur long before a worker picks up the
# task. We need to ensure the sensor gets deferred at least once, so setting 180s.
wait = TimeDeltaSensorAsync(task_id="wait", delta=timedelta(seconds=180))
with pytest.warns(AirflowProviderDeprecationWarning):
# Timedelta is compared to the DAGRun start timestamp, which can occur long before a worker picks up the
# task. We need to ensure the sensor gets deferred at least once, so setting 180s.
wait = TimeDeltaSensorAsync(task_id="wait", delta=timedelta(seconds=180))

check_start_events_amount = PythonOperator(
task_id="check_start_events_amount", python_callable=check_start_amount_func
)
check_start_events_amount = PythonOperator(
task_id="check_start_events_amount", python_callable=check_start_amount_func
)

check_events = OpenLineageTestOperator(
task_id="check_events",
file_path=str(Path(__file__).parent / "example_openlineage_defer.json"),
allow_duplicate_events=True,
)
check_events = OpenLineageTestOperator(
task_id="check_events",
file_path=str(Path(__file__).parent / "example_openlineage_defer.json"),
allow_duplicate_events=True,
)

wait >> check_start_events_amount >> check_events
wait >> check_start_events_amount >> check_events


from tests_common.test_utils.system_tests import get_test_run # noqa: E402
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,16 @@
# under the License.
from __future__ import annotations

import warnings
from datetime import datetime, timedelta
from time import sleep
from typing import TYPE_CHECKING, Any, NoReturn

from deprecated.classic import deprecated
from packaging.version import Version

from airflow.configuration import conf
from airflow.exceptions import AirflowSkipException
from airflow.exceptions import AirflowProviderDeprecationWarning, AirflowSkipException
from airflow.providers.standard.triggers.temporal import DateTimeTrigger, TimeDeltaTrigger
from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS
from airflow.sensors.base import BaseSensorOperator
Expand Down Expand Up @@ -52,16 +54,26 @@ class TimeDeltaSensor(BaseSensorOperator):
otherwise run_after will be used.

:param delta: time to wait before succeeding.
:param deferrable: Run sensor in deferrable mode. If set to True, task will defer itself to avoid taking up a worker slot while it is waiting.

.. seealso::
For more information on how to use this sensor, take a look at the guide:
:ref:`howto/operator:TimeDeltaSensor`

"""

def __init__(self, *, delta, **kwargs):
def __init__(
self,
*,
delta: timedelta,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
end_from_trigger: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self.delta = delta
self.deferrable = deferrable
self.end_from_trigger = end_from_trigger

def _derive_base_time(self, context: Context) -> datetime:
"""
Expand Down Expand Up @@ -90,27 +102,21 @@ def poke(self, context: Context) -> bool:
self.log.info("Checking if the delta has elapsed base_time=%s, delta=%s", base_time, self.delta)
return timezone.utcnow() > target_dttm


class TimeDeltaSensorAsync(TimeDeltaSensor):
"""
A deferrable drop-in replacement for TimeDeltaSensor.

Will defers itself to avoid taking up a worker slot while it is waiting.

:param delta: time length to wait after the data interval before succeeding.
:param end_from_trigger: End the task directly from the triggerer without going into the worker.

.. seealso::
For more information on how to use this sensor, take a look at the guide:
:ref:`howto/operator:TimeDeltaSensorAsync`

Asynchronous execution
"""

def __init__(self, *, end_from_trigger: bool = False, delta, **kwargs) -> None:
super().__init__(delta=delta, **kwargs)
self.end_from_trigger = end_from_trigger

def execute(self, context: Context) -> bool | NoReturn:
"""
Depending on the deferrable flag, either execute the sensor in a blocking way or defer it.

- Sync path → use BaseSensorOperator.execute() which loops over ``poke``.
- Async path → defer to DateTimeTrigger and free the worker slot.
"""
if not self.deferrable:
return super().execute(context=context)

# Deferrable path
base_time = self._derive_base_time(context=context)
target_dttm: datetime = base_time + self.delta

Expand Down Expand Up @@ -146,6 +152,26 @@ def execute_complete(self, context: Context, event: Any = None) -> None:
return None


# TODO: Remove in the next major release
@deprecated(
"Use `TimeDeltaSensor` with `deferrable=True` instead", category=AirflowProviderDeprecationWarning
)
class TimeDeltaSensorAsync(TimeDeltaSensor):
"""
Deprecated. Use TimeDeltaSensor with deferrable=True instead.

:sphinx-autoapi-skip:
"""

def __init__(self, *, end_from_trigger: bool = False, delta, **kwargs) -> None:
warnings.warn(
"TimeDeltaSensorAsync is deprecated and will be removed in a future version. Use `TimeDeltaSensor` with `deferrable=True` instead.",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
super().__init__(delta=delta, deferrable=True, end_from_trigger=end_from_trigger, **kwargs)


class WaitSensor(BaseSensorOperator):
"""
A sensor that waits a specified period of time before completing.
Expand Down
6 changes: 4 additions & 2 deletions providers/standard/tests/system/standard/example_sensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from airflow.providers.standard.sensors.filesystem import FileSensor
from airflow.providers.standard.sensors.python import PythonSensor
from airflow.providers.standard.sensors.time import TimeSensor
from airflow.providers.standard.sensors.time_delta import TimeDeltaSensor, TimeDeltaSensorAsync
from airflow.providers.standard.sensors.time_delta import TimeDeltaSensor
from airflow.providers.standard.sensors.weekday import DayOfWeekSensor
from airflow.providers.standard.utils.weekday import WeekDay
from airflow.sdk import DAG
Expand Down Expand Up @@ -57,7 +57,9 @@ def failure_callable():
# [END example_time_delta_sensor]

# [START example_time_delta_sensor_async]
t0a = TimeDeltaSensorAsync(task_id="wait_some_seconds_async", delta=datetime.timedelta(seconds=2))
t0a = TimeDeltaSensor(
task_id="wait_some_seconds_async", delta=datetime.timedelta(seconds=2), deferrable=True
)
# [END example_time_delta_sensor_async]

# [START example_time_sensors]
Expand Down
137 changes: 97 additions & 40 deletions providers/standard/tests/unit/standard/sensors/test_time_delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,22 @@
from __future__ import annotations

from datetime import timedelta
from typing import Any
from unittest import mock

import pendulum
import pytest
import time_machine

from airflow.exceptions import TaskDeferred
from airflow.exceptions import AirflowProviderDeprecationWarning, TaskDeferred
from airflow.models import DagBag
from airflow.models.dag import DAG
from airflow.providers.standard.sensors.time_delta import (
TimeDeltaSensor,
TimeDeltaSensorAsync,
WaitSensor,
)
from airflow.providers.standard.triggers.temporal import DateTimeTrigger
from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS
from airflow.utils import timezone
from airflow.utils.timezone import datetime
Expand Down Expand Up @@ -105,6 +107,57 @@ def test_timedelta_sensor_run_after_vs_interval(run_after, interval_end, dag_mak
assert actual == expected


@pytest.mark.parametrize(
"run_after, interval_end",
[
(timezone.utcnow() + timedelta(days=1), timezone.utcnow() + timedelta(days=2)),
(timezone.utcnow() + timedelta(days=1), None),
],
)
def test_timedelta_sensor_deferrable_run_after_vs_interval(run_after, interval_end, dag_maker):
"""Test that TimeDeltaSensor defers correctly when flag is enabled."""
if not AIRFLOW_V_3_0_PLUS and not interval_end:
pytest.skip("not applicable")

context: dict[str, Any] = {}
if interval_end:
context["data_interval_end"] = interval_end

with dag_maker() as dag:
kwargs = {}
if AIRFLOW_V_3_0_PLUS:
from airflow.utils.types import DagRunTriggeredByType

kwargs.update(triggered_by=DagRunTriggeredByType.TEST, run_after=run_after)

delta = timedelta(minutes=5)
sensor = TimeDeltaSensor(
task_id="timedelta_sensor_deferrable",
delta=delta,
dag=dag,
deferrable=True, # <-- the feature under test
)

dr = dag.create_dagrun(
run_id="abcrhroceuh",
run_type=DagRunType.MANUAL,
state=None,
**kwargs,
)
context.update(dag_run=dr)

expected_base = interval_end or run_after
expected_fire_time = expected_base + delta

with pytest.raises(TaskDeferred) as td:
sensor.execute(context)

# The sensor should defer once with a DateTimeTrigger
trigger = td.value.trigger
assert isinstance(trigger, DateTimeTrigger)
assert trigger.moment == expected_fire_time


class TestTimeDeltaSensorAsync:
def setup_method(self):
self.dagbag = DagBag(dag_folder=DEV_NULL, include_examples=True)
Expand All @@ -117,17 +170,20 @@ def setup_method(self):
)
@mock.patch(DEFER_PATH)
def test_timedelta_sensor(self, defer_mock, should_defer):
delta = timedelta(hours=1)
op = TimeDeltaSensorAsync(task_id="timedelta_sensor_check", delta=delta, dag=self.dag)
if should_defer:
data_interval_end = pendulum.now("UTC").add(hours=1)
else:
data_interval_end = pendulum.now("UTC").replace(microsecond=0, second=0, minute=0).add(hours=-1)
op.execute({"data_interval_end": data_interval_end})
if should_defer:
defer_mock.assert_called_once()
else:
defer_mock.assert_not_called()
with pytest.warns(AirflowProviderDeprecationWarning):
delta = timedelta(hours=1)
op = TimeDeltaSensorAsync(task_id="timedelta_sensor_check", delta=delta, dag=self.dag)
if should_defer:
data_interval_end = pendulum.now("UTC").add(hours=1)
else:
data_interval_end = (
pendulum.now("UTC").replace(microsecond=0, second=0, minute=0).add(hours=-1)
)
op.execute({"data_interval_end": data_interval_end})
if should_defer:
defer_mock.assert_called_once()
else:
defer_mock.assert_not_called()

@pytest.mark.parametrize(
"should_defer",
Expand Down Expand Up @@ -157,31 +213,32 @@ def test_wait_sensor(self, sleep_mock, defer_mock, should_defer):
)
def test_timedelta_sensor_async_run_after_vs_interval(self, run_after, interval_end, dag_maker):
"""Interval end should be used as base time when present else run_after"""
if not AIRFLOW_V_3_0_PLUS and not interval_end:
pytest.skip("not applicable")

context = {}
if interval_end:
context["data_interval_end"] = interval_end
with dag_maker() as dag:
kwargs = {}
if AIRFLOW_V_3_0_PLUS:
from airflow.utils.types import DagRunTriggeredByType

kwargs.update(triggered_by=DagRunTriggeredByType.TEST, run_after=run_after)

dr = dag.create_dagrun(
run_id="abcrhroceuh",
run_type=DagRunType.MANUAL,
state=None,
**kwargs,
)
context.update(dag_run=dr)
delta = timedelta(seconds=1)
op = TimeDeltaSensorAsync(task_id="wait_sensor_check", delta=delta, dag=dag)
base_time = interval_end or run_after
expected_time = base_time + delta
with pytest.raises(TaskDeferred) as caught:
op.execute(context)

assert caught.value.trigger.moment == expected_time
with pytest.warns(AirflowProviderDeprecationWarning):
if not AIRFLOW_V_3_0_PLUS and not interval_end:
pytest.skip("not applicable")

context = {}
if interval_end:
context["data_interval_end"] = interval_end
with dag_maker() as dag:
kwargs = {}
if AIRFLOW_V_3_0_PLUS:
from airflow.utils.types import DagRunTriggeredByType

kwargs.update(triggered_by=DagRunTriggeredByType.TEST, run_after=run_after)

dr = dag.create_dagrun(
run_id="abcrhroceuh",
run_type=DagRunType.MANUAL,
state=None,
**kwargs,
)
context.update(dag_run=dr)
delta = timedelta(seconds=1)
op = TimeDeltaSensorAsync(task_id="wait_sensor_check", delta=delta, dag=dag)
base_time = interval_end or run_after
expected_time = base_time + delta
with pytest.raises(TaskDeferred) as caught:
op.execute(context)

assert caught.value.trigger.moment == expected_time
Loading