Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,7 @@ class DagRunInfo(InfoJsonEncodable):
"data_interval_start",
"data_interval_end",
"external_trigger", # Removed in Airflow 3, use run_type instead
"execution_date", # Airflow 2
"logical_date", # Airflow 3
"run_after", # Airflow 3
"run_id",
Expand Down Expand Up @@ -802,14 +803,20 @@ class TaskInfo(InfoJsonEncodable):
"run_as_user",
"sla",
"task_id",
"trigger_dag_id",
"trigger_run_id",
"external_dag_id",
"external_task_id",
"trigger_rule",
"upstream_task_ids",
"wait_for_downstream",
"wait_for_past_depends_before_skipping",
# Operator-specific useful attributes
"trigger_dag_id", # TriggerDagRunOperator
"trigger_run_id", # TriggerDagRunOperator
"external_dag_id", # ExternalTaskSensor and ExternalTaskMarker (if run, as it's EmptyOperator)
"external_task_id", # ExternalTaskSensor and ExternalTaskMarker (if run, as it's EmptyOperator)
"external_task_ids", # ExternalTaskSensor
"external_task_group_id", # ExternalTaskSensor
"external_dates_filter", # ExternalTaskSensor
"logical_date", # AF 3 ExternalTaskMarker (if run, as it's EmptyOperator)
"execution_date", # AF 2 ExternalTaskMarker (if run, as it's EmptyOperator)
]
casts = {
"operator_class": lambda task: task.task_type,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def test_get_airflow_dag_run_facet():
dagrun_mock.external_trigger = True
dagrun_mock.run_id = "manual_2024-06-01T00:00:00+00:00"
dagrun_mock.run_type = DagRunType.MANUAL
dagrun_mock.execution_date = datetime.datetime(2024, 6, 1, 1, 2, 4, tzinfo=datetime.timezone.utc)
dagrun_mock.logical_date = datetime.datetime(2024, 6, 1, 1, 2, 4, tzinfo=datetime.timezone.utc)
dagrun_mock.run_after = datetime.datetime(2024, 6, 1, 1, 2, 4, tzinfo=datetime.timezone.utc)
dagrun_mock.start_date = datetime.datetime(2024, 6, 1, 1, 2, 4, tzinfo=datetime.timezone.utc)
Expand Down Expand Up @@ -205,6 +206,7 @@ def test_get_airflow_dag_run_facet():
"start_date": "2024-06-01T01:02:04+00:00",
"end_date": "2024-06-01T01:02:14.034172+00:00",
"duration": 10.034172,
"execution_date": "2024-06-01T01:02:04+00:00",
"logical_date": "2024-06-01T01:02:04+00:00",
"run_after": "2024-06-01T01:02:04+00:00",
"dag_bundle_name": "bundle_name",
Expand Down Expand Up @@ -2345,6 +2347,7 @@ def test_dagrun_info_af2():
"run_type": DagRunType.MANUAL,
"external_trigger": False,
"start_date": "2024-06-01T00:00:00+00:00",
"execution_date": "2024-06-01T00:00:00+00:00",
"logical_date": "2024-06-01T00:00:00+00:00",
"dag_bundle_name": None,
"dag_bundle_version": None,
Expand Down Expand Up @@ -2424,10 +2427,17 @@ def test_taskinstance_info_af2():
def test_task_info_af3():
class CustomOperator(PythonOperator):
def __init__(self, *args, **kwargs):
# Mock some specific attributes from different operators
self.deferrable = True
self.trigger_dag_id = "trigger_dag_id"
self.trigger_run_id = "trigger_run_id"
self.external_dag_id = "external_dag_id"
self.external_task_id = "external_task_id"
self.external_task_ids = "external_task_ids"
self.external_task_group_id = "external_task_group_id"
self.external_dates_filter = "external_dates_filter"
self.logical_date = "logical_date"
self.execution_date = "execution_date"
super().__init__(*args, **kwargs)

with DAG(
Expand Down Expand Up @@ -2464,12 +2474,17 @@ def __init__(self, *args, **kwargs):
"deferrable": True,
"depends_on_past": False,
"downstream_task_ids": "['task_1']",
"execution_date": "execution_date",
"execution_timeout": None,
"executor_config": {},
"external_dag_id": "external_dag_id",
"external_dates_filter": "external_dates_filter",
"external_task_id": "external_task_id",
"external_task_ids": "external_task_ids",
"external_task_group_id": "external_task_group_id",
"ignore_first_depends_on_past": False,
"inlets": "[{'uri': 'uri1', 'extra': {'a': 1}}]",
"logical_date": "logical_date",
"mapped": False,
"max_active_tis_per_dag": None,
"max_active_tis_per_dagrun": None,
Expand All @@ -2488,6 +2503,7 @@ def __init__(self, *args, **kwargs):
"task_group": tg_info,
"task_id": "section_1.task_3",
"trigger_dag_id": "trigger_dag_id",
"trigger_run_id": "trigger_run_id",
"trigger_rule": "all_success",
"upstream_task_ids": "['task_0']",
"wait_for_downstream": False,
Expand All @@ -2499,10 +2515,17 @@ def __init__(self, *args, **kwargs):
def test_task_info_af2():
class CustomOperator(PythonOperator):
def __init__(self, *args, **kwargs):
# Mock some specific attributes from different operators
self.deferrable = True
self.trigger_dag_id = "trigger_dag_id"
self.trigger_run_id = "trigger_run_id"
self.external_dag_id = "external_dag_id"
self.external_task_id = "external_task_id"
self.external_task_ids = "external_task_ids"
self.external_task_group_id = "external_task_group_id"
self.external_dates_filter = "external_dates_filter"
self.logical_date = "logical_date"
self.execution_date = "execution_date"
super().__init__(*args, **kwargs)

with DAG(
Expand Down Expand Up @@ -2539,15 +2562,20 @@ def __init__(self, *args, **kwargs):
"deferrable": True,
"depends_on_past": False,
"downstream_task_ids": "['task_1']",
"execution_date": "execution_date",
"execution_timeout": None,
"executor_config": {},
"external_dag_id": "external_dag_id",
"external_dates_filter": "external_dates_filter",
"external_task_id": "external_task_id",
"external_task_ids": "external_task_ids",
"external_task_group_id": "external_task_group_id",
"ignore_first_depends_on_past": True,
"is_setup": False,
"is_teardown": False,
"sla": None,
"inlets": "[{'uri': 'uri1', 'extra': {'a': 1}}]",
"logical_date": "logical_date",
"mapped": False,
"max_active_tis_per_dag": None,
"max_active_tis_per_dagrun": None,
Expand All @@ -2566,6 +2594,7 @@ def __init__(self, *args, **kwargs):
"task_group": tg_info,
"task_id": "section_1.task_3",
"trigger_dag_id": "trigger_dag_id",
"trigger_run_id": "trigger_run_id",
"trigger_rule": "all_success",
"upstream_task_ids": "['task_0']",
"wait_for_downstream": False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ def __init__(
self._has_checked_existence = False
self.deferrable = deferrable
self.poll_interval = poll_interval
self.external_dates_filter: str | None = None

def _get_dttm_filter(self, context: Context) -> Sequence[datetime.datetime]:
logical_date = self._get_logical_date(context)
Expand All @@ -261,13 +262,19 @@ def _get_dttm_filter(self, context: Context) -> Sequence[datetime.datetime]:
return result if isinstance(result, list) else [result]
return [logical_date]

@staticmethod
def _serialize_dttm_filter(dttm_filter: Sequence[datetime.datetime]) -> str:
return ",".join(dt.isoformat() for dt in dttm_filter)

def poke(self, context: Context) -> bool:
# delay check to poke rather than __init__ in case it was supplied as XComArgs
if self.external_task_ids and len(self.external_task_ids) > len(set(self.external_task_ids)):
raise ValueError("Duplicate task_ids passed in external_task_ids parameter")

dttm_filter = self._get_dttm_filter(context)
serialized_dttm_filter = ",".join(dt.isoformat() for dt in dttm_filter)
serialized_dttm_filter = self._serialize_dttm_filter(dttm_filter)
# Save as attribute - to be used by listeners
self.external_dates_filter = serialized_dttm_filter

if self.external_task_ids:
self.log.info(
Expand Down Expand Up @@ -456,6 +463,9 @@ def execute_complete(self, context: Context, event: dict[str, typing.Any] | None
if event is None:
raise ExternalTaskNotFoundError("No event received from trigger")

# Re-set as attribute after coming back from deferral - to be used by listeners
self.external_dates_filter = self._serialize_dttm_filter(self._get_dttm_filter(context))

if event["status"] == "success":
self.log.info("External tasks %s has executed successfully.", self.external_task_ids)
elif event["status"] == "skipped":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -901,7 +901,7 @@ def test_external_task_group_when_there_is_no_TIs(self):
def test_fail_poke(
self, _get_dttm_filter, get_count, soft_fail, expected_exception, kwargs, expected_message
):
_get_dttm_filter.return_value = []
_get_dttm_filter.return_value = [DEFAULT_DATE]
get_count.return_value = 1
op = ExternalTaskSensor(
task_id="test_external_task_duplicate_task_ids",
Expand All @@ -912,6 +912,7 @@ def test_fail_poke(
deferrable=False,
**kwargs,
)
assert op.external_dates_filter is None

# We need to handle the specific exception types based on kwargs
if not soft_fail:
Expand All @@ -931,6 +932,8 @@ def test_fail_poke(
with pytest.raises(expected_exception, match=expected_message):
op.execute(context={})

assert op.external_dates_filter == DEFAULT_DATE.isoformat()

@pytest.mark.parametrize(
("response_get_current", "response_exists", "kwargs", "expected_message"),
(
Expand Down Expand Up @@ -1105,6 +1108,7 @@ def test_external_task_sensor_success(self, dag_maker):
states=["success"],
task_ids=["test_external_task_sensor_success"],
)
assert op.external_dates_filter == DEFAULT_DATE.isoformat()

@pytest.mark.execution_timeout(10)
def test_external_task_sensor_failure(self, dag_maker):
Expand All @@ -1128,6 +1132,7 @@ def test_external_task_sensor_failure(self, dag_maker):
states=[State.FAILED],
task_ids=["test_external_task_sensor_failure"],
)
assert op.external_dates_filter == DEFAULT_DATE.isoformat()

@pytest.mark.execution_timeout(10)
def test_external_task_sensor_soft_fail(self, dag_maker):
Expand All @@ -1152,6 +1157,7 @@ def test_external_task_sensor_soft_fail(self, dag_maker):
states=[State.FAILED],
task_ids=["test_external_task_sensor_soft_fail"],
)
assert op.external_dates_filter == DEFAULT_DATE.isoformat()

@pytest.mark.execution_timeout(10)
def test_external_task_sensor_multiple_task_ids(self, dag_maker):
Expand All @@ -1172,6 +1178,7 @@ def test_external_task_sensor_multiple_task_ids(self, dag_maker):
states=["success"],
task_ids=["task1", "task2"],
)
assert op.external_dates_filter == DEFAULT_DATE.isoformat()

@pytest.mark.execution_timeout(10)
def test_external_task_sensor_skipped_states(self, dag_maker):
Expand All @@ -1193,6 +1200,7 @@ def test_external_task_sensor_skipped_states(self, dag_maker):
states=[State.SKIPPED],
task_ids=["test_external_task_sensor_skipped_states"],
)
assert op.external_dates_filter == DEFAULT_DATE.isoformat()

def test_external_task_sensor_invalid_combination(self, dag_maker):
"""Test that the sensor raises an error with invalid parameter combinations."""
Expand Down Expand Up @@ -1239,6 +1247,7 @@ def test_external_task_sensor_task_group(self, dag_maker):
logical_dates=[DEFAULT_DATE],
task_group_id="test_group",
)
assert op.external_dates_filter == DEFAULT_DATE.isoformat()

@pytest.mark.execution_timeout(10)
def test_external_task_sensor_execution_date_fn(self, dag_maker):
Expand All @@ -1264,6 +1273,7 @@ def execution_date_fn(dt):
states=["success"],
task_ids=["test_task"],
)
assert op.external_dates_filter == expected_date.isoformat()

@pytest.mark.execution_timeout(10)
def test_external_task_sensor_execution_delta(self, dag_maker):
Expand All @@ -1286,6 +1296,7 @@ def test_external_task_sensor_execution_delta(self, dag_maker):
states=["success"],
task_ids=["test_task"],
)
assert op.external_dates_filter == expected_date.isoformat()

@pytest.mark.execution_timeout(10)
def test_external_task_sensor_duplicate_task_ids(self, dag_maker):
Expand Down Expand Up @@ -1338,6 +1349,7 @@ def test_external_task_sensor_only_dag_id(self, dag_maker):
logical_dates=[DEFAULT_DATE],
states=["success"],
)
assert op.external_dates_filter == DEFAULT_DATE.isoformat()

@pytest.mark.execution_timeout(10)
def test_external_task_sensor_task_group_failed_states(self, dag_maker):
Expand All @@ -1359,6 +1371,7 @@ def test_external_task_sensor_task_group_failed_states(self, dag_maker):
logical_dates=[DEFAULT_DATE],
task_group_id="test_group",
)
assert op.external_dates_filter == DEFAULT_DATE.isoformat()

