Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
from airflow.providers.standard.triggers.external_task import DagStateTrigger
from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS
from airflow.utils import timezone
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import DagRunState
from airflow.utils.types import NOTSET, ArgNotSet, DagRunType

Expand Down Expand Up @@ -337,33 +336,36 @@ def _trigger_dag_run_af_3_execute_complete(self, event: tuple[str, dict[str, Any
f" {failed_run_id_conditions}"
)

@provide_session
def _trigger_dag_run_af_2_execute_complete(
self, event: tuple[str, dict[str, Any]], session: Session = NEW_SESSION
):
# This logical_date is parsed from the return trigger event
provided_logical_date = event[1]["execution_dates"][0]
try:
# Note: here execution fails on database isolation mode. Needs structural changes for AIP-72
dag_run = session.execute(
select(DagRun).where(
DagRun.dag_id == self.trigger_dag_id, DagRun.execution_date == provided_logical_date
if not AIRFLOW_V_3_0_PLUS:
from airflow.utils.session import NEW_SESSION, provide_session # type: ignore[misc]

@provide_session
def _trigger_dag_run_af_2_execute_complete(
self, event: tuple[str, dict[str, Any]], session: Session = NEW_SESSION
):
# This logical_date is parsed from the return trigger event
provided_logical_date = event[1]["execution_dates"][0]
try:
# Note: here execution fails on database isolation mode. Needs structural changes for AIP-72
dag_run = session.execute(
select(DagRun).where(
DagRun.dag_id == self.trigger_dag_id, DagRun.execution_date == provided_logical_date
)
).scalar_one()
except NoResultFound:
raise AirflowException(
f"No DAG run found for DAG {self.trigger_dag_id} and logical date {self.logical_date}"
)
).scalar_one()
except NoResultFound:
raise AirflowException(
f"No DAG run found for DAG {self.trigger_dag_id} and logical date {self.logical_date}"
)

state = dag_run.state
state = dag_run.state

if state in self.failed_states:
raise AirflowException(f"{self.trigger_dag_id} failed with failed state {state}")
if state in self.allowed_states:
self.log.info("%s finished with allowed state %s", self.trigger_dag_id, state)
return
if state in self.failed_states:
raise AirflowException(f"{self.trigger_dag_id} failed with failed state {state}")
if state in self.allowed_states:
self.log.info("%s finished with allowed state %s", self.trigger_dag_id, state)
return

raise AirflowException(
f"{self.trigger_dag_id} return {state} which is not in {self.failed_states}"
f" or {self.allowed_states}"
)
raise AirflowException(
f"{self.trigger_dag_id} return {state} which is not in {self.failed_states}"
f" or {self.allowed_states}"
)
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@
from airflow.providers.standard.utils.sensor_helper import _get_count, _get_external_task_group_task_ids
from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS
from airflow.utils.file import correct_maybe_zipped
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import State, TaskInstanceState

if AIRFLOW_V_3_0_PLUS:
from airflow.sdk.bases.sensor import BaseSensorOperator
else:
from airflow.sensors.base import BaseSensorOperator
from airflow.utils.session import NEW_SESSION, provide_session

if TYPE_CHECKING:
from sqlalchemy.orm import Session
Expand Down Expand Up @@ -400,21 +400,23 @@ def _handle_skipped_states(self, count_skipped: float | int) -> None:
"Skipping."
)

@provide_session
def _poke_af2(self, dttm_filter: list[datetime.datetime], session: Session = NEW_SESSION) -> bool:
if self.check_existence and not self._has_checked_existence:
self._check_for_existence(session=session)
if not AIRFLOW_V_3_0_PLUS:

if self.failed_states:
count_failed = self.get_count(dttm_filter, session, self.failed_states)
self._handle_failed_states(count_failed)
@provide_session
def _poke_af2(self, dttm_filter: list[datetime.datetime], session: Session = NEW_SESSION) -> bool:
if self.check_existence and not self._has_checked_existence:
self._check_for_existence(session=session)

if self.skipped_states:
count_skipped = self.get_count(dttm_filter, session, self.skipped_states)
self._handle_skipped_states(count_skipped)
if self.failed_states:
count_failed = self.get_count(dttm_filter, session, self.failed_states)
self._handle_failed_states(count_failed)

count_allowed = self.get_count(dttm_filter, session, self.allowed_states)
return count_allowed == len(dttm_filter)
if self.skipped_states:
count_skipped = self.get_count(dttm_filter, session, self.skipped_states)
self._handle_skipped_states(count_skipped)

count_allowed = self.get_count(dttm_filter, session, self.allowed_states)
return count_allowed == len(dttm_filter)

def execute(self, context: Context) -> None:
"""Run on the worker and defer using the triggers if deferrable is set to True."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from airflow.providers.standard.utils.sensor_helper import _get_count
from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS
from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.utils.session import NEW_SESSION, provide_session

if typing.TYPE_CHECKING:
from datetime import datetime
Expand Down Expand Up @@ -266,22 +265,25 @@ async def validate_count_dags_af_3(self, runs_ids_or_dates_len: int = 0) -> tupl
return cls_path, data
await asyncio.sleep(self.poll_interval)

@sync_to_async
@provide_session
def count_dags(self, *, session: Session = NEW_SESSION) -> int:
"""Count how many dag runs in the database match our criteria."""
_dag_run_date_condition = (
DagRun.run_id.in_(self.run_ids)
if AIRFLOW_V_3_0_PLUS
else DagRun.execution_date.in_(self.execution_dates)
)
count = (
session.query(func.count("*")) # .count() is inefficient
.filter(
DagRun.dag_id == self.dag_id,
DagRun.state.in_(self.states),
_dag_run_date_condition,
if not AIRFLOW_V_3_0_PLUS:
from airflow.utils.session import NEW_SESSION, provide_session # type: ignore[misc]

@sync_to_async
@provide_session
def count_dags(self, *, session: Session = NEW_SESSION) -> int:
"""Count how many dag runs in the database match our criteria."""
_dag_run_date_condition = (
DagRun.run_id.in_(self.run_ids)
if AIRFLOW_V_3_0_PLUS
else DagRun.execution_date.in_(self.execution_dates)
)
.scalar()
)
return typing.cast("int", count)
count = (
session.query(func.count("*")) # .count() is inefficient
.filter(
DagRun.dag_id == self.dag_id,
DagRun.state.in_(self.states),
_dag_run_date_condition,
)
.scalar()
)
return typing.cast("int", count)