Skip to content
41 changes: 37 additions & 4 deletions airflow-core/src/airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,12 +230,26 @@ def __str__(self) -> str:


def _encode_trigger(trigger: BaseEventTrigger | dict):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes in this PR potentially caused regression: #51809

Mis-leading PR title

def _ensure_serialized(d):
"""
Make sure the kwargs dict is JSON-serializable.

This is done with BaseSerialization logic. A simple check is added to
ensure we don't double-serialize, which is possible when a trigger goes
through multiple serialization layers.
"""
if isinstance(d, dict) and Encoding.TYPE in d:
return d
return BaseSerialization.serialize(d)

if isinstance(trigger, dict):
return trigger
classpath, kwargs = trigger.serialize()
classpath = trigger["classpath"]
kwargs = trigger["kwargs"]
else:
classpath, kwargs = trigger.serialize()
return {
"classpath": classpath,
"kwargs": kwargs,
"kwargs": {k: _ensure_serialized(v) for k, v in kwargs.items()},
}


Expand Down Expand Up @@ -303,14 +317,33 @@ def decode_asset_condition(var: dict[str, Any]) -> BaseAsset:


def decode_asset(var: dict[str, Any]):
def _smart_decode_trigger_kwargs(d):
"""
Slightly clean up kwargs for display.

This detects one level of BaseSerialization and tries to deserialize the
content, removing some __type __var ugliness when the value is displayed
in UI to the user.
"""
if not isinstance(d, dict) or Encoding.TYPE not in d:
return d
return BaseSerialization.deserialize(d)

watchers = var.get("watchers", [])
return Asset(
name=var["name"],
uri=var["uri"],
group=var["group"],
extra=var["extra"],
watchers=[
SerializedAssetWatcher(name=watcher["name"], trigger=watcher["trigger"]) for watcher in watchers
SerializedAssetWatcher(
name=watcher["name"],
trigger={
"classpath": watcher["trigger"]["classpath"],
"kwargs": _smart_decode_trigger_kwargs(watcher["trigger"]["kwargs"]),
},
)
for watcher in watchers
],
)

Expand Down
15 changes: 8 additions & 7 deletions airflow-core/tests/unit/api_fastapi/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from airflow.api_fastapi.app import create_app
from airflow.api_fastapi.auth.managers.simple.user import SimpleAuthManagerUser
from airflow.models import Connection
from airflow.models.serialized_dag import SerializedDagModel
from airflow.providers.standard.operators.empty import EmptyOperator

from tests_common.test_utils.config import conf_vars
Expand Down Expand Up @@ -141,7 +140,7 @@ def configure_git_connection_for_dag_bundle(session):


@pytest.fixture
def make_dag_with_multiple_versions(dag_maker, configure_git_connection_for_dag_bundle):
def make_dag_with_multiple_versions(dag_maker, configure_git_connection_for_dag_bundle, session):
"""
Create DAG with multiple versions

Expand All @@ -151,17 +150,19 @@ def make_dag_with_multiple_versions(dag_maker, configure_git_connection_for_dag_
"""
dag_id = "dag_with_multiple_versions"
for version_number in range(1, 4):
with dag_maker(dag_id) as dag:
with dag_maker(
dag_id,
session=session,
bundle_version=f"some_commit_hash{version_number}",
):
for task_number in range(version_number):
EmptyOperator(task_id=f"task{task_number + 1}")
SerializedDagModel.write_dag(
dag, bundle_name="dag_maker", bundle_version=f"some_commit_hash{version_number}"
)
dag_maker.create_dagrun(
run_id=f"run{version_number}",
logical_date=datetime.datetime(2020, 1, version_number, tzinfo=datetime.timezone.utc),
session=session,
)
dag.sync_to_db()
session.commit()


@pytest.fixture(scope="module")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import pytest

from airflow.models.serialized_dag import SerializedDagModel
from airflow.providers.standard.operators.empty import EmptyOperator

from tests_common.test_utils.db import clear_db_dags, clear_db_serialized_dags
Expand All @@ -35,16 +34,11 @@ def setup(request, dag_maker, session):
clear_db_serialized_dags()

with dag_maker(
"ANOTHER_DAG_ID",
) as dag:
dag_id="ANOTHER_DAG_ID", bundle_version="some_commit_hash", bundle_name="another_bundle_name"
):
EmptyOperator(task_id="task_1")
EmptyOperator(task_id="task_2")

dag.sync_to_db()
SerializedDagModel.write_dag(
dag, bundle_name="another_bundle_name", bundle_version="some_commit_hash"
)