def test_get_logical_date(self):
"""For AF 3, we check for logical date or dag_run.run_after in context."""
Expand Down Expand Up @@ -1416,7 +1429,7 @@ def test_defer_and_fire_task_state_trigger(self):
deferrable=True,
)

context = {"execution_date": datetime(2025, 1, 1), "logical_date": datetime(2025, 1, 1)}
context = {"execution_date": DEFAULT_DATE, "logical_date": DEFAULT_DATE}
with pytest.raises(TaskDeferred) as exc:
sensor.execute(context=context)

Expand All @@ -1431,9 +1444,10 @@ def test_defer_and_fire_failed_state_trigger(self):
deferrable=True,
)

context = {"execution_date": DEFAULT_DATE, "logical_date": DEFAULT_DATE}
with pytest.raises(ExternalTaskNotFoundError):
sensor.execute_complete(
context=mock.MagicMock(), event={"status": "error", "message": "test failure message"}
context=context, event={"status": "error", "message": "test failure message"}
)

def test_defer_and_fire_timeout_state_trigger(self):
Expand All @@ -1445,9 +1459,10 @@ def test_defer_and_fire_timeout_state_trigger(self):
deferrable=True,
)

context = {"execution_date": DEFAULT_DATE, "logical_date": DEFAULT_DATE}
with pytest.raises(ExternalTaskNotFoundError):
sensor.execute_complete(
context=mock.MagicMock(),
context=context,
event={"status": "timeout", "message": "Dag was not started within 1 minute, assuming fail."},
)

