Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from __future__ import annotations

from collections.abc import Iterable
from datetime import timedelta
from typing import TYPE_CHECKING

import pendulum
Expand All @@ -29,8 +30,9 @@
from airflow.utils.types import DagRunType

if TYPE_CHECKING:
from airflow.models import DAG, DagRun
from airflow.timetables.base import DagRunInfo
from pendulum.datetime import DateTime

from airflow.models import DagRun

try:
from airflow.sdk.definitions.context import Context
Expand Down Expand Up @@ -62,54 +64,64 @@ def choose_branch(self, context: Context) -> str | Iterable[str]:
dag_run: DagRun = context["dag_run"] # type: ignore[assignment]
if dag_run.run_type == DagRunType.MANUAL:
self.log.info("Manually triggered DAG_Run: allowing execution to proceed.")
return list(context["task"].get_direct_relative_ids(upstream=False))
return list(self.get_direct_relative_ids(upstream=False))

next_info = self._get_next_run_info(context, dag_run)
now = pendulum.now("UTC")
dates = self._get_compare_dates(dag_run)

if next_info is None:
if dates is None:
self.log.info("Last scheduled execution: allowing execution to proceed.")
return list(context["task"].get_direct_relative_ids(upstream=False))
return list(self.get_direct_relative_ids(upstream=False))

left_window, right_window = next_info.data_interval
now = pendulum.now("UTC")
left_window, right_window = dates
self.log.info(
"Checking latest only with left_window: %s right_window: %s now: %s",
left_window,
right_window,
now,
)

if left_window == right_window:
self.log.info(
"Zero-length interval [%s, %s) from timetable (%s); treating current run as latest.",
left_window,
right_window,
self.dag.timetable.__class__,
)
return list(context["task"].get_direct_relative_ids(upstream=False))

if not left_window < now <= right_window:
self.log.info("Not latest execution, skipping downstream.")
# we return an empty list, thus the parent BaseBranchOperator
# won't exclude any downstream tasks from skipping.
return []
self.log.info("Latest, allowing execution to proceed.")
return list(context["task"].get_direct_relative_ids(upstream=False))

def _get_next_run_info(self, context: Context, dag_run: DagRun) -> DagRunInfo | None:
dag: DAG = context["dag"] # type: ignore[assignment]
self.log.info("Latest, allowing execution to proceed.")
return list(self.get_direct_relative_ids(upstream=False))

def _get_compare_dates(self, dag_run: DagRun) -> tuple[DateTime, DateTime] | None:
dagrun_date: DateTime
if AIRFLOW_V_3_0_PLUS:
from airflow.timetables.base import DataInterval, TimeRestriction
dagrun_date = dag_run.logical_date or dag_run.run_after
else:
dagrun_date = dag_run.logical_date

time_restriction = TimeRestriction(earliest=None, latest=None, catchup=True)
current_interval = DataInterval(start=dag_run.data_interval_start, end=dag_run.data_interval_end)
from airflow.timetables.base import DataInterval, TimeRestriction

next_info = dag.timetable.next_dagrun_info(
last_automated_data_interval=current_interval,
restriction=time_restriction,
)
current_interval = DataInterval(
start=dag_run.data_interval_start or dagrun_date,
end=dag_run.data_interval_end or dagrun_date,
)

time_restriction = TimeRestriction(
earliest=None, latest=current_interval.end - timedelta(microseconds=1), catchup=True
)
if prev_info := self.dag.timetable.next_dagrun_info(
last_automated_data_interval=current_interval,
restriction=time_restriction,
):
left = prev_info.data_interval.end
else:
next_info = dag.next_dagrun_info(dag.get_run_data_interval(dag_run), restricted=False)
return next_info
left = current_interval.start

time_restriction = TimeRestriction(earliest=current_interval.end, latest=None, catchup=True)
next_info = self.dag.timetable.next_dagrun_info(
last_automated_data_interval=current_interval,
restriction=time_restriction,
)

if not next_info:
return None

return (left, next_info.data_interval.end)
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

import datetime
import operator

import pytest
import time_machine
Expand Down Expand Up @@ -115,9 +116,7 @@ def test_skipping_non_latest(self, dag_maker):
start_date=timezone.utcnow(),
logical_date=timezone.datetime(2016, 1, 1, 12),
state=State.RUNNING,
data_interval=DataInterval(
timezone.datetime(2016, 1, 1, 12), timezone.datetime(2016, 1, 1, 12) + INTERVAL
),
data_interval=DataInterval(timezone.datetime(2016, 1, 1, 12), timezone.datetime(2016, 1, 1, 12)),
**triggered_by_kwargs,
)

Expand All @@ -126,7 +125,7 @@ def test_skipping_non_latest(self, dag_maker):
start_date=timezone.utcnow(),
logical_date=END_DATE,
state=State.RUNNING,
data_interval=DataInterval(END_DATE, END_DATE + INTERVAL),
data_interval=DataInterval(END_DATE + INTERVAL, END_DATE + INTERVAL),
**triggered_by_kwargs,
)

Expand All @@ -145,6 +144,7 @@ def test_skipping_non_latest(self, dag_maker):
latest_ti0.run()

assert exc_info.value.tasks == [("downstream", -1)]
# TODO: Set state is needed until #45549 is completed.
latest_ti0.set_state(State.SUCCESS)
dr0.get_task_instance(task_id="downstream").set_state(State.SKIPPED)

Expand All @@ -156,6 +156,7 @@ def test_skipping_non_latest(self, dag_maker):
latest_ti1.run()