class TestGetDagVersion(TestDagVersionEndpoint):
@pytest.mark.parametrize(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import operator
from datetime import datetime
from unittest import mock
from uuid import uuid4

import pytest
import uuid6
Expand All @@ -37,7 +38,13 @@
from airflow.utils import timezone
from airflow.utils.state import State, TaskInstanceState, TerminalTIState

from tests_common.test_utils.db import clear_db_assets, clear_db_runs, clear_rendered_ti_fields
from tests_common.test_utils.db import (
clear_db_assets,
clear_db_dags,
clear_db_runs,
clear_db_serialized_dags,
clear_rendered_ti_fields,
)

pytestmark = pytest.mark.db_test

Expand Down Expand Up @@ -114,9 +121,13 @@ def side_effect(cred, validators):
class TestTIRunState:
def setup_method(self):
clear_db_runs()
clear_db_serialized_dags()
clear_db_dags()

def teardown_method(self):
clear_db_runs()
clear_db_serialized_dags()
clear_db_dags()

@pytest.mark.parametrize(
"max_tries, should_retry",
Expand Down Expand Up @@ -147,6 +158,7 @@ def test_ti_run_state_to_running(
state=State.QUEUED,
session=session,
start_date=instant,
dag_id=str(uuid4()),
)
ti.max_tries = max_tries
session.commit()
Expand All @@ -165,7 +177,7 @@ def test_ti_run_state_to_running(
assert response.status_code == 200
assert response.json() == {
"dag_run": {
"dag_id": "dag",
"dag_id": ti.dag_id,
"run_id": "test",
"clear_number": 0,
"logical_date": instant_str,
Expand All @@ -179,7 +191,7 @@ def test_ti_run_state_to_running(
"consumed_asset_events": [],
},
"task_reschedule_count": 0,
"upstream_map_indexes": None,
"upstream_map_indexes": {},
"max_tries": max_tries,
"should_retry": should_retry,
"variables": [],
Expand Down Expand Up @@ -235,6 +247,7 @@ def test_next_kwargs_still_encoded(self, client, session, create_task_instance,
state=State.QUEUED,
session=session,
start_date=instant,
dag_id=str(uuid4()),
)

ti.next_method = "execute_complete"
Expand All @@ -258,7 +271,7 @@ def test_next_kwargs_still_encoded(self, client, session, create_task_instance,
assert response.json() == {
"dag_run": mock.ANY,
"task_reschedule_count": 0,
"upstream_map_indexes": None,
"upstream_map_indexes": {},
"max_tries": 0,
"should_retry": False,
"variables": [],
Expand All @@ -282,6 +295,7 @@ def test_next_kwargs_determines_start_date_update(self, client, session, create_
state=State.QUEUED,
session=session,
start_date=orig_task_start_time,
dag_id=str(uuid4()),
)

ti.start_date = orig_task_start_time
Expand Down Expand Up @@ -320,7 +334,7 @@ def test_next_kwargs_determines_start_date_update(self, client, session, create_
assert response.json() == {
"dag_run": mock.ANY,
"task_reschedule_count": 0,
"upstream_map_indexes": None,
"upstream_map_indexes": {},
"max_tries": 0,
"should_retry": False,
"variables": [],
Expand Down Expand Up @@ -385,6 +399,7 @@ def test_xcom_not_cleared_for_deferral(self, client, session, create_task_instan
state=State.RUNNING,
session=session,
start_date=instant,
dag_id=str(uuid4()),
)
session.commit()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,6 @@ def _get_dag_version_bundle_names():
# Create DAG version with 'my-test-bundle'
with dag_maker(dag_id="test_dag", schedule=None):
EmptyOperator(task_id="mytask")
with create_session() as session:
session.add(DagVersion(dag_id="test_dag", version_number=1, bundle_name="my-test-bundle"))

# simulate bundle config change (now 'dags-folder' is active, 'my-test-bundle' becomes inactive)
manager = DagBundlesManager()
Expand Down
10 changes: 9 additions & 1 deletion airflow-core/tests/unit/jobs/test_scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -4115,7 +4115,7 @@ def test_extra_operator_links_not_loaded_in_scheduler_loop(self, dag_maker):
# Test that custom_task has no Operator Links (after de-serialization) in the Scheduling Loop
assert not custom_task.operator_extra_links

def test_scheduler_create_dag_runs_does_not_raise_error(self, caplog, dag_maker):
def test_scheduler_create_dag_runs_does_not_raise_error_when_no_serdag(self, caplog, dag_maker):
"""
Test that scheduler._create_dag_runs does not raise an error when the DAG does not exist
in serialized_dag table
Expand All @@ -4137,11 +4137,19 @@ def test_scheduler_create_dag_runs_does_not_raise_error(self, caplog, dag_maker)
logger="airflow.jobs.scheduler_job_runner",
),
):
self._clear_serdags(dag_id=dag_maker.dag.dag_id, session=session)
self.job_runner._create_dag_runs([dag_maker.dag_model], session)
assert caplog.messages == [
"DAG 'test_scheduler_create_dag_runs_does_not_raise_error' not found in serialized_dag table",
]

def _clear_serdags(self, dag_id, session):
SDM = SerializedDagModel
sdms = session.scalars(select(SDM).where(SDM.dag_id == dag_id))
for sdm in sdms:
session.delete(sdm)
session.commit()

def test_bulk_write_to_db_external_trigger_dont_skip_scheduled_run(self, dag_maker, testing_dag_bundle):
"""
Test that externally triggered Dag Runs should not affect (by skipping) next
Expand Down
17 changes: 8 additions & 9 deletions airflow-core/tests/unit/models/test_cleartasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,14 +700,18 @@ def _get_ti(old_ti):
ti1, ti2 = sorted(dr.get_task_instances(session=session), key=lambda ti: ti.task_id)
ti1.task = op1
ti2.task = op2

session.get(TaskInstance, ti2.id).try_number += 1
ti2.refresh_from_db(session=session)
ti2.try_number += 1
session.commit()
ti2.run(session=session)

# Dependency not met
assert ti2.try_number == 1
assert ti2.max_tries == 1

ti1.refresh_from_db(session=session)
assert ti1.max_tries == 0
assert ti1.try_number == 0

op2.clear(upstream=True, session=session)
ti1.refresh_from_db(session)
ti2.refresh_from_db(session)
Expand All @@ -716,14 +720,9 @@ def _get_ti(old_ti):
# max tries will be set to retries + curr try number == 1 + 1 == 2
assert ti2.max_tries == 2

ti1.try_number += 1
session.merge(ti1)
session.commit()

ti1.run(session=session)
ti1.refresh_from_db(session)
ti2.refresh_from_db(session)
assert ti1.try_number == 1
assert ti1.try_number == 0

ti2 = _get_ti(ti2)
ti2.try_number += 1
Expand Down
22 changes: 20 additions & 2 deletions airflow-core/tests/unit/models/test_dagbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

import pytest
import time_machine
from sqlalchemy import select

import airflow.example_dags
from airflow import settings
Expand Down Expand Up @@ -480,7 +481,7 @@ def process_file(self, filepath, only_if_updated=True, safe_mode=True):
assert dag_id == dag.dag_id
assert dagbag.process_file_calls == 2

def test_dag_removed_if_serialized_dag_is_removed(self, dag_maker, tmp_path):
def test_dag_removed_if_serialized_dag_is_removed(self, dag_maker, tmp_path, session):
"""
Test that if a DAG does not exist in serialized_dag table (as the DAG file was removed),
remove dags from the DagBag
Expand All @@ -493,14 +494,31 @@ def test_dag_removed_if_serialized_dag_is_removed(self, dag_maker, tmp_path):
start_date=tz.datetime(2021, 10, 12),
) as dag:
EmptyOperator(task_id="task_1")
dag_maker.create_dagrun()

dagbag = DagBag(dag_folder=os.fspath(tmp_path), include_examples=False, read_dags_from_db=True)
dagbag.dags = {dag.dag_id: SerializedDAG.from_dict(SerializedDAG.to_dict(dag))}
dagbag.dags_last_fetched = {dag.dag_id: (tz.utcnow() - timedelta(minutes=2))}
dagbag.dags_hash = {dag.dag_id: mock.ANY}

# observe we have serdag and dag is in dagbag
assert SerializedDagModel.has_dag(dag.dag_id) is True
assert dagbag.get_dag(dag.dag_id) is not None

# now delete serdags for this dag
SDM = SerializedDagModel
sdms = session.scalars(select(SDM).where(SDM.dag_id == dag.dag_id))
for sdm in sdms:
session.delete(sdm)
session.commit()

# first, confirm that serdags are gone for this dag
assert SerializedDagModel.has_dag(dag.dag_id) is False

# now see the dag is still in dagbag
assert dagbag.get_dag(dag.dag_id) is not None

# but, let's recreate the dagbag and see if the dag will be there
dagbag = DagBag(dag_folder=os.fspath(tmp_path), include_examples=False, read_dags_from_db=True)
assert dagbag.get_dag(dag.dag_id) is None
assert dag.dag_id not in dagbag.dags
assert dag.dag_id not in dagbag.dags_last_fetched
Expand Down
Loading
Loading