Expand All @@ -1460,9 +1475,10 @@ def test_defer_execute_check_correct_logging(self):
deferrable=True,
)

context = {"execution_date": DEFAULT_DATE, "logical_date": DEFAULT_DATE}
with mock.patch.object(sensor.log, "info") as mock_log_info:
sensor.execute_complete(
context=mock.MagicMock(),
context=context,
event={"status": "success"},
)
mock_log_info.assert_called_with("External tasks %s has executed successfully.", [EXTERNAL_TASK_ID])
Expand All @@ -1476,9 +1492,10 @@ def test_defer_execute_check_failed_status(self):
deferrable=True,
)

context = {"execution_date": DEFAULT_DATE, "logical_date": DEFAULT_DATE}
with pytest.raises(ExternalDagFailedError, match="External job has failed."):
sensor.execute_complete(
context=mock.MagicMock(),
context=context,
event={"status": "failed"},
)

Expand All @@ -1492,9 +1509,10 @@ def test_defer_execute_check_failed_status_soft_fail(self):
soft_fail=True,
)

context = {"execution_date": DEFAULT_DATE, "logical_date": DEFAULT_DATE}
with pytest.raises(AirflowSkipException, match="External job has failed skipping."):
sensor.execute_complete(
context=mock.MagicMock(),
context=context,
event={"status": "failed"},
)

Expand All @@ -1517,6 +1535,20 @@ def test_defer_with_failed_states(self):
assert isinstance(trigger, WorkflowTrigger), "Trigger is not a WorkflowTrigger"
assert trigger.failed_states == failed_states, "failed_states not properly passed to WorkflowTrigger"

def test_defer_execute_complete_re_sets_external_dates_filter_attr(self):
sensor = ExternalTaskSensor(
task_id=TASK_ID,
external_task_id=EXTERNAL_TASK_ID,
external_dag_id=EXTERNAL_DAG_ID,
deferrable=True,
)
assert sensor.external_dates_filter is None

context = {"execution_date": DEFAULT_DATE, "logical_date": DEFAULT_DATE}
sensor.execute_complete(context=context, event={"status": "success"})

assert sensor.external_dates_filter == DEFAULT_DATE.isoformat()


@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Needs Flask app context fixture for AF 2")
@pytest.mark.parametrize(
Expand Down
Loading