Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions providers/standard/docs/sensors/datetime.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ TimeSensor

Use the :class:`~airflow.providers.standard.sensors.time_sensor.TimeSensor` to end sensing after time specified.

Time will be evaluated against ``data_interval_end`` if present for the dag run, otherwise ``run_after`` will be used.

.. exampleinclude:: /../../airflow/example_dags/example_sensors.py
:language: python
:dedent: 4
Expand All @@ -71,6 +73,8 @@ TimeSensorAsync
Use the :class:`~airflow.providers.standard.sensors.time_sensor.TimeSensorAsync` to end sensing after time specified.
It is an async version of the operator and requires Triggerer to run.

Time will be evaluated against ``data_interval_end`` if present for the dag run, otherwise ``run_after`` will be used.

.. exampleinclude:: /../../airflow/example_dags/example_sensors.py
:language: python
:dedent: 4
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,29 +46,48 @@ def _get_airflow_version():

class TimeDeltaSensor(BaseSensorOperator):
"""
Waits for a timedelta after the run's data interval.
Waits for a timedelta.

:param delta: time length to wait after the data interval before succeeding.
The delta will be evaluated against data_interval_end if present for the dag run,
otherwise run_after will be used.

:param delta: time to wait before succeeding.

.. seealso::
For more information on how to use this sensor, take a look at the guide:
:ref:`howto/operator:TimeDeltaSensor`


"""

def __init__(self, *, delta, **kwargs):
super().__init__(**kwargs)
self.delta = delta

def poke(self, context: Context):
data_interval_end = context["data_interval_end"]
def _derive_base_time(self, context: Context) -> datetime:
"""
Get the "base time" against which the delta should be calculated.

If data_interval_end is populated, use it; else use run_after.
"""
data_interval_end = context.get("data_interval_end")
if data_interval_end:
if not isinstance(data_interval_end, datetime):
raise ValueError("`data_interval_end` returned non-datetime object")

if not isinstance(data_interval_end, datetime):
raise ValueError("`data_interval_end` returned non-datetime object")
return data_interval_end

target_dttm: datetime = data_interval_end + self.delta
self.log.info("Checking if the time (%s) has come", target_dttm)
if not data_interval_end and not AIRFLOW_V_3_0_PLUS:
raise ValueError("`data_interval_end` not found in task context.")

dag_run = context.get("dag_run")
if not dag_run:
raise ValueError("`dag_run` not found in task context")
return dag_run.run_after

def poke(self, context: Context) -> bool:
base_time = self._derive_base_time(context=context)
target_dttm = base_time + self.delta
self.log.info("Checking if the delta has elapsed base_time=%s, delta=%s", base_time, self.delta)
return timezone.utcnow() > target_dttm


Expand All @@ -92,12 +111,8 @@ def __init__(self, *, end_from_trigger: bool = False, delta, **kwargs) -> None:
self.end_from_trigger = end_from_trigger

def execute(self, context: Context) -> bool | NoReturn:
data_interval_end = context["data_interval_end"]

if not isinstance(data_interval_end, datetime):
raise ValueError("`data_interval_end` returned non-datetime object")

target_dttm: datetime = data_interval_end + self.delta
base_time = self._derive_base_time(context=context)
target_dttm: datetime = base_time + self.delta

if timezone.utcnow() > target_dttm:
# If the target datetime is in the past, return immediately
Expand Down
88 changes: 88 additions & 0 deletions providers/standard/tests/unit/standard/sensors/test_time_delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,20 @@
import pytest
import time_machine

from airflow.exceptions import TaskDeferred
from airflow.models import DagBag
from airflow.models.dag import DAG
from airflow.providers.standard.sensors.time_delta import (
TimeDeltaSensor,
TimeDeltaSensorAsync,
WaitSensor,
)
from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS
from airflow.utils import timezone
from airflow.utils.timezone import datetime
from airflow.utils.types import DagRunType

from tests_common.test_utils import db

pytestmark = pytest.mark.db_test

Expand All @@ -40,6 +46,12 @@
TEST_DAG_ID = "unit_tests"


@pytest.fixture(autouse=True)
def clear_db():
db.clear_db_dags()
db.clear_db_runs()


class TestTimedeltaSensor:
def setup_method(self):
self.dagbag = DagBag(dag_folder=DEV_NULL, include_examples=True)
Expand All @@ -51,6 +63,44 @@ def test_timedelta_sensor(self):
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)


@pytest.mark.parametrize(
"run_after, interval_end",
[
(timezone.utcnow() + timedelta(days=1), timezone.utcnow() + timedelta(days=2)),
(timezone.utcnow() + timedelta(days=1), None),
],
)
def test_timedelta_sensor_run_after_vs_interval(run_after, interval_end, dag_maker, session):
"""Interval end should be used as base time when present else run_after"""
if not AIRFLOW_V_3_0_PLUS and not interval_end:
pytest.skip("not applicable")

context = {}
if interval_end:
context["data_interval_end"] = interval_end
delta = timedelta(seconds=1)
with dag_maker() as dag:
op = TimeDeltaSensor(task_id="wait_sensor_check", delta=delta, dag=dag, mode="reschedule")

kwargs = {}
if AIRFLOW_V_3_0_PLUS:
from airflow.utils.types import DagRunTriggeredByType

kwargs.update(triggered_by=DagRunTriggeredByType.TEST, run_after=run_after)
dr = dag.create_dagrun(
run_id="abcrhroceuh",
run_type=DagRunType.MANUAL,
state=None,
session=session,
**kwargs,
)
ti = dr.task_instances[0]
context.update(dag_run=dr, ti=ti)
expected = interval_end or run_after
actual = op._derive_base_time(context)
assert actual == expected


class TestTimeDeltaSensorAsync:
def setup_method(self):
self.dagbag = DagBag(dag_folder=DEV_NULL, include_examples=True)
Expand Down Expand Up @@ -93,3 +143,41 @@ def test_wait_sensor(self, sleep_mock, defer_mock, should_defer):
else:
defer_mock.assert_not_called()
sleep_mock.assert_called_once_with(30)

@pytest.mark.parametrize(
"run_after, interval_end",
[
(timezone.utcnow() + timedelta(days=1), timezone.utcnow() + timedelta(days=2)),
(timezone.utcnow() + timedelta(days=1), None),
],
)
def test_timedelta_sensor_async_run_after_vs_interval(self, run_after, interval_end, dag_maker):
"""Interval end should be used as base time when present else run_after"""
if not AIRFLOW_V_3_0_PLUS and not interval_end:
pytest.skip("not applicable")

context = {}
if interval_end:
context["data_interval_end"] = interval_end
with dag_maker() as dag:
kwargs = {}
if AIRFLOW_V_3_0_PLUS:
from airflow.utils.types import DagRunTriggeredByType

kwargs.update(triggered_by=DagRunTriggeredByType.TEST, run_after=run_after)

dr = dag.create_dagrun(
run_id="abcrhroceuh",
run_type=DagRunType.MANUAL,
state=None,
**kwargs,
)
context.update(dag_run=dr)
delta = timedelta(seconds=1)
op = TimeDeltaSensorAsync(task_id="wait_sensor_check", delta=delta, dag=dag)
base_time = interval_end or run_after
expected_time = base_time + delta
with pytest.raises(TaskDeferred) as caught:
op.execute(context)

assert caught.value.trigger.moment == expected_time