Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion airflow-core/src/airflow/cli/commands/db_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from airflow import settings
from airflow.exceptions import AirflowException
from airflow.utils import cli as cli_utils, db
from airflow.utils.db import _REVISION_HEADS_MAP
from airflow.utils.db import _REVISION_HEADS_MAP, _SER_DAG_VERSIONS_MAP
from airflow.utils.db_cleanup import config_dict, drop_archived_tables, export_archived_records, run_cleanup
from airflow.utils.process_utils import execute_interactive
from airflow.utils.providers_configuration_loader import providers_configuration_loaded
Expand Down Expand Up @@ -85,6 +85,45 @@ def _get_version_revision(version: str, revision_heads_map: dict[str, str] | Non
return None


def _get_serializer_version_for_target_version(version: str) -> str | None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did #55975 yesterday @ephraimbuddy

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opps. Didn't see that. Closing

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sry

"""Return serialized DAG version for the highest Airflow version <= target."""
if not version:
return None
target = parse_version(version)

best_match: str | None = None
for v in _SER_DAG_VERSIONS_MAP.keys():
pv = parse_version(v)
if pv <= target and (best_match is None or pv > parse_version(best_match)):
best_match = v
return _SER_DAG_VERSIONS_MAP[best_match] if best_match is not None else None


def _reserialize_dags_after_downgrade(args):
# After a successful downgrade, if a target Airflow version implies an older
# Serialized DAG format, reserialize all DAGs down to that serializer version.
# This is skipped for show-sql-only runs, and when no serializer version mapping exists.
ser_version = None
if args.to_version:
ser_version = _get_serializer_version_for_target_version(args.to_version)
elif args.to_revision:
config = db._get_alembic_config()
resolved_version = None
for version, head in _REVISION_HEADS_MAP.items():
if head == args.to_revision or db._revision_greater(config, head, args.to_revision):
resolved_version = version
break
if resolved_version is not None:
ser_version = _get_serializer_version_for_target_version(resolved_version)

if ser_version is not None:
try:
db.reserialize_all_dags_to_serializer_version(int(ser_version))
except Exception:
# Do not fail CLI on reserialization errors; warn instead.
log.exception("Reserializing DAGs to serializer version %s failed.", ser_version)


def run_db_migrate_command(args, command, revision_heads_map: dict[str, str]):
"""
Run the db migrate command.
Expand Down Expand Up @@ -189,6 +228,7 @@ def run_db_downgrade_command(args, command, revision_heads_map: dict[str, str]):
command(to_revision=to_revision, from_revision=from_revision, show_sql_only=args.show_sql_only)
if not args.show_sql_only:
print("Downgrade complete")
_reserialize_dags_after_downgrade(args)
else:
raise SystemExit("Cancelled")

Expand Down
118 changes: 118 additions & 0 deletions airflow-core/src/airflow/utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,11 @@ class MappedClassProtocol(Protocol):
"3.2.0": "15d84ca19038",
}

_SER_DAG_VERSIONS_MAP: dict[str, str] = {
"3.0.0": "2",
"3.1.0": "3",
}


@contextlib.contextmanager
def timeout_with_traceback(seconds, message="Operation timed out"):
Expand Down Expand Up @@ -1218,6 +1223,119 @@ def downgrade(*, to_revision, from_revision=None, show_sql_only=False, session:
command.downgrade(config, revision=to_revision, sql=show_sql_only)


@provide_session
def reserialize_all_dags_to_serializer_version(
target_version: int, *, session: Session = NEW_SESSION
) -> None:
"""
Reserialize all Serialized DAG rows to the specified serializer version.

Only down-conversion to version 2 is currently supported (from version 3).

This function does not commit the transaction. Caller is responsible for committing.
"""
from airflow.models.serialized_dag import SerializedDagModel as SDM

if target_version not in (2, 3):
log.info("Unsupported target serializer version %s; skipping.", target_version)
return

# Fast-path SQL update for uncompressed rows: remove client_defaults and set __version from 3 to 2
if target_version == 2:
dialect = session.bind.dialect.name
try:
if dialect == "postgresql":
session.execute(
text(
"""
UPDATE serialized_dag
SET data = jsonb_set((data::jsonb - 'client_defaults'), '{__version}', to_jsonb(2))
WHERE data_compressed IS NULL AND ((data ->> '__version')::int = 3)
"""
)
)
elif dialect == "mysql":
session.execute(
text(
"""
UPDATE serialized_dag
SET data = JSON_SET(JSON_REMOVE(data, '$.client_defaults'), '$.__version', CAST(2 AS JSON))
WHERE data_compressed IS NULL
AND CAST(JSON_UNQUOTE(JSON_EXTRACT(data, '$.__version')) AS UNSIGNED) = 3
"""
)
)
elif dialect == "sqlite":
session.execute(
text(
"""
UPDATE serialized_dag
SET data = json_set(json_remove(data, '$.client_defaults'), '$.__version', 2)
WHERE data_compressed IS NULL AND CAST(json_extract(data, '$.__version') AS INTEGER) = 3
"""
)
)
log.info("SQL-level serializer version update completed.")
except Exception:
# Non-fatal: fallback to per-row python updates below
log.exception("SQL-level serializer version update failed; falling back to row-wise updates.")

# Select compressed rows and convert any with __version > target_version
rows = session.scalars(select(SDM).where(SDM._data_compressed.isnot(None))).all()

if not rows:
log.debug("No compressed serialized DAGs found to reserialize.")
return

changed = 0
for row in rows:
try:
data = row.data
if isinstance(data, str):
data = settings.json.loads(data)
if not isinstance(data, dict):
continue
current_ver = data.get("__version")
if not isinstance(current_ver, int):
# Some very old rows or malformed; skip
continue
if current_ver <= target_version:
continue
# Only support 3 -> 2 down-conversion for now
if target_version == 2 and current_ver == 3:
new_data = dict(data)
# client_defaults were introduced in v3; remove for v2 compatibility
new_data.pop("client_defaults", None)
new_data["__version"] = 2

# Update storage columns
serialized_sorted_json_bytes = settings.json.dumps(new_data, sort_keys=True).encode("utf-8")
if settings.COMPRESS_SERIALIZED_DAGS:
import zlib

row._data = None
row._data_compressed = zlib.compress(serialized_sorted_json_bytes)
else:
row._data = new_data
row._data_compressed = None
session.merge(row)
changed += 1
else:
log.info(
"Skipping unsupported serializer down-conversion: current=%s target=%s for DAG %s",
current_ver,
target_version,
row.dag_id,
)
except Exception: # pragma: no cover - do not fail entire reserialization on single DAG
log.exception("Failed to reserialize DAG %s; continuing.", getattr(row, "dag_id", "<unknown>"))

if changed:
log.info("Reserialized %s DAG(s) to serializer version %s.", changed, target_version)
else:
log.info("No DAGs required reserialization for target serializer version %s.", target_version)


def _get_fab_migration_version(*, session: Session) -> str | None:
"""
Get the current FAB migration version from the database.
Expand Down
46 changes: 46 additions & 0 deletions airflow-core/tests/unit/cli/commands/test_db_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,52 @@ def test_check(self):
always_fail.assert_has_calls([call()] * (retry + 1))
sleep.assert_has_calls([call(retry_delay)] * retry)

@mock.patch("airflow.cli.commands.db_command.db.reserialize_all_dags_to_serializer_version")
@mock.patch("airflow.cli.commands.db_command.db.downgrade")
def test_cli_downgrade_triggers_reserialize_for_version(self, mock_dg, mock_reser):
"""Downgrade with --to-version should trigger DAG reserialization to mapped serializer version."""
# 3.0.0 maps to serializer version 2 in _SER_DAG_VERSIONS_MAP
args = self.parser.parse_args(["db", "downgrade", "-y", "--to-version", "3.0.0"])
db_command.downgrade(args)
mock_dg.assert_called_once_with(to_revision="29ce7909c52b", from_revision=None, show_sql_only=False)
mock_reser.assert_called_once()
(called_version,), _ = mock_reser.call_args
assert called_version == 2

@mock.patch("airflow.cli.commands.db_command.db.reserialize_all_dags_to_serializer_version")
@mock.patch("airflow.cli.commands.db_command.db.downgrade")
def test_cli_downgrade_does_not_reserialize_with_show_sql_only(self, mock_dg, mock_reser):
args = self.parser.parse_args(["db", "downgrade", "-y", "--to-version", "3.0.0", "--show-sql-only"])
db_command.downgrade(args)
mock_dg.assert_called_once_with(to_revision="29ce7909c52b", from_revision=None, show_sql_only=True)
mock_reser.assert_not_called()

@mock.patch("airflow.cli.commands.db_command.db.reserialize_all_dags_to_serializer_version")
@mock.patch("airflow.cli.commands.db_command.db.downgrade")
def test_cli_downgrade_with_to_revision_triggers_reserialize(self, mock_dg, mock_reser, monkeypatch):
# Map heads so that revision "29ce7909c52b" corresponds to version 3.0.0
monkeypatch.setattr(
db_command,
"_REVISION_HEADS_MAP",
{
"2.10.0": "22ed7efa9da2",
"3.0.0": "29ce7909c52b",
},
)

# Simulate alembic config and revision ordering
class EmptyConfig: ...

monkeypatch.setattr(db_command.db, "_get_alembic_config", lambda: EmptyConfig())
monkeypatch.setattr(db_command.db, "_revision_greater", lambda cfg, head, rev: head >= rev)

args = self.parser.parse_args(["db", "downgrade", "-y", "--to-revision", "29ce7909c52b"])
db_command.downgrade(args)
mock_dg.assert_called_once_with(to_revision="29ce7909c52b", from_revision=None, show_sql_only=False)
mock_reser.assert_called_once()
(called_version,), _ = mock_reser.call_args
assert called_version == 2


class TestCLIDBClean:
@classmethod
Expand Down
Loading