Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
f55818f
TriggerDagRunOperator can now will have default null as logical date
vatsrahul1001 Feb 10, 2025
ec5ec27
Merge branch 'main' of github.com:astronomer/airflow into TriggerDagR…
vatsrahul1001 Feb 11, 2025
03b26e4
Merge branch 'main' of github.com:astronomer/airflow into TriggerDagR…
vatsrahul1001 Feb 11, 2025
532e7fc
Merge branch 'main' into TriggerDagRunOperator-logical-date-default-v…
vatsrahul1001 Feb 11, 2025
2155c82
Merge branch 'main' of github.com:astronomer/airflow into TriggerDagR…
vatsrahul1001 Feb 12, 2025
be75d89
refactor deferrable code and also fix TriggerDagRunOperator tests
vatsrahul1001 Feb 12, 2025
0ee6540
Merge branch 'TriggerDagRunOperator-logical-date-default-value-null' …
vatsrahul1001 Feb 12, 2025
4f0d4cc
fix test_external_task.py
vatsrahul1001 Feb 12, 2025
b82d9e6
Merge branch 'main' into TriggerDagRunOperator-logical-date-default-v…
vatsrahul1001 Feb 12, 2025
985b06a
Merge branch 'main' of github.com:astronomer/airflow into TriggerDagR…
vatsrahul1001 Feb 12, 2025
cff5b99
Merge branch 'TriggerDagRunOperator-logical-date-default-value-null' …
vatsrahul1001 Feb 12, 2025
0118ca3
Merge branch 'main' into TriggerDagRunOperator-logical-date-default-v…
vatsrahul1001 Feb 12, 2025
754cece
Merge branch 'main' into TriggerDagRunOperator-logical-date-default-v…
vatsrahul1001 Feb 12, 2025
4b2f87c
fix external_task tests + run_id changes in WorkflowTrigger
vatsrahul1001 Feb 13, 2025
49452e1
Merge branch 'TriggerDagRunOperator-logical-date-default-value-null' …
vatsrahul1001 Feb 13, 2025
1278717
Merge branch 'main' into TriggerDagRunOperator-logical-date-default-v…
vatsrahul1001 Feb 13, 2025
fff668d
fix test_external_tasks
vatsrahul1001 Feb 13, 2025
445eb91
Merge branch 'TriggerDagRunOperator-logical-date-default-value-null' …
vatsrahul1001 Feb 13, 2025
3ddffc1
Merge branch 'main' of github.com:astronomer/airflow into TriggerDagR…
vatsrahul1001 Feb 15, 2025
5cb3fd1
implementent review comments + fix 2.9 tests
vatsrahul1001 Feb 15, 2025
5374ced
Merge branch 'main' into TriggerDagRunOperator-logical-date-default-v…
vatsrahul1001 Feb 15, 2025
bade5b5
fix test_external_tasks
vatsrahul1001 Feb 16, 2025
32990e7
Merge branch 'TriggerDagRunOperator-logical-date-default-value-null' …
vatsrahul1001 Feb 16, 2025
2359b55
fix static checks
vatsrahul1001 Feb 16, 2025
b9f6300
fix test failure for mysql and sqlite
vatsrahul1001 Feb 16, 2025
bc21118
Merge branch 'main' of github.com:astronomer/airflow into TriggerDagR…
vatsrahul1001 Feb 16, 2025
45258b0
update dag_run_id in tests
vatsrahul1001 Feb 17, 2025
b9b4891
Merge branch 'main' into TriggerDagRunOperator-logical-date-default-v…
vatsrahul1001 Feb 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion airflow/api/common/trigger_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def _trigger_dag(
dag_version = DagVersion.get_latest_version(dag.dag_id, session=session)
dag_run = dag.create_dagrun(
run_id=run_id,
logical_date=logical_date,
logical_date=coerced_logical_date,
data_interval=data_interval,
run_after=run_after,
conf=run_conf,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,10 @@ def __init__(
self.logical_date = logical_date

def execute(self, context: Context):
if isinstance(self.logical_date, datetime.datetime):
if self.logical_date is None or isinstance(self.logical_date, datetime.datetime):
parsed_logical_date = self.logical_date
elif isinstance(self.logical_date, str):
parsed_logical_date = timezone.parse(self.logical_date)
else:
parsed_logical_date = timezone.utcnow()
parsed_logical_date = timezone.parse(self.logical_date)

try:
json.dumps(self.conf)
Expand All @@ -195,7 +193,7 @@ def execute(self, context: Context):
run_id = DagRun.generate_run_id(
run_type=DagRunType.MANUAL,
logical_date=parsed_logical_date,
run_after=parsed_logical_date,
run_after=parsed_logical_date or timezone.utcnow(),
)

try:
Expand All @@ -211,7 +209,7 @@ def execute(self, context: Context):
except DagRunAlreadyExists as e:
if self.reset_dag_run:
dag_run = e.dag_run
self.log.info("Clearing %s on %s", self.trigger_dag_id, dag_run.logical_date)
self.log.info("Clearing %s on %s", self.trigger_dag_id, dag_run.run_id)

# Get target dag object and call clear()
dag_model = DagModel.get_current(self.trigger_dag_id)
Expand All @@ -221,7 +219,7 @@ def execute(self, context: Context):
# Note: here execution fails on database isolation mode. Needs structural changes for AIP-72
dag_bag = DagBag(dag_folder=dag_model.fileloc, read_dags_from_db=True)
dag = dag_bag.get_dag(self.trigger_dag_id)
dag.clear(start_date=dag_run.logical_date, end_date=dag_run.logical_date)
dag.clear(run_id=dag_run.run_id)
else:
if self.skip_when_already_exists:
raise AirflowSkipException(
Expand All @@ -242,7 +240,7 @@ def execute(self, context: Context):
trigger=DagStateTrigger(
dag_id=self.trigger_dag_id,
states=self.allowed_states + self.failed_states,
logical_dates=[dag_run.logical_date],
run_ids=[run_id],
poll_interval=self.poke_interval,
),
method_name="execute_complete",
Expand All @@ -252,7 +250,7 @@ def execute(self, context: Context):
self.log.info(
"Waiting for %s on %s to become allowed state %s ...",
self.trigger_dag_id,
dag_run.logical_date,
run_id,
self.allowed_states,
)
time.sleep(self.poke_interval)
Expand All @@ -268,18 +266,16 @@ def execute(self, context: Context):

@provide_session
def execute_complete(self, context: Context, session: Session, event: tuple[str, dict[str, Any]]):
# This logical_date is parsed from the return trigger event
provided_logical_date = event[1]["logical_dates"][0]
# This run_ids is parsed from the return trigger event
provided_run_id = event[1]["run_ids"][0]
try:
# Note: here execution fails on database isolation mode. Needs structural changes for AIP-72
dag_run = session.execute(
select(DagRun).where(
DagRun.dag_id == self.trigger_dag_id, DagRun.logical_date == provided_logical_date
)
select(DagRun).where(DagRun.dag_id == self.trigger_dag_id, DagRun.run_id == provided_run_id)
).scalar_one()
except NoResultFound:
raise AirflowException(
f"No DAG run found for DAG {self.trigger_dag_id} and logical date {self.logical_date}"
f"No DAG run found for DAG {self.trigger_dag_id} and run ID {provided_run_id}"
)

state = dag_run.state
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,21 +41,21 @@ class WorkflowTrigger(BaseTrigger):
"""
A trigger to monitor tasks, task group and dag execution in Apache Airflow.

:param external_dag_id: The ID of the external DAG.
:param logical_dates: A list of logical dates for the external DAG.
:param external_dag_id: The ID of the external dag.
:param run_ids: A list of run ids for the external dag.
:param external_task_ids: A collection of external task IDs to wait for.
:param external_task_group_id: The ID of the external task group to wait for.
:param failed_states: States considered as failed for external tasks.
:param skipped_states: States considered as skipped for external tasks.
:param allowed_states: States considered as successful for external tasks.
:param poke_interval: The interval (in seconds) for poking the external tasks.
:param soft_fail: If True, the trigger will not fail the entire DAG on external task failure.
:param soft_fail: If True, the trigger will not fail the entire dag on external task failure.
"""

def __init__(
self,
external_dag_id: str,
logical_dates: list[datetime] | None = None,
run_ids: list[str] | None = None,
execution_dates: list[datetime] | None = None,
external_task_ids: typing.Collection[str] | None = None,
external_task_group_id: str | None = None,
Expand All @@ -72,33 +72,30 @@ def __init__(
self.failed_states = failed_states
self.skipped_states = skipped_states
self.allowed_states = allowed_states
self.logical_dates = logical_dates
self.run_ids = run_ids
self.poke_interval = poke_interval
self.soft_fail = soft_fail
self.execution_dates = execution_dates
super().__init__(**kwargs)

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serialize the trigger param and module path."""
_dates = (
{"logical_dates": self.logical_dates}
if AIRFLOW_V_3_0_PLUS
else {"execution_dates": self.execution_dates}
)
return (
"airflow.providers.standard.triggers.external_task.WorkflowTrigger",
{
"external_dag_id": self.external_dag_id,
"external_task_ids": self.external_task_ids,
"external_task_group_id": self.external_task_group_id,
"failed_states": self.failed_states,
"skipped_states": self.skipped_states,
"allowed_states": self.allowed_states,
**_dates,
"poke_interval": self.poke_interval,
"soft_fail": self.soft_fail,
},
)
data: dict[str, typing.Any] = {
"external_dag_id": self.external_dag_id,
"external_task_ids": self.external_task_ids,
"external_task_group_id": self.external_task_group_id,
"failed_states": self.failed_states,
"skipped_states": self.skipped_states,
"allowed_states": self.allowed_states,
"poke_interval": self.poke_interval,
"soft_fail": self.soft_fail,
}
if AIRFLOW_V_3_0_PLUS:
data["run_ids"] = self.run_ids
else:
data["execution_dates"] = self.execution_dates

return "airflow.providers.standard.triggers.external_task.WorkflowTrigger", data

async def run(self) -> typing.AsyncIterator[TriggerEvent]:
"""Check periodically tasks, task group or dag status."""
Expand All @@ -117,7 +114,7 @@ async def run(self) -> typing.AsyncIterator[TriggerEvent]:
yield TriggerEvent({"status": "skipped"})
return
allowed_count = await self._get_count(self.allowed_states)
_dates = self.logical_dates if AIRFLOW_V_3_0_PLUS else self.execution_dates
_dates = self.run_ids if AIRFLOW_V_3_0_PLUS else self.execution_dates
if allowed_count == len(_dates): # type: ignore[arg-type]
yield TriggerEvent({"status": "success"})
return
Expand All @@ -133,7 +130,7 @@ def _get_count(self, states: typing.Iterable[str] | None) -> int:
:return The count of records.
"""
return _get_count(
dttm_filter=self.logical_dates if AIRFLOW_V_3_0_PLUS else self.execution_dates,
dttm_filter=self.run_ids if AIRFLOW_V_3_0_PLUS else self.execution_dates,
external_task_ids=self.external_task_ids,
external_task_group_id=self.external_task_group_id,
external_dag_id=self.external_dag_id,
Expand All @@ -143,11 +140,11 @@ def _get_count(self, states: typing.Iterable[str] | None) -> int:

class DagStateTrigger(BaseTrigger):
"""
Waits asynchronously for a DAG to complete for a specific logical date.
Waits asynchronously for a dag to complete for a specific run_id.

:param dag_id: The dag_id that contains the task you want to wait for
:param states: allowed states, default is ``['success']``
:param logical_dates: The logical date at which DAG run.
:param run_ids: The run_id of dag run.
:param poll_interval: The time interval in seconds to check the state.
The default value is 5.0 sec.
"""
Expand All @@ -156,40 +153,38 @@ def __init__(
self,
dag_id: str,
states: list[DagRunState],
logical_dates: list[datetime] | None = None,
run_ids: list[str] | None = None,
execution_dates: list[datetime] | None = None,
poll_interval: float = 5.0,
):
super().__init__()
self.dag_id = dag_id
self.states = states
self.logical_dates = logical_dates
self.run_ids = run_ids
self.execution_dates = execution_dates
self.poll_interval = poll_interval

def serialize(self) -> tuple[str, dict[str, typing.Any]]:
"""Serialize DagStateTrigger arguments and classpath."""
_dates = (
{"logical_dates": self.logical_dates}
if AIRFLOW_V_3_0_PLUS
else {"execution_dates": self.execution_dates}
)
return (
"airflow.providers.standard.triggers.external_task.DagStateTrigger",
{
"dag_id": self.dag_id,
"states": self.states,
**_dates,
"poll_interval": self.poll_interval,
},
)
data = {
"dag_id": self.dag_id,
"states": self.states,
"poll_interval": self.poll_interval,
}

if AIRFLOW_V_3_0_PLUS:
data["run_ids"] = self.run_ids
else:
data["execution_dates"] = self.execution_dates

return "airflow.providers.standard.triggers.external_task.DagStateTrigger", data

async def run(self) -> typing.AsyncIterator[TriggerEvent]:
"""Check periodically if the dag run exists, and has hit one of the states yet, or not."""
while True:
# mypy confuses typing here
num_dags = await self.count_dags() # type: ignore[call-arg]
_dates = self.logical_dates if AIRFLOW_V_3_0_PLUS else self.execution_dates
_dates = self.run_ids if AIRFLOW_V_3_0_PLUS else self.execution_dates
if num_dags == len(_dates): # type: ignore[arg-type]
yield TriggerEvent(self.serialize())
return
Expand All @@ -200,7 +195,7 @@ async def run(self) -> typing.AsyncIterator[TriggerEvent]:
def count_dags(self, *, session: Session = NEW_SESSION) -> int | None:
"""Count how many dag runs in the database match our criteria."""
_dag_run_date_condition = (
DagRun.logical_date.in_(self.logical_dates)
DagRun.run_id.in_(self.run_ids)
if AIRFLOW_V_3_0_PLUS
else DagRun.execution_date.in_(self.execution_dates)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS

_DATES = (
{"logical_dates": [timezone.datetime(2022, 1, 1)]}
{"run_ids": ["external_task_run_id"]}
if AIRFLOW_V_3_0_PLUS
else {"execution_dates": [timezone.datetime(2022, 1, 1)]}
)
key, value = next(iter(_DATES.items()))


class TestWorkflowTrigger:
Expand Down Expand Up @@ -67,8 +68,9 @@ async def test_task_workflow_trigger_success(self, mock_get_count):
assert trigger_task.done()
result = trigger_task.result()
assert result.payload == {"status": "success"}

mock_get_count.assert_called_once_with(
dttm_filter=[timezone.datetime(2022, 1, 1)],
dttm_filter=value,
external_task_ids=["external_task_op"],
external_task_group_id=None,
external_dag_id="external_task",
Expand Down Expand Up @@ -102,7 +104,7 @@ async def test_task_workflow_trigger_failed(self, mock_get_count):
assert isinstance(result, TriggerEvent)
assert result.payload == {"status": "failed"}
mock_get_count.assert_called_once_with(
dttm_filter=[timezone.datetime(2022, 1, 1)],
dttm_filter=value,
external_task_ids=["external_task_op"],
external_task_group_id=None,
external_dag_id="external_task",
Expand Down Expand Up @@ -133,7 +135,7 @@ async def test_task_workflow_trigger_fail_count_eq_0(self, mock_get_count):
assert isinstance(result, TriggerEvent)
assert result.payload == {"status": "success"}
mock_get_count.assert_called_once_with(
dttm_filter=[timezone.datetime(2022, 1, 1)],
dttm_filter=value,
external_task_ids=["external_task_op"],
external_task_group_id=None,
external_dag_id="external_task",
Expand Down Expand Up @@ -167,7 +169,7 @@ async def test_task_workflow_trigger_skipped(self, mock_get_count):
assert isinstance(result, TriggerEvent)
assert result.payload == {"status": "skipped"}
mock_get_count.assert_called_once_with(
dttm_filter=[timezone.datetime(2022, 1, 1)],
dttm_filter=value,
external_task_ids=["external_task_op"],
external_task_group_id=None,
external_dag_id="external_task",
Expand Down Expand Up @@ -243,17 +245,12 @@ async def test_dag_state_trigger(self, session):
reaches an allowed state (i.e. SUCCESS).
"""
dag = DAG(self.DAG_ID, schedule=None, start_date=timezone.datetime(2022, 1, 1))
logical_date_or_execution_date = (
{"logical_date": timezone.datetime(2022, 1, 1)}
run_id_or_execution_date = (
{"run_id": "external_task_run_id"}
if AIRFLOW_V_3_0_PLUS
else {"execution_date": timezone.datetime(2022, 1, 1)}
)
dag_run = DagRun(
dag_id=dag.dag_id,
run_type="manual",
**logical_date_or_execution_date,
run_id=self.RUN_ID,
else {"execution_date": timezone.datetime(2022, 1, 1), "run_id": "external_task_run_id"}
)
dag_run = DagRun(dag_id=dag.dag_id, run_type="manual", **run_id_or_execution_date)
session.add(dag_run)
session.commit()

Expand Down
Loading
Loading