assert exc_info.value.tasks == [("downstream", -1)]
# TODO: Set state is needed until #45549 is completed.
latest_ti1.set_state(State.SUCCESS)
dr1.get_task_instance(task_id="downstream").set_state(State.SKIPPED)

Expand All @@ -165,77 +166,49 @@ def test_skipping_non_latest(self, dag_maker):
latest_ti2.task = latest_task
latest_ti2.run()

latest_ti2.set_state(State.SUCCESS)

# Verify the state of the other downstream tasks
downstream_task.run(start_date=DEFAULT_DATE, end_date=END_DATE)
downstream_task2.run(start_date=DEFAULT_DATE, end_date=END_DATE)
downstream_task3.run(start_date=DEFAULT_DATE, end_date=END_DATE)

downstream_instances = get_task_instances("downstream")
exec_date_to_downstream_state = {ti.logical_date: ti.state for ti in downstream_instances}
assert exec_date_to_downstream_state == {
timezone.datetime(2016, 1, 1): "skipped",
timezone.datetime(2016, 1, 1, 12): "skipped",
timezone.datetime(2016, 1, 2): "success",
}

downstream_instances = get_task_instances("downstream_2")
exec_date_to_downstream_state = {ti.logical_date: ti.state for ti in downstream_instances}
assert exec_date_to_downstream_state == {
timezone.datetime(2016, 1, 1): None,
timezone.datetime(2016, 1, 1, 12): None,
timezone.datetime(2016, 1, 2): "success",
}

downstream_instances = get_task_instances("downstream_3")
exec_date_to_downstream_state = {ti.logical_date: ti.state for ti in downstream_instances}
assert exec_date_to_downstream_state == {
timezone.datetime(2016, 1, 1): "success",
timezone.datetime(2016, 1, 1, 12): "success",
timezone.datetime(2016, 1, 2): "success",
}

date_getter = operator.attrgetter("logical_date")
else:
latest_task.run(start_date=DEFAULT_DATE, end_date=END_DATE)
date_getter = operator.attrgetter("execution_date")

downstream_task.run(start_date=DEFAULT_DATE, end_date=END_DATE)
downstream_task2.run(start_date=DEFAULT_DATE, end_date=END_DATE)
downstream_task3.run(start_date=DEFAULT_DATE, end_date=END_DATE)
latest_instances = get_task_instances("latest")
exec_date_to_latest_state = {date_getter(ti): ti.state for ti in latest_instances}
assert exec_date_to_latest_state == {
timezone.datetime(2016, 1, 1): "success",
timezone.datetime(2016, 1, 1, 12): "success",
timezone.datetime(2016, 1, 2): "success",
}

latest_instances = get_task_instances("latest")
exec_date_to_latest_state = {ti.execution_date: ti.state for ti in latest_instances}
assert exec_date_to_latest_state == {
timezone.datetime(2016, 1, 1): "success",
timezone.datetime(2016, 1, 1, 12): "success",
timezone.datetime(2016, 1, 2): "success",
}
# Verify the state of the other downstream tasks
downstream_task.run(start_date=DEFAULT_DATE, end_date=END_DATE)
downstream_task2.run(start_date=DEFAULT_DATE, end_date=END_DATE)
downstream_task3.run(start_date=DEFAULT_DATE, end_date=END_DATE)

downstream_instances = get_task_instances("downstream")
exec_date_to_downstream_state = {ti.execution_date: ti.state for ti in downstream_instances}
assert exec_date_to_downstream_state == {
timezone.datetime(2016, 1, 1): "skipped",
timezone.datetime(2016, 1, 1, 12): "skipped",
timezone.datetime(2016, 1, 2): "success",
}
downstream_instances = get_task_instances("downstream")
exec_date_to_downstream_state = {date_getter(ti): ti.state for ti in downstream_instances}
assert exec_date_to_downstream_state == {
timezone.datetime(2016, 1, 1): "skipped",
timezone.datetime(2016, 1, 1, 12): "skipped",
timezone.datetime(2016, 1, 2): "success",
}

downstream_instances = get_task_instances("downstream_2")
exec_date_to_downstream_state = {ti.execution_date: ti.state for ti in downstream_instances}
assert exec_date_to_downstream_state == {
timezone.datetime(2016, 1, 1): None,
timezone.datetime(2016, 1, 1, 12): None,
timezone.datetime(2016, 1, 2): "success",
}
downstream_instances = get_task_instances("downstream_2")
exec_date_to_downstream_state = {date_getter(ti): ti.state for ti in downstream_instances}
assert exec_date_to_downstream_state == {
timezone.datetime(2016, 1, 1): None,
timezone.datetime(2016, 1, 1, 12): None,
timezone.datetime(2016, 1, 2): "success",
}

downstream_instances = get_task_instances("downstream_3")
exec_date_to_downstream_state = {ti.execution_date: ti.state for ti in downstream_instances}
assert exec_date_to_downstream_state == {
timezone.datetime(2016, 1, 1): "success",
timezone.datetime(2016, 1, 1, 12): "success",
timezone.datetime(2016, 1, 2): "success",
}
downstream_instances = get_task_instances("downstream_3")
exec_date_to_downstream_state = {date_getter(ti): ti.state for ti in downstream_instances}
assert exec_date_to_downstream_state == {
timezone.datetime(2016, 1, 1): "success",
timezone.datetime(2016, 1, 1, 12): "success",
timezone.datetime(2016, 1, 2): "success",
}

def test_not_skipping_external(self, dag_maker):
def test_not_skipping_manual(self, dag_maker):
with dag_maker(
default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, schedule=INTERVAL, serialized=True
):
Expand Down